layer_model.py 2.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102
  1. import tensorflow as tf
  2. from tensorflow.keras import datasets, layers, optimizers, Sequential, metrics
  3. from tensorflow import keras
  4. def preprocess(x, y):
  5. """
  6. x is a simple image, not a batch
  7. """
  8. x = tf.cast(x, dtype=tf.float32) / 255.
  9. x = tf.reshape(x, [28*28])
  10. y = tf.cast(y, dtype=tf.int32)
  11. y = tf.one_hot(y, depth=10)
  12. return x,y
  13. batchsz = 128
  14. (x, y), (x_val, y_val) = datasets.mnist.load_data()
  15. print('datasets:', x.shape, y.shape, x.min(), x.max())
  16. db = tf.data.Dataset.from_tensor_slices((x,y))
  17. db = db.map(preprocess).shuffle(60000).batch(batchsz)
  18. ds_val = tf.data.Dataset.from_tensor_slices((x_val, y_val))
  19. ds_val = ds_val.map(preprocess).batch(batchsz)
  20. sample = next(iter(db))
  21. print(sample[0].shape, sample[1].shape)
  22. network = Sequential([layers.Dense(256, activation='relu'),
  23. layers.Dense(128, activation='relu'),
  24. layers.Dense(64, activation='relu'),
  25. layers.Dense(32, activation='relu'),
  26. layers.Dense(10)])
  27. network.build(input_shape=(None, 28*28))
  28. network.summary()
  29. class MyDense(layers.Layer):
  30. def __init__(self, inp_dim, outp_dim):
  31. super(MyDense, self).__init__()
  32. self.kernel = self.add_weight('w', [inp_dim, outp_dim])
  33. self.bias = self.add_weight('b', [outp_dim])
  34. def call(self, inputs, training=None):
  35. out = inputs @ self.kernel + self.bias
  36. return out
  37. class MyModel(keras.Model):
  38. def __init__(self):
  39. super(MyModel, self).__init__()
  40. self.fc1 = MyDense(28*28, 256)
  41. self.fc2 = MyDense(256, 128)
  42. self.fc3 = MyDense(128, 64)
  43. self.fc4 = MyDense(64, 32)
  44. self.fc5 = MyDense(32, 10)
  45. def call(self, inputs, training=None):
  46. x = self.fc1(inputs)
  47. x = tf.nn.relu(x)
  48. x = self.fc2(x)
  49. x = tf.nn.relu(x)
  50. x = self.fc3(x)
  51. x = tf.nn.relu(x)
  52. x = self.fc4(x)
  53. x = tf.nn.relu(x)
  54. x = self.fc5(x)
  55. return x
  56. network = MyModel()
  57. network.compile(optimizer=optimizers.Adam(lr=0.01),
  58. loss=tf.losses.CategoricalCrossentropy(from_logits=True),
  59. metrics=['accuracy']
  60. )
  61. network.fit(db, epochs=5, validation_data=ds_val,
  62. validation_freq=2)
  63. network.evaluate(ds_val)
  64. sample = next(iter(ds_val))
  65. x = sample[0]
  66. y = sample[1] # one-hot
  67. pred = network.predict(x) # [b, 10]
  68. # convert back to number
  69. y = tf.argmax(y, axis=1)
  70. pred = tf.argmax(pred, axis=1)
  71. print(pred)
  72. print(y)