Browse Source

修改水印工具类

liyan 1 year ago
parent
commit
26d0c987b5
1 changed files with 3 additions and 5 deletions
  1. 3 5
      tf_watermark/tf_watermark_utils.py

+ 3 - 5
tf_watermark/tf_watermark_utils.py

@@ -26,12 +26,10 @@ def save_wmark_signatures(model):
         np.save(fname_b, wmark_regularizer.get_signature())
 
 
-# 从文件中获取
-def get_layer_weights_and_predicted(model, checkpoint_save_path, target_blk_id):
+# 获取指定层权重和密钥预测值
+def get_layer_weights_and_predicted(target_layer):
     x = np.load(fname_x)
-    model.load_weights(checkpoint_save_path)
-    # get signature from model weight and matrix
-    target_layer = model.get_layer(index=target_blk_id)
+    # get signature from model layer and matrix
     layer_weights = target_layer.get_weights()
     weight = (np.array(layer_weights[0])).mean(axis=3)
     pred_bparam = np.dot(weight.reshape(1, weight.size), x)  # dot product