verify_cifar10_inception10.py 5.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144
  1. import tensorflow as tf
  2. import os
  3. import numpy as np
  4. from matplotlib import pyplot as plt
  5. from keras.layers import Conv2D, BatchNormalization, Activation, MaxPool2D, Dropout, Flatten, Dense, \
  6. GlobalAveragePooling2D
  7. from keras import Model
  8. from tf_watermark.tf_watermark_regularizers import WatermarkRegularizer
  9. np.set_printoptions(threshold=np.inf)
  10. cifar10 = tf.keras.datasets.cifar10
  11. (x_train, y_train), (x_test, y_test) = cifar10.load_data()
  12. x_train, x_test = x_train / 255.0, x_test / 255.0
  13. # 初始化参数
  14. scale = 0.01 # 正则化项偏置系数
  15. randseed = 5 # 投影矩阵生成随机数种子
  16. embed_dim = 768 # 密钥长度
  17. np.random.seed(5)
  18. b = np.random.randint(low=0, high=2, size=(1, embed_dim)) # 生成模拟随机密钥
  19. epoch = 25
  20. # 初始化水印正则化器
  21. watermark_regularizer = WatermarkRegularizer(scale, b)
  22. class ConvBNRelu(Model):
  23. def __init__(self, ch, kernelsz=3, strides=1, padding='same'):
  24. super(ConvBNRelu, self).__init__()
  25. self.model = tf.keras.models.Sequential([
  26. Conv2D(ch, kernelsz, strides=strides, padding=padding),
  27. BatchNormalization(),
  28. Activation('relu')
  29. ])
  30. def call(self, x):
  31. x = self.model(x, training=False) #在training=False时,BN通过整个训练集计算均值、方差去做批归一化,training=True时,通过当前batch的均值、方差去做批归一化。推理时 training=False效果好
  32. return x
  33. class InceptionBlk(Model):
  34. def __init__(self, ch, strides=1):
  35. super(InceptionBlk, self).__init__()
  36. self.ch = ch
  37. self.strides = strides
  38. self.c1 = ConvBNRelu(ch, kernelsz=1, strides=strides)
  39. self.c2_1 = ConvBNRelu(ch, kernelsz=1, strides=strides)
  40. self.c2_2 = ConvBNRelu(ch, kernelsz=3, strides=1)
  41. self.c3_1 = ConvBNRelu(ch, kernelsz=1, strides=strides)
  42. self.c3_2 = ConvBNRelu(ch, kernelsz=5, strides=1)
  43. self.p4_1 = MaxPool2D(3, strides=1, padding='same')
  44. self.c4_2 = ConvBNRelu(ch, kernelsz=1, strides=strides)
  45. def call(self, x):
  46. x1 = self.c1(x)
  47. x2_1 = self.c2_1(x)
  48. x2_2 = self.c2_2(x2_1)
  49. x3_1 = self.c3_1(x)
  50. x3_2 = self.c3_2(x3_1)
  51. x4_1 = self.p4_1(x)
  52. x4_2 = self.c4_2(x4_1)
  53. # concat along axis=channel
  54. x = tf.concat([x1, x2_2, x3_2, x4_2], axis=3)
  55. return x
  56. class Inception10(Model):
  57. def __init__(self, num_blocks, num_classes, init_ch=16, **kwargs):
  58. super(Inception10, self).__init__(**kwargs)
  59. self.in_channels = init_ch
  60. self.out_channels = init_ch
  61. self.num_blocks = num_blocks
  62. self.init_ch = init_ch
  63. self.c1 = ConvBNRelu(init_ch)
  64. self.blocks = tf.keras.models.Sequential()
  65. for block_id in range(num_blocks):
  66. for layer_id in range(2):
  67. if layer_id == 0:
  68. block = InceptionBlk(self.out_channels, strides=2)
  69. else:
  70. block = InceptionBlk(self.out_channels, strides=1)
  71. self.blocks.add(block)
  72. # enlarger out_channels per block
  73. self.out_channels *= 2
  74. self.p1 = GlobalAveragePooling2D()
  75. self.f1 = Dense(num_classes, activation='softmax')
  76. def call(self, x):
  77. x = self.c1(x)
  78. x = self.blocks(x)
  79. x = self.p1(x)
  80. y = self.f1(x)
  81. return y
  82. model = Inception10(num_blocks=2, num_classes=10)
  83. model.compile(optimizer='adam',
  84. loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=False),
  85. metrics=['sparse_categorical_accuracy'])
  86. checkpoint_save_path = "./checkpoint/Inception10.ckpt"
  87. if os.path.exists(checkpoint_save_path + '.index'):
  88. print('-------------load the model-----------------')
  89. model.load_weights(checkpoint_save_path)
  90. cp_callback = tf.keras.callbacks.ModelCheckpoint(filepath=checkpoint_save_path,
  91. save_weights_only=True,
  92. save_best_only=True)
  93. history = model.fit(x_train, y_train, batch_size=32, epochs=epoch, validation_data=(x_test, y_test), validation_freq=1,
  94. callbacks=[cp_callback])
  95. model.summary()
  96. # print(model.trainable_variables)
  97. file = open('./weights.txt', 'w')
  98. for v in model.trainable_variables:
  99. file.write(str(v.name) + '\n')
  100. file.write(str(v.shape) + '\n')
  101. file.write(str(v.numpy()) + '\n')
  102. file.close()
  103. ############################################### show ###############################################
  104. # 显示训练集和验证集的acc和loss曲线
  105. acc = history.history['sparse_categorical_accuracy']
  106. val_acc = history.history['val_sparse_categorical_accuracy']
  107. loss = history.history['loss']
  108. val_loss = history.history['val_loss']
  109. plt.subplot(1, 2, 1)
  110. plt.plot(acc, label='Training Accuracy')
  111. plt.plot(val_acc, label='Validation Accuracy')
  112. plt.title('Training and Validation Accuracy')
  113. plt.legend()
  114. plt.subplot(1, 2, 2)
  115. plt.plot(loss, label='Training Loss')
  116. plt.plot(val_loss, label='Validation Loss')
  117. plt.title('Training and Validation Loss')
  118. plt.legend()
  119. plt.show()