resnet.py 3.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101
  1. import tensorflow as tf
  2. from tensorflow import keras
  3. from tensorflow.keras import layers, Sequential
  4. class BasicBlock(layers.Layer):
  5. # 残差模块
  6. def __init__(self, filter_num, stride=1):
  7. super(BasicBlock, self).__init__()
  8. # 第一个卷积单元
  9. self.conv1 = layers.Conv2D(filter_num, (3, 3), strides=stride, padding='same')
  10. self.bn1 = layers.BatchNormalization()
  11. self.relu = layers.Activation('relu')
  12. # 第二个卷积单元
  13. self.conv2 = layers.Conv2D(filter_num, (3, 3), strides=1, padding='same')
  14. self.bn2 = layers.BatchNormalization()
  15. if stride != 1:# 通过1x1卷积完成shape匹配
  16. self.downsample = Sequential()
  17. self.downsample.add(layers.Conv2D(filter_num, (1, 1), strides=stride))
  18. else:# shape匹配,直接短接
  19. self.downsample = lambda x:x
  20. def call(self, inputs, training=None):
  21. # [b, h, w, c],通过第一个卷积单元
  22. out = self.conv1(inputs)
  23. out = self.bn1(out)
  24. out = self.relu(out)
  25. # 通过第二个卷积单元
  26. out = self.conv2(out)
  27. out = self.bn2(out)
  28. # 通过identity模块
  29. identity = self.downsample(inputs)
  30. # 2条路径输出直接相加
  31. output = layers.add([out, identity])
  32. output = tf.nn.relu(output) # 激活函数
  33. return output
  34. class ResNet(keras.Model):
  35. # 通用的ResNet实现类
  36. def __init__(self, layer_dims, num_classes=10): # [2, 2, 2, 2]
  37. super(ResNet, self).__init__()
  38. # 根网络,预处理
  39. self.stem = Sequential([layers.Conv2D(64, (3, 3), strides=(1, 1)),
  40. layers.BatchNormalization(),
  41. layers.Activation('relu'),
  42. layers.MaxPool2D(pool_size=(2, 2), strides=(1, 1), padding='same')
  43. ])
  44. # 堆叠4个Block,每个block包含了多个BasicBlock,设置步长不一样
  45. self.layer1 = self.build_resblock(64, layer_dims[0])
  46. self.layer2 = self.build_resblock(128, layer_dims[1], stride=2)
  47. self.layer3 = self.build_resblock(256, layer_dims[2], stride=2)
  48. self.layer4 = self.build_resblock(512, layer_dims[3], stride=2)
  49. # 通过Pooling层将高宽降低为1x1
  50. self.avgpool = layers.GlobalAveragePooling2D()
  51. # 最后连接一个全连接层分类
  52. self.fc = layers.Dense(num_classes)
  53. def call(self, inputs, training=None):
  54. # 通过根网络
  55. x = self.stem(inputs)
  56. # 一次通过4个模块
  57. x = self.layer1(x)
  58. x = self.layer2(x)
  59. x = self.layer3(x)
  60. x = self.layer4(x)
  61. # 通过池化层
  62. x = self.avgpool(x)
  63. # 通过全连接层
  64. x = self.fc(x)
  65. return x
  66. def build_resblock(self, filter_num, blocks, stride=1):
  67. # 辅助函数,堆叠filter_num个BasicBlock
  68. res_blocks = Sequential()
  69. # 只有第一个BasicBlock的步长可能不为1,实现下采样
  70. res_blocks.add(BasicBlock(filter_num, stride))
  71. for _ in range(1, blocks):#其他BasicBlock步长都为1
  72. res_blocks.add(BasicBlock(filter_num, stride=1))
  73. return res_blocks
  74. def resnet18():
  75. # 通过调整模块内部BasicBlock的数量和配置实现不同的ResNet
  76. return ResNet([2, 2, 2, 2])
  77. def resnet34():
  78. # 通过调整模块内部BasicBlock的数量和配置实现不同的ResNet
  79. return ResNet([3, 4, 6, 3])