verify_cifar10_inception10.py 4.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131
  1. import tensorflow as tf
  2. import os
  3. import numpy as np
  4. from matplotlib import pyplot as plt
  5. from tensorflow.keras.layers import Conv2D, BatchNormalization, Activation, MaxPool2D, Dropout, Flatten, Dense, \
  6. GlobalAveragePooling2D
  7. from tensorflow.keras import Model
  8. np.set_printoptions(threshold=np.inf)
  9. cifar10 = tf.keras.datasets.cifar10
  10. (x_train, y_train), (x_test, y_test) = cifar10.load_data()
  11. x_train, x_test = x_train / 255.0, x_test / 255.0
  12. class ConvBNRelu(Model):
  13. def __init__(self, ch, kernelsz=3, strides=1, padding='same'):
  14. super(ConvBNRelu, self).__init__()
  15. self.model = tf.keras.models.Sequential([
  16. Conv2D(ch, kernelsz, strides=strides, padding=padding),
  17. BatchNormalization(),
  18. Activation('relu')
  19. ])
  20. def call(self, x):
  21. x = self.model(x, training=False) #在training=False时,BN通过整个训练集计算均值、方差去做批归一化,training=True时,通过当前batch的均值、方差去做批归一化。推理时 training=False效果好
  22. return x
  23. class InceptionBlk(Model):
  24. def __init__(self, ch, strides=1):
  25. super(InceptionBlk, self).__init__()
  26. self.ch = ch
  27. self.strides = strides
  28. self.c1 = ConvBNRelu(ch, kernelsz=1, strides=strides)
  29. self.c2_1 = ConvBNRelu(ch, kernelsz=1, strides=strides)
  30. self.c2_2 = ConvBNRelu(ch, kernelsz=3, strides=1)
  31. self.c3_1 = ConvBNRelu(ch, kernelsz=1, strides=strides)
  32. self.c3_2 = ConvBNRelu(ch, kernelsz=5, strides=1)
  33. self.p4_1 = MaxPool2D(3, strides=1, padding='same')
  34. self.c4_2 = ConvBNRelu(ch, kernelsz=1, strides=strides)
  35. def call(self, x):
  36. x1 = self.c1(x)
  37. x2_1 = self.c2_1(x)
  38. x2_2 = self.c2_2(x2_1)
  39. x3_1 = self.c3_1(x)
  40. x3_2 = self.c3_2(x3_1)
  41. x4_1 = self.p4_1(x)
  42. x4_2 = self.c4_2(x4_1)
  43. # concat along axis=channel
  44. x = tf.concat([x1, x2_2, x3_2, x4_2], axis=3)
  45. return x
  46. class Inception10(Model):
  47. def __init__(self, num_blocks, num_classes, init_ch=16, **kwargs):
  48. super(Inception10, self).__init__(**kwargs)
  49. self.in_channels = init_ch
  50. self.out_channels = init_ch
  51. self.num_blocks = num_blocks
  52. self.init_ch = init_ch
  53. self.c1 = ConvBNRelu(init_ch)
  54. self.blocks = tf.keras.models.Sequential()
  55. for block_id in range(num_blocks):
  56. for layer_id in range(2):
  57. if layer_id == 0:
  58. block = InceptionBlk(self.out_channels, strides=2)
  59. else:
  60. block = InceptionBlk(self.out_channels, strides=1)
  61. self.blocks.add(block)
  62. # enlarger out_channels per block
  63. self.out_channels *= 2
  64. self.p1 = GlobalAveragePooling2D()
  65. self.f1 = Dense(num_classes, activation='softmax')
  66. def call(self, x):
  67. x = self.c1(x)
  68. x = self.blocks(x)
  69. x = self.p1(x)
  70. y = self.f1(x)
  71. return y
  72. model = Inception10(num_blocks=2, num_classes=10)
  73. model.compile(optimizer='adam',
  74. loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=False),
  75. metrics=['sparse_categorical_accuracy'])
  76. checkpoint_save_path = "./checkpoint/Inception10.ckpt"
  77. if os.path.exists(checkpoint_save_path + '.index'):
  78. print('-------------load the model-----------------')
  79. model.load_weights(checkpoint_save_path)
  80. cp_callback = tf.keras.callbacks.ModelCheckpoint(filepath=checkpoint_save_path,
  81. save_weights_only=True,
  82. save_best_only=True)
  83. history = model.fit(x_train, y_train, batch_size=32, epochs=5, validation_data=(x_test, y_test), validation_freq=1,
  84. callbacks=[cp_callback])
  85. model.summary()
  86. # print(model.trainable_variables)
  87. file = open('./weights.txt', 'w')
  88. for v in model.trainable_variables:
  89. file.write(str(v.name) + '\n')
  90. file.write(str(v.shape) + '\n')
  91. file.write(str(v.numpy()) + '\n')
  92. file.close()
  93. ############################################### show ###############################################
  94. # 显示训练集和验证集的acc和loss曲线
  95. acc = history.history['sparse_categorical_accuracy']
  96. val_acc = history.history['val_sparse_categorical_accuracy']
  97. loss = history.history['loss']
  98. val_loss = history.history['val_loss']
  99. plt.subplot(1, 2, 1)
  100. plt.plot(acc, label='Training Accuracy')
  101. plt.plot(val_acc, label='Validation Accuracy')
  102. plt.title('Training and Validation Accuracy')
  103. plt.legend()
  104. plt.subplot(1, 2, 2)
  105. plt.plot(loss, label='Training Loss')
  106. plt.plot(val_loss, label='Validation Loss')
  107. plt.title('Training and Validation Loss')
  108. plt.legend()
  109. plt.show()