utils_map.py 36 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923
  1. import glob
  2. import json
  3. import math
  4. import operator
  5. import os
  6. import shutil
  7. import sys
  8. try:
  9. from pycocotools.coco import COCO
  10. from pycocotools.cocoeval import COCOeval
  11. except:
  12. pass
  13. import cv2
  14. import matplotlib
  15. matplotlib.use('Agg')
  16. from matplotlib import pyplot as plt
  17. import numpy as np
  18. '''
  19. 0,0 ------> x (width)
  20. |
  21. | (Left,Top)
  22. | *_________
  23. | | |
  24. | |
  25. y |_________|
  26. (height) *
  27. (Right,Bottom)
  28. '''
  29. def log_average_miss_rate(precision, fp_cumsum, num_images):
  30. """
  31. log-average miss rate:
  32. Calculated by averaging miss rates at 9 evenly spaced FPPI points
  33. between 10e-2 and 10e0, in log-space.
  34. output:
  35. lamr | log-average miss rate
  36. mr | miss rate
  37. fppi | false positives per image
  38. references:
  39. [1] Dollar, Piotr, et al. "Pedestrian Detection: An Evaluation of the
  40. State of the Art." Pattern Analysis and Machine Intelligence, IEEE
  41. Transactions on 34.4 (2012): 743 - 761.
  42. """
  43. if precision.size == 0:
  44. lamr = 0
  45. mr = 1
  46. fppi = 0
  47. return lamr, mr, fppi
  48. fppi = fp_cumsum / float(num_images)
  49. mr = (1 - precision)
  50. fppi_tmp = np.insert(fppi, 0, -1.0)
  51. mr_tmp = np.insert(mr, 0, 1.0)
  52. ref = np.logspace(-2.0, 0.0, num = 9)
  53. for i, ref_i in enumerate(ref):
  54. j = np.where(fppi_tmp <= ref_i)[-1][-1]
  55. ref[i] = mr_tmp[j]
  56. lamr = math.exp(np.mean(np.log(np.maximum(1e-10, ref))))
  57. return lamr, mr, fppi
  58. """
  59. throw error and exit
  60. """
  61. def error(msg):
  62. print(msg)
  63. sys.exit(0)
  64. """
  65. check if the number is a float between 0.0 and 1.0
  66. """
  67. def is_float_between_0_and_1(value):
  68. try:
  69. val = float(value)
  70. if val > 0.0 and val < 1.0:
  71. return True
  72. else:
  73. return False
  74. except ValueError:
  75. return False
  76. """
  77. Calculate the AP given the recall and precision array
  78. 1st) We compute a version of the measured precision/recall curve with
  79. precision monotonically decreasing
  80. 2nd) We compute the AP as the area under this curve by numerical integration.
  81. """
  82. def voc_ap(rec, prec):
  83. """
  84. --- Official matlab code VOC2012---
  85. mrec=[0 ; rec ; 1];
  86. mpre=[0 ; prec ; 0];
  87. for i=numel(mpre)-1:-1:1
  88. mpre(i)=max(mpre(i),mpre(i+1));
  89. end
  90. i=find(mrec(2:end)~=mrec(1:end-1))+1;
  91. ap=sum((mrec(i)-mrec(i-1)).*mpre(i));
  92. """
  93. rec.insert(0, 0.0) # insert 0.0 at begining of list
  94. rec.append(1.0) # insert 1.0 at end of list
  95. mrec = rec[:]
  96. prec.insert(0, 0.0) # insert 0.0 at begining of list
  97. prec.append(0.0) # insert 0.0 at end of list
  98. mpre = prec[:]
  99. """
  100. This part makes the precision monotonically decreasing
  101. (goes from the end to the beginning)
  102. matlab: for i=numel(mpre)-1:-1:1
  103. mpre(i)=max(mpre(i),mpre(i+1));
  104. """
  105. for i in range(len(mpre)-2, -1, -1):
  106. mpre[i] = max(mpre[i], mpre[i+1])
  107. """
  108. This part creates a list of indexes where the recall changes
  109. matlab: i=find(mrec(2:end)~=mrec(1:end-1))+1;
  110. """
  111. i_list = []
  112. for i in range(1, len(mrec)):
  113. if mrec[i] != mrec[i-1]:
  114. i_list.append(i) # if it was matlab would be i + 1
  115. """
  116. The Average Precision (AP) is the area under the curve
  117. (numerical integration)
  118. matlab: ap=sum((mrec(i)-mrec(i-1)).*mpre(i));
  119. """
  120. ap = 0.0
  121. for i in i_list:
  122. ap += ((mrec[i]-mrec[i-1])*mpre[i])
  123. return ap, mrec, mpre
  124. """
  125. Convert the lines of a file to a list
  126. """
  127. def file_lines_to_list(path):
  128. # open txt file lines to a list
  129. with open(path) as f:
  130. content = f.readlines()
  131. # remove whitespace characters like `\n` at the end of each line
  132. content = [x.strip() for x in content]
  133. return content
  134. """
  135. Draws text in image
  136. """
  137. def draw_text_in_image(img, text, pos, color, line_width):
  138. font = cv2.FONT_HERSHEY_PLAIN
  139. fontScale = 1
  140. lineType = 1
  141. bottomLeftCornerOfText = pos
  142. cv2.putText(img, text,
  143. bottomLeftCornerOfText,
  144. font,
  145. fontScale,
  146. color,
  147. lineType)
  148. text_width, _ = cv2.getTextSize(text, font, fontScale, lineType)[0]
  149. return img, (line_width + text_width)
  150. """
  151. Plot - adjust axes
  152. """
  153. def adjust_axes(r, t, fig, axes):
  154. # get text width for re-scaling
  155. bb = t.get_window_extent(renderer=r)
  156. text_width_inches = bb.width / fig.dpi
  157. # get axis width in inches
  158. current_fig_width = fig.get_figwidth()
  159. new_fig_width = current_fig_width + text_width_inches
  160. propotion = new_fig_width / current_fig_width
  161. # get axis limit
  162. x_lim = axes.get_xlim()
  163. axes.set_xlim([x_lim[0], x_lim[1]*propotion])
  164. """
  165. Draw plot using Matplotlib
  166. """
  167. def draw_plot_func(dictionary, n_classes, window_title, plot_title, x_label, output_path, to_show, plot_color, true_p_bar):
  168. # sort the dictionary by decreasing value, into a list of tuples
  169. sorted_dic_by_value = sorted(dictionary.items(), key=operator.itemgetter(1))
  170. # unpacking the list of tuples into two lists
  171. sorted_keys, sorted_values = zip(*sorted_dic_by_value)
  172. #
  173. if true_p_bar != "":
  174. """
  175. Special case to draw in:
  176. - green -> TP: True Positives (object detected and matches ground-truth)
  177. - red -> FP: False Positives (object detected but does not match ground-truth)
  178. - orange -> FN: False Negatives (object not detected but present in the ground-truth)
  179. """
  180. fp_sorted = []
  181. tp_sorted = []
  182. for key in sorted_keys:
  183. fp_sorted.append(dictionary[key] - true_p_bar[key])
  184. tp_sorted.append(true_p_bar[key])
  185. plt.barh(range(n_classes), fp_sorted, align='center', color='crimson', label='False Positive')
  186. plt.barh(range(n_classes), tp_sorted, align='center', color='forestgreen', label='True Positive', left=fp_sorted)
  187. # add legend
  188. plt.legend(loc='lower right')
  189. """
  190. Write number on side of bar
  191. """
  192. fig = plt.gcf() # gcf - get current figure
  193. axes = plt.gca()
  194. r = fig.canvas.get_renderer()
  195. for i, val in enumerate(sorted_values):
  196. fp_val = fp_sorted[i]
  197. tp_val = tp_sorted[i]
  198. fp_str_val = " " + str(fp_val)
  199. tp_str_val = fp_str_val + " " + str(tp_val)
  200. # trick to paint multicolor with offset:
  201. # first paint everything and then repaint the first number
  202. t = plt.text(val, i, tp_str_val, color='forestgreen', va='center', fontweight='bold')
  203. plt.text(val, i, fp_str_val, color='crimson', va='center', fontweight='bold')
  204. if i == (len(sorted_values)-1): # largest bar
  205. adjust_axes(r, t, fig, axes)
  206. else:
  207. plt.barh(range(n_classes), sorted_values, color=plot_color)
  208. """
  209. Write number on side of bar
  210. """
  211. fig = plt.gcf() # gcf - get current figure
  212. axes = plt.gca()
  213. r = fig.canvas.get_renderer()
  214. for i, val in enumerate(sorted_values):
  215. str_val = " " + str(val) # add a space before
  216. if val < 1.0:
  217. str_val = " {0:.2f}".format(val)
  218. t = plt.text(val, i, str_val, color=plot_color, va='center', fontweight='bold')
  219. # re-set axes to show number inside the figure
  220. if i == (len(sorted_values)-1): # largest bar
  221. adjust_axes(r, t, fig, axes)
  222. # set window title
  223. fig.canvas.manager.set_window_title(window_title)
  224. # write classes in y axis
  225. tick_font_size = 12
  226. plt.yticks(range(n_classes), sorted_keys, fontsize=tick_font_size)
  227. """
  228. Re-scale height accordingly
  229. """
  230. init_height = fig.get_figheight()
  231. # comput the matrix height in points and inches
  232. dpi = fig.dpi
  233. height_pt = n_classes * (tick_font_size * 1.4) # 1.4 (some spacing)
  234. height_in = height_pt / dpi
  235. # compute the required figure height
  236. top_margin = 0.15 # in percentage of the figure height
  237. bottom_margin = 0.05 # in percentage of the figure height
  238. figure_height = height_in / (1 - top_margin - bottom_margin)
  239. # set new height
  240. if figure_height > init_height:
  241. fig.set_figheight(figure_height)
  242. # set plot title
  243. plt.title(plot_title, fontsize=14)
  244. # set axis titles
  245. # plt.xlabel('classes')
  246. plt.xlabel(x_label, fontsize='large')
  247. # adjust size of window
  248. fig.tight_layout()
  249. # save the plot
  250. fig.savefig(output_path)
  251. # show image
  252. if to_show:
  253. plt.show()
  254. # close the plot
  255. plt.close()
  256. def get_map(MINOVERLAP, draw_plot, score_threhold=0.5, path = './map_out'):
  257. GT_PATH = os.path.join(path, 'ground-truth')
  258. DR_PATH = os.path.join(path, 'detection-results')
  259. IMG_PATH = os.path.join(path, 'images-optional')
  260. TEMP_FILES_PATH = os.path.join(path, '.temp_files')
  261. RESULTS_FILES_PATH = os.path.join(path, 'results')
  262. show_animation = True
  263. if os.path.exists(IMG_PATH):
  264. for dirpath, dirnames, files in os.walk(IMG_PATH):
  265. if not files:
  266. show_animation = False
  267. else:
  268. show_animation = False
  269. if not os.path.exists(TEMP_FILES_PATH):
  270. os.makedirs(TEMP_FILES_PATH)
  271. if os.path.exists(RESULTS_FILES_PATH):
  272. shutil.rmtree(RESULTS_FILES_PATH)
  273. else:
  274. os.makedirs(RESULTS_FILES_PATH)
  275. if draw_plot:
  276. try:
  277. matplotlib.use('TkAgg')
  278. except:
  279. pass
  280. os.makedirs(os.path.join(RESULTS_FILES_PATH, "AP"))
  281. os.makedirs(os.path.join(RESULTS_FILES_PATH, "F1"))
  282. os.makedirs(os.path.join(RESULTS_FILES_PATH, "Recall"))
  283. os.makedirs(os.path.join(RESULTS_FILES_PATH, "Precision"))
  284. if show_animation:
  285. os.makedirs(os.path.join(RESULTS_FILES_PATH, "images", "detections_one_by_one"))
  286. ground_truth_files_list = glob.glob(GT_PATH + '/*.txt')
  287. if len(ground_truth_files_list) == 0:
  288. error("Error: No ground-truth files found!")
  289. ground_truth_files_list.sort()
  290. gt_counter_per_class = {}
  291. counter_images_per_class = {}
  292. for txt_file in ground_truth_files_list:
  293. file_id = txt_file.split(".txt", 1)[0]
  294. file_id = os.path.basename(os.path.normpath(file_id))
  295. temp_path = os.path.join(DR_PATH, (file_id + ".txt"))
  296. if not os.path.exists(temp_path):
  297. error_msg = "Error. File not found: {}\n".format(temp_path)
  298. error(error_msg)
  299. lines_list = file_lines_to_list(txt_file)
  300. bounding_boxes = []
  301. is_difficult = False
  302. already_seen_classes = []
  303. for line in lines_list:
  304. try:
  305. if "difficult" in line:
  306. class_name, left, top, right, bottom, _difficult = line.split()
  307. is_difficult = True
  308. else:
  309. class_name, left, top, right, bottom = line.split()
  310. except:
  311. if "difficult" in line:
  312. line_split = line.split()
  313. _difficult = line_split[-1]
  314. bottom = line_split[-2]
  315. right = line_split[-3]
  316. top = line_split[-4]
  317. left = line_split[-5]
  318. class_name = ""
  319. for name in line_split[:-5]:
  320. class_name += name + " "
  321. class_name = class_name[:-1]
  322. is_difficult = True
  323. else:
  324. line_split = line.split()
  325. bottom = line_split[-1]
  326. right = line_split[-2]
  327. top = line_split[-3]
  328. left = line_split[-4]
  329. class_name = ""
  330. for name in line_split[:-4]:
  331. class_name += name + " "
  332. class_name = class_name[:-1]
  333. bbox = left + " " + top + " " + right + " " + bottom
  334. if is_difficult:
  335. bounding_boxes.append({"class_name":class_name, "bbox":bbox, "used":False, "difficult":True})
  336. is_difficult = False
  337. else:
  338. bounding_boxes.append({"class_name":class_name, "bbox":bbox, "used":False})
  339. if class_name in gt_counter_per_class:
  340. gt_counter_per_class[class_name] += 1
  341. else:
  342. gt_counter_per_class[class_name] = 1
  343. if class_name not in already_seen_classes:
  344. if class_name in counter_images_per_class:
  345. counter_images_per_class[class_name] += 1
  346. else:
  347. counter_images_per_class[class_name] = 1
  348. already_seen_classes.append(class_name)
  349. with open(TEMP_FILES_PATH + "/" + file_id + "_ground_truth.json", 'w') as outfile:
  350. json.dump(bounding_boxes, outfile)
  351. gt_classes = list(gt_counter_per_class.keys())
  352. gt_classes = sorted(gt_classes)
  353. n_classes = len(gt_classes)
  354. dr_files_list = glob.glob(DR_PATH + '/*.txt')
  355. dr_files_list.sort()
  356. for class_index, class_name in enumerate(gt_classes):
  357. bounding_boxes = []
  358. for txt_file in dr_files_list:
  359. file_id = txt_file.split(".txt",1)[0]
  360. file_id = os.path.basename(os.path.normpath(file_id))
  361. temp_path = os.path.join(GT_PATH, (file_id + ".txt"))
  362. if class_index == 0:
  363. if not os.path.exists(temp_path):
  364. error_msg = "Error. File not found: {}\n".format(temp_path)
  365. error(error_msg)
  366. lines = file_lines_to_list(txt_file)
  367. for line in lines:
  368. try:
  369. tmp_class_name, confidence, left, top, right, bottom = line.split()
  370. except:
  371. line_split = line.split()
  372. bottom = line_split[-1]
  373. right = line_split[-2]
  374. top = line_split[-3]
  375. left = line_split[-4]
  376. confidence = line_split[-5]
  377. tmp_class_name = ""
  378. for name in line_split[:-5]:
  379. tmp_class_name += name + " "
  380. tmp_class_name = tmp_class_name[:-1]
  381. if tmp_class_name == class_name:
  382. bbox = left + " " + top + " " + right + " " +bottom
  383. bounding_boxes.append({"confidence":confidence, "file_id":file_id, "bbox":bbox})
  384. bounding_boxes.sort(key=lambda x:float(x['confidence']), reverse=True)
  385. with open(TEMP_FILES_PATH + "/" + class_name + "_dr.json", 'w') as outfile:
  386. json.dump(bounding_boxes, outfile)
  387. sum_AP = 0.0
  388. ap_dictionary = {}
  389. lamr_dictionary = {}
  390. with open(RESULTS_FILES_PATH + "/results.txt", 'w') as results_file:
  391. results_file.write("# AP and precision/recall per class\n")
  392. count_true_positives = {}
  393. for class_index, class_name in enumerate(gt_classes):
  394. count_true_positives[class_name] = 0
  395. dr_file = TEMP_FILES_PATH + "/" + class_name + "_dr.json"
  396. dr_data = json.load(open(dr_file))
  397. nd = len(dr_data)
  398. tp = [0] * nd
  399. fp = [0] * nd
  400. score = [0] * nd
  401. score_threhold_idx = 0
  402. for idx, detection in enumerate(dr_data):
  403. file_id = detection["file_id"]
  404. score[idx] = float(detection["confidence"])
  405. if score[idx] >= score_threhold:
  406. score_threhold_idx = idx
  407. if show_animation:
  408. ground_truth_img = glob.glob1(IMG_PATH, file_id + ".*")
  409. if len(ground_truth_img) == 0:
  410. error("Error. Image not found with id: " + file_id)
  411. elif len(ground_truth_img) > 1:
  412. error("Error. Multiple image with id: " + file_id)
  413. else:
  414. img = cv2.imread(IMG_PATH + "/" + ground_truth_img[0])
  415. img_cumulative_path = RESULTS_FILES_PATH + "/images/" + ground_truth_img[0]
  416. if os.path.isfile(img_cumulative_path):
  417. img_cumulative = cv2.imread(img_cumulative_path)
  418. else:
  419. img_cumulative = img.copy()
  420. bottom_border = 60
  421. BLACK = [0, 0, 0]
  422. img = cv2.copyMakeBorder(img, 0, bottom_border, 0, 0, cv2.BORDER_CONSTANT, value=BLACK)
  423. gt_file = TEMP_FILES_PATH + "/" + file_id + "_ground_truth.json"
  424. ground_truth_data = json.load(open(gt_file))
  425. ovmax = -1
  426. gt_match = -1
  427. bb = [float(x) for x in detection["bbox"].split()]
  428. for obj in ground_truth_data:
  429. if obj["class_name"] == class_name:
  430. bbgt = [ float(x) for x in obj["bbox"].split() ]
  431. bi = [max(bb[0],bbgt[0]), max(bb[1],bbgt[1]), min(bb[2],bbgt[2]), min(bb[3],bbgt[3])]
  432. iw = bi[2] - bi[0] + 1
  433. ih = bi[3] - bi[1] + 1
  434. if iw > 0 and ih > 0:
  435. ua = (bb[2] - bb[0] + 1) * (bb[3] - bb[1] + 1) + (bbgt[2] - bbgt[0]
  436. + 1) * (bbgt[3] - bbgt[1] + 1) - iw * ih
  437. ov = iw * ih / ua
  438. if ov > ovmax:
  439. ovmax = ov
  440. gt_match = obj
  441. if show_animation:
  442. status = "NO MATCH FOUND!"
  443. min_overlap = MINOVERLAP
  444. if ovmax >= min_overlap:
  445. if "difficult" not in gt_match:
  446. if not bool(gt_match["used"]):
  447. tp[idx] = 1
  448. gt_match["used"] = True
  449. count_true_positives[class_name] += 1
  450. with open(gt_file, 'w') as f:
  451. f.write(json.dumps(ground_truth_data))
  452. if show_animation:
  453. status = "MATCH!"
  454. else:
  455. fp[idx] = 1
  456. if show_animation:
  457. status = "REPEATED MATCH!"
  458. else:
  459. fp[idx] = 1
  460. if ovmax > 0:
  461. status = "INSUFFICIENT OVERLAP"
  462. """
  463. Draw image to show animation
  464. """
  465. if show_animation:
  466. height, widht = img.shape[:2]
  467. white = (255,255,255)
  468. light_blue = (255,200,100)
  469. green = (0,255,0)
  470. light_red = (30,30,255)
  471. margin = 10
  472. # 1nd line
  473. v_pos = int(height - margin - (bottom_border / 2.0))
  474. text = "Image: " + ground_truth_img[0] + " "
  475. img, line_width = draw_text_in_image(img, text, (margin, v_pos), white, 0)
  476. text = "Class [" + str(class_index) + "/" + str(n_classes) + "]: " + class_name + " "
  477. img, line_width = draw_text_in_image(img, text, (margin + line_width, v_pos), light_blue, line_width)
  478. if ovmax != -1:
  479. color = light_red
  480. if status == "INSUFFICIENT OVERLAP":
  481. text = "IoU: {0:.2f}% ".format(ovmax*100) + "< {0:.2f}% ".format(min_overlap*100)
  482. else:
  483. text = "IoU: {0:.2f}% ".format(ovmax*100) + ">= {0:.2f}% ".format(min_overlap*100)
  484. color = green
  485. img, _ = draw_text_in_image(img, text, (margin + line_width, v_pos), color, line_width)
  486. # 2nd line
  487. v_pos += int(bottom_border / 2.0)
  488. rank_pos = str(idx+1)
  489. text = "Detection #rank: " + rank_pos + " confidence: {0:.2f}% ".format(float(detection["confidence"])*100)
  490. img, line_width = draw_text_in_image(img, text, (margin, v_pos), white, 0)
  491. color = light_red
  492. if status == "MATCH!":
  493. color = green
  494. text = "Result: " + status + " "
  495. img, line_width = draw_text_in_image(img, text, (margin + line_width, v_pos), color, line_width)
  496. font = cv2.FONT_HERSHEY_SIMPLEX
  497. if ovmax > 0:
  498. bbgt = [ int(round(float(x))) for x in gt_match["bbox"].split() ]
  499. cv2.rectangle(img,(bbgt[0],bbgt[1]),(bbgt[2],bbgt[3]),light_blue,2)
  500. cv2.rectangle(img_cumulative,(bbgt[0],bbgt[1]),(bbgt[2],bbgt[3]),light_blue,2)
  501. cv2.putText(img_cumulative, class_name, (bbgt[0],bbgt[1] - 5), font, 0.6, light_blue, 1, cv2.LINE_AA)
  502. bb = [int(i) for i in bb]
  503. cv2.rectangle(img,(bb[0],bb[1]),(bb[2],bb[3]),color,2)
  504. cv2.rectangle(img_cumulative,(bb[0],bb[1]),(bb[2],bb[3]),color,2)
  505. cv2.putText(img_cumulative, class_name, (bb[0],bb[1] - 5), font, 0.6, color, 1, cv2.LINE_AA)
  506. cv2.imshow("Animation", img)
  507. cv2.waitKey(20)
  508. output_img_path = RESULTS_FILES_PATH + "/images/detections_one_by_one/" + class_name + "_detection" + str(idx) + ".jpg"
  509. cv2.imwrite(output_img_path, img)
  510. cv2.imwrite(img_cumulative_path, img_cumulative)
  511. cumsum = 0
  512. for idx, val in enumerate(fp):
  513. fp[idx] += cumsum
  514. cumsum += val
  515. cumsum = 0
  516. for idx, val in enumerate(tp):
  517. tp[idx] += cumsum
  518. cumsum += val
  519. rec = tp[:]
  520. for idx, val in enumerate(tp):
  521. rec[idx] = float(tp[idx]) / np.maximum(gt_counter_per_class[class_name], 1)
  522. prec = tp[:]
  523. for idx, val in enumerate(tp):
  524. prec[idx] = float(tp[idx]) / np.maximum((fp[idx] + tp[idx]), 1)
  525. ap, mrec, mprec = voc_ap(rec[:], prec[:])
  526. F1 = np.array(rec)*np.array(prec)*2 / np.where((np.array(prec)+np.array(rec))==0, 1, (np.array(prec)+np.array(rec)))
  527. sum_AP += ap
  528. text = "{0:.2f}%".format(ap*100) + " = " + class_name + " AP " #class_name + " AP = {0:.2f}%".format(ap*100)
  529. if len(prec)>0:
  530. F1_text = "{0:.2f}".format(F1[score_threhold_idx]) + " = " + class_name + " F1 "
  531. Recall_text = "{0:.2f}%".format(rec[score_threhold_idx]*100) + " = " + class_name + " Recall "
  532. Precision_text = "{0:.2f}%".format(prec[score_threhold_idx]*100) + " = " + class_name + " Precision "
  533. else:
  534. F1_text = "0.00" + " = " + class_name + " F1 "
  535. Recall_text = "0.00%" + " = " + class_name + " Recall "
  536. Precision_text = "0.00%" + " = " + class_name + " Precision "
  537. rounded_prec = [ '%.2f' % elem for elem in prec ]
  538. rounded_rec = [ '%.2f' % elem for elem in rec ]
  539. results_file.write(text + "\n Precision: " + str(rounded_prec) + "\n Recall :" + str(rounded_rec) + "\n\n")
  540. if len(prec)>0:
  541. print(text + "\t||\tscore_threhold=" + str(score_threhold) + " : " + "F1=" + "{0:.2f}".format(F1[score_threhold_idx])\
  542. + " ; Recall=" + "{0:.2f}%".format(rec[score_threhold_idx]*100) + " ; Precision=" + "{0:.2f}%".format(prec[score_threhold_idx]*100))
  543. else:
  544. print(text + "\t||\tscore_threhold=" + str(score_threhold) + " : " + "F1=0.00% ; Recall=0.00% ; Precision=0.00%")
  545. ap_dictionary[class_name] = ap
  546. n_images = counter_images_per_class[class_name]
  547. lamr, mr, fppi = log_average_miss_rate(np.array(rec), np.array(fp), n_images)
  548. lamr_dictionary[class_name] = lamr
  549. if draw_plot:
  550. plt.plot(rec, prec, '-o')
  551. area_under_curve_x = mrec[:-1] + [mrec[-2]] + [mrec[-1]]
  552. area_under_curve_y = mprec[:-1] + [0.0] + [mprec[-1]]
  553. plt.fill_between(area_under_curve_x, 0, area_under_curve_y, alpha=0.2, edgecolor='r')
  554. fig = plt.gcf()
  555. fig.canvas.manager.set_window_title('AP ' + class_name)
  556. plt.title('class: ' + text)
  557. plt.xlabel('Recall')
  558. plt.ylabel('Precision')
  559. axes = plt.gca()
  560. axes.set_xlim([0.0,1.0])
  561. axes.set_ylim([0.0,1.05])
  562. fig.savefig(RESULTS_FILES_PATH + "/AP/" + class_name + ".png")
  563. plt.cla()
  564. plt.plot(score, F1, "-", color='orangered')
  565. plt.title('class: ' + F1_text + "\nscore_threhold=" + str(score_threhold))
  566. plt.xlabel('Score_Threhold')
  567. plt.ylabel('F1')
  568. axes = plt.gca()
  569. axes.set_xlim([0.0,1.0])
  570. axes.set_ylim([0.0,1.05])
  571. fig.savefig(RESULTS_FILES_PATH + "/F1/" + class_name + ".png")
  572. plt.cla()
  573. plt.plot(score, rec, "-H", color='gold')
  574. plt.title('class: ' + Recall_text + "\nscore_threhold=" + str(score_threhold))
  575. plt.xlabel('Score_Threhold')
  576. plt.ylabel('Recall')
  577. axes = plt.gca()
  578. axes.set_xlim([0.0,1.0])
  579. axes.set_ylim([0.0,1.05])
  580. fig.savefig(RESULTS_FILES_PATH + "/Recall/" + class_name + ".png")
  581. plt.cla()
  582. plt.plot(score, prec, "-s", color='palevioletred')
  583. plt.title('class: ' + Precision_text + "\nscore_threhold=" + str(score_threhold))
  584. plt.xlabel('Score_Threhold')
  585. plt.ylabel('Precision')
  586. axes = plt.gca()
  587. axes.set_xlim([0.0,1.0])
  588. axes.set_ylim([0.0,1.05])
  589. fig.savefig(RESULTS_FILES_PATH + "/Precision/" + class_name + ".png")
  590. plt.cla()
  591. if show_animation:
  592. cv2.destroyAllWindows()
  593. if n_classes == 0:
  594. print("未检测到任何种类,请检查标签信息与get_map.py中的classes_path是否修改。")
  595. return 0
  596. results_file.write("\n# mAP of all classes\n")
  597. mAP = sum_AP / n_classes
  598. text = "mAP = {0:.2f}%".format(mAP*100)
  599. results_file.write(text + "\n")
  600. print(text)
  601. shutil.rmtree(TEMP_FILES_PATH)
  602. """
  603. Count total of detection-results
  604. """
  605. det_counter_per_class = {}
  606. for txt_file in dr_files_list:
  607. lines_list = file_lines_to_list(txt_file)
  608. for line in lines_list:
  609. class_name = line.split()[0]
  610. if class_name in det_counter_per_class:
  611. det_counter_per_class[class_name] += 1
  612. else:
  613. det_counter_per_class[class_name] = 1
  614. dr_classes = list(det_counter_per_class.keys())
  615. """
  616. Write number of ground-truth objects per class to results.txt
  617. """
  618. with open(RESULTS_FILES_PATH + "/results.txt", 'a') as results_file:
  619. results_file.write("\n# Number of ground-truth objects per class\n")
  620. for class_name in sorted(gt_counter_per_class):
  621. results_file.write(class_name + ": " + str(gt_counter_per_class[class_name]) + "\n")
  622. """
  623. Finish counting true positives
  624. """
  625. for class_name in dr_classes:
  626. if class_name not in gt_classes:
  627. count_true_positives[class_name] = 0
  628. """
  629. Write number of detected objects per class to results.txt
  630. """
  631. with open(RESULTS_FILES_PATH + "/results.txt", 'a') as results_file:
  632. results_file.write("\n# Number of detected objects per class\n")
  633. for class_name in sorted(dr_classes):
  634. n_det = det_counter_per_class[class_name]
  635. text = class_name + ": " + str(n_det)
  636. text += " (tp:" + str(count_true_positives[class_name]) + ""
  637. text += ", fp:" + str(n_det - count_true_positives[class_name]) + ")\n"
  638. results_file.write(text)
  639. """
  640. Plot the total number of occurences of each class in the ground-truth
  641. """
  642. if draw_plot:
  643. window_title = "ground-truth-info"
  644. plot_title = "ground-truth\n"
  645. plot_title += "(" + str(len(ground_truth_files_list)) + " files and " + str(n_classes) + " classes)"
  646. x_label = "Number of objects per class"
  647. output_path = RESULTS_FILES_PATH + "/ground-truth-info.png"
  648. to_show = False
  649. plot_color = 'forestgreen'
  650. draw_plot_func(
  651. gt_counter_per_class,
  652. n_classes,
  653. window_title,
  654. plot_title,
  655. x_label,
  656. output_path,
  657. to_show,
  658. plot_color,
  659. '',
  660. )
  661. # """
  662. # Plot the total number of occurences of each class in the "detection-results" folder
  663. # """
  664. # if draw_plot:
  665. # window_title = "detection-results-info"
  666. # # Plot title
  667. # plot_title = "detection-results\n"
  668. # plot_title += "(" + str(len(dr_files_list)) + " files and "
  669. # count_non_zero_values_in_dictionary = sum(int(x) > 0 for x in list(det_counter_per_class.values()))
  670. # plot_title += str(count_non_zero_values_in_dictionary) + " detected classes)"
  671. # # end Plot title
  672. # x_label = "Number of objects per class"
  673. # output_path = RESULTS_FILES_PATH + "/detection-results-info.png"
  674. # to_show = False
  675. # plot_color = 'forestgreen'
  676. # true_p_bar = count_true_positives
  677. # draw_plot_func(
  678. # det_counter_per_class,
  679. # len(det_counter_per_class),
  680. # window_title,
  681. # plot_title,
  682. # x_label,
  683. # output_path,
  684. # to_show,
  685. # plot_color,
  686. # true_p_bar
  687. # )
  688. """
  689. Draw log-average miss rate plot (Show lamr of all classes in decreasing order)
  690. """
  691. if draw_plot:
  692. window_title = "lamr"
  693. plot_title = "log-average miss rate"
  694. x_label = "log-average miss rate"
  695. output_path = RESULTS_FILES_PATH + "/lamr.png"
  696. to_show = False
  697. plot_color = 'royalblue'
  698. draw_plot_func(
  699. lamr_dictionary,
  700. n_classes,
  701. window_title,
  702. plot_title,
  703. x_label,
  704. output_path,
  705. to_show,
  706. plot_color,
  707. ""
  708. )
  709. """
  710. Draw mAP plot (Show AP's of all classes in decreasing order)
  711. """
  712. if draw_plot:
  713. window_title = "mAP"
  714. plot_title = "mAP = {0:.2f}%".format(mAP*100)
  715. x_label = "Average Precision"
  716. output_path = RESULTS_FILES_PATH + "/mAP.png"
  717. to_show = True
  718. plot_color = 'royalblue'
  719. draw_plot_func(
  720. ap_dictionary,
  721. n_classes,
  722. window_title,
  723. plot_title,
  724. x_label,
  725. output_path,
  726. to_show,
  727. plot_color,
  728. ""
  729. )
  730. return mAP
  731. def preprocess_gt(gt_path, class_names):
  732. image_ids = os.listdir(gt_path)
  733. results = {}
  734. images = []
  735. bboxes = []
  736. for i, image_id in enumerate(image_ids):
  737. lines_list = file_lines_to_list(os.path.join(gt_path, image_id))
  738. boxes_per_image = []
  739. image = {}
  740. image_id = os.path.splitext(image_id)[0]
  741. image['file_name'] = image_id + '.jpg'
  742. image['width'] = 1
  743. image['height'] = 1
  744. #-----------------------------------------------------------------#
  745. # 感谢 多学学英语吧 的提醒
  746. # 解决了'Results do not correspond to current coco set'问题
  747. #-----------------------------------------------------------------#
  748. image['id'] = str(image_id)
  749. for line in lines_list:
  750. difficult = 0
  751. if "difficult" in line:
  752. line_split = line.split()
  753. left, top, right, bottom, _difficult = line_split[-5:]
  754. class_name = ""
  755. for name in line_split[:-5]:
  756. class_name += name + " "
  757. class_name = class_name[:-1]
  758. difficult = 1
  759. else:
  760. line_split = line.split()
  761. left, top, right, bottom = line_split[-4:]
  762. class_name = ""
  763. for name in line_split[:-4]:
  764. class_name += name + " "
  765. class_name = class_name[:-1]
  766. left, top, right, bottom = float(left), float(top), float(right), float(bottom)
  767. if class_name not in class_names:
  768. continue
  769. cls_id = class_names.index(class_name) + 1
  770. bbox = [left, top, right - left, bottom - top, difficult, str(image_id), cls_id, (right - left) * (bottom - top) - 10.0]
  771. boxes_per_image.append(bbox)
  772. images.append(image)
  773. bboxes.extend(boxes_per_image)
  774. results['images'] = images
  775. categories = []
  776. for i, cls in enumerate(class_names):
  777. category = {}
  778. category['supercategory'] = cls
  779. category['name'] = cls
  780. category['id'] = i + 1
  781. categories.append(category)
  782. results['categories'] = categories
  783. annotations = []
  784. for i, box in enumerate(bboxes):
  785. annotation = {}
  786. annotation['area'] = box[-1]
  787. annotation['category_id'] = box[-2]
  788. annotation['image_id'] = box[-3]
  789. annotation['iscrowd'] = box[-4]
  790. annotation['bbox'] = box[:4]
  791. annotation['id'] = i
  792. annotations.append(annotation)
  793. results['annotations'] = annotations
  794. return results
  795. def preprocess_dr(dr_path, class_names):
  796. image_ids = os.listdir(dr_path)
  797. results = []
  798. for image_id in image_ids:
  799. lines_list = file_lines_to_list(os.path.join(dr_path, image_id))
  800. image_id = os.path.splitext(image_id)[0]
  801. for line in lines_list:
  802. line_split = line.split()
  803. confidence, left, top, right, bottom = line_split[-5:]
  804. class_name = ""
  805. for name in line_split[:-5]:
  806. class_name += name + " "
  807. class_name = class_name[:-1]
  808. left, top, right, bottom = float(left), float(top), float(right), float(bottom)
  809. result = {}
  810. result["image_id"] = str(image_id)
  811. if class_name not in class_names:
  812. continue
  813. result["category_id"] = class_names.index(class_name) + 1
  814. result["bbox"] = [left, top, right - left, bottom - top]
  815. result["score"] = float(confidence)
  816. results.append(result)
  817. return results
  818. def get_coco_map(class_names, path):
  819. GT_PATH = os.path.join(path, 'ground-truth')
  820. DR_PATH = os.path.join(path, 'detection-results')
  821. COCO_PATH = os.path.join(path, 'coco_eval')
  822. if not os.path.exists(COCO_PATH):
  823. os.makedirs(COCO_PATH)
  824. GT_JSON_PATH = os.path.join(COCO_PATH, 'instances_gt.json')
  825. DR_JSON_PATH = os.path.join(COCO_PATH, 'instances_dr.json')
  826. with open(GT_JSON_PATH, "w") as f:
  827. results_gt = preprocess_gt(GT_PATH, class_names)
  828. json.dump(results_gt, f, indent=4)
  829. with open(DR_JSON_PATH, "w") as f:
  830. results_dr = preprocess_dr(DR_PATH, class_names)
  831. json.dump(results_dr, f, indent=4)
  832. if len(results_dr) == 0:
  833. print("未检测到任何目标。")
  834. return [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
  835. cocoGt = COCO(GT_JSON_PATH)
  836. cocoDt = cocoGt.loadRes(DR_JSON_PATH)
  837. cocoEval = COCOeval(cocoGt, cocoDt, 'bbox')
  838. cocoEval.evaluate()
  839. cocoEval.accumulate()
  840. cocoEval.summarize()
  841. return cocoEval.stats