mnist_tensor.py 3.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152
  1. #%%
  2. import matplotlib
  3. from matplotlib import pyplot as plt
  4. # Default parameters for plots
  5. matplotlib.rcParams['font.size'] = 20
  6. matplotlib.rcParams['figure.titlesize'] = 20
  7. matplotlib.rcParams['figure.figsize'] = [9, 7]
  8. matplotlib.rcParams['font.family'] = ['STKaiTi']
  9. matplotlib.rcParams['axes.unicode_minus']=False
  10. import tensorflow as tf
  11. from tensorflow import keras
  12. from tensorflow.keras import datasets, layers, optimizers
  13. import os
  14. os.environ['TF_CPP_MIN_LOG_LEVEL']='2'
  15. print(tf.__version__)
  16. def preprocess(x, y):
  17. # [b, 28, 28], [b]
  18. print(x.shape,y.shape)
  19. x = tf.cast(x, dtype=tf.float32) / 255.
  20. x = tf.reshape(x, [-1, 28*28])
  21. y = tf.cast(y, dtype=tf.int32)
  22. y = tf.one_hot(y, depth=10)
  23. return x,y
  24. #%%
  25. (x, y), (x_test, y_test) = datasets.mnist.load_data()
  26. print('x:', x.shape, 'y:', y.shape, 'x test:', x_test.shape, 'y test:', y_test)
  27. #%%
  28. batchsz = 512
  29. train_db = tf.data.Dataset.from_tensor_slices((x, y))
  30. train_db = train_db.shuffle(1000)
  31. train_db = train_db.batch(batchsz)
  32. train_db = train_db.map(preprocess)
  33. train_db = train_db.repeat(20)
  34. #%%
  35. test_db = tf.data.Dataset.from_tensor_slices((x_test, y_test))
  36. test_db = test_db.shuffle(1000).batch(batchsz).map(preprocess)
  37. x,y = next(iter(train_db))
  38. print('train sample:', x.shape, y.shape)
  39. # print(x[0], y[0])
  40. #%%
  41. def main():
  42. # learning rate
  43. lr = 1e-2
  44. accs,losses = [], []
  45. # 784 => 512
  46. w1, b1 = tf.Variable(tf.random.normal([784, 256], stddev=0.1)), tf.Variable(tf.zeros([256]))
  47. # 512 => 256
  48. w2, b2 = tf.Variable(tf.random.normal([256, 128], stddev=0.1)), tf.Variable(tf.zeros([128]))
  49. # 256 => 10
  50. w3, b3 = tf.Variable(tf.random.normal([128, 10], stddev=0.1)), tf.Variable(tf.zeros([10]))
  51. for step, (x,y) in enumerate(train_db):
  52. # [b, 28, 28] => [b, 784]
  53. x = tf.reshape(x, (-1, 784))
  54. with tf.GradientTape() as tape:
  55. # layer1.
  56. h1 = x @ w1 + b1
  57. h1 = tf.nn.relu(h1)
  58. # layer2
  59. h2 = h1 @ w2 + b2
  60. h2 = tf.nn.relu(h2)
  61. # output
  62. out = h2 @ w3 + b3
  63. # out = tf.nn.relu(out)
  64. # compute loss
  65. # [b, 10] - [b, 10]
  66. loss = tf.square(y-out)
  67. # [b, 10] => scalar
  68. loss = tf.reduce_mean(loss)
  69. grads = tape.gradient(loss, [w1, b1, w2, b2, w3, b3])
  70. for p, g in zip([w1, b1, w2, b2, w3, b3], grads):
  71. p.assign_sub(lr * g)
  72. # print
  73. if step % 80 == 0:
  74. print(step, 'loss:', float(loss))
  75. losses.append(float(loss))
  76. if step %80 == 0:
  77. # evaluate/test
  78. total, total_correct = 0., 0
  79. for x, y in test_db:
  80. # layer1.
  81. h1 = x @ w1 + b1
  82. h1 = tf.nn.relu(h1)
  83. # layer2
  84. h2 = h1 @ w2 + b2
  85. h2 = tf.nn.relu(h2)
  86. # output
  87. out = h2 @ w3 + b3
  88. # [b, 10] => [b]
  89. pred = tf.argmax(out, axis=1)
  90. # convert one_hot y to number y
  91. y = tf.argmax(y, axis=1)
  92. # bool type
  93. correct = tf.equal(pred, y)
  94. # bool tensor => int tensor => numpy
  95. total_correct += tf.reduce_sum(tf.cast(correct, dtype=tf.int32)).numpy()
  96. total += x.shape[0]
  97. print(step, 'Evaluate Acc:', total_correct/total)
  98. accs.append(total_correct/total)
  99. plt.figure()
  100. x = [i*80 for i in range(len(losses))]
  101. plt.plot(x, losses, color='C0', marker='s', label='训练')
  102. plt.ylabel('MSE')
  103. plt.xlabel('Step')
  104. plt.legend()
  105. plt.savefig('train.svg')
  106. plt.figure()
  107. plt.plot(x, accs, color='C1', marker='s', label='测试')
  108. plt.ylabel('准确率')
  109. plt.xlabel('Step')
  110. plt.legend()
  111. plt.savefig('test.svg')
  112. if __name__ == '__main__':
  113. main()