vae.py 3.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147
  1. import os
  2. import tensorflow as tf
  3. import numpy as np
  4. from tensorflow import keras
  5. from tensorflow.keras import Sequential, layers
  6. from PIL import Image
  7. from matplotlib import pyplot as plt
  8. tf.random.set_seed(22)
  9. np.random.seed(22)
  10. os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
  11. assert tf.__version__.startswith('2.')
  12. def save_images(imgs, name):
  13. new_im = Image.new('L', (280, 280))
  14. index = 0
  15. for i in range(0, 280, 28):
  16. for j in range(0, 280, 28):
  17. im = imgs[index]
  18. im = Image.fromarray(im, mode='L')
  19. new_im.paste(im, (i, j))
  20. index += 1
  21. new_im.save(name)
  22. h_dim = 20
  23. batchsz = 512
  24. lr = 1e-3
  25. (x_train, y_train), (x_test, y_test) = keras.datasets.fashion_mnist.load_data()
  26. x_train, x_test = x_train.astype(np.float32) / 255., x_test.astype(np.float32) / 255.
  27. # we do not need label
  28. train_db = tf.data.Dataset.from_tensor_slices(x_train)
  29. train_db = train_db.shuffle(batchsz * 5).batch(batchsz)
  30. test_db = tf.data.Dataset.from_tensor_slices(x_test)
  31. test_db = test_db.batch(batchsz)
  32. print(x_train.shape, y_train.shape)
  33. print(x_test.shape, y_test.shape)
  34. z_dim = 10
  35. class VAE(keras.Model):
  36. def __init__(self):
  37. super(VAE, self).__init__()
  38. # Encoder
  39. self.fc1 = layers.Dense(128)
  40. self.fc2 = layers.Dense(z_dim) # get mean prediction
  41. self.fc3 = layers.Dense(z_dim)
  42. # Decoder
  43. self.fc4 = layers.Dense(128)
  44. self.fc5 = layers.Dense(784)
  45. def encoder(self, x):
  46. h = tf.nn.relu(self.fc1(x))
  47. # get mean
  48. mu = self.fc2(h)
  49. # get variance
  50. log_var = self.fc3(h)
  51. return mu, log_var
  52. def decoder(self, z):
  53. out = tf.nn.relu(self.fc4(z))
  54. out = self.fc5(out)
  55. return out
  56. def reparameterize(self, mu, log_var):
  57. eps = tf.random.normal(log_var.shape)
  58. std = tf.exp(log_var*0.5)
  59. z = mu + std * eps
  60. return z
  61. def call(self, inputs, training=None):
  62. # [b, 784] => [b, z_dim], [b, z_dim]
  63. mu, log_var = self.encoder(inputs)
  64. # reparameterization trick
  65. z = self.reparameterize(mu, log_var)
  66. x_hat = self.decoder(z)
  67. return x_hat, mu, log_var
  68. model = VAE()
  69. model.build(input_shape=(4, 784))
  70. optimizer = tf.optimizers.Adam(lr)
  71. for epoch in range(1000):
  72. for step, x in enumerate(train_db):
  73. x = tf.reshape(x, [-1, 784])
  74. with tf.GradientTape() as tape:
  75. x_rec_logits, mu, log_var = model(x)
  76. rec_loss = tf.nn.sigmoid_cross_entropy_with_logits(labels=x, logits=x_rec_logits)
  77. rec_loss = tf.reduce_sum(rec_loss) / x.shape[0]
  78. # compute kl divergence (mu, var) ~ N (0, 1)
  79. # https://stats.stackexchange.com/questions/7440/kl-divergence-between-two-univariate-gaussians
  80. kl_div = -0.5 * (log_var + 1 - mu**2 - tf.exp(log_var))
  81. kl_div = tf.reduce_sum(kl_div) / x.shape[0]
  82. loss = rec_loss + 1. * kl_div
  83. grads = tape.gradient(loss, model.trainable_variables)
  84. optimizer.apply_gradients(zip(grads, model.trainable_variables))
  85. if step % 100 == 0:
  86. print(epoch, step, 'kl div:', float(kl_div), 'rec loss:', float(rec_loss))
  87. # evaluation
  88. z = tf.random.normal((batchsz, z_dim))
  89. logits = model.decoder(z)
  90. x_hat = tf.sigmoid(logits)
  91. x_hat = tf.reshape(x_hat, [-1, 28, 28]).numpy() *255.
  92. x_hat = x_hat.astype(np.uint8)
  93. save_images(x_hat, 'vae_images/sampled_epoch%d.png'%epoch)
  94. x = next(iter(test_db))
  95. x = tf.reshape(x, [-1, 784])
  96. x_hat_logits, _, _ = model(x)
  97. x_hat = tf.sigmoid(x_hat_logits)
  98. x_hat = tf.reshape(x_hat, [-1, 28, 28]).numpy() *255.
  99. x_hat = x_hat.astype(np.uint8)
  100. save_images(x_hat, 'vae_images/rec_epoch%d.png'%epoch)