predict_onnx.py 4.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172
  1. import os
  2. import cv2
  3. import time
  4. import argparse
  5. import onnxruntime
  6. import numpy as np
  7. import albumentations
  8. # -------------------------------------------------------------------------------------------------------------------- #
  9. parser = argparse.ArgumentParser(description='|onnx模型推理|')
  10. parser.add_argument('--model_path', default='best.onnx', type=str, help='|onnx模型位置|')
  11. parser.add_argument('--data_path', default='image', type=str, help='|图片文件夹位置|')
  12. parser.add_argument('--input_size', default=320, type=int, help='|模型输入图片大小,要与导出的模型对应|')
  13. parser.add_argument('--batch', default=1, type=int, help='|输入图片批量,要与导出的模型对应|')
  14. parser.add_argument('--device', default='cuda', type=str, help='|推理设备|')
  15. parser.add_argument('--float16', default=True, type=bool, help='|推理数据类型,要与导出的模型对应,False时为float32|')
  16. args, _ = parser.parse_known_args() # 防止传入参数冲突,替代args = parser.parse_args()
  17. # -------------------------------------------------------------------------------------------------------------------- #
  18. assert os.path.exists(args.model_path), f'! model_path不存在:{args.model_path} !'
  19. assert os.path.exists(args.data_path), f'! data_path不存在:{args.data_path} !'
  20. # -------------------------------------------------------------------------------------------------------------------- #
  21. def predict_onnx(args):
  22. # 加载模型
  23. provider = 'CUDAExecutionProvider' if args.device.lower() in ['gpu', 'cuda'] else 'CPUExecutionProvider'
  24. model = onnxruntime.InferenceSession(args.model_path, providers=[provider]) # 加载模型和框架
  25. input_name = model.get_inputs()[0].name # 获取输入名称
  26. output_name = model.get_outputs()[0].name # 获取输出名称
  27. print(f'| 模型加载成功:{args.model_path} |')
  28. # 加载数据
  29. start_time = time.time()
  30. transform = albumentations.Compose([
  31. albumentations.LongestMaxSize(args.input_size),
  32. albumentations.PadIfNeeded(min_height=args.input_size, min_width=args.input_size,
  33. border_mode=cv2.BORDER_CONSTANT, value=(128, 128, 128))])
  34. image_dir = sorted(os.listdir(args.data_path))
  35. image_all = np.zeros((len(image_dir), args.input_size, args.input_size, 3)).astype(
  36. np.float16 if args.float16 else np.float32)
  37. for i in range(len(image_dir)):
  38. image = cv2.imread(args.data_path + '/' + image_dir[i])
  39. image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) # 转为RGB通道
  40. image = transform(image=image)['image'] # 缩放和填充图片(归一化、减均值、除以方差、调维度等在模型中完成)
  41. image_all[i] = image
  42. end_time = time.time()
  43. print('| 数据加载成功:{} 每张耗时:{:.4f} |'.format(len(image_all), (end_time - start_time) / len(image_all)))
  44. # 推理
  45. start_time = time.time()
  46. result = []
  47. n = len(image_all) // args.batch
  48. if n > 0: # 如果图片数量>=批量(分批预测)
  49. for i in range(n):
  50. batch = image_all[i * args.batch:(i + 1) * args.batch]
  51. pred_batch = model.run([output_name], {input_name: batch})
  52. result.extend(pred_batch[0].tolist())
  53. if len(image_all) % args.batch > 0: # 如果图片数量没有刚好满足批量
  54. batch = image_all[(i + 1) * args.batch:]
  55. pred_batch = model.run([output_name], {input_name: batch})
  56. result.extend(pred_batch[0].tolist())
  57. else: # 如果图片数量<批量(直接预测)
  58. batch = image_all
  59. pred_batch = model.run([output_name], {input_name: batch})
  60. result.extend(pred_batch[0].tolist())
  61. for i in range(len(result)):
  62. result[i] = [round(result[i][_], 2) for _ in range(len(result[i]))]
  63. print(f'| {image_dir[i]}:{result[i]} |')
  64. end_time = time.time()
  65. print('| 数据:{} 批量:{} 每张耗时:{:.4f} |'.format(len(image_all), args.batch, (end_time - start_time) / len(image_all)))
  66. if __name__ == '__main__':
  67. predict_onnx(args)