### Lei Mao

Machine Learning, Artificial Intelligence, Computer Science.

# Data Parallelism VS Model Parallelism in Distributed Deep Learning Training

### Introduction

The number of parameters in modern deep learning models is becoming larger and larger, and the size of the data set is also increasing dramatically. To train a sophisticated modern deep learning model on a large dataset, one has to use multi-node training otherwise it just takes forever. One may always see data parallelism and model parallelism in distributed deep learning training. In this blog post, I am going to talk about the theory, logic, and some misleading points about these two deep learning parallelism approaches.

### Data Parallelism

In modern deep learning, because the dataset is too big to be fit into the memory, we could only do stochastic gradient descent for batches. For example, if we have 10K data points in the training dataset, every time we could only use 16 data points to calculate the estimate of the gradients, otherwise our GPU may stop working due to insufficient GPU memories.

The shortcoming of stochastic gradient descent is that the estimate of the gradients might not accurately represent the true gradients of using the full dataset. Therefore, it may take much longer to converge.

A natural way to have a more accurate estimate of the gradients is to use larger batch sizes or even use the full dataset. To allow this, the gradients of small batches were calculated on each GPU, the final estimate of the gradients is the weighted average of the gradients calculated from all the small batches.

Mathematically, data parallelism is valid because

\begin{aligned} \frac{\partial \text{ Loss}}{\partial w} &= \frac{\partial \Big[ \frac{1}{n} \sum_{i=1}^{n} f(x_i, y_i) \Big] }{\partial w} \\ &= \frac{1}{n} \sum_{i=1}^{n} \frac{\partial f(x_i, y_i) }{\partial w} \\ &= \frac{m_1}{n} \frac{\partial \Big[ \frac{1}{m_1} \sum_{i=1}^{m_1} f(x_i, y_i) \Big] }{\partial w} + \frac{m_2}{n} \frac{\partial \Big[ \frac{1}{m_2} \sum_{i=m_1 + 1}^{m_1 + m_2} f(x_i, y_i) \Big] }{\partial w} + \cdots + \frac{m_k}{n} \frac{\partial \Big[ \frac{1}{m_k} \sum_{i=m_{k-1} + 1}^{m_{k-1} + m_k} f(x_i, y_i) \Big] }{\partial w} \\ &= \frac{m_1}{n} \frac{\partial l_1}{\partial w} + \frac{m_2}{n} \frac{\partial l_2}{\partial w} + \cdots + \frac{m_k}{n} \frac{\partial l_k}{\partial w} \end{aligned}

Where

$w$ is the parameters of the model,

$\frac{\partial \text{ Loss}}{\partial w}$ is the true gradient of the big batch of size $n$,

$\frac{\partial l_k}{\partial w}$ is the gradient of the small batch in GPU/node $k$,

$x_i$ and $y_i$ are the features and labels of data point $i$,

$f(x_i, y_i)$ is the loss for data point $i$ calculated from the forward propagation,

$n$ is the total number of data points in the dataset,

$k$ is the total number of GPUs/nodes,

$m_k$ is the number of data points assigned to GPU/node $k$,

$m_1 + m_2 + \cdots + m_k = n$.

When $m_1 = m_2 = \cdots = m_k = \frac{n}{k}$, we could further have

\begin{aligned} \frac{\partial \text{ Loss}}{\partial w} &= \frac{1}{k} \big[ \frac{\partial l_1}{\partial w} + \frac{\partial l_2}{\partial w} + \cdots + \frac{\partial l_k}{\partial w} \big] \end{aligned}

Here for each GPU/node, we use the same model/parameters to do the forward propagation, we send a small batch of different data to each node, compute the gradient normally, and send the gradients back to the main node. This step is asynchronous because the speed of each GPU/node is slightly different. Once we got all the gradients (we are doing synchronization here), we calculate the (weighted) average of the gradients, and use the (weighted) average of the gradients to update the model/parameters. Then we move on to the next iteration.

### Model Parallelism

Model parallelism sounds terrifying to me but it actually has nothing to do with math. It is an instinct of allocating computer resources. Sometimes we could not fit all the data into (GPU) memory is because there are too many layers and parameters in our deep learning model. Therefore we could divide the deep learning models to pieces, put a few consecutive layers to a single node and calculate its gradients. In this way, the number of parameters in a single node gets reduced and could use data to train to get more accurate gradients.

For example, we have 10 GPUs and we want to train a simple ResNet50 model. We could assign the first 5 layers to GPU #1, the second 5 layers to GPU #2, and so on, and the last 5 layers to GPU #10. During the training, in each iteration, the forward propagation has to be done in GPU #1 first. GPU #2 is waiting for the output from GPU #1, GPU #3 is waiting for the output from GPU #2, etc. Once the forward propagation is done. We calculate the gradients for the last layers which reside in GPU #10 and update the model parameters for those layers in GPU #10. Then the gradients backpropagate to the previous layers in GPU #9, etc. Each GPU/node is like a compartment in the factory production line, it waits for the products from its previous compartment and sends its own products to the next compartment.

### Final Remarks

In my opinion, the name of model parallelism is misleading and it should not be considered as an example of parallel computing. A better name could be “Model Serialization”, since it is using a serial approach instead of a parallel approach in parallel computing. However, in some scenarios, some layers in some neural networks, such as Siamese Network, are actually “parallel”. In that way, model parallelism could behave like real parallel computing to some extent. Data parallelism, however, is 100% parallel computing.

In addition, some interesting questions. Are some of the reinforcement learning algorithms such as A3C data parallelism or model parallelism? Well, I would say, A3C is more like data parallelism, although it is not exactly the same as the data parallelism described above since A3C does weight updates on the main node asynchronously throughout the whole training process while the process described above has clearly boundary of iterations and it does synchronization after iteration. You could disagree with me :)