Browse Source

修改模型保存,精简权重文件

liyan 1 year ago
parent
commit
320d92d446
1 changed files with 2 additions and 2 deletions
  1. 2 2
      block/train_get.py

+ 2 - 2
block/train_get.py

@@ -135,11 +135,11 @@ def train_get(args, model_dict, loss):
             model_dict['train_loss'] = train_loss
             model_dict['val_loss'] = val_loss
             model_dict['val_accuracy'] = accuracy
-            torch.save(model_dict, args.save_path_last if not args.prune else 'prune_last.pt')  # 保存最后一次训练的模型
+            torch.save(model.state_dict(), args.save_path_last if not args.prune else 'prune_last.pt')  # 保存最后一次训练的模型
             if accuracy > 0.5 and accuracy > model_dict['standard']:
                 model_dict['standard'] = accuracy
                 save_path = args.save_path if not args.prune else args.prune_save
-                torch.save(model_dict, save_path)  # 保存最佳模型
+                torch.save(model.state_dict(), save_path)  # 保存最佳模型
                 print(f'| 保存最佳模型:{save_path} | accuracy:{accuracy:.4f} |')
             # wandb
             # if args.wandb: