liyan 1 рік тому
батько
коміт
77b8d67585
2 змінених файлів з 20 додано та 2 видалено
  1. 6 1
      tf_watermark/tf_watermark_utils.py
  2. 14 1
      verify_cifar10_inception10.py

+ 6 - 1
tf_watermark/tf_watermark_utils.py

@@ -1,4 +1,3 @@
-
 import numpy as np
 
 from tf_watermark.tf_watermark_regularizers import WatermarkRegularizer
@@ -26,6 +25,12 @@ def save_wmark_signatures(model):
         np.save(fname_b, wmark_regularizer.get_signature())
 
 
+def save_wmark_signatures_by_layer(layer):
+    wmark_regularizer = layer.kernel_regularizer
+    np.save(fname_x, wmark_regularizer.get_matrix())
+    np.save(fname_b, wmark_regularizer.get_signature())
+
+
 # 获取指定层权重和密钥预测值
 def get_layer_weights_and_predicted(target_layer):
     x = np.load(fname_x)

+ 14 - 1
verify_cifar10_inception10.py

@@ -6,12 +6,25 @@ from keras.layers import Conv2D, BatchNormalization, Activation, MaxPool2D, Drop
     GlobalAveragePooling2D
 from keras import Model
 
+from tf_watermark.tf_watermark_regularizers import WatermarkRegularizer
+
 np.set_printoptions(threshold=np.inf)
 
 cifar10 = tf.keras.datasets.cifar10
 (x_train, y_train), (x_test, y_test) = cifar10.load_data()
 x_train, x_test = x_train / 255.0, x_test / 255.0
 
+# 初始化参数
+scale = 0.01  # 正则化项偏置系数
+randseed = 5  # 投影矩阵生成随机数种子
+embed_dim = 768  # 密钥长度
+np.random.seed(5)
+b = np.random.randint(low=0, high=2, size=(1, embed_dim))  # 生成模拟随机密钥
+epoch = 25
+
+# 初始化水印正则化器
+watermark_regularizer = WatermarkRegularizer(scale, b)
+
 
 class ConvBNRelu(Model):
     def __init__(self, ch, kernelsz=3, strides=1, padding='same'):
@@ -97,7 +110,7 @@ cp_callback = tf.keras.callbacks.ModelCheckpoint(filepath=checkpoint_save_path,
                                                  save_weights_only=True,
                                                  save_best_only=True)
 
-history = model.fit(x_train, y_train, batch_size=32, epochs=5, validation_data=(x_test, y_test), validation_freq=1,
+history = model.fit(x_train, y_train, batch_size=32, epochs=epoch, validation_data=(x_test, y_test), validation_freq=1,
                     callbacks=[cp_callback])
 model.summary()