Lei Mao bio photo

Lei Mao

Machine Learning, Artificial Intelligence, Computer Science.

Twitter Facebook LinkedIn GitHub   G. Scholar E-Mail RSS

Introduction

In my previous blog post “Convolution and Transposed Convolution as Matrix Multiplication”, I discussed how to understand convolution and transposed convolution as matrix multiplications.


It turns out that transposed convolutions could also be implemented as slower convolutions. In this blog post, I would like to discuss about how to understand transposed convolution as convolution.

Transposed Convolution as Convolution

Consider an extremely simple transposed convolution whose kernel $W$ is of shape $(2, 2)$ and stride is $1$. No input padding or output padding. We apply this transposed convolution onto an input tensor $X$ whose shape is $(2, 2)$. This results in an output tensor $Y$ whose shape is $(3, 3)$. Concretely,


\[\begin{gather} X = \begin{bmatrix} x_{1,1} & x_{1,2} \\ x_{2,1} & x_{2,2} \\ \end{bmatrix} \\ W = \begin{bmatrix} w_{1,1} & w_{1,2} \\ w_{2,1} & w_{2,2} \\ \end{bmatrix} \\ Y = W \star X = \begin{bmatrix} x_{1,1}w_{1,1} & x_{1,1}w_{1,2} + x_{1,2}w_{1,1} & x_{1,2}w_{1,2} \\ x_{1,1}w_{2,1} + x_{2,1}w_{1,1} & \begin{split} & x_{1,1}w_{2,2} + x_{1,2}w_{2,1} \\ & + x_{2,1}w_{1,2} + x_{2,2}w_{1,1} \\ \end{split} & x_{1,2}w_{2,2} + x_{2,2}w_{1,2}\\ x_{2,1}w_{2,1} & x_{2,1}w_{2,2} + x_{2,2}w_{2,1} & x_{2,2}w_{2,2} \\ \end{bmatrix} \\ \end{gather}\]

This is equivalent as applying a convolution, whose kernel $W^{\prime}$ is just the flip of the kernel from the transposed convolution, stride is $1$, no input padding, onto an input tensor $X^{\prime}$, which is just a zero-padded $X$.

\[\begin{gather} X^{\prime} = \begin{bmatrix} 0 & 0 & 0 & 0 \\ 0 & x_{1,1} & x_{1,2} & 0 \\ 0 & x_{2,1} & x_{2,2} & 0 \\ 0 & 0 & 0 & 0 \\ \end{bmatrix} \\ W^{\prime} = \begin{bmatrix} w_{2,2} & w_{2,1} \\ w_{1,2} & w_{1,1} \\ \end{bmatrix} \\ Y^{\prime} = W^{\prime} \ast X^{\prime} = \begin{bmatrix} x_{1,1}w_{1,1} & x_{1,1}w_{1,2} + x_{1,2}w_{1,1} & x_{1,2}w_{1,2} \\ x_{1,1}w_{2,1} + x_{2,1}w_{1,1} & \begin{split} & x_{1,1}w_{2,2} + x_{1,2}w_{2,1} \\ & + x_{2,1}w_{1,2} + x_{2,2}w_{1,1} \\ \end{split} & x_{1,2}w_{2,2} + x_{2,2}w_{1,2}\\ x_{2,1}w_{2,1} & x_{2,1}w_{2,2} + x_{2,2}w_{2,1} & x_{2,2}w_{2,2} \\ \end{bmatrix} \\ \end{gather}\]

The equivalency also holds when the transposed convolution becomes more complicated. In the white paper “A Guide to Convolution Arithmetic for Deep Learning”, the authors have summarized the relationship between square transposed convolution and square convolution in Relationship 14. Concretely,


The application of a transposed convolution, whose kernel size is $k$, stride is $s$ and output padding is $p$, on an input tensor of size $i^{\prime}$ and an output tensor of size $i$ has an associated application of a convolution, whose kernel size is $k^{\prime} = k$, stride is $s^{\prime}=1$ and padding is $p^{\prime} = k - p - 1$, on an input tensor which is stretched from the transposed convolution input by adding $s-1$ zeros between each input unit, and $a = (i + 2p - k) \mod s$ row/column or zeros added to the bottom and right edges of the original input.

Implementing Transposed Convolution as Convolution

In this section, we have not only implemented the square transposed convolution as square convolution but also the non-square transposed convolution as non-square convolution. The implementation was verified using random unit tests.

# transposed_conv_as_conv.py
import numpy as np
from tqdm import tqdm
import torch
from torch import nn


# Square Special-Case Conv
@torch.no_grad()
def square_conv():

    in_channels = 2
    out_channels = 3

    k = 3
    s = 3
    p = 1

    i_prime = 3
    i_tao_prime = (i_prime - 1) * (s - 1) + i_prime
    k_prime = k
    s_prime = 1
    p_prime = k - p - 1

    conv_transposed_input = torch.rand(1, in_channels, i_prime, i_prime)
    conv_transposed = nn.ConvTranspose2d(in_channels=in_channels,
                                         out_channels=out_channels,
                                         kernel_size=(k, k),
                                         padding=0,
                                         output_padding=p,
                                         stride=s,
                                         bias=False)
    conv_transposed_output = conv_transposed(conv_transposed_input)

    if p != 0:
        conv_transposed_output = conv_transposed_output[:, :, p:-p, p:-p]

    # We have to know the ConvTranspose2D output
    # to configure the Conv2D for ConvTranspose2D.
    # This can be pre-computed using the ConvTranspose2D configurations
    # without executing the ConvTranspose2D kernel.
    # For simplicity, here we just use the shape information of ConvTranspose2D
    # output as if we have never executed the ConvTranspose2D kernel.
    i = conv_transposed_output.shape[2]

    a = (i + 2 * p - k) % s

    conv_transposed_as_conv = nn.Conv2d(in_channels=in_channels,
                                        out_channels=out_channels,
                                        kernel_size=(k_prime, k_prime),
                                        padding=(p_prime, p_prime),
                                        stride=s_prime,
                                        bias=False)

    # The input and output channel need to be transposed.
    # The spatial dimensions need to be flipped.
    conv_transposed_as_conv.weight.data = conv_transposed.weight.data.transpose(
        0, 1).flip(-1, -2)

    conv_transposed_as_conv_input = conv_transposed_input.numpy()
    for m in range(s - 1):
        conv_transposed_as_conv_input = np.insert(
            conv_transposed_as_conv_input,
            list(range(1, conv_transposed_as_conv_input.shape[2], m + 1)),
            0,
            axis=2)
        conv_transposed_as_conv_input = np.insert(
            conv_transposed_as_conv_input,
            list(range(1, conv_transposed_as_conv_input.shape[3], m + 1)),
            0,
            axis=3)

    for _ in range(a):
        conv_transposed_as_conv_input = np.insert(
            conv_transposed_as_conv_input,
            conv_transposed_as_conv_input.shape[2],
            0,
            axis=2)
        conv_transposed_as_conv_input = np.insert(
            conv_transposed_as_conv_input,
            conv_transposed_as_conv_input.shape[3],
            0,
            axis=3)

    conv_transposed_as_conv_input = torch.from_numpy(
        conv_transposed_as_conv_input)

    conv_transposed_as_conv_output = conv_transposed_as_conv(
        conv_transposed_as_conv_input)

    print(f"i = {i}, k = {k}, s = {s}, p = {p}")
    print(f"i' = {i_prime}, k' = {k_prime}, s' = {s_prime}, p' = {p_prime}")

    assert torch.allclose(conv_transposed_output,
                          conv_transposed_as_conv_output)

    return


# Non-Square General Conv
@torch.no_grad()
def non_square_conv():

    in_channels = 2
    out_channels = 3

    k_1 = 3
    k_2 = 4
    s_1 = 3
    s_2 = 4
    p_1 = 1
    p_2 = 2

    i_1_prime = 3
    i_2_prime = 4
    i_1_tao_prime = (i_1_prime - 1) * (s_1 - 1) + i_1_prime
    i_2_tao_prime = (i_2_prime - 1) * (s_2 - 1) + i_2_prime
    k_1_prime = k_1
    k_2_prime = k_2
    s_1_prime = 1
    s_2_prime = 1
    p_1_prime = k_1 - p_1 - 1
    p_2_prime = k_2 - p_2 - 1

    conv_transposed_input = torch.rand(1, in_channels, i_1_prime, i_2_prime)
    conv_transposed = nn.ConvTranspose2d(in_channels=in_channels,
                                         out_channels=out_channels,
                                         kernel_size=(k_1, k_2),
                                         padding=0,
                                         output_padding=(p_1, p_2),
                                         stride=(s_1, s_2),
                                         bias=False)
    conv_transposed_output = conv_transposed(conv_transposed_input)

    if p_1 != 0:
        conv_transposed_output = conv_transposed_output[:, :, p_1:-p_1, :]
    if p_2 != 0:
        conv_transposed_output = conv_transposed_output[:, :, :, p_2:-p_2]

    # We have to know the ConvTranspose2D output
    # to configure the Conv2D for ConvTranspose2D.
    # This can be pre-computed using the ConvTranspose2D configurations
    # without executing the ConvTranspose2D kernel.
    # For simplicity, here we just use the shape information of ConvTranspose2D
    # output as if we have never executed the ConvTranspose2D kernel.
    i_1 = conv_transposed_output.shape[2]
    i_2 = conv_transposed_output.shape[3]

    a_1 = (i_1 + 2 * p_1 - k_1) % s_1
    a_2 = (i_2 + 2 * p_2 - k_2) % s_2

    conv_transposed_as_conv = nn.Conv2d(in_channels=in_channels,
                                        out_channels=out_channels,
                                        kernel_size=(k_1_prime, k_2_prime),
                                        padding=(p_1_prime, p_2_prime),
                                        stride=(s_1_prime, s_2_prime),
                                        bias=False)

    # The input and output channel need to be transposed.
    # The spatial dimensions need to be flipped.
    conv_transposed_as_conv.weight.data = conv_transposed.weight.data.transpose(
        0, 1).flip(-1, -2)

    conv_transposed_as_conv_input = conv_transposed_input.numpy()

    # Add zero spacings between each element in the spatial dimension.
    for m in range(s_1 - 1):
        conv_transposed_as_conv_input = np.insert(
            conv_transposed_as_conv_input,
            list(range(1, conv_transposed_as_conv_input.shape[2], m + 1)),
            0,
            axis=2)
    for m in range(s_2 - 1):
        conv_transposed_as_conv_input = np.insert(
            conv_transposed_as_conv_input,
            list(range(1, conv_transposed_as_conv_input.shape[3], m + 1)),
            0,
            axis=3)

    # Add additional zero paddings to bottom and right in the spatial dimension.
    for _ in range(a_1):
        conv_transposed_as_conv_input = np.insert(
            conv_transposed_as_conv_input,
            conv_transposed_as_conv_input.shape[2],
            0,
            axis=2)
    for _ in range(a_2):
        conv_transposed_as_conv_input = np.insert(
            conv_transposed_as_conv_input,
            conv_transposed_as_conv_input.shape[3],
            0,
            axis=3)

    conv_transposed_as_conv_input = torch.from_numpy(
        conv_transposed_as_conv_input)

    conv_transposed_as_conv_output = conv_transposed_as_conv(
        conv_transposed_as_conv_input)

    print(
        f"i = {(i_1, i_2)}, k = {(k_1, k_2)}, s = {(s_1, s_2)}, p = {(p_1, p_2)}"
    )
    print(f"i' = {(i_1_prime, i_2_prime)}, k' = {(k_1_prime, k_2_prime)}, "
          f"s' = {(s_1_prime, s_2_prime)}, p' = {(p_1_prime, p_2_prime)}")

    assert torch.allclose(conv_transposed_output,
                          conv_transposed_as_conv_output)

    return


@torch.no_grad()
def random_test(in_channels, out_channels, k_1, k_2, s_1, s_2, p_1, p_2,
                i_1_prime, i_2_prime):

    i_1_tao_prime = (i_1_prime - 1) * (s_1 - 1) + i_1_prime
    i_2_tao_prime = (i_2_prime - 1) * (s_2 - 1) + i_2_prime
    k_1_prime = k_1
    k_2_prime = k_2
    s_1_prime = 1
    s_2_prime = 1
    p_1_prime = k_1 - p_1 - 1
    p_2_prime = k_2 - p_2 - 1

    conv_transposed_input = torch.rand(1, in_channels, i_1_prime, i_2_prime)
    conv_transposed = nn.ConvTranspose2d(in_channels=in_channels,
                                         out_channels=out_channels,
                                         kernel_size=(k_1, k_2),
                                         padding=0,
                                         output_padding=(p_1, p_2),
                                         stride=(s_1, s_2),
                                         bias=False)
    conv_transposed_output = conv_transposed(conv_transposed_input)

    if p_1 != 0:
        conv_transposed_output = conv_transposed_output[:, :, p_1:-p_1, :]
    if p_2 != 0:
        conv_transposed_output = conv_transposed_output[:, :, :, p_2:-p_2]

    # We have to know the ConvTranspose2D output
    # to configure the Conv2D for ConvTranspose2D.
    # This can be pre-computed using the ConvTranspose2D configurations
    # without executing the ConvTranspose2D kernel.
    # For simplicity, here we just use the shape information of ConvTranspose2D
    # output as if we have never executed the ConvTranspose2D kernel.
    i_1 = conv_transposed_output.shape[2]
    i_2 = conv_transposed_output.shape[3]

    a_1 = (i_1 + 2 * p_1 - k_1) % s_1
    a_2 = (i_2 + 2 * p_2 - k_2) % s_2

    conv_transposed_as_conv = nn.Conv2d(in_channels=in_channels,
                                        out_channels=out_channels,
                                        kernel_size=(k_1_prime, k_2_prime),
                                        padding=(p_1_prime, p_2_prime),
                                        stride=(s_1_prime, s_2_prime),
                                        bias=False)

    # The input and output channel need to be transposed.
    # The spatial dimensions need to be flipped.
    conv_transposed_as_conv.weight.data = conv_transposed.weight.data.transpose(
        0, 1).flip(-1, -2)

    conv_transposed_as_conv_input = conv_transposed_input.numpy()

    # Add zero spacings between each element in the spatial dimension.
    for m in range(s_1 - 1):
        conv_transposed_as_conv_input = np.insert(
            conv_transposed_as_conv_input,
            list(range(1, conv_transposed_as_conv_input.shape[2], m + 1)),
            0,
            axis=2)
    for m in range(s_2 - 1):
        conv_transposed_as_conv_input = np.insert(
            conv_transposed_as_conv_input,
            list(range(1, conv_transposed_as_conv_input.shape[3], m + 1)),
            0,
            axis=3)

    # Add additional zero paddings to bottom and right in the spatial dimension.
    for _ in range(a_1):
        conv_transposed_as_conv_input = np.insert(
            conv_transposed_as_conv_input,
            conv_transposed_as_conv_input.shape[2],
            0,
            axis=2)
    for _ in range(a_2):
        conv_transposed_as_conv_input = np.insert(
            conv_transposed_as_conv_input,
            conv_transposed_as_conv_input.shape[3],
            0,
            axis=3)

    conv_transposed_as_conv_input = torch.from_numpy(
        conv_transposed_as_conv_input)

    conv_transposed_as_conv_output = conv_transposed_as_conv(
        conv_transposed_as_conv_input)

    assert torch.allclose(conv_transposed_output,
                          conv_transposed_as_conv_output,
                          rtol=1e-05,
                          atol=1e-06)

    return


if __name__ == "__main__":

    np.random.seed(0)
    torch.manual_seed(0)

    print("Square ConvTranspose2D as Conv:")
    square_conv()
    print("Non-Square ConvTranspose2D as Conv:")
    non_square_conv()

    print("Running Random Unit Tests...")
    for _ in tqdm(range(1000)):

        in_channels = np.random.randint(1, 10)
        out_channels = np.random.randint(1, 10)

        k_1 = np.random.randint(1, 5)
        k_2 = np.random.randint(1, 5)
        s_1 = np.random.randint(1, 5)
        s_2 = np.random.randint(1, 5)
        p_1 = np.random.randint(0, min(s_1, k_1))
        p_2 = np.random.randint(0, min(s_2, k_2))
        i_1_prime = np.random.randint(1, 50)
        i_2_prime = np.random.randint(1, 50)

        random_test(in_channels=in_channels,
                    out_channels=out_channels,
                    k_1=k_1,
                    k_2=k_2,
                    s_1=s_1,
                    s_2=s_2,
                    p_1=p_1,
                    p_2=p_2,
                    i_1_prime=i_1_prime,
                    i_2_prime=i_2_prime)
    print("Random Unit Tests Passed.")

We could see that the all random unit tests have passed, suggesting that the implementation of transposed convolution using convolution is correct.

$ python transposed_conv_as_conv.py 
Square ConvTranspose2D as Conv:
i = 8, k = 3, s = 3, p = 1
i' = 3, k' = 3, s' = 1, p' = 1
Non-Square ConvTranspose2D as Conv:
i = (8, 14), k = (3, 4), s = (3, 4), p = (1, 2)
i' = (3, 4), k' = (3, 4), s' = (1, 1), p' = (1, 1)
Running Random Unit Tests...
100%|████████████████████████████████████| 1000/1000 [00:01<00:00, 517.76it/s]
Random Unit Tests Passed.

References