ソースを参照

修改模型预测脚本

liyan 1 年間 前
コミット
71d99df948
2 ファイル変更6 行追加3 行削除
  1. 6 3
      predict_pt.py
  2. 0 0
      result/.keep

+ 6 - 3
predict_pt.py

@@ -8,10 +8,13 @@ import numpy as np
 import albumentations
 from model.layer import deploy
 
+
+# 获取当前文件路径
+pwd = os.getcwd()
 # -------------------------------------------------------------------------------------------------------------------- #
 parser = argparse.ArgumentParser(description='|pt模型推理|')
-parser.add_argument('--model_path', default=r'D:\桌面\ObjectDetection-main\last.pt', type=str, help='|pt模型位置|')
-parser.add_argument('--image_path', default=r'D:\桌面\ObjectDetection-main\datasets\coco_wm\images\test2017_wm', type=str, help='|图片文件夹位置|')
+parser.add_argument('--model_path', default=f'{pwd}/last.pt', type=str, help='|pt模型位置|')
+parser.add_argument('--image_path', default=r'./datasets/coco_wm/images/test2017', type=str, help='|图片文件夹位置|')
 parser.add_argument('--input_size', default=640, type=int, help='|模型输入图片大小|')
 parser.add_argument('--batch', default=1, type=int, help='|输入图片批量|')
 parser.add_argument('--confidence_threshold', default=0.35, type=float, help='|置信筛选度阈值(>阈值留下)|')
@@ -68,7 +71,7 @@ def draw(image, frame, cls, name):  # 输入(x_min,y_min,w,h)真实坐标
         b = (int(frame[i][0] + frame[i][2]), int(frame[i][1] + frame[i][3]))
         cv2.rectangle(image, a, b, color=(0, 255, 0), thickness=2)
         cv2.putText(image, 'class:' + str(cls[i]), a, cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 255, 0), 1)
-    cv2.imwrite('save_' + name, image)
+    cv2.imwrite('./result/save_' + name, image)
     print(f'| {name}: save_{name} |')
 
 

+ 0 - 0
result/.keep