train_evalute_test.py 1.9 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273
  1. import tensorflow as tf
  2. from tensorflow.keras import datasets, layers, optimizers, Sequential, metrics
  3. def preprocess(x, y):
  4. """
  5. x is a simple image, not a batch
  6. """
  7. x = tf.cast(x, dtype=tf.float32) / 255.
  8. x = tf.reshape(x, [28*28])
  9. y = tf.cast(y, dtype=tf.int32)
  10. y = tf.one_hot(y, depth=10)
  11. return x,y
  12. batchsz = 128
  13. (x, y), (x_test, y_test) = datasets.mnist.load_data()
  14. print('datasets:', x.shape, y.shape, x.min(), x.max())
  15. idx = tf.range(60000)
  16. idx = tf.random.shuffle(idx)
  17. x_train, y_train = tf.gather(x, idx[:50000]), tf.gather(y, idx[:50000])
  18. x_val, y_val = tf.gather(x, idx[-10000:]) , tf.gather(y, idx[-10000:])
  19. print(x_train.shape, y_train.shape, x_val.shape, y_val.shape)
  20. db_train = tf.data.Dataset.from_tensor_slices((x_train,y_train))
  21. db_train = db_train.map(preprocess).shuffle(50000).batch(batchsz)
  22. db_val = tf.data.Dataset.from_tensor_slices((x_val,y_val))
  23. db_val = db_val.map(preprocess).shuffle(10000).batch(batchsz)
  24. db_test = tf.data.Dataset.from_tensor_slices((x_test, y_test))
  25. db_test = db_test.map(preprocess).batch(batchsz)
  26. sample = next(iter(db_train))
  27. print(sample[0].shape, sample[1].shape)
  28. network = Sequential([layers.Dense(256, activation='relu'),
  29. layers.Dense(128, activation='relu'),
  30. layers.Dense(64, activation='relu'),
  31. layers.Dense(32, activation='relu'),
  32. layers.Dense(10)])
  33. network.build(input_shape=(None, 28*28))
  34. network.summary()
  35. network.compile(optimizer=optimizers.Adam(lr=0.01),
  36. loss=tf.losses.CategoricalCrossentropy(from_logits=True),
  37. metrics=['accuracy']
  38. )
  39. network.fit(db_train, epochs=6, validation_data=db_val, validation_freq=2)
  40. print('Test performance:')
  41. network.evaluate(db_test)
  42. sample = next(iter(db_test))
  43. x = sample[0]
  44. y = sample[1] # one-hot
  45. pred = network.predict(x) # [b, 10]
  46. # convert back to number
  47. y = tf.argmax(y, axis=1)
  48. pred = tf.argmax(pred, axis=1)
  49. print(pred)
  50. print(y)