| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181 |
- import os
- os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
- import numpy as np
- import tensorflow as tf
- from tensorflow import keras
- from PIL import Image
- import glob
- from gan import Generator, Discriminator
- from dataset import make_anime_dataset
- def save_result(val_out, val_block_size, image_path, color_mode):
- def preprocess(img):
- img = ((img + 1.0) * 127.5).astype(np.uint8)
- # img = img.astype(np.uint8)
- return img
- preprocesed = preprocess(val_out)
- final_image = np.array([])
- single_row = np.array([])
- for b in range(val_out.shape[0]):
- # concat image into a row
- if single_row.size == 0:
- single_row = preprocesed[b, :, :, :]
- else:
- single_row = np.concatenate((single_row, preprocesed[b, :, :, :]), axis=1)
- # concat image row to final_image
- if (b+1) % val_block_size == 0:
- if final_image.size == 0:
- final_image = single_row
- else:
- final_image = np.concatenate((final_image, single_row), axis=0)
- # reset single row
- single_row = np.array([])
- if final_image.shape[2] == 1:
- final_image = np.squeeze(final_image, axis=2)
- Image.fromarray(final_image).save(image_path)
- def celoss_ones(logits):
- # [b, 1]
- # [b] = [1, 1, 1, 1,]
- # loss = tf.keras.losses.categorical_crossentropy(y_pred=logits,
- # y_true=tf.ones_like(logits))
- return - tf.reduce_mean(logits)
- def celoss_zeros(logits):
- # [b, 1]
- # [b] = [1, 1, 1, 1,]
- # loss = tf.keras.losses.categorical_crossentropy(y_pred=logits,
- # y_true=tf.zeros_like(logits))
- return tf.reduce_mean(logits)
- def gradient_penalty(discriminator, batch_x, fake_image):
- batchsz = batch_x.shape[0]
- # [b, h, w, c]
- t = tf.random.uniform([batchsz, 1, 1, 1])
- # [b, 1, 1, 1] => [b, h, w, c]
- t = tf.broadcast_to(t, batch_x.shape)
- interplate = t * batch_x + (1 - t) * fake_image
- with tf.GradientTape() as tape:
- tape.watch([interplate])
- d_interplote_logits = discriminator(interplate, training=True)
- grads = tape.gradient(d_interplote_logits, interplate)
- # grads:[b, h, w, c] => [b, -1]
- grads = tf.reshape(grads, [grads.shape[0], -1])
- gp = tf.norm(grads, axis=1) #[b]
- gp = tf.reduce_mean( (gp-1)**2 )
- return gp
- def d_loss_fn(generator, discriminator, batch_z, batch_x, is_training):
- # 1. treat real image as real
- # 2. treat generated image as fake
- fake_image = generator(batch_z, is_training)
- d_fake_logits = discriminator(fake_image, is_training)
- d_real_logits = discriminator(batch_x, is_training)
- d_loss_real = celoss_ones(d_real_logits)
- d_loss_fake = celoss_zeros(d_fake_logits)
- gp = gradient_penalty(discriminator, batch_x, fake_image)
- loss = d_loss_real + d_loss_fake + 10. * gp
- return loss, gp
- def g_loss_fn(generator, discriminator, batch_z, is_training):
- fake_image = generator(batch_z, is_training)
- d_fake_logits = discriminator(fake_image, is_training)
- loss = celoss_ones(d_fake_logits)
- return loss
- def main():
- tf.random.set_seed(233)
- np.random.seed(233)
- assert tf.__version__.startswith('2.')
- # hyper parameters
- z_dim = 100
- epochs = 3000000
- batch_size = 512
- learning_rate = 0.0005
- is_training = True
- img_path = glob.glob(r'C:\Users\Jackie\Downloads\faces\*.jpg')
- assert len(img_path) > 0
-
- dataset, img_shape, _ = make_anime_dataset(img_path, batch_size)
- print(dataset, img_shape)
- sample = next(iter(dataset))
- print(sample.shape, tf.reduce_max(sample).numpy(),
- tf.reduce_min(sample).numpy())
- dataset = dataset.repeat()
- db_iter = iter(dataset)
- generator = Generator()
- generator.build(input_shape = (None, z_dim))
- discriminator = Discriminator()
- discriminator.build(input_shape=(None, 64, 64, 3))
- z_sample = tf.random.normal([100, z_dim])
- g_optimizer = tf.keras.optimizers.Adam(learning_rate=learning_rate, beta_1=0.5)
- d_optimizer = tf.keras.optimizers.Adam(learning_rate=learning_rate, beta_1=0.5)
- for epoch in range(epochs):
- for _ in range(5):
- batch_z = tf.random.normal([batch_size, z_dim])
- batch_x = next(db_iter)
- # train D
- with tf.GradientTape() as tape:
- d_loss, gp = d_loss_fn(generator, discriminator, batch_z, batch_x, is_training)
- grads = tape.gradient(d_loss, discriminator.trainable_variables)
- d_optimizer.apply_gradients(zip(grads, discriminator.trainable_variables))
-
- batch_z = tf.random.normal([batch_size, z_dim])
- with tf.GradientTape() as tape:
- g_loss = g_loss_fn(generator, discriminator, batch_z, is_training)
- grads = tape.gradient(g_loss, generator.trainable_variables)
- g_optimizer.apply_gradients(zip(grads, generator.trainable_variables))
- if epoch % 100 == 0:
- print(epoch, 'd-loss:',float(d_loss), 'g-loss:', float(g_loss),
- 'gp:', float(gp))
- z = tf.random.normal([100, z_dim])
- fake_image = generator(z, training=False)
- img_path = os.path.join('images', 'wgan-%d.png'%epoch)
- save_result(fake_image.numpy(), 10, img_path, color_mode='P')
- if __name__ == '__main__':
- main()
|