gan_train.py 5.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172
  1. import os
  2. import numpy as np
  3. import tensorflow as tf
  4. from tensorflow import keras
  5. from scipy.misc import toimage
  6. import glob
  7. from gan import Generator, Discriminator
  8. from dataset import make_anime_dataset
  9. def save_result(val_out, val_block_size, image_path, color_mode):
  10. def preprocess(img):
  11. img = ((img + 1.0) * 127.5).astype(np.uint8)
  12. # img = img.astype(np.uint8)
  13. return img
  14. preprocesed = preprocess(val_out)
  15. final_image = np.array([])
  16. single_row = np.array([])
  17. for b in range(val_out.shape[0]):
  18. # concat image into a row
  19. if single_row.size == 0:
  20. single_row = preprocesed[b, :, :, :]
  21. else:
  22. single_row = np.concatenate((single_row, preprocesed[b, :, :, :]), axis=1)
  23. # concat image row to final_image
  24. if (b+1) % val_block_size == 0:
  25. if final_image.size == 0:
  26. final_image = single_row
  27. else:
  28. final_image = np.concatenate((final_image, single_row), axis=0)
  29. # reset single row
  30. single_row = np.array([])
  31. if final_image.shape[2] == 1:
  32. final_image = np.squeeze(final_image, axis=2)
  33. toimage(final_image).save(image_path)
  34. def celoss_ones(logits):
  35. # 计算属于与标签为1的交叉熵
  36. y = tf.ones_like(logits)
  37. loss = keras.losses.binary_crossentropy(y, logits, from_logits=True)
  38. return tf.reduce_mean(loss)
  39. def celoss_zeros(logits):
  40. # 计算属于与便签为0的交叉熵
  41. y = tf.zeros_like(logits)
  42. loss = keras.losses.binary_crossentropy(y, logits, from_logits=True)
  43. return tf.reduce_mean(loss)
  44. def d_loss_fn(generator, discriminator, batch_z, batch_x, is_training):
  45. # 计算判别器的误差函数
  46. # 采样生成图片
  47. fake_image = generator(batch_z, is_training)
  48. # 判定生成图片
  49. d_fake_logits = discriminator(fake_image, is_training)
  50. # 判定真实图片
  51. d_real_logits = discriminator(batch_x, is_training)
  52. # 真实图片与1之间的误差
  53. d_loss_real = celoss_ones(d_real_logits)
  54. # 生成图片与0之间的误差
  55. d_loss_fake = celoss_zeros(d_fake_logits)
  56. # 合并误差
  57. loss = d_loss_fake + d_loss_real
  58. return loss
  59. def g_loss_fn(generator, discriminator, batch_z, is_training):
  60. # 采样生成图片
  61. fake_image = generator(batch_z, is_training)
  62. # 在训练生成网络时,需要迫使生成图片判定为真
  63. d_fake_logits = discriminator(fake_image, is_training)
  64. # 计算生成图片与1之间的误差
  65. loss = celoss_ones(d_fake_logits)
  66. return loss
  67. def main():
  68. tf.random.set_seed(3333)
  69. np.random.seed(3333)
  70. os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
  71. assert tf.__version__.startswith('2.')
  72. z_dim = 100 # 隐藏向量z的长度
  73. epochs = 3000000 # 训练步数
  74. batch_size = 64 # batch size
  75. learning_rate = 0.0002
  76. is_training = True
  77. # 获取数据集路径
  78. # C:\Users\z390\Downloads\anime-faces
  79. # r'C:\Users\z390\Downloads\faces\*.jpg'
  80. img_path = glob.glob(r'C:\Users\z390\Downloads\anime-faces\*\*.jpg') + \
  81. glob.glob(r'C:\Users\z390\Downloads\anime-faces\*\*.png')
  82. # img_path = glob.glob(r'C:\Users\z390\Downloads\getchu_aligned_with_label\GetChu_aligned2\*.jpg')
  83. # img_path.extend(img_path2)
  84. print('images num:', len(img_path))
  85. # 构建数据集对象
  86. dataset, img_shape, _ = make_anime_dataset(img_path, batch_size, resize=64)
  87. print(dataset, img_shape)
  88. sample = next(iter(dataset)) # 采样
  89. print(sample.shape, tf.reduce_max(sample).numpy(),
  90. tf.reduce_min(sample).numpy())
  91. dataset = dataset.repeat(100) # 重复循环
  92. db_iter = iter(dataset)
  93. generator = Generator() # 创建生成器
  94. generator.build(input_shape = (4, z_dim))
  95. discriminator = Discriminator() # 创建判别器
  96. discriminator.build(input_shape=(4, 64, 64, 3))
  97. # 分别为生成器和判别器创建优化器
  98. g_optimizer = keras.optimizers.Adam(learning_rate=learning_rate, beta_1=0.5)
  99. d_optimizer = keras.optimizers.Adam(learning_rate=learning_rate, beta_1=0.5)
  100. generator.load_weights('generator.ckpt')
  101. discriminator.load_weights('discriminator.ckpt')
  102. print('Loaded chpt!!')
  103. d_losses, g_losses = [],[]
  104. for epoch in range(epochs): # 训练epochs次
  105. # 1. 训练判别器
  106. for _ in range(1):
  107. # 采样隐藏向量
  108. batch_z = tf.random.normal([batch_size, z_dim])
  109. batch_x = next(db_iter) # 采样真实图片
  110. # 判别器前向计算
  111. with tf.GradientTape() as tape:
  112. d_loss = d_loss_fn(generator, discriminator, batch_z, batch_x, is_training)
  113. grads = tape.gradient(d_loss, discriminator.trainable_variables)
  114. d_optimizer.apply_gradients(zip(grads, discriminator.trainable_variables))
  115. # 2. 训练生成器
  116. # 采样隐藏向量
  117. batch_z = tf.random.normal([batch_size, z_dim])
  118. batch_x = next(db_iter) # 采样真实图片
  119. # 生成器前向计算
  120. with tf.GradientTape() as tape:
  121. g_loss = g_loss_fn(generator, discriminator, batch_z, is_training)
  122. grads = tape.gradient(g_loss, generator.trainable_variables)
  123. g_optimizer.apply_gradients(zip(grads, generator.trainable_variables))
  124. if epoch % 100 == 0:
  125. print(epoch, 'd-loss:',float(d_loss), 'g-loss:', float(g_loss))
  126. # 可视化
  127. z = tf.random.normal([100, z_dim])
  128. fake_image = generator(z, training=False)
  129. img_path = os.path.join('gan_images', 'gan-%d.png'%epoch)
  130. save_result(fake_image.numpy(), 10, img_path, color_mode='P')
  131. d_losses.append(float(d_loss))
  132. g_losses.append(float(g_loss))
  133. if epoch % 10000 == 1:
  134. # print(d_losses)
  135. # print(g_losses)
  136. generator.save_weights('generator.ckpt')
  137. discriminator.save_weights('discriminator.ckpt')
  138. if __name__ == '__main__':
  139. main()