wgan.py 2.1 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091
  1. import tensorflow as tf
  2. from tensorflow import keras
  3. from tensorflow.keras import layers
  4. class Generator(keras.Model):
  5. def __init__(self):
  6. super(Generator, self).__init__()
  7. # z: [b, 100] => [b, 3*3*512] => [b, 3, 3, 512] => [b, 64, 64, 3]
  8. self.fc = layers.Dense(3*3*512)
  9. self.conv1 = layers.Conv2DTranspose(256, 3, 3, 'valid')
  10. self.bn1 = layers.BatchNormalization()
  11. self.conv2 = layers.Conv2DTranspose(128, 5, 2, 'valid')
  12. self.bn2 = layers.BatchNormalization()
  13. self.conv3 = layers.Conv2DTranspose(3, 4, 3, 'valid')
  14. def call(self, inputs, training=None):
  15. # [z, 100] => [z, 3*3*512]
  16. x = self.fc(inputs)
  17. x = tf.reshape(x, [-1, 3, 3, 512])
  18. x = tf.nn.leaky_relu(x)
  19. #
  20. x = tf.nn.leaky_relu(self.bn1(self.conv1(x), training=training))
  21. x = tf.nn.leaky_relu(self.bn2(self.conv2(x), training=training))
  22. x = self.conv3(x)
  23. x = tf.tanh(x)
  24. return x
  25. class Discriminator(keras.Model):
  26. def __init__(self):
  27. super(Discriminator, self).__init__()
  28. # [b, 64, 64, 3] => [b, 1]
  29. self.conv1 = layers.Conv2D(64, 5, 3, 'valid')
  30. self.conv2 = layers.Conv2D(128, 5, 3, 'valid')
  31. self.bn2 = layers.BatchNormalization()
  32. self.conv3 = layers.Conv2D(256, 5, 3, 'valid')
  33. self.bn3 = layers.BatchNormalization()
  34. # [b, h, w ,c] => [b, -1]
  35. self.flatten = layers.Flatten()
  36. self.fc = layers.Dense(1)
  37. def call(self, inputs, training=None):
  38. x = tf.nn.leaky_relu(self.conv1(inputs))
  39. x = tf.nn.leaky_relu(self.bn2(self.conv2(x), training=training))
  40. x = tf.nn.leaky_relu(self.bn3(self.conv3(x), training=training))
  41. # [b, h, w, c] => [b, -1]
  42. x = self.flatten(x)
  43. # [b, -1] => [b, 1]
  44. logits = self.fc(x)
  45. return logits
  46. def main():
  47. d = Discriminator()
  48. g = Generator()
  49. x = tf.random.normal([2, 64, 64, 3])
  50. z = tf.random.normal([2, 100])
  51. prob = d(x)
  52. print(prob)
  53. x_hat = g(z)
  54. print(x_hat.shape)
  55. if __name__ == '__main__':
  56. main()