resnet.py 3.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119
  1. import os
  2. import tensorflow as tf
  3. import numpy as np
  4. from tensorflow import keras
  5. from tensorflow.keras import layers
  6. tf.random.set_seed(22)
  7. np.random.seed(22)
  8. os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
  9. assert tf.__version__.startswith('2.')
  10. class ResnetBlock(keras.Model):
  11. def __init__(self, channels, strides=1):
  12. super(ResnetBlock, self).__init__()
  13. self.channels = channels
  14. self.strides = strides
  15. self.conv1 = layers.Conv2D(channels, 3, strides=strides,
  16. padding=[[0,0],[1,1],[1,1],[0,0]])
  17. self.bn1 = keras.layers.BatchNormalization()
  18. self.conv2 = layers.Conv2D(channels, 3, strides=1,
  19. padding=[[0,0],[1,1],[1,1],[0,0]])
  20. self.bn2 = keras.layers.BatchNormalization()
  21. if strides!=1:
  22. self.down_conv = layers.Conv2D(channels, 1, strides=strides, padding='valid')
  23. self.down_bn = tf.keras.layers.BatchNormalization()
  24. def call(self, inputs, training=None):
  25. residual = inputs
  26. x = self.conv1(inputs)
  27. x = tf.nn.relu(x)
  28. x = self.bn1(x, training=training)
  29. x = self.conv2(x)
  30. x = tf.nn.relu(x)
  31. x = self.bn2(x, training=training)
  32. # 残差连接
  33. if self.strides!=1:
  34. residual = self.down_conv(inputs)
  35. residual = tf.nn.relu(residual)
  36. residual = self.down_bn(residual, training=training)
  37. x = x + residual
  38. x = tf.nn.relu(x)
  39. return x
  40. class ResNet(keras.Model):
  41. def __init__(self, num_classes, initial_filters=16, **kwargs):
  42. super(ResNet, self).__init__(**kwargs)
  43. self.stem = layers.Conv2D(initial_filters, 3, strides=3, padding='valid')
  44. self.blocks = keras.models.Sequential([
  45. ResnetBlock(initial_filters * 2, strides=3),
  46. ResnetBlock(initial_filters * 2, strides=1),
  47. # layers.Dropout(rate=0.5),
  48. ResnetBlock(initial_filters * 4, strides=3),
  49. ResnetBlock(initial_filters * 4, strides=1),
  50. ResnetBlock(initial_filters * 8, strides=2),
  51. ResnetBlock(initial_filters * 8, strides=1),
  52. ResnetBlock(initial_filters * 16, strides=2),
  53. ResnetBlock(initial_filters * 16, strides=1),
  54. ])
  55. self.final_bn = layers.BatchNormalization()
  56. self.avg_pool = layers.GlobalMaxPool2D()
  57. self.fc = layers.Dense(num_classes)
  58. def call(self, inputs, training=None):
  59. # print('x:',inputs.shape)
  60. out = self.stem(inputs)
  61. out = tf.nn.relu(out)
  62. # print('stem:',out.shape)
  63. out = self.blocks(out, training=training)
  64. # print('res:',out.shape)
  65. out = self.final_bn(out, training=training)
  66. # out = tf.nn.relu(out)
  67. out = self.avg_pool(out)
  68. # print('avg_pool:',out.shape)
  69. out = self.fc(out)
  70. # print('out:',out.shape)
  71. return out
  72. def main():
  73. num_classes = 5
  74. resnet18 = ResNet(5)
  75. resnet18.build(input_shape=(4,224,224,3))
  76. resnet18.summary()
  77. if __name__ == '__main__':
  78. main()