|
@@ -25,7 +25,7 @@ def train_embed(args, model_dict, loss, secret):
|
|
|
if isinstance(module, nn.Conv2d):
|
|
|
conv_list.append(module)
|
|
|
conv_list = conv_list[0:2]
|
|
|
- model_dict['enc_layers'] = conv_list
|
|
|
+ model_dict['enc_layers'] = conv_list # 将加密层保存至权重文件中
|
|
|
encoder = ModelEncoder(layers=conv_list, secret=secret, key_path=args.key_path, device=args.device)
|
|
|
|
|
|
# 数据集
|