|
@@ -2,7 +2,7 @@ import tensorflow as tf
|
|
import os
|
|
import os
|
|
import numpy as np
|
|
import numpy as np
|
|
from matplotlib import pyplot as plt
|
|
from matplotlib import pyplot as plt
|
|
-from keras.layers import Conv2D, BatchNormalization, Activation, MaxPool2D, Dropout, Flatten, Dense
|
|
|
|
|
|
+from keras.layers import Conv2D, BatchNormalization, Activation, MaxPool2D, Dropout, Flatten, Dense, AveragePooling2D
|
|
from keras import Model
|
|
from keras import Model
|
|
|
|
|
|
from tf_watermark.tf_watermark_regularizers import WatermarkRegularizer
|
|
from tf_watermark.tf_watermark_regularizers import WatermarkRegularizer
|
|
@@ -19,10 +19,12 @@ x_train, x_test = x_train / 255.0, x_test / 255.0
|
|
scale = 0.01 # 正则化项偏置系数
|
|
scale = 0.01 # 正则化项偏置系数
|
|
randseed = 5 # 投影矩阵生成随机数种子
|
|
randseed = 5 # 投影矩阵生成随机数种子
|
|
embed_dim = 768 # 密钥长度
|
|
embed_dim = 768 # 密钥长度
|
|
|
|
+# np.random.seed(5)
|
|
b = np.random.randint(low=0, high=2, size=(1, embed_dim)) # 生成模拟随机密钥
|
|
b = np.random.randint(low=0, high=2, size=(1, embed_dim)) # 生成模拟随机密钥
|
|
|
|
+epoch = 30
|
|
|
|
|
|
# 初始化水印正则化器
|
|
# 初始化水印正则化器
|
|
-watermark_regularizer = WatermarkRegularizer(scale, b, randseed=randseed)
|
|
|
|
|
|
+watermark_regularizer = WatermarkRegularizer(scale, b)
|
|
|
|
|
|
|
|
|
|
class AlexNet8(Model):
|
|
class AlexNet8(Model):
|
|
@@ -36,7 +38,8 @@ class AlexNet8(Model):
|
|
self.c2 = Conv2D(filters=256, kernel_size=(3, 3))
|
|
self.c2 = Conv2D(filters=256, kernel_size=(3, 3))
|
|
self.b2 = BatchNormalization()
|
|
self.b2 = BatchNormalization()
|
|
self.a2 = Activation('relu')
|
|
self.a2 = Activation('relu')
|
|
- self.p2 = MaxPool2D(pool_size=(3, 3), strides=2)
|
|
|
|
|
|
+ # self.p2 = MaxPool2D(pool_size=(3, 3), strides=2)
|
|
|
|
+ self.p2 = AveragePooling2D(pool_size=(2, 2), strides=2)
|
|
|
|
|
|
self.c3 = Conv2D(filters=384, kernel_size=(3, 3), padding='same',
|
|
self.c3 = Conv2D(filters=384, kernel_size=(3, 3), padding='same',
|
|
activation='relu')
|
|
activation='relu')
|
|
@@ -46,7 +49,8 @@ class AlexNet8(Model):
|
|
|
|
|
|
self.c5 = Conv2D(filters=256, kernel_size=(3, 3), padding='same',
|
|
self.c5 = Conv2D(filters=256, kernel_size=(3, 3), padding='same',
|
|
activation='relu')
|
|
activation='relu')
|
|
- self.p3 = MaxPool2D(pool_size=(3, 3), strides=2)
|
|
|
|
|
|
+ # self.p3 = MaxPool2D(pool_size=(3, 3), strides=2)
|
|
|
|
+ self.p3 = AveragePooling2D(pool_size=(2, 2), strides=2)
|
|
|
|
|
|
self.flatten = Flatten()
|
|
self.flatten = Flatten()
|
|
self.f1 = Dense(2048, activation='relu')
|
|
self.f1 = Dense(2048, activation='relu')
|
|
@@ -97,7 +101,7 @@ cp_callback = tf.keras.callbacks.ModelCheckpoint(filepath=checkpoint_save_path,
|
|
save_weights_only=True,
|
|
save_weights_only=True,
|
|
save_best_only=True)
|
|
save_best_only=True)
|
|
|
|
|
|
-history = model.fit(x_train, y_train, batch_size=128, epochs=30, validation_data=(x_test, y_test), validation_freq=1,
|
|
|
|
|
|
+history = model.fit(x_train, y_train, batch_size=128, epochs=epoch, validation_data=(x_test, y_test), validation_freq=1,
|
|
callbacks=[cp_callback])
|
|
callbacks=[cp_callback])
|
|
model.summary()
|
|
model.summary()
|
|
|
|
|
|
@@ -113,13 +117,12 @@ model.summary()
|
|
save_wmark_signatures(model)
|
|
save_wmark_signatures(model)
|
|
|
|
|
|
############################################### verify watermarker ###################################
|
|
############################################### verify watermarker ###################################
|
|
-
|
|
|
|
-layer_weights, pred_bparam = get_layer_weights_and_predicted(model, checkpoint_save_path, 9)
|
|
|
|
|
|
+target_layer = model.get_layer(index=9)
|
|
|
|
+layer_weights, pred_bparam = get_layer_weights_and_predicted(target_layer)
|
|
print("b_param:")
|
|
print("b_param:")
|
|
print(b)
|
|
print(b)
|
|
print("pred_bparam:")
|
|
print("pred_bparam:")
|
|
print(pred_bparam)
|
|
print(pred_bparam)
|
|
-print(b == pred_bparam)
|
|
|
|
print(np.sum(b != pred_bparam))
|
|
print(np.sum(b != pred_bparam))
|
|
|
|
|
|
############################################### show ###############################################
|
|
############################################### show ###############################################
|