The generative AI revolution is in place, and much has been said about it. From large language models, like ChatGPT [1] or LLaMA [2], for text generation to Text-to-image models, like Stable Diffusion [3] or Midjourney [4], there's no denying that these technologies have impacted daily life.
But what is the common denominator behind these technologies? How are they related to classical generative machine learning models? And what is the difference between a generative model and a discriminative model? In this article, we will go back to basics and explain the ideas behind all these technologies.
Classification and Types of Machine Learning Models
Before delving into what are generative and discriminative models, we have to understand the idea behind classification as a machine learning task: In machine learning, classification is the problem of predicting a single discrete variable y given a vector of features x = (x1, x2, …, xn).
The problem of classification can be solved with two approaches:Â
We model the conditional probability distribution P(Y|X)Â of the target variable YÂ given an observed variable X, this is called a discriminative model;
We model the joint probability distribution of P(X, Y)Â on the observable variable X and the target variable Y, this is called a generative model.
In particular, when dealing with the problem of classification we aim to obtain the value of the target variable YÂ given an observation from X. To do this, we can model the probability of such target value as P(Y|X=x). In the following sections, we will see how these two types of models are helpful in obtaining such conditional probability.
Discriminative Models
Discriminative models study how to obtain a mapping between an observed value x from a variable of observations X and a target value y from a variable of classes Y. These models are sometimes referred to as conditional models since they are technically modelling the conditional distribution of the variable Y given the observations in X. However, this definition is a little tricky as it requires the final result always to be a probability distribution, which is not the case for all discriminative models.
The mapping learnt by the models is usually in the form of a function between XÂ and Y, whose parameters are obtained through some optimization process (e.g. gradient descent). The conditional models (e.g. logistic regression), learn the probabilistic conditional distribution P(Y|X); other methods use different approaches to learn that mapping (e.g. support vector machines or the perceptron algorithm learn the position of a hyperplane to divide the data).
In any case, a discriminative model learns to discriminate the value of a target variable YÂ given an observation from observable variable X. This is why it is often said that these models learn a boundary (hard or soft) between the classes, and why they are often represented with the following image:
Examples of discriminative models are logistic regression, conditional random fields, decision trees, the perceptron and the multi-layer perceptron (i.e. the traditional feed-forward neural network), and support vector machines.
Generative Models
For the classification problem, we need to calculate the conditional probability of the target variable Y given the value of the observed variable X, i.e. P(Y|X). We can benefit from Bayes' Theorem [5] to do it:
In classification, since the values of features in variable X are given (we are classifying them after all) the denominator is constant. As a result, to calculate the value of P(Y|X) we only need the numerator: P(Y)P(X|Y). If we take a closer look at the definition of conditional probability, we know that P(X, Y)=P(Y)P(X|Y), which is what a generative classifier is modelling.
Another way to see this is that generative models have a way of generating random instances (or outcomes in probability theory) of an observation (x,y) since they model the probability of that event happening. Because of the equivalence above, we can say generative models generate classes y of target variable Y with probability P(Y) and features x of observed variable X given the label Y with P(X|Y); in other terms, we can generate data that is similarly distributed to that of the dataset we used to train the model. This is why generative models are often represented with the following image:
Examples of generative models are naive Bayes, Bayesian networks, hidden Markov models, variational autoencoders, generative adversarial networks, diffusion models (e.g., Stable Diffusion), and autoregressive models (e.g., GPT).
Connection Between Generative and Discriminative
Conditional discriminative models, i.e. those that model the conditional distribution of P(Y|X), using the same Bayes' Theorem applied to use generative models for classification, could, in theory, be able to model the joint distribution P(Y, X), however, this requires them to supply the marginal distribution P(X). This is often not needed for tasks such as classification (or regression for that matter), and it's usually difficult to model since the features can be highly dependent.
Discriminative models make fewer assumptions about the data or the distribution. Their principal advantage is that they are better suited to include rich, overlapping features. Discriminative models usually make assumptions about the distribution of the target variable Y or how this target variable Y is dependent on the observed variable X. However, they can remain agnostic of the form of P(X), which can be quite complex to model because of the interdependencies of the features in X.
In contrast, generative models must either make assumptions about the form of the features in X or try to model the interdependence of these features. The latter can become very difficult, so most common examples of these models usually make some (rather strong) assumption about the data.
A classic example of the difference between generative and discriminative is the naive Bayes classifier, which makes the strong assumption that all the features are independent of each other and only depend on the class and its discriminative pair, the Logistic Regression. These two models, under certain conditions, are two different ways to represent the same distribution. When dealing with small amounts of data, naive Bayes is more robust to the presence of outliers since the model itself can act as a regularizer and doesn't generate things out of distribution, preventing overfitting. On the other hand, Logistic Regression can have better performance than naive Bayes by capturing some dependencies that are assumed to be independent of the naive Bayes classifier but exist in the data. On the other hand, if the amount of data is small, it can pick on spurious patterns that aren't really part of the data because of the outliers and thus have worse performance than the more robust naive Bayes classifier for those particular scenarios.
Use Cases and Limitations
Discriminative Models
Since discriminative models don't make assumptions about the data, they usually require more data to avoid overfitting. They are more susceptible to outliers in the dataset, which can be problematic if the amount of data is small.Â
Discriminative models usually perform better in classification tasks than their generative counterparts because they can identify, through the learning process, dependencies that are either different from the assumptions made by generative models or that require more complex assumptions to be correctly captured.
Discriminative models tend to be simpler and cheaper to train because they make weaker assumptions about the data.
Discriminative models rarely find uses outside of supervised learning, mainly classification.
These models can't deal with missing data. They require all the features of X to function and produce results correctly.
Those discriminative models that aren't conditional, like support vector machines, can't express classification in terms of probability. There are some tricks to do so, but that's not the intended purpose of such models. These might be troublesome if we need to optimise some metrics that require a probability score. For modelling CTR, for example, sometimes metrics like AUCROC or AUCPR are needed.
Generative Models
Generative models require some assumptions about the data distribution. If the assumptions are wrong, they usually perform worse than their discriminative counterparts.
These models are quite resistant to outliers. Since they have an internal representation of the data, it's harder for them to fall out of distribution, acting like regularizers when classifying.
Because of the last point, when the amount of data is small and the assumptions are good enough, these models usually have better performance than the discriminative ones.
Generative models offer a wider range of tasks. They can generate data points and also be used for classification. Moreover, they are also good at tasks such as pattern detection in unsupervised learning and denoising data (particularly crucial for diffusion models).
These models are good at dealing with missing data. Since they can generate precisely those features not available in the dataset, they can achieve something that discriminative models simply can't.
Complex generative models (i.e. those that have weaker assumptions about the data) require large computational power (e.g. large language models).
And how does Generative AI play a part in this?
Generative AI (or Gen AI) models are a type of generative model. They are trained by learning the distribution of the data. The main difference between the Gen AI models and the Generative Classifiers is that in the case of Gen AI, the models have their observable variable X and their target variable Y being the same. These types of models have become so powerful because they don't require manually labelled data to be trained, and their training process is based on data that is both part of the observed variable and the target variable.
For autoregressive large language models (e.g., GPT), the observed variables are the sequence of words in a text, and the target variable is the next word in that sequence of words. Thus, the only data needed to model P(Y|X) is the same data needed to model P(X), which makes the assumption over the data easier to define.
In the case of generative text-to-image models, there are 2 main parts: the encoder of the text and the decoder of the image. The image decoder does the generative part. In this case, there's a defined target which is the produced image itself, but the start is noise. The idea behind stable diffusion is to use a decoder to remove noise that was added initially to the image. Thus, although technically, in this case, the observed variable X is noise, that noise was generated from the target variable Y originally when noise was added to it.
So does this mean that any Generative Model is capable of Generative AI?
Well, technically, yes, but actually, no. Let's take the example of naive Bayes, the simplest of the generative classifier algorithms which is also one that makes the strongest assumption (the independence of the features).
Suppose we have a dataset consisting of emails whose labels are spam or no spam, this is a classic problem that can be solved using a naive Bayes classifier. Now, we have the two labels of the problem, and we use the words as our features (e.g., whether a word appears or not). After learning the probability of the labels and the words given a label, we can technically generate a spam and a non-spam message. Of course, since there's no notion of sequence in a naive Bayes classifier even though the generated text could be technically a correct "non-spam text" (as defined by the distribution learned by naive Bayes), it most certainly won't make any sense: the most likely scenario is that the words of the generated text would have a completely random order and make no sense for a human, but for the purpose of data generation, this is still a valid text for the known distribution. This is the part where the strong assumption of independence in the naive Bayes classifier fails. This generative model might be good for learning the joint distribution of what constitutes spam mail, but it doesn't have any idea of the correct dependency between the words in a sequence (i.e., what the correct word order is). For these later you would need models with better assumptions, such as the autoregressive language models like GPT.
So, what models should we use?
There isn't a single answer to this question. Even though reading this article makes it look like generative models have a clear advantage over discriminative models, especially because of their flexibility, the truth is that each model has its own limitations and set of constraints to be used correctly.
While discriminative models are more limited in the types of tasks they can perform, they are usually the best alternative performance-wise and are cheaper to train. However, given some conditions, such as missing data or outliers and small amounts of data, generative models are usually a better choice.
When dealing with tasks other than classification, generative models provide a much wider spectrum of possibilities. For unsupervised machine learning tasks such as clustering, there's no other possibility than using generative models; it is a requirement of the task to be able to learn the data distribution in and of itself. Finally, for Generative AI, although, as we discussed, it's technically possible to do with any generative model, the autoregressive language models and the diffusion models are the right tools to use.
References
[1] OpenAI. (2022, November 30). Introducing ChatGPT. OpenAI Blog. https://www.openai.com/blog/chatgpt
[2] Touvron, H., Lavril, T., Izacard, G., Martinet, X., Lachaux, M.-A., Lacroix, T., Rozière, B., Goyal, N., Hambro, E., Azhar, F., Rodriguez, A., Joulin, A., Grave, E. & Lample, G. (2023). LLaMA: Open and Efficient Foundation Language Models (cite arxiv:2302.13971)
[3] Rombach, R., Blattmann, A., Lorenz, D., Esser, P. & Ommer, B. (2022). High-resolution image synthesis with latent diffusion models. Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (p./pp. 10684--10695).
[4] @midjourney (July 13, 2022). "We're officially moving to open-beta! Join now at https://discord.gg/midjourney. Please read our directions carefully or check out our detailed how-to guides here: https://midjourney.gitbook.io/docs. Most importantly, have fun!" (Tweet). Retrieved August 31, 2022 – via Twitter.
[5] Stuart, A.; Ord, K. (1994), Kendall's Advanced Theory of Statistics: Volume I – Distribution Theory, Edward Arnold, §8.7
Comments