predict.py 7.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139
  1. #----------------------------------------------------#
  2. # 将单张图片预测、摄像头检测和FPS测试功能
  3. # 整合到了一个py文件中,通过指定mode进行模式的修改。
  4. #----------------------------------------------------#
  5. import time
  6. import cv2
  7. import numpy as np
  8. from PIL import Image
  9. import os
  10. from tqdm import tqdm
  11. from frcnn import FRCNN
  12. if __name__ == "__main__":
  13. frcnn = FRCNN()
  14. #----------------------------------------------------------------------------------------------------------#
  15. # mode用于指定测试的模式:
  16. # 'predict' 表示单张图片预测,如果想对预测过程进行修改,如保存图片,截取对象等,可以先看下方详细的注释
  17. # 'video' 表示视频检测,可调用摄像头或者视频进行检测,详情查看下方注释。
  18. # 'fps' 表示测试fps,使用的图片是img里面的street.jpg,详情查看下方注释。
  19. # 'dir_predict' 表示遍历文件夹进行检测并保存。默认遍历img文件夹,保存img_out文件夹,详情查看下方注释。
  20. #----------------------------------------------------------------------------------------------------------#
  21. mode = "dir_predict"
  22. #-------------------------------------------------------------------------#
  23. # crop 指定了是否在单张图片预测后对目标进行截取
  24. # count 指定了是否进行目标的计数
  25. # crop、count仅在mode='predict'时有效
  26. #-------------------------------------------------------------------------#
  27. crop = False
  28. count = False
  29. #----------------------------------------------------------------------------------------------------------#
  30. # video_path 用于指定视频的路径,当video_path=0时表示检测摄像头
  31. # 想要检测视频,则设置如video_path = "xxx.mp4"即可,代表读取出根目录下的xxx.mp4文件。
  32. # video_save_path 表示视频保存的路径,当video_save_path=""时表示不保存
  33. # 想要保存视频,则设置如video_save_path = "yyy.mp4"即可,代表保存为根目录下的yyy.mp4文件。
  34. # video_fps 用于保存的视频的fps
  35. #
  36. # video_path、video_save_path和video_fps仅在mode='video'时有效
  37. # 保存视频时需要ctrl+c退出或者运行到最后一帧才会完成完整的保存步骤。
  38. #----------------------------------------------------------------------------------------------------------#
  39. video_path = 0
  40. video_save_path = ""
  41. video_fps = 25.0
  42. #----------------------------------------------------------------------------------------------------------#
  43. # test_interval 用于指定测量fps的时候,图片检测的次数。理论上test_interval越大,fps越准确。
  44. # fps_image_path 用于指定测试的fps图片
  45. #
  46. # test_interval和fps_image_path仅在mode='fps'有效
  47. #----------------------------------------------------------------------------------------------------------#
  48. test_interval = 100
  49. fps_image_path = "img/street.jpg"
  50. #-------------------------------------------------------------------------#
  51. # dir_origin_path 指定了用于检测的图片的文件夹路径
  52. # dir_save_path 指定了检测完图片的保存路径
  53. #
  54. # dir_origin_path和dir_save_path仅在mode='dir_predict'时有效
  55. #-------------------------------------------------------------------------#
  56. # dir_origin_path = "/root/autodl-tmp/faster-rcnn-pytorch-master/VOCdevkit/VOC2007_wm_val/JPEGImages/"
  57. # dir_save_path = "/root/autodl-tmp/faster-rcnn-pytorch-master/VOCdevkit/VOC2007_wm_val/JPEGImages_out/"
  58. dir_origin_path = "./img"
  59. dir_save_path = "./img_out2"
  60. if mode == "predict":
  61. '''
  62. 1、该代码无法直接进行批量预测,如果想要批量预测,可以利用os.listdir()遍历文件夹,利用Image.open打开图片文件进行预测。
  63. 具体流程可以参考get_dr_txt.py,在get_dr_txt.py即实现了遍历还实现了目标信息的保存。
  64. 2、如果想要进行检测完的图片的保存,利用r_image.save("img.jpg")即可保存,直接在predict.py里进行修改即可。
  65. 3、如果想要获得预测框的坐标,可以进入frcnn.detect_image函数,在绘图部分读取top,left,bottom,right这四个值。
  66. 4、如果想要利用预测框截取下目标,可以进入frcnn.detect_image函数,在绘图部分利用获取到的top,left,bottom,right这四个值
  67. 在原图上利用矩阵的方式进行截取。
  68. 5、如果想要在预测图上写额外的字,比如检测到的特定目标的数量,可以进入frcnn.detect_image函数,在绘图部分对predicted_class进行判断,
  69. 比如判断if predicted_class == 'car': 即可判断当前目标是否为车,然后记录数量即可。利用draw.text即可写字。
  70. '''
  71. while True:
  72. img = input('Input image filename:')
  73. try:
  74. image = Image.open(img)
  75. except:
  76. print('Open Error! Try again!')
  77. else:
  78. r_image = frcnn.detect_image(image, crop = crop, count = count)
  79. r_image.show()
  80. elif mode == "video":
  81. capture = cv2.VideoCapture(video_path)
  82. if video_save_path != "":
  83. fourcc = cv2.VideoWriter_fourcc(*'XVID')
  84. size = (int(capture.get(cv2.CAP_PROP_FRAME_WIDTH)), int(capture.get(cv2.CAP_PROP_FRAME_HEIGHT)))
  85. out = cv2.VideoWriter(video_save_path, fourcc, video_fps, size)
  86. fps = 0.0
  87. while(True):
  88. t1 = time.time()
  89. # 读取某一帧
  90. ref,frame = capture.read()
  91. # 格式转变,BGRtoRGB
  92. frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
  93. # 转变成Image
  94. frame = Image.fromarray(np.uint8(frame))
  95. # 进行检测
  96. frame = np.array(frcnn.detect_image(frame))
  97. # RGBtoBGR满足opencv显示格式
  98. frame = cv2.cvtColor(frame, cv2.COLOR_RGB2BGR)
  99. fps = ( fps + (1. / (time.time() - t1)) ) / 2
  100. print("fps = %.2f"%(fps))
  101. frame = cv2.putText(frame, "fps = %.2f"%(fps), (0, 40), cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 255, 0), 2)
  102. cv2.imshow("video", frame)
  103. c = cv2.waitKey(1) & 0xff
  104. if video_save_path != "":
  105. out.write(frame)
  106. if c == 27:
  107. capture.release()
  108. break
  109. capture.release()
  110. out.release()
  111. cv2.destroyAllWindows()
  112. elif mode == "fps":
  113. img = Image.open(fps_image_path)
  114. tact_time = frcnn.get_FPS(img, test_interval)
  115. print(str(tact_time) + ' seconds, ' + str(1 / tact_time) + 'FPS, @batch_size 1')
  116. elif mode == "dir_predict":
  117. img_names = os.listdir(dir_origin_path)
  118. for img_name in tqdm(img_names):
  119. if img_name.lower().endswith(('.bmp', '.dib', '.png', '.jpg', '.jpeg', '.pbm', '.pgm', '.ppm', '.tif', '.tiff')):
  120. image_path = os.path.join(dir_origin_path, img_name)
  121. image = Image.open(image_path)
  122. r_image = frcnn.detect_image(image)
  123. if not os.path.exists(dir_save_path):
  124. os.makedirs(dir_save_path)
  125. r_image.save(os.path.join(dir_save_path, img_name.replace(".jpg", ".png")), quality = 95, subsampling = 0)
  126. else:
  127. raise AssertionError("Please specify the correct mode: 'predict', 'video', 'fps' or 'dir_predict'.")