Minimax Game for Training Generative Adversarial Networks
Introduction
Generative adversarial networks are one of the most important neural network families to generate realistic data. However, when we looked at the actual implementation of the generative adversarial networks, sometimes we could not correlate the code to the theory with respect to the minimax training fashion. Previously, I had some experiences on the generative adversarial networks and had a couple of projects and implementations on it. However, when I looked back to my source code recently, I feel I also couldn’t correlate the source code to the theory.
In this blog post, I would like to discuss the mathematical motivations for the minimax game for training generative adversarial networks. With this content, it should be sufficient to understand the general theory of the generative adversarial networks and helpful for understanding the source code of the generative adversarial networks implementations.
Minimax Game for Training Generative Adversarial Networks
In Goodfellow’s original paper for generative adversarial networks, the idea of the entire paper could be summarized as the following minimax game for training a generator model and a discriminator model simultaneously.
$$
\min_{G} \max_{D} V(D,G) = \mathbb{E}_{\mathbf{x} \sim p_{\text{data}}(\mathbf{x})} \big[\log D(\mathbf{x}) \big] + \mathbb{E}_{\mathbf{z} \sim p_{\mathbf{z}}(\mathbf{z})} \big[ \log \big( 1 - D\big(G(\mathbf{z})\big)\big) \big]
$$
where $D$ and $G$ are the discriminator model $D(\mathbf{x}; \theta_d)$ and the generator model $G(\mathbf{x}; \theta_g)$, $p_{\text{data}}$ is the distribution of the data, $p_{\mathbf{z}}$ is the distribution of the noise.
In this context, the $D(\mathbf{x}; \theta_d)$ is a binary classification model which produces probability of $\mathbf{x}$ coming from $p_{\text{data}}$ rather than $p_{\mathbf{z}}$, whereas $G(\mathbf{x}; \theta_g)$ consumes some noise input and produces a piece of generated data.
Note that this optimization target shares some similarities to the log loss, someone calls it as the binary cross entropy loss, used for binary classification.
In practice, during training, the optimization for the generator and the discriminator was usually done iteratively.
To train the discriminator, for each iteration, we have some real data $\mathbf{x}_0, \mathbf{x}_1, \cdots, \mathbf{x}_{n-1} \sim p_{\text{data}}(\mathbf{x})$ and some generated data $G(\mathbf{z}_0), G(\mathbf{z}_1), \cdots, G(\mathbf{z}_{n-1})$ where $\mathbf{z}_0, \mathbf{z}_1, \cdots, \mathbf{z}_{n-1} \sim p_{\mathbf{z}}$. We optimize the discriminator using the following equation.
$$
\begin{align}
\max_{D} V(D;G) &= \frac{1}{n} \sum_{i=0}^{n-1} \log D(\mathbf{x}_i) + \frac{1}{n} \sum_{i=0}^{n-1} \log \big( 1 - D\big(G(\mathbf{z}_i)\big)\big) \
\end{align}
$$
This could also be optimized using the existing binary cross entropy loss that is usually available in most of the deep learning training frameworks by setting the target probabilities for $\mathbf{x}_0, \mathbf{x}_1, \cdots, \mathbf{x}_{n-1}$ as $1.0$ and the target probabilities for $G(\mathbf{z}_0), G(\mathbf{z}_1), \cdots, G(\mathbf{z}_{n-1})$ as $0.0$, where $1.0$ represents the data was from real for with a probability of $1.0$ and $0.0$ represents the data was from real for with a probability of $0.0$. In most of the implementations, this will be the thing we will see from the code. There could also be some label smoothing tricks used by setting the value of $1.0$ to be less than $1.0$. But we are not going to discuss it here. At the high level, this looks like we want the discriminator to predict all the real data as real data and all the generated data as not real data correctly.
To train the generator, for each iteration, we have some generated data $G(\mathbf{z}_0), G(\mathbf{z}_1), \cdots, G(\mathbf{z}_{n-1})$ where $\mathbf{z}_0, \mathbf{z}_1, \cdots, \mathbf{z}_{n-1} \sim p_{\mathbf{z}}$. We optimize the generator using the following equation.
$$
\begin{align}
\min_{G} V(G;D) &= \frac{1}{n} \sum_{i=0}^{n-1} \log \big( 1 - D\big(G(\mathbf{z}_i)\big)\big) \
\end{align}
$$
This could also be optimized using the existing binary cross entropy loss that is usually available in most of the deep learning training frameworks by setting the target probabilities for $G(\mathbf{z}_0), G(\mathbf{z}_1), \cdots, G(\mathbf{z}_{n-1})$ as $1.0$. In most of the implementations, this will be the thing we will see from the code. At the high level, this looks like we want the generator to fool the discriminator such that the discriminator predicts all the generated data as real data.
Mathematical Motivations
To understand the mathematical motivations of such minimax game optimization approach for training the generator and the discriminator. Remember our ultimate goal is to train a perfect generator model which produces generated data as if the data were from $p_{\text{data}}$. The questions now becomes how to train such perfect generator. Here the discriminator comes into play. If there is a perfect discriminator which could tell whether the given data is from $p_{\text{data}}$ or not, the generator could always send the generated data to the perfect discriminator and collect the feedback to improve the generator until it becomes perfect and fool the discriminator. But such perfect discriminator usually does not exist.
The idea is to develop the generator and the discriminator together. The generator and the discriminator both use useful information from each other to improve their own model quality.
Ideally, we would like to have both perfect discriminator and generator, a discriminator that could identify any data that is not from $p_{\text{data}}$ and a perfect generator that could fool the discriminator in a way such that the discriminator classify the generated data $G(\mathbf{z})$ as the data from $p_{\text{data}}$.
To train a perfect discriminator $D$ given a generator $G$, we would like to maximize the probability of the discriminator predicting the data $\mathbf{x}$ coming from $p_{\text{data}}$, and minimize the probability of the discriminator predicting the generated data $G(\mathbf{z})$ coming from $p_{\text{data}}$.
Mathematically, it could be expressed as
$$
\max_{D} V(D;G) = \mathbb{E}_{\mathbf{x} \sim p_{\text{data}}(\mathbf{x})} \big[\log D(\mathbf{x}) \big] + \mathbb{E}_{\mathbf{z} \sim p_{\mathbf{z}}(\mathbf{z})} \big[ \log \big( 1 - D\big(G(\mathbf{z})\big)\big) \big]
$$
Similarly, to train a perfect generator $G$ given a discriminator $D$, we would like to maximize the probability of the discriminator predicting the generated data $G(\mathbf{z})$ coming from $p_{\text{data}}$.
Mathematically, it could be expressed as
$$
\max_{G} V(G;D) = \mathbb{E}_{\mathbf{z} \sim p_{\mathbf{z}}(\mathbf{z})} \big[ \log D\big( G(\mathbf{z})\big) \big]
$$
Note that the above expression is “equivalent” to
$$
\min_{G} V(G;D) = \mathbb{E}_{\mathbf{z} \sim p_{\mathbf{z}}(\mathbf{z})} \big[ \log \big( 1 - D\big(G(\mathbf{z})\big)\big) \big]
$$
This naturally leads to the merge of the two optimization targets to the one we have shown above for the minimax optimization target for training generative adversarial networks.
Mode Collapse
The real-world data to be generated by generative adversarial networks are usually multimodal. The data from the same mode share some unique high-level features. For example, the MNIST dataset consists of at least 10 modes.
Mode collapse is a term that the developer often hear during generative adversarial network training. When mode collapse happens in generative adversarial network training, the generator would only learn to generate one or a few modes of the data. Therefore, even though the input data to the generator are random noises that are different every time, the generated data would just be very similar to each other.
The reason why mode collapse would happen during generative adversarial network training is the discriminator was learned and it was not always good enough and updated fast enough to find out that the data generated by the generator is not real. If the generator found that synthesizing data, usually of a certain mode, regardless of whatever input noises are, can always fool the discriminator, it will try to exploit learning the generation of the data of that mode, results in overfitting. Even if later the discriminator catches up, it’s sometimes difficult for the generator to jump out of the local minimum.
References
Minimax Game for Training Generative Adversarial Networks
https://leimao.github.io/blog/Generative-Adversarial-Networks-Minmax-Game/