Jelajahi Sumber

添加水印正则化器代码

liyan 1 tahun lalu
induk
melakukan
b236fcaad5
1 mengubah file dengan 25 tambahan dan 0 penghapusan
  1. 25 0
      tensorflow/tf_watermark_regularizers.py

+ 25 - 0
tensorflow/tf_watermark_regularizers.py

@@ -0,0 +1,25 @@
+from tensorflow.python.keras.losses import binary_crossentropy
+from tensorflow.python.keras.regularizers import Regularizer
+import numpy as np
+import tensorflow as tf
+
+
+class WatermarkRegularizer(Regularizer):
+
+    def __init__(self, scale, b, randseed='none'):
+        self.scale = tf.constant(scale, dtype=tf.float32)
+        self.b = b  # 密钥
+        self.x = None  # 投影矩阵
+        self.v_x = None  # 投影矩阵设为可训练参数
+
+    def __call__(self, weight):
+        if self.x is None:
+            x_rows = np.prod(weight.shape[0:3])
+            x_cols = self.b.shape[1]
+            self.x = np.random.randn(x_rows, x_cols)
+            self.v_x = tf.Variable(self.x, trainable=False, dtype=tf.float32)
+            self.b = tf.constant(self.b, dtype=tf.float32)
+        weight_mean = tf.reduce_mean(weight, axis=3)
+        w = tf.reshape(weight_mean, (1, tf.reduce_prod(weight_mean.shape)))
+        regularized_loss = self.scale * tf.reduce_sum(binary_crossentropy(tf.sigmoid(tf.matmul(w, self.v_x))), self.b)
+        return regularized_loss