Ver Fonte

新增将选择的加密层保存在权重文件中

liyan há 1 ano atrás
pai
commit
adecf0ae1b
2 ficheiros alterados com 2 adições e 5 exclusões
  1. 1 0
      block/train_with_watermark.py
  2. 1 5
      predict_pt_embed.py

+ 1 - 0
block/train_with_watermark.py

@@ -25,6 +25,7 @@ def train_embed(args, data_dict, model_dict, loss, secret):
         if isinstance(module, nn.Conv2d):
             conv_list.append(module)
     conv_list = conv_list[0:2]
+    model_dict['enc_layers'] = conv_list
     encoder = ModelEncoder(layers=conv_list, secret=secret, key_path=args.key_path, device=args.device)
     # 学习率
     optimizer = adam(args.regularization, args.r_value, model.parameters(), lr=args.lr_start, betas=(0.937, 0.999))

+ 1 - 5
predict_pt_embed.py

@@ -40,11 +40,7 @@ def predict_pt(args):
     print(f'| 模型加载成功:{args.model_path} | epoch:{epoch} | accuracy:{accuracy} |')
 
     # 选择加密层并初始化白盒水印编码器
-    conv_list = []
-    for module in model.modules():
-        if isinstance(module, nn.Conv2d):
-            conv_list.append(module)
-    conv_list = conv_list[0:2]
+    conv_list = model_dict['enc_layers']
     decoder = ModelDecoder(layers=conv_list, key_path=args.key_path, device=args.device)  # 传入待嵌入的卷积层列表,编码器生成密钥路径,运算设备(cuda/cpu)
     secret_extract = decoder.decode()  # 提取密码标签
     result = secret_get.verify_secret(secret_extract)