|
@@ -139,7 +139,7 @@ def train_get(args, data_dict, model_dict, loss):
|
|
|
model_dict['train_loss'] = train_loss
|
|
|
model_dict['val_loss'] = val_loss
|
|
|
model_dict['val_m_ap'] = m_ap
|
|
|
- torch.save(model_dict, 'last.pt' if not args.prune else 'prune_last.pt') # 保存最后一次训练的模型
|
|
|
+ torch.save(model_dict, args.weight if not args.prune else args.prune_save) # 保存最后一次训练的模型
|
|
|
if m_ap > 0.1 and m_ap > model_dict['standard']:
|
|
|
model_dict['standard'] = m_ap
|
|
|
save_path = args.save_path if not args.prune else args.prune_save
|