|
@@ -135,11 +135,11 @@ def train_get(args, model_dict, loss):
|
|
|
model_dict['train_loss'] = train_loss
|
|
|
model_dict['val_loss'] = val_loss
|
|
|
model_dict['val_accuracy'] = accuracy
|
|
|
- torch.save(model_dict, args.save_path_last if not args.prune else 'prune_last.pt') # 保存最后一次训练的模型
|
|
|
+ torch.save(model.state_dict(), args.save_path_last if not args.prune else 'prune_last.pt') # 保存最后一次训练的模型
|
|
|
if accuracy > 0.5 and accuracy > model_dict['standard']:
|
|
|
model_dict['standard'] = accuracy
|
|
|
save_path = args.save_path if not args.prune else args.prune_save
|
|
|
- torch.save(model_dict, save_path) # 保存最佳模型
|
|
|
+ torch.save(model.state_dict(), save_path) # 保存最佳模型
|
|
|
print(f'| 保存最佳模型:{save_path} | accuracy:{accuracy:.4f} |')
|
|
|
# wandb
|
|
|
# if args.wandb:
|