verify_cifar10_vgg16.py 7.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206
  1. import tensorflow as tf
  2. import os
  3. import numpy as np
  4. from matplotlib import pyplot as plt
  5. from keras.layers import Conv2D, BatchNormalization, Activation, MaxPool2D, Dropout, Flatten, Dense
  6. from keras import Model
  7. from tf_watermark.tf_watermark_regularizers import WatermarkRegularizer
  8. from tf_watermark.tf_watermark_utils import save_wmark_signatures, get_layer_weights_and_predicted
  9. np.set_printoptions(threshold=np.inf)
  10. cifar10 = tf.keras.datasets.cifar10
  11. (x_train, y_train), (x_test, y_test) = cifar10.load_data()
  12. x_train, x_test = x_train / 255.0, x_test / 255.0
  13. # 初始化参数
  14. scale = 0.01 # 正则化项偏置系数
  15. randseed = 5 # 投影矩阵生成随机数种子
  16. embed_dim = 768 # 密钥长度
  17. np.random.seed(5)
  18. b = np.random.randint(low=0, high=2, size=(1, embed_dim)) # 生成模拟随机密钥
  19. epoch = 25
  20. # 初始化水印正则化器
  21. watermark_regularizer = WatermarkRegularizer(scale, b)
  22. class VGG16(Model):
  23. def __init__(self):
  24. super(VGG16, self).__init__()
  25. self.c1 = Conv2D(filters=64, kernel_size=(3, 3), padding='same') # 卷积层1
  26. self.b1 = BatchNormalization() # BN层1
  27. self.a1 = Activation('relu') # 激活层1
  28. self.c2 = Conv2D(filters=64, kernel_size=(3, 3), padding='same', )
  29. self.b2 = BatchNormalization() # BN层1
  30. self.a2 = Activation('relu') # 激活层1
  31. self.p1 = MaxPool2D(pool_size=(2, 2), strides=2, padding='same')
  32. self.d1 = Dropout(0.2) # dropout层
  33. self.c3 = Conv2D(filters=128, kernel_size=(3, 3), padding='same')
  34. self.b3 = BatchNormalization() # BN层1
  35. self.a3 = Activation('relu') # 激活层1
  36. self.c4 = Conv2D(filters=128, kernel_size=(3, 3), padding='same')
  37. self.b4 = BatchNormalization() # BN层1
  38. self.a4 = Activation('relu') # 激活层1
  39. self.p2 = MaxPool2D(pool_size=(2, 2), strides=2, padding='same')
  40. self.d2 = Dropout(0.2) # dropout层
  41. self.c5 = Conv2D(filters=256, kernel_size=(3, 3), padding='same')
  42. self.b5 = BatchNormalization() # BN层1
  43. self.a5 = Activation('relu') # 激活层1
  44. self.c6 = Conv2D(filters=256, kernel_size=(3, 3), padding='same', kernel_regularizer=watermark_regularizer)
  45. self.b6 = BatchNormalization() # BN层1
  46. self.a6 = Activation('relu') # 激活层1
  47. self.c7 = Conv2D(filters=256, kernel_size=(3, 3), padding='same')
  48. self.b7 = BatchNormalization()
  49. self.a7 = Activation('relu')
  50. self.p3 = MaxPool2D(pool_size=(2, 2), strides=2, padding='same')
  51. self.d3 = Dropout(0.2)
  52. self.c8 = Conv2D(filters=512, kernel_size=(3, 3), padding='same')
  53. self.b8 = BatchNormalization() # BN层1
  54. self.a8 = Activation('relu') # 激活层1
  55. self.c9 = Conv2D(filters=512, kernel_size=(3, 3), padding='same')
  56. self.b9 = BatchNormalization() # BN层1
  57. self.a9 = Activation('relu') # 激活层1
  58. self.c10 = Conv2D(filters=512, kernel_size=(3, 3), padding='same')
  59. self.b10 = BatchNormalization()
  60. self.a10 = Activation('relu')
  61. self.p4 = MaxPool2D(pool_size=(2, 2), strides=2, padding='same')
  62. self.d4 = Dropout(0.2)
  63. self.c11 = Conv2D(filters=512, kernel_size=(3, 3), padding='same')
  64. self.b11 = BatchNormalization() # BN层1
  65. self.a11 = Activation('relu') # 激活层1
  66. self.c12 = Conv2D(filters=512, kernel_size=(3, 3), padding='same')
  67. self.b12 = BatchNormalization() # BN层1
  68. self.a12 = Activation('relu') # 激活层1
  69. self.c13 = Conv2D(filters=512, kernel_size=(3, 3), padding='same')
  70. self.b13 = BatchNormalization()
  71. self.a13 = Activation('relu')
  72. self.p5 = MaxPool2D(pool_size=(2, 2), strides=2, padding='same')
  73. self.d5 = Dropout(0.2)
  74. self.flatten = Flatten()
  75. self.f1 = Dense(512, activation='relu')
  76. self.d6 = Dropout(0.2)
  77. self.f2 = Dense(512, activation='relu')
  78. self.d7 = Dropout(0.2)
  79. self.f3 = Dense(10, activation='softmax')
  80. def call(self, x):
  81. x = self.c1(x)
  82. x = self.b1(x)
  83. x = self.a1(x)
  84. x = self.c2(x)
  85. x = self.b2(x)
  86. x = self.a2(x)
  87. x = self.p1(x)
  88. x = self.d1(x)
  89. x = self.c3(x)
  90. x = self.b3(x)
  91. x = self.a3(x)
  92. x = self.c4(x)
  93. x = self.b4(x)
  94. x = self.a4(x)
  95. x = self.p2(x)
  96. x = self.d2(x)
  97. x = self.c5(x)
  98. x = self.b5(x)
  99. x = self.a5(x)
  100. x = self.c6(x)
  101. x = self.b6(x)
  102. x = self.a6(x)
  103. x = self.c7(x)
  104. x = self.b7(x)
  105. x = self.a7(x)
  106. x = self.p3(x)
  107. x = self.d3(x)
  108. x = self.c8(x)
  109. x = self.b8(x)
  110. x = self.a8(x)
  111. x = self.c9(x)
  112. x = self.b9(x)
  113. x = self.a9(x)
  114. x = self.c10(x)
  115. x = self.b10(x)
  116. x = self.a10(x)
  117. x = self.p4(x)
  118. x = self.d4(x)
  119. x = self.c11(x)
  120. x = self.b11(x)
  121. x = self.a11(x)
  122. x = self.c12(x)
  123. x = self.b12(x)
  124. x = self.a12(x)
  125. x = self.c13(x)
  126. x = self.b13(x)
  127. x = self.a13(x)
  128. x = self.p5(x)
  129. x = self.d5(x)
  130. x = self.flatten(x)
  131. x = self.f1(x)
  132. x = self.d6(x)
  133. x = self.f2(x)
  134. x = self.d7(x)
  135. y = self.f3(x)
  136. return y
  137. model = VGG16()
  138. model.compile(optimizer='adam',
  139. loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=False),
  140. metrics=['sparse_categorical_accuracy'])
  141. checkpoint_save_path = "./checkpoint/VGG16.ckpt"
  142. if os.path.exists(checkpoint_save_path + '.index'):
  143. print('-------------load the model-----------------')
  144. model.load_weights(checkpoint_save_path)
  145. cp_callback = tf.keras.callbacks.ModelCheckpoint(filepath=checkpoint_save_path,
  146. save_weights_only=True,
  147. save_best_only=True)
  148. history = model.fit(x_train, y_train, batch_size=64, epochs=epoch, validation_data=(x_test, y_test), validation_freq=1,
  149. callbacks=[cp_callback])
  150. model.summary()
  151. ############################################### verify watermarker ###################################
  152. # 保存投影矩阵和密钥
  153. save_wmark_signatures(model)
  154. target_layer = model.get_layer(index=19)
  155. layer_weights, pred_bparam = get_layer_weights_and_predicted(target_layer)
  156. print("b_param:")
  157. print(b)
  158. print("pred_bparam:")
  159. print(pred_bparam)
  160. print(np.sum(b != pred_bparam))
  161. ############################################### show ###############################################
  162. # 显示训练集和验证集的acc和loss曲线
  163. acc = history.history['sparse_categorical_accuracy']
  164. val_acc = history.history['val_sparse_categorical_accuracy']
  165. loss = history.history['loss']
  166. val_loss = history.history['val_loss']
  167. plt.subplot(1, 2, 1)
  168. plt.plot(acc, label='Training Accuracy')
  169. plt.plot(val_acc, label='Validation Accuracy')
  170. plt.title('Training and Validation Accuracy')
  171. plt.legend()
  172. plt.subplot(1, 2, 2)
  173. plt.plot(loss, label='Training Loss')
  174. plt.plot(val_loss, label='Validation Loss')
  175. plt.title('Training and Validation Loss')
  176. plt.legend()
  177. plt.show()