Browse Source

修改alexnet测试代码

liyan 1 year ago
parent
commit
db09f69d95
1 changed files with 11 additions and 8 deletions
  1. 11 8
      verify_cifar10_alexnet8.py

+ 11 - 8
verify_cifar10_alexnet8.py

@@ -2,7 +2,7 @@ import tensorflow as tf
 import os
 import os
 import numpy as np
 import numpy as np
 from matplotlib import pyplot as plt
 from matplotlib import pyplot as plt
-from keras.layers import Conv2D, BatchNormalization, Activation, MaxPool2D, Dropout, Flatten, Dense
+from keras.layers import Conv2D, BatchNormalization, Activation, MaxPool2D, Dropout, Flatten, Dense, AveragePooling2D
 from keras import Model
 from keras import Model
 
 
 from tf_watermark.tf_watermark_regularizers import WatermarkRegularizer
 from tf_watermark.tf_watermark_regularizers import WatermarkRegularizer
@@ -19,10 +19,12 @@ x_train, x_test = x_train / 255.0, x_test / 255.0
 scale = 0.01  # 正则化项偏置系数
 scale = 0.01  # 正则化项偏置系数
 randseed = 5  # 投影矩阵生成随机数种子
 randseed = 5  # 投影矩阵生成随机数种子
 embed_dim = 768  # 密钥长度
 embed_dim = 768  # 密钥长度
+# np.random.seed(5)
 b = np.random.randint(low=0, high=2, size=(1, embed_dim))  # 生成模拟随机密钥
 b = np.random.randint(low=0, high=2, size=(1, embed_dim))  # 生成模拟随机密钥
+epoch = 30
 
 
 # 初始化水印正则化器
 # 初始化水印正则化器
-watermark_regularizer = WatermarkRegularizer(scale, b, randseed=randseed)
+watermark_regularizer = WatermarkRegularizer(scale, b)
 
 
 
 
 class AlexNet8(Model):
 class AlexNet8(Model):
@@ -36,7 +38,8 @@ class AlexNet8(Model):
         self.c2 = Conv2D(filters=256, kernel_size=(3, 3))
         self.c2 = Conv2D(filters=256, kernel_size=(3, 3))
         self.b2 = BatchNormalization()
         self.b2 = BatchNormalization()
         self.a2 = Activation('relu')
         self.a2 = Activation('relu')
-        self.p2 = MaxPool2D(pool_size=(3, 3), strides=2)
+        # self.p2 = MaxPool2D(pool_size=(3, 3), strides=2)
+        self.p2 = AveragePooling2D(pool_size=(2, 2), strides=2)
 
 
         self.c3 = Conv2D(filters=384, kernel_size=(3, 3), padding='same',
         self.c3 = Conv2D(filters=384, kernel_size=(3, 3), padding='same',
                          activation='relu')
                          activation='relu')
@@ -46,7 +49,8 @@ class AlexNet8(Model):
 
 
         self.c5 = Conv2D(filters=256, kernel_size=(3, 3), padding='same',
         self.c5 = Conv2D(filters=256, kernel_size=(3, 3), padding='same',
                          activation='relu')
                          activation='relu')
-        self.p3 = MaxPool2D(pool_size=(3, 3), strides=2)
+        # self.p3 = MaxPool2D(pool_size=(3, 3), strides=2)
+        self.p3 = AveragePooling2D(pool_size=(2, 2), strides=2)
 
 
         self.flatten = Flatten()
         self.flatten = Flatten()
         self.f1 = Dense(2048, activation='relu')
         self.f1 = Dense(2048, activation='relu')
@@ -97,7 +101,7 @@ cp_callback = tf.keras.callbacks.ModelCheckpoint(filepath=checkpoint_save_path,
                                                  save_weights_only=True,
                                                  save_weights_only=True,
                                                  save_best_only=True)
                                                  save_best_only=True)
 
 
-history = model.fit(x_train, y_train, batch_size=128, epochs=30, validation_data=(x_test, y_test), validation_freq=1,
+history = model.fit(x_train, y_train, batch_size=128, epochs=epoch, validation_data=(x_test, y_test), validation_freq=1,
                     callbacks=[cp_callback])
                     callbacks=[cp_callback])
 model.summary()
 model.summary()
 
 
@@ -113,13 +117,12 @@ model.summary()
 save_wmark_signatures(model)
 save_wmark_signatures(model)
 
 
 ###############################################    verify watermarker ###################################
 ###############################################    verify watermarker ###################################
-
-layer_weights, pred_bparam = get_layer_weights_and_predicted(model, checkpoint_save_path, 9)
+target_layer = model.get_layer(index=9)
+layer_weights, pred_bparam = get_layer_weights_and_predicted(target_layer)
 print("b_param:")
 print("b_param:")
 print(b)
 print(b)
 print("pred_bparam:")
 print("pred_bparam:")
 print(pred_bparam)
 print(pred_bparam)
-print(b == pred_bparam)
 print(np.sum(b != pred_bparam))
 print(np.sum(b != pred_bparam))
 
 
 ###############################################    show   ###############################################
 ###############################################    show   ###############################################