PyTorch Variational Autoencoder

Introduction

Previously, I discussed mathematically how to optimize probabilistic models with latent variables using Variational Autoencoder (VAE) in the article “Variational Autoencoder”.

In this blog post, I will demonstrate how to implement a variational autoencoder model in PyTorch, train the model on the MNIST dataset, and generate images using the trained model.

PyTorch Variational Autoencoder

The variational autoencoder was implemented in PyTorch and trained on the MNIST dataset. The decoder of the variational autoencoder would be used as the generative model to generate MNIST images by sampling from the latent space.

PyTorch Implementation

In the PyTorch implementation of the variational autoencoder, the approximate posterior $q_{\phi}(z|x)$ was modeled as a multivariate Gaussian distribution with full covariance matrix using a simple two-layer MLP inference model. The sampling from the multivariate Gaussian distribution was performed using the reparameterization trick. The conditional distribution $p_{\theta}(x|z)$ was modeled as a multivariate Bernoulli distribution using a simple two-layer MLP generative model.

The number of latent variables can be adjusted. Usually when the number of the latent variables becomes larger, the model can learn more complex patterns in the data and generate sharper images. In our case, for the purpose of demonstration, we only used a two-dimensional latent space.

main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
import math
import os
import random
import statistics

import numpy as np

import torch
import torch.nn as nn
import torch.optim as optim
import torchvision


def set_random_seeds(random_seed=0):

torch.manual_seed(random_seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
np.random.seed(random_seed)
random.seed(random_seed)


class VariationalEncoder(nn.Module):

def __init__(self,
num_observed_dims=784,
num_latent_dims=8,
num_hidden_dims=512):
super(VariationalEncoder, self).__init__()

self.num_observed_dims = num_observed_dims
self.num_latent_dims = num_latent_dims
self.num_hidden_dims = num_hidden_dims

self.fc = nn.Linear(in_features=self.num_observed_dims,
out_features=self.num_hidden_dims)
# Encoded variational mean.
self.fc_mean = nn.Linear(in_features=self.num_hidden_dims,
out_features=self.num_latent_dims)
# Encoded variational log standard deviation.
self.fc_log_std = nn.Linear(in_features=self.num_hidden_dims,
out_features=self.num_latent_dims)
# Encoded flattened unmasked lower triangular matrix.
self.fc_unmasked_lower_triangular_flatten = nn.Linear(
in_features=self.num_hidden_dims,
out_features=self.num_latent_dims * self.num_latent_dims)
# Constant mask for lower triangular matrix.
self.mask = torch.tril(torch.ones(self.num_latent_dims,
self.num_latent_dims),
diagonal=-1)
self.register_buffer('mask_const', self.mask)

# Using MultivariateNormal for sampling is awkward in PyTorch as of PyTorch 2.2,
# because it always produces samples on CPU.
# from torch.distributions.multivariate_normal import MultivariateNormal
# self.std_normal_mu = torch.zeros(self.num_latent_dims)
# self.std_normal_std = torch.eye(self.num_latent_dims)
# self.register_buffer('std_normal_mu_const', self.std_normal_mu)
# self.register_buffer('std_normal_std_const', self.std_normal_std)
# self.multivariate_std_normal = MultivariateNormal(self.std_normal_mu_const, self.std_normal_std_const)

def encode(self, x):

h = torch.relu(self.fc(x))
mu = self.fc_mean(h)
log_std = self.fc_log_std(h)
unmasked_lower_triangular_flatten = self.fc_unmasked_lower_triangular_flatten(
h)
unmasked_lower_triangular = unmasked_lower_triangular_flatten.view(
-1, self.num_latent_dims, self.num_latent_dims)

return mu, log_std, unmasked_lower_triangular

def reparameterize(self, mu, log_std, unmasked_lower_triangular):

# Perform one sampling operation for each sample in the batch.
# Using full-covariance Gaussian posterior
std = torch.exp(log_std)
# torch.diag_embed diagonalizes the vector in batches.
lower_triangular = unmasked_lower_triangular * self.mask_const + torch.diag_embed(
std)
# Sample from standard normal in batch.
# eps = self.multivariate_std_normal.sample(sample_shape=torch.Size([mu.shape[0]]))
# The variables in the multivariate standard distribution are independent and follows the univariate standard normal distribution.
# Thus we can use the following trick to sample from the multivariate standard normal distribution.
eps = torch.randn_like(std)
z = mu + torch.bmm(lower_triangular,
eps.view(-1, self.num_latent_dims, 1)).view(
-1, self.num_latent_dims)

return z, eps

def forward(self, x):

mu, log_std, unmasked_lower_triangular = self.encode(x)
z, eps = self.reparameterize(mu, log_std, unmasked_lower_triangular)

return z, eps, log_std


class Decoder(nn.Module):

def __init__(self,
num_observed_dims=784,
num_latent_dims=8,
num_hidden_dims=512):
super(Decoder, self).__init__()

self.num_observed_dims = num_observed_dims
self.num_latent_dims = num_latent_dims
self.num_hidden_dims = num_hidden_dims

self.fc = nn.Linear(in_features=self.num_latent_dims,
out_features=self.num_hidden_dims)
self.fc_out = nn.Linear(in_features=self.num_hidden_dims,
out_features=self.num_observed_dims)

def decode(self, z):

h = torch.relu(self.fc(z))
x = torch.sigmoid(self.fc_out(h))

return x

def forward(self, z):

return self.decode(z)


class VAE(nn.Module):

def __init__(self,
num_observed_dims=784,
num_latent_dims=8,
num_hidden_dims=512):
super(VAE, self).__init__()

self.num_observed_dims = num_observed_dims
self.num_latent_dims = num_latent_dims
self.num_hidden_dims = num_hidden_dims

self.encoder = VariationalEncoder(
num_observed_dims=self.num_observed_dims,
num_latent_dims=self.num_latent_dims,
num_hidden_dims=self.num_hidden_dims)
self.decoder = Decoder(num_observed_dims=self.num_observed_dims,
num_latent_dims=self.num_latent_dims,
num_hidden_dims=self.num_hidden_dims)

def forward(self, x):

z, eps, log_std = self.encoder(x)
x_reconstructed = self.decoder(z)

return x_reconstructed, z, eps, log_std


def compute_negative_evidence_lower_bound(x, x_reconstructed, z, eps, log_std):

pi = torch.tensor(math.pi).to(x.device)

# Reconstruction loss.
# E[log p(x|z)]
log_px = -torch.nn.functional.binary_cross_entropy(
x_reconstructed, x, reduction="sum")
# E[log q(z|x)]
log_qz = -0.5 * torch.sum(eps**2 + log_std + torch.log(2 * pi))
# E[log p(z)]
# Assuming standard normal prior.
log_pz = -0.5 * torch.sum(z**2 + torch.log(2 * pi))
# ELBO
elbo = log_px + log_pz - log_qz
negative_elbo = -elbo

# Compute average negative ELBO.
batch_size = x.shape[0]
negative_elbo_avg = negative_elbo / batch_size

return negative_elbo_avg


class BinarizeTransform(object):

def __init__(self, threshold=0.5):
self.threshold = threshold

def __call__(self, x):
return (x > self.threshold).float()


def prepare_cifar10_dataset(root="data"):

train_transform = torchvision.transforms.Compose([
torchvision.transforms.ToTensor(),
# Binarize the input using some threshold.
# This will improve the performance of the model.
BinarizeTransform(threshold=0.5),
])

test_transform = torchvision.transforms.Compose([
torchvision.transforms.ToTensor(),
BinarizeTransform(threshold=0.5),
])

train_set = torchvision.datasets.MNIST(root="data",
train=True,
download=True,
transform=train_transform)

test_set = torchvision.datasets.MNIST(root="data",
train=False,
download=True,
transform=test_transform)

class_names = train_set.classes

return train_set, test_set, class_names


def prepare_cifar10_dataloader(train_set,
test_set,
train_batch_size=128,
eval_batch_size=256,
num_workers=2):

train_sampler = torch.utils.data.RandomSampler(train_set)
test_sampler = torch.utils.data.SequentialSampler(test_set)

train_loader = torch.utils.data.DataLoader(dataset=train_set,
batch_size=train_batch_size,
sampler=train_sampler,
num_workers=num_workers)

test_loader = torch.utils.data.DataLoader(dataset=test_set,
batch_size=eval_batch_size,
sampler=test_sampler,
num_workers=num_workers)

return train_loader, test_loader


def train(model,
device,
train_loader,
loss_func,
optimizer,
epoch,
log_interval=10):

model.train()
train_loss = 0
for batch_idx, (x, _) in enumerate(train_loader):
image_height = x.shape[2]
image_width = x.shape[3]
x = x.to(device)
x = x.view(-1, image_height * image_width)
optimizer.zero_grad()
x_reconstructed, z, eps, log_std = model(x)
loss = loss_func(x, x_reconstructed, z, eps, log_std)
loss.backward()
train_loss += loss.item() * len(x)
optimizer.step()
if batch_idx % log_interval == 0:
print(
f"Train Epoch: {epoch} [{batch_idx * len(x)}/{len(train_loader.dataset)} "
f"({100. * batch_idx / len(train_loader):.0f}%)]\tLoss: {loss.item():.6f}"
)
avg_train_loss = train_loss / len(train_loader.dataset)
print(f"====> Epoch: {epoch} Average Loss: {avg_train_loss:.4f}")


def test(model, device, num_samples, test_loader, loss_func, epoch,
results_dir):

image_dir = os.path.join(results_dir, "reconstruction")
if not os.path.exists(image_dir):
os.makedirs(image_dir)

model.eval()
test_loss = 0
with torch.no_grad():
for i, (x, _) in enumerate(test_loader):
x = x.to(device)
x = x.view(-1, model.num_observed_dims)
x_reconstructed, z, eps, log_std = model(x)
loss = loss_func(x, x_reconstructed, z, eps, log_std)
test_loss += loss.item() * len(x)
if i == 0:
n = min(x.size(0), num_samples)
comparison = torch.cat([
x.view(x.size(0), 1, 28, 28)[:n],
x_reconstructed.view(x.size(0), 1, 28, 28)[:n]
])
torchvision.utils.save_image(
comparison.cpu(),
os.path.join(image_dir, f"reconstruction_{epoch}.png"),
nrow=n)
avg_test_loss = test_loss / len(test_loader.dataset)
print(f"====> Test set loss: {avg_test_loss:.4f}")


def sample_random_images_using_std_normal_prior(model, device, num_samples,
epoch, results_dir):

image_dir = os.path.join(results_dir, "sample_using_std_normal_prior")
if not os.path.exists(image_dir):
os.makedirs(image_dir)

model.eval()
with torch.no_grad():
sample = torch.randn(num_samples, model.num_latent_dims).to(device)
sample = model.decoder(sample).cpu()
torchvision.utils.save_image(
sample.view(num_samples, 1, 28, 28),
os.path.join(image_dir,
f"sample_using_std_normal_prior_{epoch}.png"))


def sample_random_images_using_2d_std_normal_prior_inverse_cdf(
model, device, num_samples, epoch, results_dir):

image_dir = os.path.join(results_dir,
"sample_using_2d_std_normal_prior_inverse_cdf")
if not os.path.exists(image_dir):
os.makedirs(image_dir)

num_samples_per_dimension = int(math.sqrt(num_samples))
cumulative_probability_samples = np.linspace(start=0.0001,
stop=0.9999,
num=num_samples_per_dimension)
quantile_samples = torch.tensor([
statistics.NormalDist(mu=0.0, sigma=1.0).inv_cdf(cp)
for cp in cumulative_probability_samples
],
dtype=torch.float32)

model.eval()
# Collect samples.
samples = []
with torch.no_grad():
# Get z1 and z2.
for i in range(num_samples_per_dimension):
for j in range(num_samples_per_dimension):
sample = torch.tensor(
[quantile_samples[i], quantile_samples[j]]).to(device)
sample = sample.view(1, model.num_latent_dims)
sample = model.decoder(sample).cpu()
samples.append(sample)
# Concatenate samples.
samples = torch.cat(samples)
# Save images. num_samples_per_dimension rows and columns.
torchvision.utils.save_image(
samples.view(num_samples, 1, 28, 28),
os.path.join(
image_dir,
f"sample_using_2d_std_normal_prior_inverse_cdf_{epoch}.png"),
nrow=num_samples_per_dimension)


def sample_random_images_using_reference_images(model, device, data_set,
num_samples, epoch,
results_dir):

image_dir = os.path.join(results_dir, "sample_using_reference")
if not os.path.exists(image_dir):
os.makedirs(image_dir)
reference_image_dir = os.path.join(results_dir, "reference")
if not os.path.exists(reference_image_dir):
os.makedirs(reference_image_dir)

model.eval()
with torch.no_grad():
indices = np.random.choice(len(data_set), num_samples, replace=False)
reference = torch.stack([data_set[i][0] for i in indices])
reference = reference.to(device)
reference = reference.view(-1, model.num_observed_dims)
sample, _, _, _ = model(reference)
torchvision.utils.save_image(
sample.view(num_samples, 1, 28, 28),
os.path.join(image_dir,
f"sample_using_reference_images_{epoch}.png"))
torchvision.utils.save_image(
reference.view(num_samples, 1, 28, 28),
os.path.join(reference_image_dir, f"reference_images_{epoch}.png"))


def sample_ground_truth_images(data_set, num_samples, results_dir):

indices = np.random.choice(len(data_set), num_samples, replace=False)
sample = torch.stack([data_set[i][0] for i in indices])
torchvision.utils.save_image(
sample, os.path.join(results_dir, "ground_truth_sample.png"))


def main():

cuda_device = torch.device("cuda:0")

results_dir = "results"
if not os.path.exists(results_dir):
os.makedirs(results_dir)
model_dir = "models"
if not os.path.exists(model_dir):
os.makedirs(model_dir)
data_dir = "data"
if not os.path.exists(data_dir):
os.makedirs(data_dir)

random_seed = 0
set_random_seeds(random_seed=random_seed)

mnist_image_height = 28
mnist_image_width = 28

num_observed_dims = mnist_image_height * mnist_image_width
# This is a parameter to tune.
# It should neither be too small nor too large.
num_latent_dims = 2
num_hidden_dims = 1024

# 30 epochs is sufficient for MNIST and 2D manifold.
num_epochs = 30
learning_rate = 1e-3
log_interval = 10

train_set, test_set, class_names = prepare_cifar10_dataset(root=data_dir)

sample_ground_truth_images(data_set=train_set,
num_samples=64,
results_dir=results_dir)

train_loader, test_loader = prepare_cifar10_dataloader(
train_set=train_set,
test_set=test_set,
train_batch_size=128,
eval_batch_size=256,
num_workers=2)

model = VAE(num_observed_dims=num_observed_dims,
num_latent_dims=num_latent_dims,
num_hidden_dims=num_hidden_dims)
model.to(cuda_device)
optimizer = optim.Adam(model.parameters(), lr=learning_rate)

for epoch in range(num_epochs):

train(model=model,
device=cuda_device,
train_loader=train_loader,
loss_func=compute_negative_evidence_lower_bound,
optimizer=optimizer,
epoch=epoch,
log_interval=log_interval)
test(model=model,
device=cuda_device,
num_samples=16,
test_loader=test_loader,
loss_func=compute_negative_evidence_lower_bound,
epoch=epoch,
results_dir=results_dir)
sample_random_images_using_std_normal_prior(model=model,
device=cuda_device,
num_samples=64,
epoch=epoch,
results_dir=results_dir)
sample_random_images_using_reference_images(model=model,
device=cuda_device,
data_set=train_set,
num_samples=64,
epoch=epoch,
results_dir=results_dir)
if num_latent_dims == 2:
sample_random_images_using_2d_std_normal_prior_inverse_cdf(
model=model,
device=cuda_device,
num_samples=400,
epoch=epoch,
results_dir=results_dir)

# Save the model.
torch.save(model.state_dict(), os.path.join(model_dir, "model.pth"))
# Export the decoder to ONNX using Opset 13.
# The decoder input name should be "input" and output name should be "output".
z = torch.randn(1, num_latent_dims).to(cuda_device)
model.decoder.eval()
torch.onnx.export(model.decoder,
z,
os.path.join(model_dir, "decoder.onnx"),
input_names=["input"],
output_names=["output"],
opset_version=13)


if __name__ == "__main__":

main()

Training Performance

The intermediate performance of the model during the training was recorded using image reconstruction (with noises) and latent space sampling. The model was trained for 30 epochs on the MNIST dataset and it was already sufficient. The MNIST dataset images were binarized using a threshold of 0.5 so that the optimization problem was easier to solve.

MNIST Dataset Sampled Images at Epoch 1
MNIST Dataset Sampled Images Reconstructed at Epoch 1

MNIST Dataset Sampled Images at Epoch 30
MNIST Dataset Sampled Images Reconstructed at Epoch 30

Visualizations of Learned Data Manifold for Generative Model with Two-Dimensional Latent Space

Generate MNIST Images Using Variational Autoencoder

Using the decoder of the variational autoencoder trained on the MNIST dataset, we can generate images by sampling from the latent space. In our case, the latent space is two-dimensional. The model is served using ONNX Runtime JavaScript.

0
0

Please change the values of the latent variables to see how different values affect the generated image.

References

Author

Lei Mao

Posted on

06-14-2024

Updated on

06-14-2024

Licensed under


Comments