autoencoder.py 2.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118
  1. import os
  2. import tensorflow as tf
  3. import numpy as np
  4. from tensorflow import keras
  5. from tensorflow.keras import Sequential, layers
  6. from PIL import Image
  7. from matplotlib import pyplot as plt
  8. tf.random.set_seed(22)
  9. np.random.seed(22)
  10. os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
  11. assert tf.__version__.startswith('2.')
  12. def save_images(imgs, name):
  13. new_im = Image.new('L', (280, 280))
  14. index = 0
  15. for i in range(0, 280, 28):
  16. for j in range(0, 280, 28):
  17. im = imgs[index]
  18. im = Image.fromarray(im, mode='L')
  19. new_im.paste(im, (i, j))
  20. index += 1
  21. new_im.save(name)
  22. h_dim = 20
  23. batchsz = 512
  24. lr = 1e-3
  25. (x_train, y_train), (x_test, y_test) = keras.datasets.fashion_mnist.load_data()
  26. x_train, x_test = x_train.astype(np.float32) / 255., x_test.astype(np.float32) / 255.
  27. # we do not need label
  28. train_db = tf.data.Dataset.from_tensor_slices(x_train)
  29. train_db = train_db.shuffle(batchsz * 5).batch(batchsz)
  30. test_db = tf.data.Dataset.from_tensor_slices(x_test)
  31. test_db = test_db.batch(batchsz)
  32. print(x_train.shape, y_train.shape)
  33. print(x_test.shape, y_test.shape)
  34. class AE(keras.Model):
  35. def __init__(self):
  36. super(AE, self).__init__()
  37. # Encoders
  38. self.encoder = Sequential([
  39. layers.Dense(256, activation=tf.nn.relu),
  40. layers.Dense(128, activation=tf.nn.relu),
  41. layers.Dense(h_dim)
  42. ])
  43. # Decoders
  44. self.decoder = Sequential([
  45. layers.Dense(128, activation=tf.nn.relu),
  46. layers.Dense(256, activation=tf.nn.relu),
  47. layers.Dense(784)
  48. ])
  49. def call(self, inputs, training=None):
  50. # [b, 784] => [b, 10]
  51. h = self.encoder(inputs)
  52. # [b, 10] => [b, 784]
  53. x_hat = self.decoder(h)
  54. return x_hat
  55. model = AE()
  56. model.build(input_shape=(None, 784))
  57. model.summary()
  58. optimizer = tf.optimizers.Adam(lr=lr)
  59. for epoch in range(100):
  60. for step, x in enumerate(train_db):
  61. #[b, 28, 28] => [b, 784]
  62. x = tf.reshape(x, [-1, 784])
  63. with tf.GradientTape() as tape:
  64. x_rec_logits = model(x)
  65. rec_loss = tf.losses.binary_crossentropy(x, x_rec_logits, from_logits=True)
  66. rec_loss = tf.reduce_mean(rec_loss)
  67. grads = tape.gradient(rec_loss, model.trainable_variables)
  68. optimizer.apply_gradients(zip(grads, model.trainable_variables))
  69. if step % 100 ==0:
  70. print(epoch, step, float(rec_loss))
  71. # evaluation
  72. x = next(iter(test_db))
  73. logits = model(tf.reshape(x, [-1, 784]))
  74. x_hat = tf.sigmoid(logits)
  75. # [b, 784] => [b, 28, 28]
  76. x_hat = tf.reshape(x_hat, [-1, 28, 28])
  77. # [b, 28, 28] => [2b, 28, 28]
  78. x_concat = tf.concat([x, x_hat], axis=0)
  79. x_concat = x_hat
  80. x_concat = x_concat.numpy() * 255.
  81. x_concat = x_concat.astype(np.uint8)
  82. save_images(x_concat, 'ae_images/rec_epoch_%d.png'%epoch)