tf_watermark_regularizers.py 1.0 KB

12345678910111213141516171819202122232425
  1. from tensorflow.python.keras.losses import binary_crossentropy
  2. from tensorflow.python.keras.regularizers import Regularizer
  3. import numpy as np
  4. import tensorflow as tf
  5. class WatermarkRegularizer(Regularizer):
  6. def __init__(self, scale, b, randseed='none'):
  7. self.scale = tf.constant(scale, dtype=tf.float32)
  8. self.b = b # 密钥
  9. self.x = None # 投影矩阵
  10. self.v_x = None # 投影矩阵设为可训练参数
  11. def __call__(self, weight):
  12. if self.x is None:
  13. x_rows = np.prod(weight.shape[0:3])
  14. x_cols = self.b.shape[1]
  15. self.x = np.random.randn(x_rows, x_cols)
  16. self.v_x = tf.Variable(self.x, trainable=False, dtype=tf.float32)
  17. self.b = tf.constant(self.b, dtype=tf.float32)
  18. weight_mean = tf.reduce_mean(weight, axis=3)
  19. w = tf.reshape(weight_mean, (1, tf.reduce_prod(weight_mean.shape)))
  20. regularized_loss = self.scale * tf.reduce_sum(binary_crossentropy(tf.sigmoid(tf.matmul(w, self.v_x))), self.b)
  21. return regularized_loss