Просмотр исходного кода

修改模型保存,精简权重文件

liyan 1 год назад
Родитель
Сommit
320d92d446
1 измененных файлов с 2 добавлено и 2 удалено
  1. 2 2
      block/train_get.py

+ 2 - 2
block/train_get.py

@@ -135,11 +135,11 @@ def train_get(args, model_dict, loss):
             model_dict['train_loss'] = train_loss
             model_dict['train_loss'] = train_loss
             model_dict['val_loss'] = val_loss
             model_dict['val_loss'] = val_loss
             model_dict['val_accuracy'] = accuracy
             model_dict['val_accuracy'] = accuracy
-            torch.save(model_dict, args.save_path_last if not args.prune else 'prune_last.pt')  # 保存最后一次训练的模型
+            torch.save(model.state_dict(), args.save_path_last if not args.prune else 'prune_last.pt')  # 保存最后一次训练的模型
             if accuracy > 0.5 and accuracy > model_dict['standard']:
             if accuracy > 0.5 and accuracy > model_dict['standard']:
                 model_dict['standard'] = accuracy
                 model_dict['standard'] = accuracy
                 save_path = args.save_path if not args.prune else args.prune_save
                 save_path = args.save_path if not args.prune else args.prune_save
-                torch.save(model_dict, save_path)  # 保存最佳模型
+                torch.save(model.state_dict(), save_path)  # 保存最佳模型
                 print(f'| 保存最佳模型:{save_path} | accuracy:{accuracy:.4f} |')
                 print(f'| 保存最佳模型:{save_path} | accuracy:{accuracy:.4f} |')
             # wandb
             # wandb
             # if args.wandb:
             # if args.wandb: