top of page
Search

Generative vs. Discriminative Models in Machine Learning

Updated: Feb 19

The generative AI revolution is in place and much has been said about it. From Large Language Models, like ChatGPT [1] or LLaMA [2], for text generation to Text-to-image models, like Stable Diffusion [3] or Midjourney [4], there's no denying that these technologies have taken a hold on our daily life.

But, what is the common denominator behind these technologies? How are they related to the classical generative machine learning models? And what is the difference between a generative model and a discriminative model? In this article we will go back to basics and explain the ideas behind all these technologies.

Classification and Types of Machine Learning Models


Before delving into what are generative and discriminative models, we have to understand the idea behind classification as a machine learning task: In machine learning, classification is the problem of predicting a single discrete variable y given a vector of features x = (x1, x2, …, xn).

The problem of classification can be solved with two approaches: 

  • We model the conditional probability distribution P(Y|X) of the target variable Y given an observed variable X, this is called a discriminative model;

  • We model the joint probability distribution of P(X,Y) on the observable variable X and the target variable Y, this is called a generative model.

In particular, when dealing with the problem of classification we aim to obtain the value of the target variable Y given an observation from X. To do this, we can model the probability of such target value as P(Y|X=x). In the following sections, we will see how these two types of models are useful in obtaining such conditional probability.

Discriminative Models


Discriminative models study how to obtain a mapping between an observed value x from a variable of observations X and a target value y from a variable of classes Y. These models are sometimes referred to as conditional models, since they are technically modelling the conditional distribution of the variable Y given the observations in X. However, this definition is a little tricky as it requires the final result to always be a probability distribution, which is not the case for all discriminative models.

The mapping learnt by the models is usually in the form of a function between X and Y, which parameters are obtained through some optimization process (e.g. gradient descent). The conditional models (e.g. logistic regression), learn the probabilistic conditional distribution P(Y|X); other methods use different approaches to learn that mapping (e.g. support vector machines or the perceptron algorithm learn the position of a hyperplane to divide the data).

In any case, a discriminative model learns to discriminate the value of a target variable Y given an observation from observable variable X. This is why it is often said that these models learn a boundary (hard or soft) between the classes, and why they are often represented with the following image:


Discriminative Model
Discriminative Models Decision Boundary. CC-BY-SA Jordi Esteve Sorribas.

Examples of discriminative models are: logistic regression, conditional random fields, decision trees, the perceptron and the multi-layer perceptron (i.e. the traditional feed-forward neural network), and support vector machines.

Generative Models


For the classification problem we need to calculate the conditional probability of the target variable Y given the value of the observed variable X, i.e. P(Y|X). We can benefit from Bayes' Theorem [5] to do it:


Bayes Rule

In classification, since the values of features in variable X are given (we are classifying them after all) the denominator is constant. As a result, to calculate the value of P(Y|X) we only need the numerator: P(Y)P(X|Y). If we take a closer look at the definition of conditional probability, we know that P(X,Y)=P(Y)P(X|Y), which is what a generative classifier is modelling.

Another way to see this is that generative models have a way to generate random instances (or outcomes in probability theory) of an observation (x,y), since they model the probability of that event happening. Because of the equivalence above, we can say generative models generate classes y of target variable Y with probability P(Y) and features x of observed variable X given the label Y with P(X|Y); in other terms, we can generate data that is similarly distributed to that of the dataset we used to train the model. This is why generative models are often represented with the following image:


Generative Models
Generative Models. CC-BY-SA Jordi Esteve Sorribas.

Examples of generative models are: naive Bayes, bayesian networks, hidden markov models, variational autoencoders, generative adversarial networks, diffusion models (e.g. Stable Diffusion) and autoregressive models (e.g. GPT).

Connection Between Generative and Discriminative


Conditional discriminative models, i.e. those that model the conditional distribution of P(Y|X), using the same Bayes' Theorem applied to use generative models for classification, could in theory be able to model the joint distribution P(Y,X), however this requires them to supply the marginal distribution P(X). This is often not needed for tasks such as classification (or regression for that matter), and it's usually difficult to model since the features can be highly dependent.

Discriminative models make fewer assumptions over the data or the distribution. The principal advantage with these models is they are better suited to include rich, overlapping features. Discriminative models usually make assumptions over the distribution of the target variable Y or how this target variable Y is dependent on the observed variable X, however they can remain agnostic of the form of P(X), which can be quite complex to model because of inter-dependencies of the features in X.

In contrast, generative models need either to make assumptions over the form of the features in X, or try to model the interdependence of these features. The latter can become very difficult to do, so most common examples of these models usually make some (rather strong) assumption over the data.

A classic example of the difference between generative and discriminative is the naive Bayes classifier, which makes the strong assumption that all the features are independent of each other and only depend on the class, and its discriminative pair the Logistic Regression. These two models, under certain conditions, are two different ways to represent the same distribution. When dealing with little amounts of data, naive Bayes is more robust to the presence of outliers since the model itself can act as a regularizer and doesn't generate things out of distribution, preventing overfitting. On the other hand, Logistic Regression can have better performance than naive Bayes by capturing some dependencies that are assumed to be independent by the naive Bayes classifier, but exist on the data. On the other hand, if the amount of data is small it can pick on spurious patterns that aren't really part of the data because of the outliers, and thus have worse performance than the more robust naive Bayes classifier for those particular scenarios.

Use Cases and Limitations


Discriminative Models

  • Since discriminative models don't make assumptions over the data, they usually require more data to avoid overfitting. They are more susceptible to outliers in the dataset, which can be problematic if the amount of data is small. 

  • Discriminative models usually have better performance in classification tasks than their generative counterpart because they can identify, through the learning process, some dependencies that are either different to the assumption done by generative models or that require more complex assumptions to be correctly captured.

  • Discriminative models tend to be simpler and cheaper to train, precisely because of the weaker assumptions on the data.

  • Discriminative models rarely find uses outside of supervised learning, mainly classification.

  • These models can't deal with missing data. They require all the features of X to correctly function and produce results.

  • Those discriminative models that aren't conditional models, like support vector machines, can't express classification in terms of probability. There are some tricks to do so, but that's not the intended purpose of such models. These might be troublesome if we need to optimise some metrics that require a probability score. E.g For modelling CTR sometimes metrics like AUCROC or AUCPR are needed.

Generative Models

  • Generative models require some assumption over the data distribution. As a consequence, if the assumptions are wrong, they usually perform worse than their discriminative counterparts.

  • These models are quite resistant to outliers. Since they have an internal representation of the data, it's harder for them to fall out of distribution, acting like regularizers when classifying.

  • Because of the last point, when the amount of data is small and the assumptions are good enough, these models usually have better performance than the discriminative ones.

  • Generative models offer a wider range of tasks. They can generate data points and also be used for classification. Moreover, they are also good at tasks such as pattern detection in unsupervised learning and denoising data (something particularly crucial for diffusion models).

  • These models are good at dealing with missing data. Since they can generate precisely those features that are not available in the dataset, they can achieve something that discriminative models simply can't.

  • Complex generative models (i.e. those who have weaker assumptions about the data) require large computational power (e.g. large language models).

And how does Generative AI play a part in this?


Generative AI (or Gen AI) models are a type of generative models. They are trained by learning the distribution of the data. The main difference between the Gen AI models and the Generative Classifiers, is that for the case of Gen AI, the models have their observable variable X and their target variable Y being the same. These types of models have become so powerful because they don't require manual labelled data to be trained, and their training process is based on data that is both part of the observed variable and the target variable.

For the case of autoregressive large language models (e.g. GPT), the observed variables are the sequence of words in a text, and the target variable is the next word in that sequence of words. Thus the only data needed to model the P(Y|X) is the same data needed to model P(X), which makes the assumption over the data easier to define.

For the case of generative text-to-image models there are 2 main parts: the encoder of the text and the decoder of the image. The generative part is done by the image decoder. In this case, there's a defined target which is the produced image itself, but the start is noise. The idea behind stable diffusion is to use a decoder to remove noise that was added originally to the image. Thus, although technically in this case the observed variable X is noise, that noise was generated from the target variable Y originally when added noise to it.

So does this mean that any Generative Model is capable of Generative AI?


Well, technically yes but actually no. Let's take the example of naive Bayes, the simplest of the generative classifier algorithms which is also one that makes the strongest assumption (the independence of the features).

Suppose we have a dataset consisting of emails which labels are spam or no spam, this is a classic problem that can be solved using a naive Bayes classifier. Now, we have the two labels of the problem and we use the words as our features (e.g. either a word appears or not). After learning the probability of the labels and of the words given a label, we can technically generate a spam and a non spam message. Of course, since there's no notion of sequence in a naive Bayes classifier even though the generated text could be technically a correct "non spam text" (as defined by the distribution learned by naive Bayes), it most certainly won't make any sense: the most likely scenario is that the words of the generated text would have a completely random order and make no sense for a human, but for the purpose of data generation, this is still a valid text for the known distribution. This is the part where the strong assumption of independence in the naive Bayes classifier fails. This generative model might be good to learn the joint distribution of what constitutes spam mail, but doesn't have any idea of what is the correct dependency between the words in a sequence (i.e. what is the correct word order). For these later you would need models with better assumptions, such as the autoregressive language models like GPT.

So, what models should we use?


There isn't a single answer to this question. Even though by reading this article it looks like generative models have a clear advantage over discriminative models, especially because of their flexibility, the truth of the matter is that each model has its own limitations and has its own set of constraints to be used correctly.

While discriminative models are more limited in the types of tasks they can perform, they are usually the best alternative performance wise and are cheaper to train. However, given some conditions, such as missing data or having outliers and small amounts of data, generative models are usually a better choice.

When dealing with tasks other than classification, generative models provide a much wider spectrum of possibilities to be used. For unsupervised machine learning tasks such as clustering there's no other possibility than using generative models, it is a requirement of the task to be able to learn the data distribution in and of itself. Finally, for Generative AI, although as we discussed, it's technically possible to do with any generative model, the autoregressive language models and the diffusion models are the right tool to use.


References

[1] OpenAI. (2022, November 30). Introducing ChatGPT. OpenAI Blog. https://www.openai.com/blog/chatgpt

[5] Stuart, A.; Ord, K. (1994), Kendall's Advanced Theory of Statistics: Volume I – Distribution Theory, Edward Arnold, §8.7


Recent Posts

See All
bottom of page