|
@@ -25,6 +25,7 @@ def train_embed(args, data_dict, model_dict, loss, secret):
|
|
|
if isinstance(module, nn.Conv2d):
|
|
|
conv_list.append(module)
|
|
|
conv_list = conv_list[0:2]
|
|
|
+ model_dict['enc_layers'] = conv_list
|
|
|
encoder = ModelEncoder(layers=conv_list, secret=secret, key_path=args.key_path, device=args.device)
|
|
|
# 学习率
|
|
|
optimizer = adam(args.regularization, args.r_value, model.parameters(), lr=args.lr_start, betas=(0.937, 0.999))
|