Transposed Convolution as Convolution

Introduction

In my previous blog post “Convolution and Transposed Convolution as Matrix Multiplication”, I discussed how to understand convolution and transposed convolution as matrix multiplications.

It turns out that transposed convolutions could also be implemented as slower convolutions. In this blog post, I would like to discuss about how to understand transposed convolution as convolution.

Transposed Convolution as Convolution

Consider an extremely simple transposed convolution whose kernel $W$ is of shape $(2, 2)$ and stride is $1$. No input padding or output padding. We apply this transposed convolution onto an input tensor $X$ whose shape is $(2, 2)$. This results in an output tensor $Y$ whose shape is $(3, 3)$. Concretely,

$$
\begin{gather}
X = \begin{bmatrix}
x_{1,1} & x_{1,2} \\
x_{2,1} & x_{2,2} \\
\end{bmatrix} \\
W = \begin{bmatrix}
w_{1,1} & w_{1,2} \\
w_{2,1} & w_{2,2} \\
\end{bmatrix} \\
Y = W \star X = \begin{bmatrix}
x_{1,1}w_{1,1} & x_{1,1}w_{1,2} + x_{1,2}w_{1,1} & x_{1,2}w_{1,2} \\
x_{1,1}w_{2,1} + x_{2,1}w_{1,1} & \begin{split} & x_{1,1}w_{2,2} + x_{1,2}w_{2,1} \\ & + x_{2,1}w_{1,2} + x_{2,2}w_{1,1} \\ \end{split} & x_{1,2}w_{2,2} + x_{2,2}w_{1,2}\\
x_{2,1}w_{2,1} & x_{2,1}w_{2,2} + x_{2,2}w_{2,1} & x_{2,2}w_{2,2} \\
\end{bmatrix} \\
\end{gather}
$$

This is equivalent as applying a convolution, whose kernel $W^{\prime}$ is just the flip of the kernel from the transposed convolution, stride is $1$, no input padding, onto an input tensor $X^{\prime}$, which is just a zero-padded $X$.

$$
\begin{gather}
X^{\prime} = \begin{bmatrix}
0 & 0 & 0 & 0 \\
0 & x_{1,1} & x_{1,2} & 0 \\
0 & x_{2,1} & x_{2,2} & 0 \\
0 & 0 & 0 & 0 \\
\end{bmatrix} \\
W^{\prime} = \begin{bmatrix}
w_{2,2} & w_{2,1} \\
w_{1,2} & w_{1,1} \\
\end{bmatrix} \\
Y^{\prime} = W^{\prime} \ast X^{\prime} = \begin{bmatrix}
x_{1,1}w_{1,1} & x_{1,1}w_{1,2} + x_{1,2}w_{1,1} & x_{1,2}w_{1,2} \\
x_{1,1}w_{2,1} + x_{2,1}w_{1,1} & \begin{split} & x_{1,1}w_{2,2} + x_{1,2}w_{2,1} \\ & + x_{2,1}w_{1,2} + x_{2,2}w_{1,1} \\ \end{split} & x_{1,2}w_{2,2} + x_{2,2}w_{1,2}\\
x_{2,1}w_{2,1} & x_{2,1}w_{2,2} + x_{2,2}w_{2,1} & x_{2,2}w_{2,2} \\
\end{bmatrix} \\
\end{gather}
$$

The equivalency also holds when the transposed convolution becomes more complicated. In the white paper “A Guide to Convolution Arithmetic for Deep Learning”, the authors have summarized the relationship between square transposed convolution and square convolution in Relationship 14. Concretely,

The application of a transposed convolution, whose kernel size is $k$, stride is $s$ and output padding is $p$, on an input tensor of size $i^{\prime}$ and an output tensor of size $i$ has an associated application of a convolution, whose kernel size is $k^{\prime} = k$, stride is $s^{\prime}=1$ and padding is $p^{\prime} = k - p - 1$, on an input tensor which is stretched from the transposed convolution input by adding $s-1$ zeros between each input unit, and $a = (i + 2p - k) \mod s$ row/column or zeros added to the bottom and right edges of the original input.

Implementing Transposed Convolution as Convolution

In this section, we have not only implemented the square transposed convolution as square convolution but also the non-square transposed convolution as non-square convolution. The implementation was verified using random unit tests.

transposed_conv_as_conv.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
import numpy as np
from tqdm import tqdm
import torch
from torch import nn


# Square Special-Case Conv
@torch.no_grad()
def square_conv():

in_channels = 2
out_channels = 3

k = 3
s = 3
p = 1

i_prime = 3
i_tao_prime = (i_prime - 1) * (s - 1) + i_prime
k_prime = k
s_prime = 1
p_prime = k - p - 1

conv_transposed_input = torch.rand(1, in_channels, i_prime, i_prime)
conv_transposed = nn.ConvTranspose2d(in_channels=in_channels,
out_channels=out_channels,
kernel_size=(k, k),
padding=0,
output_padding=p,
stride=s,
bias=False)
conv_transposed_output = conv_transposed(conv_transposed_input)

if p != 0:
conv_transposed_output = conv_transposed_output[:, :, p:-p, p:-p]

# We have to know the ConvTranspose2D output
# to configure the Conv2D for ConvTranspose2D.
# This can be pre-computed using the ConvTranspose2D configurations
# without executing the ConvTranspose2D kernel.
# For simplicity, here we just use the shape information of ConvTranspose2D
# output as if we have never executed the ConvTranspose2D kernel.
i = conv_transposed_output.shape[2]

a = (i + 2 * p - k) % s

conv_transposed_as_conv = nn.Conv2d(in_channels=in_channels,
out_channels=out_channels,
kernel_size=(k_prime, k_prime),
padding=(p_prime, p_prime),
stride=s_prime,
bias=False)

# The input and output channel need to be transposed.
# The spatial dimensions need to be flipped.
conv_transposed_as_conv.weight.data = conv_transposed.weight.data.transpose(
0, 1).flip(-1, -2)

conv_transposed_as_conv_input = conv_transposed_input.numpy()
for m in range(s - 1):
conv_transposed_as_conv_input = np.insert(
conv_transposed_as_conv_input,
list(range(1, conv_transposed_as_conv_input.shape[2], m + 1)),
0,
axis=2)
conv_transposed_as_conv_input = np.insert(
conv_transposed_as_conv_input,
list(range(1, conv_transposed_as_conv_input.shape[3], m + 1)),
0,
axis=3)

for _ in range(a):
conv_transposed_as_conv_input = np.insert(
conv_transposed_as_conv_input,
conv_transposed_as_conv_input.shape[2],
0,
axis=2)
conv_transposed_as_conv_input = np.insert(
conv_transposed_as_conv_input,
conv_transposed_as_conv_input.shape[3],
0,
axis=3)

conv_transposed_as_conv_input = torch.from_numpy(
conv_transposed_as_conv_input)

conv_transposed_as_conv_output = conv_transposed_as_conv(
conv_transposed_as_conv_input)

print(f"i = {i}, k = {k}, s = {s}, p = {p}")
print(f"i' = {i_prime}, k' = {k_prime}, s' = {s_prime}, p' = {p_prime}")

assert torch.allclose(conv_transposed_output,
conv_transposed_as_conv_output)

return


# Non-Square General Conv
@torch.no_grad()
def non_square_conv():

in_channels = 2
out_channels = 3

k_1 = 3
k_2 = 4
s_1 = 3
s_2 = 4
p_1 = 1
p_2 = 2

i_1_prime = 3
i_2_prime = 4
i_1_tao_prime = (i_1_prime - 1) * (s_1 - 1) + i_1_prime
i_2_tao_prime = (i_2_prime - 1) * (s_2 - 1) + i_2_prime
k_1_prime = k_1
k_2_prime = k_2
s_1_prime = 1
s_2_prime = 1
p_1_prime = k_1 - p_1 - 1
p_2_prime = k_2 - p_2 - 1

conv_transposed_input = torch.rand(1, in_channels, i_1_prime, i_2_prime)
conv_transposed = nn.ConvTranspose2d(in_channels=in_channels,
out_channels=out_channels,
kernel_size=(k_1, k_2),
padding=0,
output_padding=(p_1, p_2),
stride=(s_1, s_2),
bias=False)
conv_transposed_output = conv_transposed(conv_transposed_input)

if p_1 != 0:
conv_transposed_output = conv_transposed_output[:, :, p_1:-p_1, :]
if p_2 != 0:
conv_transposed_output = conv_transposed_output[:, :, :, p_2:-p_2]

# We have to know the ConvTranspose2D output
# to configure the Conv2D for ConvTranspose2D.
# This can be pre-computed using the ConvTranspose2D configurations
# without executing the ConvTranspose2D kernel.
# For simplicity, here we just use the shape information of ConvTranspose2D
# output as if we have never executed the ConvTranspose2D kernel.
i_1 = conv_transposed_output.shape[2]
i_2 = conv_transposed_output.shape[3]

a_1 = (i_1 + 2 * p_1 - k_1) % s_1
a_2 = (i_2 + 2 * p_2 - k_2) % s_2

conv_transposed_as_conv = nn.Conv2d(in_channels=in_channels,
out_channels=out_channels,
kernel_size=(k_1_prime, k_2_prime),
padding=(p_1_prime, p_2_prime),
stride=(s_1_prime, s_2_prime),
bias=False)

# The input and output channel need to be transposed.
# The spatial dimensions need to be flipped.
conv_transposed_as_conv.weight.data = conv_transposed.weight.data.transpose(
0, 1).flip(-1, -2)

conv_transposed_as_conv_input = conv_transposed_input.numpy()

# Add zero spacings between each element in the spatial dimension.
for m in range(s_1 - 1):
conv_transposed_as_conv_input = np.insert(
conv_transposed_as_conv_input,
list(range(1, conv_transposed_as_conv_input.shape[2], m + 1)),
0,
axis=2)
for m in range(s_2 - 1):
conv_transposed_as_conv_input = np.insert(
conv_transposed_as_conv_input,
list(range(1, conv_transposed_as_conv_input.shape[3], m + 1)),
0,
axis=3)

# Add additional zero paddings to bottom and right in the spatial dimension.
for _ in range(a_1):
conv_transposed_as_conv_input = np.insert(
conv_transposed_as_conv_input,
conv_transposed_as_conv_input.shape[2],
0,
axis=2)
for _ in range(a_2):
conv_transposed_as_conv_input = np.insert(
conv_transposed_as_conv_input,
conv_transposed_as_conv_input.shape[3],
0,
axis=3)

conv_transposed_as_conv_input = torch.from_numpy(
conv_transposed_as_conv_input)

conv_transposed_as_conv_output = conv_transposed_as_conv(
conv_transposed_as_conv_input)

print(
f"i = {(i_1, i_2)}, k = {(k_1, k_2)}, s = {(s_1, s_2)}, p = {(p_1, p_2)}"
)
print(f"i' = {(i_1_prime, i_2_prime)}, k' = {(k_1_prime, k_2_prime)}, "
f"s' = {(s_1_prime, s_2_prime)}, p' = {(p_1_prime, p_2_prime)}")

assert torch.allclose(conv_transposed_output,
conv_transposed_as_conv_output)

return


@torch.no_grad()
def random_test(in_channels, out_channels, k_1, k_2, s_1, s_2, p_1, p_2,
i_1_prime, i_2_prime):

i_1_tao_prime = (i_1_prime - 1) * (s_1 - 1) + i_1_prime
i_2_tao_prime = (i_2_prime - 1) * (s_2 - 1) + i_2_prime
k_1_prime = k_1
k_2_prime = k_2
s_1_prime = 1
s_2_prime = 1
p_1_prime = k_1 - p_1 - 1
p_2_prime = k_2 - p_2 - 1

conv_transposed_input = torch.rand(1, in_channels, i_1_prime, i_2_prime)
conv_transposed = nn.ConvTranspose2d(in_channels=in_channels,
out_channels=out_channels,
kernel_size=(k_1, k_2),
padding=0,
output_padding=(p_1, p_2),
stride=(s_1, s_2),
bias=False)
conv_transposed_output = conv_transposed(conv_transposed_input)

if p_1 != 0:
conv_transposed_output = conv_transposed_output[:, :, p_1:-p_1, :]
if p_2 != 0:
conv_transposed_output = conv_transposed_output[:, :, :, p_2:-p_2]

# We have to know the ConvTranspose2D output
# to configure the Conv2D for ConvTranspose2D.
# This can be pre-computed using the ConvTranspose2D configurations
# without executing the ConvTranspose2D kernel.
# For simplicity, here we just use the shape information of ConvTranspose2D
# output as if we have never executed the ConvTranspose2D kernel.
i_1 = conv_transposed_output.shape[2]
i_2 = conv_transposed_output.shape[3]

a_1 = (i_1 + 2 * p_1 - k_1) % s_1
a_2 = (i_2 + 2 * p_2 - k_2) % s_2

conv_transposed_as_conv = nn.Conv2d(in_channels=in_channels,
out_channels=out_channels,
kernel_size=(k_1_prime, k_2_prime),
padding=(p_1_prime, p_2_prime),
stride=(s_1_prime, s_2_prime),
bias=False)

# The input and output channel need to be transposed.
# The spatial dimensions need to be flipped.
conv_transposed_as_conv.weight.data = conv_transposed.weight.data.transpose(
0, 1).flip(-1, -2)

conv_transposed_as_conv_input = conv_transposed_input.numpy()

# Add zero spacings between each element in the spatial dimension.
for m in range(s_1 - 1):
conv_transposed_as_conv_input = np.insert(
conv_transposed_as_conv_input,
list(range(1, conv_transposed_as_conv_input.shape[2], m + 1)),
0,
axis=2)
for m in range(s_2 - 1):
conv_transposed_as_conv_input = np.insert(
conv_transposed_as_conv_input,
list(range(1, conv_transposed_as_conv_input.shape[3], m + 1)),
0,
axis=3)

# Add additional zero paddings to bottom and right in the spatial dimension.
for _ in range(a_1):
conv_transposed_as_conv_input = np.insert(
conv_transposed_as_conv_input,
conv_transposed_as_conv_input.shape[2],
0,
axis=2)
for _ in range(a_2):
conv_transposed_as_conv_input = np.insert(
conv_transposed_as_conv_input,
conv_transposed_as_conv_input.shape[3],
0,
axis=3)

conv_transposed_as_conv_input = torch.from_numpy(
conv_transposed_as_conv_input)

conv_transposed_as_conv_output = conv_transposed_as_conv(
conv_transposed_as_conv_input)

assert torch.allclose(conv_transposed_output,
conv_transposed_as_conv_output,
rtol=1e-05,
atol=1e-06)

return


if __name__ == "__main__":

np.random.seed(0)
torch.manual_seed(0)

print("Square ConvTranspose2D as Conv:")
square_conv()
print("Non-Square ConvTranspose2D as Conv:")
non_square_conv()

print("Running Random Unit Tests...")
for _ in tqdm(range(1000)):

in_channels = np.random.randint(1, 10)
out_channels = np.random.randint(1, 10)

k_1 = np.random.randint(1, 5)
k_2 = np.random.randint(1, 5)
s_1 = np.random.randint(1, 5)
s_2 = np.random.randint(1, 5)
p_1 = np.random.randint(0, min(s_1, k_1))
p_2 = np.random.randint(0, min(s_2, k_2))
i_1_prime = np.random.randint(1, 50)
i_2_prime = np.random.randint(1, 50)

random_test(in_channels=in_channels,
out_channels=out_channels,
k_1=k_1,
k_2=k_2,
s_1=s_1,
s_2=s_2,
p_1=p_1,
p_2=p_2,
i_1_prime=i_1_prime,
i_2_prime=i_2_prime)
print("Random Unit Tests Passed.")

We could see that the all random unit tests have passed, suggesting that the implementation of transposed convolution using convolution is correct.

1
2
3
4
5
6
7
8
9
10
$ python transposed_conv_as_conv.py 
Square ConvTranspose2D as Conv:
i = 8, k = 3, s = 3, p = 1
i' = 3, k' = 3, s' = 1, p' = 1
Non-Square ConvTranspose2D as Conv:
i = (8, 14), k = (3, 4), s = (3, 4), p = (1, 2)
i' = (3, 4), k' = (3, 4), s' = (1, 1), p' = (1, 1)
Running Random Unit Tests...
100%|████████████████████████████████████| 1000/1000 [00:01<00:00, 517.76it/s]
Random Unit Tests Passed.

References

Author

Lei Mao

Posted on

11-22-2021

Updated on

11-22-2021

Licensed under


Comments