123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147 |
- import tensorflow as tf
- import os
- import numpy as np
- from matplotlib import pyplot as plt
- from keras.layers import Conv2D, BatchNormalization, Activation, MaxPool2D, Dropout, Flatten, Dense, AveragePooling2D
- from keras import Model
- from tf_watermark.tf_watermark_regularizers import WatermarkRegularizer
- from tf_watermark.tf_watermark_utils import save_wmark_signatures, get_layer_weights_and_predicted
- np.set_printoptions(threshold=np.inf)
- # 加载数据集
- cifar10 = tf.keras.datasets.cifar10
- (x_train, y_train), (x_test, y_test) = cifar10.load_data()
- x_train, x_test = x_train / 255.0, x_test / 255.0
- # 初始化参数
- scale = 0.01 # 正则化项偏置系数
- randseed = 5 # 投影矩阵生成随机数种子
- embed_dim = 768 # 密钥长度
- # np.random.seed(5)
- b = np.random.randint(low=0, high=2, size=(1, embed_dim)) # 生成模拟随机密钥
- epoch = 30
- # 初始化水印正则化器
- watermark_regularizer = WatermarkRegularizer(scale, b)
- class AlexNet8(Model):
- def __init__(self):
- super(AlexNet8, self).__init__()
- self.c1 = Conv2D(filters=96, kernel_size=(3, 3))
- self.b1 = BatchNormalization()
- self.a1 = Activation('relu')
- self.p1 = MaxPool2D(pool_size=(3, 3), strides=2)
- self.c2 = Conv2D(filters=256, kernel_size=(3, 3))
- self.b2 = BatchNormalization()
- self.a2 = Activation('relu')
- # self.p2 = MaxPool2D(pool_size=(3, 3), strides=2)
- self.p2 = AveragePooling2D(pool_size=(2, 2), strides=2)
- self.c3 = Conv2D(filters=384, kernel_size=(3, 3), padding='same',
- activation='relu')
- self.c4 = Conv2D(filters=384, kernel_size=(3, 3), padding='same',
- activation='relu', kernel_regularizer=watermark_regularizer)
- self.c5 = Conv2D(filters=256, kernel_size=(3, 3), padding='same',
- activation='relu')
- # self.p3 = MaxPool2D(pool_size=(3, 3), strides=2)
- self.p3 = AveragePooling2D(pool_size=(2, 2), strides=2)
- self.flatten = Flatten()
- self.f1 = Dense(2048, activation='relu')
- self.d1 = Dropout(0.5)
- self.f2 = Dense(2048, activation='relu')
- self.d2 = Dropout(0.5)
- self.f3 = Dense(10, activation='softmax')
- def call(self, x):
- x = self.c1(x)
- x = self.b1(x)
- x = self.a1(x)
- x = self.p1(x)
- x = self.c2(x)
- x = self.b2(x)
- x = self.a2(x)
- x = self.p2(x)
- x = self.c3(x)
- x = self.c4(x)
- x = self.c5(x)
- x = self.p3(x)
- x = self.flatten(x)
- x = self.f1(x)
- x = self.d1(x)
- x = self.f2(x)
- x = self.d2(x)
- y = self.f3(x)
- return y
- model = AlexNet8()
- model.compile(optimizer='adam',
- loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=False),
- metrics=['sparse_categorical_accuracy'])
- checkpoint_save_path = "./checkpoint/AlexNet8.ckpt"
- if os.path.exists(checkpoint_save_path + '.index'):
- print('-------------load the model-----------------')
- model.load_weights(checkpoint_save_path)
- cp_callback = tf.keras.callbacks.ModelCheckpoint(filepath=checkpoint_save_path,
- save_weights_only=True,
- save_best_only=True)
- history = model.fit(x_train, y_train, batch_size=128, epochs=epoch, validation_data=(x_test, y_test), validation_freq=1,
- callbacks=[cp_callback])
- model.summary()
- # print(model.trainable_variables)
- # file = open('./weights.txt', 'w')
- # for v in model.trainable_variables:
- # file.write(str(v.name) + '\n')
- # file.write(str(v.shape) + '\n')
- # file.write(str(v.numpy()) + '\n')
- # file.close()
- # 保存投影矩阵和密钥
- save_wmark_signatures(model)
- ############################################### verify watermarker ###################################
- target_layer = model.get_layer(index=9)
- layer_weights, pred_bparam = get_layer_weights_and_predicted(target_layer)
- print("b_param:")
- print(b)
- print("pred_bparam:")
- print(pred_bparam)
- print(np.sum(b != pred_bparam))
- ############################################### show ###############################################
- # 显示训练集和验证集的acc和loss曲线
- acc = history.history['sparse_categorical_accuracy']
- val_acc = history.history['val_sparse_categorical_accuracy']
- loss = history.history['loss']
- val_loss = history.history['val_loss']
- plt.subplot(1, 2, 1)
- plt.plot(acc, label='Training Accuracy')
- plt.plot(val_acc, label='Validation Accuracy')
- plt.title('Training and Validation Accuracy')
- plt.legend()
- plt.subplot(1, 2, 2)
- plt.plot(loss, label='Training Loss')
- plt.plot(val_loss, label='Validation Loss')
- plt.title('Training and Validation Loss')
- plt.legend()
- plt.show()
|