| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172 |
- import os
- import numpy as np
- import tensorflow as tf
- from tensorflow import keras
- from scipy.misc import toimage
- import glob
- from gan import Generator, Discriminator
- from dataset import make_anime_dataset
- def save_result(val_out, val_block_size, image_path, color_mode):
- def preprocess(img):
- img = ((img + 1.0) * 127.5).astype(np.uint8)
- # img = img.astype(np.uint8)
- return img
- preprocesed = preprocess(val_out)
- final_image = np.array([])
- single_row = np.array([])
- for b in range(val_out.shape[0]):
- # concat image into a row
- if single_row.size == 0:
- single_row = preprocesed[b, :, :, :]
- else:
- single_row = np.concatenate((single_row, preprocesed[b, :, :, :]), axis=1)
- # concat image row to final_image
- if (b+1) % val_block_size == 0:
- if final_image.size == 0:
- final_image = single_row
- else:
- final_image = np.concatenate((final_image, single_row), axis=0)
- # reset single row
- single_row = np.array([])
- if final_image.shape[2] == 1:
- final_image = np.squeeze(final_image, axis=2)
- toimage(final_image).save(image_path)
- def celoss_ones(logits):
- # 计算属于与标签为1的交叉熵
- y = tf.ones_like(logits)
- loss = keras.losses.binary_crossentropy(y, logits, from_logits=True)
- return tf.reduce_mean(loss)
- def celoss_zeros(logits):
- # 计算属于与便签为0的交叉熵
- y = tf.zeros_like(logits)
- loss = keras.losses.binary_crossentropy(y, logits, from_logits=True)
- return tf.reduce_mean(loss)
- def d_loss_fn(generator, discriminator, batch_z, batch_x, is_training):
- # 计算判别器的误差函数
- # 采样生成图片
- fake_image = generator(batch_z, is_training)
- # 判定生成图片
- d_fake_logits = discriminator(fake_image, is_training)
- # 判定真实图片
- d_real_logits = discriminator(batch_x, is_training)
- # 真实图片与1之间的误差
- d_loss_real = celoss_ones(d_real_logits)
- # 生成图片与0之间的误差
- d_loss_fake = celoss_zeros(d_fake_logits)
- # 合并误差
- loss = d_loss_fake + d_loss_real
- return loss
- def g_loss_fn(generator, discriminator, batch_z, is_training):
- # 采样生成图片
- fake_image = generator(batch_z, is_training)
- # 在训练生成网络时,需要迫使生成图片判定为真
- d_fake_logits = discriminator(fake_image, is_training)
- # 计算生成图片与1之间的误差
- loss = celoss_ones(d_fake_logits)
- return loss
- def main():
- tf.random.set_seed(3333)
- np.random.seed(3333)
- os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
- assert tf.__version__.startswith('2.')
- z_dim = 100 # 隐藏向量z的长度
- epochs = 3000000 # 训练步数
- batch_size = 64 # batch size
- learning_rate = 0.0002
- is_training = True
- # 获取数据集路径
- # C:\Users\z390\Downloads\anime-faces
- # r'C:\Users\z390\Downloads\faces\*.jpg'
- img_path = glob.glob(r'C:\Users\z390\Downloads\anime-faces\*\*.jpg') + \
- glob.glob(r'C:\Users\z390\Downloads\anime-faces\*\*.png')
- # img_path = glob.glob(r'C:\Users\z390\Downloads\getchu_aligned_with_label\GetChu_aligned2\*.jpg')
- # img_path.extend(img_path2)
- print('images num:', len(img_path))
- # 构建数据集对象
- dataset, img_shape, _ = make_anime_dataset(img_path, batch_size, resize=64)
- print(dataset, img_shape)
- sample = next(iter(dataset)) # 采样
- print(sample.shape, tf.reduce_max(sample).numpy(),
- tf.reduce_min(sample).numpy())
- dataset = dataset.repeat(100) # 重复循环
- db_iter = iter(dataset)
- generator = Generator() # 创建生成器
- generator.build(input_shape = (4, z_dim))
- discriminator = Discriminator() # 创建判别器
- discriminator.build(input_shape=(4, 64, 64, 3))
- # 分别为生成器和判别器创建优化器
- g_optimizer = keras.optimizers.Adam(learning_rate=learning_rate, beta_1=0.5)
- d_optimizer = keras.optimizers.Adam(learning_rate=learning_rate, beta_1=0.5)
- generator.load_weights('generator.ckpt')
- discriminator.load_weights('discriminator.ckpt')
- print('Loaded chpt!!')
- d_losses, g_losses = [],[]
- for epoch in range(epochs): # 训练epochs次
- # 1. 训练判别器
- for _ in range(1):
- # 采样隐藏向量
- batch_z = tf.random.normal([batch_size, z_dim])
- batch_x = next(db_iter) # 采样真实图片
- # 判别器前向计算
- with tf.GradientTape() as tape:
- d_loss = d_loss_fn(generator, discriminator, batch_z, batch_x, is_training)
- grads = tape.gradient(d_loss, discriminator.trainable_variables)
- d_optimizer.apply_gradients(zip(grads, discriminator.trainable_variables))
- # 2. 训练生成器
- # 采样隐藏向量
- batch_z = tf.random.normal([batch_size, z_dim])
- batch_x = next(db_iter) # 采样真实图片
- # 生成器前向计算
- with tf.GradientTape() as tape:
- g_loss = g_loss_fn(generator, discriminator, batch_z, is_training)
- grads = tape.gradient(g_loss, generator.trainable_variables)
- g_optimizer.apply_gradients(zip(grads, generator.trainable_variables))
- if epoch % 100 == 0:
- print(epoch, 'd-loss:',float(d_loss), 'g-loss:', float(g_loss))
- # 可视化
- z = tf.random.normal([100, z_dim])
- fake_image = generator(z, training=False)
- img_path = os.path.join('gan_images', 'gan-%d.png'%epoch)
- save_result(fake_image.numpy(), 10, img_path, color_mode='P')
- d_losses.append(float(d_loss))
- g_losses.append(float(g_loss))
- if epoch % 10000 == 1:
- # print(d_losses)
- # print(g_losses)
- generator.save_weights('generator.ckpt')
- discriminator.save_weights('discriminator.ckpt')
-
- if __name__ == '__main__':
- main()
|