Browse Source

修改onnx文件导出逻辑

liyan 10 months ago
parent
commit
c129292a0f
1 changed files with 8 additions and 6 deletions
  1. 8 6
      export_onnx.py

+ 8 - 6
export_onnx.py

@@ -1,18 +1,20 @@
 import os
 import os
+import shutil
+
 import torch
 import torch
 import argparse
 import argparse
 from model.layer import deploy
 from model.layer import deploy
 
 
 # -------------------------------------------------------------------------------------------------------------------- #
 # -------------------------------------------------------------------------------------------------------------------- #
 parser = argparse.ArgumentParser(description='|将pt模型转为onnx,同时导出类别信息|')
 parser = argparse.ArgumentParser(description='|将pt模型转为onnx,同时导出类别信息|')
-parser.add_argument('--weight', default='best.pt', type=str, help='|模型位置|')
+parser.add_argument('--weight', default='checkpoints/Alexnet/wm_embed/last.pt', type=str, help='|模型位置|')
 parser.add_argument('--input_size', default=32, type=int, help='|输入图片大小|')
 parser.add_argument('--input_size', default=32, type=int, help='|输入图片大小|')
 parser.add_argument('--normalization', default='sigmoid', type=str, help='|选择sigmoid或softmax归一化,单类别一定要选sigmoid|')
 parser.add_argument('--normalization', default='sigmoid', type=str, help='|选择sigmoid或softmax归一化,单类别一定要选sigmoid|')
 parser.add_argument('--batch', default=0, type=int, help='|输入图片批量,0为动态|')
 parser.add_argument('--batch', default=0, type=int, help='|输入图片批量,0为动态|')
-parser.add_argument('--sim', default=True, type=bool, help='|使用onnxsim压缩简化模型|')
+parser.add_argument('--sim', default=False, type=bool, help='|使用onnxsim压缩简化模型|')
 parser.add_argument('--device', default='cuda', type=str, help='|在哪个设备上加载模型|')
 parser.add_argument('--device', default='cuda', type=str, help='|在哪个设备上加载模型|')
 parser.add_argument('--float16', default=True, type=bool, help='|转换的onnx模型数据类型,需要GPU,False时为float32|')
 parser.add_argument('--float16', default=True, type=bool, help='|转换的onnx模型数据类型,需要GPU,False时为float32|')
-parser.add_argument('--save_path', default='best.onnx', type=str, help='|移动存储位置|')
+parser.add_argument('--save_path', default='checkpoints/Alexnet/wm_embed', type=str, help='|移动存储位置|')
 args = parser.parse_args()
 args = parser.parse_args()
 args.weight = args.weight.split('.')[0] + '.pt'
 args.weight = args.weight.split('.')[0] + '.pt'
 args.save_name = args.weight.split('.')[0] + '.onnx'
 args.save_name = args.weight.split('.')[0] + '.onnx'
@@ -47,6 +49,6 @@ def export_onnx():
 if __name__ == '__main__':
 if __name__ == '__main__':
     export_onnx()
     export_onnx()
     # 移动生成的 ONNX 文件到指定文件夹
     # 移动生成的 ONNX 文件到指定文件夹
-    destination_folder = args.save_path
-    shutil.move(args.save_name, os.path.join(destination_folder, args.save_name))
-    print(f'| 已将 {args.save_name} 移动到 {destination_folder} 中 |')
+    # destination_folder = args.save_path
+    # shutil.move(args.save_name, destination_folder)
+    # print(f'| 已将 {args.save_name} 移动到 {destination_folder} 中 |')