train_scratch.py 3.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103
  1. import matplotlib
  2. from matplotlib import pyplot as plt
  3. matplotlib.rcParams['font.size'] = 18
  4. matplotlib.rcParams['figure.titlesize'] = 18
  5. matplotlib.rcParams['figure.figsize'] = [9, 7]
  6. matplotlib.rcParams['font.family'] = ['KaiTi']
  7. matplotlib.rcParams['axes.unicode_minus']=False
  8. import os
  9. import tensorflow as tf
  10. import numpy as np
  11. from tensorflow import keras
  12. from tensorflow.keras import layers,optimizers,losses
  13. from tensorflow.keras.callbacks import EarlyStopping
  14. tf.random.set_seed(1234)
  15. np.random.seed(1234)
  16. os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
  17. assert tf.__version__.startswith('2.')
  18. from pokemon import load_pokemon,normalize
  19. def preprocess(x,y):
  20. # x: 图片的路径,y:图片的数字编码
  21. x = tf.io.read_file(x)
  22. x = tf.image.decode_jpeg(x, channels=3) # RGBA
  23. x = tf.image.resize(x, [244, 244])
  24. x = tf.image.random_flip_left_right(x)
  25. x = tf.image.random_flip_up_down(x)
  26. x = tf.image.random_crop(x, [224,224,3])
  27. # x: [0,255]=> -1~1
  28. x = tf.cast(x, dtype=tf.float32) / 255.
  29. x = normalize(x)
  30. y = tf.convert_to_tensor(y)
  31. y = tf.one_hot(y, depth=5)
  32. return x, y
  33. batchsz = 32
  34. # 创建训练集Datset对象
  35. images, labels, table = load_pokemon('pokemon',mode='train')
  36. db_train = tf.data.Dataset.from_tensor_slices((images, labels))
  37. db_train = db_train.shuffle(1000).map(preprocess).batch(batchsz)
  38. # 创建验证集Datset对象
  39. images2, labels2, table = load_pokemon('pokemon',mode='val')
  40. db_val = tf.data.Dataset.from_tensor_slices((images2, labels2))
  41. db_val = db_val.map(preprocess).batch(batchsz)
  42. # 创建测试集Datset对象
  43. images3, labels3, table = load_pokemon('pokemon',mode='test')
  44. db_test = tf.data.Dataset.from_tensor_slices((images3, labels3))
  45. db_test = db_test.map(preprocess).batch(batchsz)
  46. # 加载DenseNet网络模型,并去掉最后一层全连接层,最后一个池化层设置为max pooling
  47. net = keras.applications.DenseNet121(include_top=False, pooling='max')
  48. # 设计为不参与优化,即MobileNet这部分参数固定不动
  49. net.trainable = True
  50. newnet = keras.Sequential([
  51. net, # 去掉最后一层的DenseNet121
  52. layers.Dense(1024, activation='relu'), # 追加全连接层
  53. layers.BatchNormalization(), # 追加BN层
  54. layers.Dropout(rate=0.5), # 追加Dropout层,防止过拟合
  55. layers.Dense(5) # 根据宝可梦数据的任务,设置最后一层输出节点数为5
  56. ])
  57. newnet.build(input_shape=(4,224,224,3))
  58. newnet.summary()
  59. # 创建Early Stopping类,连续3次不下降则终止
  60. early_stopping = EarlyStopping(
  61. monitor='val_accuracy',
  62. min_delta=0.001,
  63. patience=3
  64. )
  65. newnet.compile(optimizer=optimizers.Adam(lr=1e-3),
  66. loss=losses.CategoricalCrossentropy(from_logits=True),
  67. metrics=['accuracy'])
  68. history = newnet.fit(db_train, validation_data=db_val, validation_freq=1, epochs=100,
  69. callbacks=[early_stopping])
  70. history = history.history
  71. print(history.keys())
  72. print(history['val_accuracy'])
  73. print(history['accuracy'])
  74. test_acc = newnet.evaluate(db_test)
  75. plt.figure()
  76. returns = history['val_accuracy']
  77. plt.plot(np.arange(len(returns)), returns, label='验证准确率')
  78. plt.plot(np.arange(len(returns)), returns, 's')
  79. returns = history['accuracy']
  80. plt.plot(np.arange(len(returns)), returns, label='训练准确率')
  81. plt.plot(np.arange(len(returns)), returns, 's')
  82. plt.plot([len(returns)-1],[test_acc[-1]], 'D', label='测试准确率')
  83. plt.legend()
  84. plt.xlabel('Epoch')
  85. plt.ylabel('准确率')
  86. plt.savefig('scratch.svg')