|
@@ -15,7 +15,7 @@ from block.model_ema import model_ema
|
|
|
from block.lr_get import adam, lr_adjust
|
|
|
|
|
|
|
|
|
-def train_embed(args, data_dict, model_dict, loss, secret):
|
|
|
+def train_embed(args, model_dict, loss, secret):
|
|
|
# 加载模型
|
|
|
model = model_dict['model'].to(args.device, non_blocking=args.latch)
|
|
|
print(model)
|
|
@@ -27,18 +27,7 @@ def train_embed(args, data_dict, model_dict, loss, secret):
|
|
|
conv_list = conv_list[0:2]
|
|
|
model_dict['enc_layers'] = conv_list
|
|
|
encoder = ModelEncoder(layers=conv_list, secret=secret, key_path=args.key_path, device=args.device)
|
|
|
- # 学习率
|
|
|
- optimizer = adam(args.regularization, args.r_value, model.parameters(), lr=args.lr_start, betas=(0.937, 0.999))
|
|
|
- optimizer.load_state_dict(model_dict['optimizer_state_dict']) if model_dict['optimizer_state_dict'] else None
|
|
|
- step_epoch = len(data_dict['train']) // args.batch // args.device_number * args.device_number # 每轮的步数
|
|
|
- print(len(data_dict['train']) // args.batch)
|
|
|
- print(step_epoch)
|
|
|
- optimizer_adjust = lr_adjust(args, step_epoch, model_dict['epoch_finished']) # 学习率调整函数
|
|
|
- optimizer = optimizer_adjust(optimizer) # 学习率初始化
|
|
|
- # 使用平均指数移动(EMA)调整参数(不能将ema放到args中,否则会导致模型保存出错)
|
|
|
- ema = model_ema(model) if args.ema else None
|
|
|
- if args.ema:
|
|
|
- ema.updates = model_dict['ema_updates']
|
|
|
+
|
|
|
# 数据集
|
|
|
print("加载训练集至内存中...")
|
|
|
train_transform = transforms.Compose([
|
|
@@ -65,6 +54,21 @@ def train_embed(args, data_dict, model_dict, loss, secret):
|
|
|
val_dataloader = torch.utils.data.DataLoader(val_dataset, batch_size=val_batch, shuffle=False,
|
|
|
drop_last=False, pin_memory=args.latch, num_workers=args.num_worker,
|
|
|
sampler=val_sampler)
|
|
|
+
|
|
|
+ # 学习率
|
|
|
+ optimizer = adam(args.regularization, args.r_value, model.parameters(), lr=args.lr_start, betas=(0.937, 0.999))
|
|
|
+ optimizer.load_state_dict(model_dict['optimizer_state_dict']) if model_dict['optimizer_state_dict'] else None
|
|
|
+ train_len = train_dataset.__len__()
|
|
|
+ step_epoch = train_len // args.batch // args.device_number * args.device_number # 每轮的步数
|
|
|
+ print(train_len // args.batch)
|
|
|
+ print(step_epoch)
|
|
|
+ optimizer_adjust = lr_adjust(args, step_epoch, model_dict['epoch_finished']) # 学习率调整函数
|
|
|
+ optimizer = optimizer_adjust(optimizer) # 学习率初始化
|
|
|
+ # 使用平均指数移动(EMA)调整参数(不能将ema放到args中,否则会导致模型保存出错)
|
|
|
+ ema = model_ema(model) if args.ema else None
|
|
|
+ if args.ema:
|
|
|
+ ema.updates = model_dict['ema_updates']
|
|
|
+
|
|
|
# 分布式初始化
|
|
|
model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.local_rank],
|
|
|
output_device=args.local_rank) if args.distributed else model
|
|
@@ -140,15 +144,13 @@ def train_embed(args, data_dict, model_dict, loss, secret):
|
|
|
torch.cuda.empty_cache()
|
|
|
# 验证
|
|
|
if args.local_rank == 0: # 分布式时只验证一次
|
|
|
- val_loss, accuracy = val_get(args, val_dataloader, model, loss, ema,
|
|
|
- len(data_dict['test']))
|
|
|
+ val_loss, accuracy = val_get(args, val_dataloader, model, loss, ema, val_dataset.__len__())
|
|
|
# 保存
|
|
|
if args.local_rank == 0: # 分布式时只保存一次
|
|
|
model_dict['model'] = model.module if args.distributed else model
|
|
|
model_dict['epoch_finished'] = epoch
|
|
|
model_dict['optimizer_state_dict'] = optimizer.state_dict()
|
|
|
model_dict['ema_updates'] = ema.updates if args.ema else model_dict['ema_updates']
|
|
|
- model_dict['class'] = data_dict['class']
|
|
|
model_dict['train_loss'] = train_loss
|
|
|
model_dict['val_loss'] = val_loss
|
|
|
model_dict['val_accuracy'] = accuracy
|