wgan_train.py 5.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181
  1. import os
  2. os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
  3. import numpy as np
  4. import tensorflow as tf
  5. from tensorflow import keras
  6. from PIL import Image
  7. import glob
  8. from gan import Generator, Discriminator
  9. from dataset import make_anime_dataset
  10. def save_result(val_out, val_block_size, image_path, color_mode):
  11. def preprocess(img):
  12. img = ((img + 1.0) * 127.5).astype(np.uint8)
  13. # img = img.astype(np.uint8)
  14. return img
  15. preprocesed = preprocess(val_out)
  16. final_image = np.array([])
  17. single_row = np.array([])
  18. for b in range(val_out.shape[0]):
  19. # concat image into a row
  20. if single_row.size == 0:
  21. single_row = preprocesed[b, :, :, :]
  22. else:
  23. single_row = np.concatenate((single_row, preprocesed[b, :, :, :]), axis=1)
  24. # concat image row to final_image
  25. if (b+1) % val_block_size == 0:
  26. if final_image.size == 0:
  27. final_image = single_row
  28. else:
  29. final_image = np.concatenate((final_image, single_row), axis=0)
  30. # reset single row
  31. single_row = np.array([])
  32. if final_image.shape[2] == 1:
  33. final_image = np.squeeze(final_image, axis=2)
  34. Image.fromarray(final_image).save(image_path)
  35. def celoss_ones(logits):
  36. # [b, 1]
  37. # [b] = [1, 1, 1, 1,]
  38. # loss = tf.keras.losses.categorical_crossentropy(y_pred=logits,
  39. # y_true=tf.ones_like(logits))
  40. return - tf.reduce_mean(logits)
  41. def celoss_zeros(logits):
  42. # [b, 1]
  43. # [b] = [1, 1, 1, 1,]
  44. # loss = tf.keras.losses.categorical_crossentropy(y_pred=logits,
  45. # y_true=tf.zeros_like(logits))
  46. return tf.reduce_mean(logits)
  47. def gradient_penalty(discriminator, batch_x, fake_image):
  48. batchsz = batch_x.shape[0]
  49. # [b, h, w, c]
  50. t = tf.random.uniform([batchsz, 1, 1, 1])
  51. # [b, 1, 1, 1] => [b, h, w, c]
  52. t = tf.broadcast_to(t, batch_x.shape)
  53. interplate = t * batch_x + (1 - t) * fake_image
  54. with tf.GradientTape() as tape:
  55. tape.watch([interplate])
  56. d_interplote_logits = discriminator(interplate, training=True)
  57. grads = tape.gradient(d_interplote_logits, interplate)
  58. # grads:[b, h, w, c] => [b, -1]
  59. grads = tf.reshape(grads, [grads.shape[0], -1])
  60. gp = tf.norm(grads, axis=1) #[b]
  61. gp = tf.reduce_mean( (gp-1)**2 )
  62. return gp
  63. def d_loss_fn(generator, discriminator, batch_z, batch_x, is_training):
  64. # 1. treat real image as real
  65. # 2. treat generated image as fake
  66. fake_image = generator(batch_z, is_training)
  67. d_fake_logits = discriminator(fake_image, is_training)
  68. d_real_logits = discriminator(batch_x, is_training)
  69. d_loss_real = celoss_ones(d_real_logits)
  70. d_loss_fake = celoss_zeros(d_fake_logits)
  71. gp = gradient_penalty(discriminator, batch_x, fake_image)
  72. loss = d_loss_real + d_loss_fake + 10. * gp
  73. return loss, gp
  74. def g_loss_fn(generator, discriminator, batch_z, is_training):
  75. fake_image = generator(batch_z, is_training)
  76. d_fake_logits = discriminator(fake_image, is_training)
  77. loss = celoss_ones(d_fake_logits)
  78. return loss
  79. def main():
  80. tf.random.set_seed(233)
  81. np.random.seed(233)
  82. assert tf.__version__.startswith('2.')
  83. # hyper parameters
  84. z_dim = 100
  85. epochs = 3000000
  86. batch_size = 512
  87. learning_rate = 0.0005
  88. is_training = True
  89. img_path = glob.glob(r'C:\Users\Jackie\Downloads\faces\*.jpg')
  90. assert len(img_path) > 0
  91. dataset, img_shape, _ = make_anime_dataset(img_path, batch_size)
  92. print(dataset, img_shape)
  93. sample = next(iter(dataset))
  94. print(sample.shape, tf.reduce_max(sample).numpy(),
  95. tf.reduce_min(sample).numpy())
  96. dataset = dataset.repeat()
  97. db_iter = iter(dataset)
  98. generator = Generator()
  99. generator.build(input_shape = (None, z_dim))
  100. discriminator = Discriminator()
  101. discriminator.build(input_shape=(None, 64, 64, 3))
  102. z_sample = tf.random.normal([100, z_dim])
  103. g_optimizer = tf.keras.optimizers.Adam(learning_rate=learning_rate, beta_1=0.5)
  104. d_optimizer = tf.keras.optimizers.Adam(learning_rate=learning_rate, beta_1=0.5)
  105. for epoch in range(epochs):
  106. for _ in range(5):
  107. batch_z = tf.random.normal([batch_size, z_dim])
  108. batch_x = next(db_iter)
  109. # train D
  110. with tf.GradientTape() as tape:
  111. d_loss, gp = d_loss_fn(generator, discriminator, batch_z, batch_x, is_training)
  112. grads = tape.gradient(d_loss, discriminator.trainable_variables)
  113. d_optimizer.apply_gradients(zip(grads, discriminator.trainable_variables))
  114. batch_z = tf.random.normal([batch_size, z_dim])
  115. with tf.GradientTape() as tape:
  116. g_loss = g_loss_fn(generator, discriminator, batch_z, is_training)
  117. grads = tape.gradient(g_loss, generator.trainable_variables)
  118. g_optimizer.apply_gradients(zip(grads, generator.trainable_variables))
  119. if epoch % 100 == 0:
  120. print(epoch, 'd-loss:',float(d_loss), 'g-loss:', float(g_loss),
  121. 'gp:', float(gp))
  122. z = tf.random.normal([100, z_dim])
  123. fake_image = generator(z, training=False)
  124. img_path = os.path.join('images', 'wgan-%d.png'%epoch)
  125. save_result(fake_image.numpy(), 10, img_path, color_mode='P')
  126. if __name__ == '__main__':
  127. main()