| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103 |
- import matplotlib
- from matplotlib import pyplot as plt
- matplotlib.rcParams['font.size'] = 18
- matplotlib.rcParams['figure.titlesize'] = 18
- matplotlib.rcParams['figure.figsize'] = [9, 7]
- matplotlib.rcParams['font.family'] = ['KaiTi']
- matplotlib.rcParams['axes.unicode_minus']=False
- import os
- import tensorflow as tf
- import numpy as np
- from tensorflow import keras
- from tensorflow.keras import layers,optimizers,losses
- from tensorflow.keras.callbacks import EarlyStopping
- tf.random.set_seed(2222)
- np.random.seed(2222)
- os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
- assert tf.__version__.startswith('2.')
- from pokemon import load_pokemon,normalize
- def preprocess(x,y):
- # x: 图片的路径,y:图片的数字编码
- x = tf.io.read_file(x)
- x = tf.image.decode_jpeg(x, channels=3) # RGBA
- x = tf.image.resize(x, [244, 244])
- x = tf.image.random_flip_left_right(x)
- x = tf.image.random_flip_up_down(x)
- x = tf.image.random_crop(x, [224,224,3])
- # x: [0,255]=> -1~1
- x = tf.cast(x, dtype=tf.float32) / 255.
- x = normalize(x)
- y = tf.convert_to_tensor(y)
- y = tf.one_hot(y, depth=5)
- return x, y
- batchsz = 32
- # 创建训练集Datset对象
- images, labels, table = load_pokemon('pokemon',mode='train')
- db_train = tf.data.Dataset.from_tensor_slices((images, labels))
- db_train = db_train.shuffle(1000).map(preprocess).batch(batchsz)
- # 创建验证集Datset对象
- images2, labels2, table = load_pokemon('pokemon',mode='val')
- db_val = tf.data.Dataset.from_tensor_slices((images2, labels2))
- db_val = db_val.map(preprocess).batch(batchsz)
- # 创建测试集Datset对象
- images3, labels3, table = load_pokemon('pokemon',mode='test')
- db_test = tf.data.Dataset.from_tensor_slices((images3, labels3))
- db_test = db_test.map(preprocess).batch(batchsz)
- # 加载DenseNet网络模型,并去掉最后一层全连接层,最后一个池化层设置为max pooling
- net = keras.applications.DenseNet121(weights='imagenet', include_top=False, pooling='max')
- # 设计为不参与优化,即MobileNet这部分参数固定不动
- net.trainable = True
- newnet = keras.Sequential([
- net, # 去掉最后一层的DenseNet121
- layers.Dense(1024, activation='relu'), # 追加全连接层
- layers.BatchNormalization(), # 追加BN层
- layers.Dropout(rate=0.5), # 追加Dropout层,防止过拟合
- layers.Dense(5) # 根据宝可梦数据的任务,设置最后一层输出节点数为5
- ])
- newnet.build(input_shape=(4,224,224,3))
- newnet.summary()
- # 创建Early Stopping类,连续3次不下降则终止
- early_stopping = EarlyStopping(
- monitor='val_accuracy',
- min_delta=0.001,
- patience=3
- )
- newnet.compile(optimizer=optimizers.Adam(lr=1e-3),
- loss=losses.CategoricalCrossentropy(from_logits=True),
- metrics=['accuracy'])
- history = newnet.fit(db_train, validation_data=db_val, validation_freq=1, epochs=100,
- callbacks=[early_stopping])
- history = history.history
- print(history.keys())
- print(history['val_accuracy'])
- print(history['accuracy'])
- test_acc = newnet.evaluate(db_test)
- plt.figure()
- returns = history['val_accuracy']
- plt.plot(np.arange(len(returns)), returns, label='验证准确率')
- plt.plot(np.arange(len(returns)), returns, 's')
- returns = history['accuracy']
- plt.plot(np.arange(len(returns)), returns, label='训练准确率')
- plt.plot(np.arange(len(returns)), returns, 's')
- plt.plot([len(returns)-1],[test_acc[-1]], 'D', label='测试准确率')
- plt.legend()
- plt.xlabel('Epoch')
- plt.ylabel('准确率')
- plt.savefig('transfer.svg')
|