|
@@ -1,18 +1,20 @@
|
|
import os
|
|
import os
|
|
|
|
+import shutil
|
|
|
|
+
|
|
import torch
|
|
import torch
|
|
import argparse
|
|
import argparse
|
|
from model.layer import deploy
|
|
from model.layer import deploy
|
|
|
|
|
|
# -------------------------------------------------------------------------------------------------------------------- #
|
|
# -------------------------------------------------------------------------------------------------------------------- #
|
|
parser = argparse.ArgumentParser(description='|将pt模型转为onnx,同时导出类别信息|')
|
|
parser = argparse.ArgumentParser(description='|将pt模型转为onnx,同时导出类别信息|')
|
|
-parser.add_argument('--weight', default='best.pt', type=str, help='|模型位置|')
|
|
|
|
|
|
+parser.add_argument('--weight', default='checkpoints/Alexnet/wm_embed/last.pt', type=str, help='|模型位置|')
|
|
parser.add_argument('--input_size', default=32, type=int, help='|输入图片大小|')
|
|
parser.add_argument('--input_size', default=32, type=int, help='|输入图片大小|')
|
|
parser.add_argument('--normalization', default='sigmoid', type=str, help='|选择sigmoid或softmax归一化,单类别一定要选sigmoid|')
|
|
parser.add_argument('--normalization', default='sigmoid', type=str, help='|选择sigmoid或softmax归一化,单类别一定要选sigmoid|')
|
|
parser.add_argument('--batch', default=0, type=int, help='|输入图片批量,0为动态|')
|
|
parser.add_argument('--batch', default=0, type=int, help='|输入图片批量,0为动态|')
|
|
-parser.add_argument('--sim', default=True, type=bool, help='|使用onnxsim压缩简化模型|')
|
|
|
|
|
|
+parser.add_argument('--sim', default=False, type=bool, help='|使用onnxsim压缩简化模型|')
|
|
parser.add_argument('--device', default='cuda', type=str, help='|在哪个设备上加载模型|')
|
|
parser.add_argument('--device', default='cuda', type=str, help='|在哪个设备上加载模型|')
|
|
parser.add_argument('--float16', default=True, type=bool, help='|转换的onnx模型数据类型,需要GPU,False时为float32|')
|
|
parser.add_argument('--float16', default=True, type=bool, help='|转换的onnx模型数据类型,需要GPU,False时为float32|')
|
|
-parser.add_argument('--save_path', default='best.onnx', type=str, help='|移动存储位置|')
|
|
|
|
|
|
+parser.add_argument('--save_path', default='checkpoints/Alexnet/wm_embed', type=str, help='|移动存储位置|')
|
|
args = parser.parse_args()
|
|
args = parser.parse_args()
|
|
args.weight = args.weight.split('.')[0] + '.pt'
|
|
args.weight = args.weight.split('.')[0] + '.pt'
|
|
args.save_name = args.weight.split('.')[0] + '.onnx'
|
|
args.save_name = args.weight.split('.')[0] + '.onnx'
|
|
@@ -47,6 +49,6 @@ def export_onnx():
|
|
if __name__ == '__main__':
|
|
if __name__ == '__main__':
|
|
export_onnx()
|
|
export_onnx()
|
|
# 移动生成的 ONNX 文件到指定文件夹
|
|
# 移动生成的 ONNX 文件到指定文件夹
|
|
- destination_folder = args.save_path
|
|
|
|
- shutil.move(args.save_name, os.path.join(destination_folder, args.save_name))
|
|
|
|
- print(f'| 已将 {args.save_name} 移动到 {destination_folder} 中 |')
|
|
|
|
|
|
+ # destination_folder = args.save_path
|
|
|
|
+ # shutil.move(args.save_name, destination_folder)
|
|
|
|
+ # print(f'| 已将 {args.save_name} 移动到 {destination_folder} 中 |')
|