utils_map.py 35 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901
  1. import glob
  2. import json
  3. import math
  4. import operator
  5. import os
  6. import shutil
  7. import sys
  8. import cv2
  9. import matplotlib.pyplot as plt
  10. import numpy as np
  11. '''
  12. 0,0 ------> x (width)
  13. |
  14. | (Left,Top)
  15. | *_________
  16. | | |
  17. | |
  18. y |_________|
  19. (height) *
  20. (Right,Bottom)
  21. '''
  22. def log_average_miss_rate(precision, fp_cumsum, num_images):
  23. """
  24. log-average miss rate:
  25. Calculated by averaging miss rates at 9 evenly spaced FPPI points
  26. between 10e-2 and 10e0, in log-space.
  27. output:
  28. lamr | log-average miss rate
  29. mr | miss rate
  30. fppi | false positives per image
  31. references:
  32. [1] Dollar, Piotr, et al. "Pedestrian Detection: An Evaluation of the
  33. State of the Art." Pattern Analysis and Machine Intelligence, IEEE
  34. Transactions on 34.4 (2012): 743 - 761.
  35. """
  36. if precision.size == 0:
  37. lamr = 0
  38. mr = 1
  39. fppi = 0
  40. return lamr, mr, fppi
  41. fppi = fp_cumsum / float(num_images)
  42. mr = (1 - precision)
  43. fppi_tmp = np.insert(fppi, 0, -1.0)
  44. mr_tmp = np.insert(mr, 0, 1.0)
  45. ref = np.logspace(-2.0, 0.0, num = 9)
  46. for i, ref_i in enumerate(ref):
  47. j = np.where(fppi_tmp <= ref_i)[-1][-1]
  48. ref[i] = mr_tmp[j]
  49. lamr = math.exp(np.mean(np.log(np.maximum(1e-10, ref))))
  50. return lamr, mr, fppi
  51. """
  52. throw error and exit
  53. """
  54. def error(msg):
  55. print(msg)
  56. sys.exit(0)
  57. """
  58. check if the number is a float between 0.0 and 1.0
  59. """
  60. def is_float_between_0_and_1(value):
  61. try:
  62. val = float(value)
  63. if val > 0.0 and val < 1.0:
  64. return True
  65. else:
  66. return False
  67. except ValueError:
  68. return False
  69. """
  70. Calculate the AP given the recall and precision array
  71. 1st) We compute a version of the measured precision/recall curve with
  72. precision monotonically decreasing
  73. 2nd) We compute the AP as the area under this curve by numerical integration.
  74. """
  75. def voc_ap(rec, prec):
  76. """
  77. --- Official matlab code VOC2012---
  78. mrec=[0 ; rec ; 1];
  79. mpre=[0 ; prec ; 0];
  80. for i=numel(mpre)-1:-1:1
  81. mpre(i)=max(mpre(i),mpre(i+1));
  82. end
  83. i=find(mrec(2:end)~=mrec(1:end-1))+1;
  84. ap=sum((mrec(i)-mrec(i-1)).*mpre(i));
  85. """
  86. rec.insert(0, 0.0) # insert 0.0 at begining of list
  87. rec.append(1.0) # insert 1.0 at end of list
  88. mrec = rec[:]
  89. prec.insert(0, 0.0) # insert 0.0 at begining of list
  90. prec.append(0.0) # insert 0.0 at end of list
  91. mpre = prec[:]
  92. """
  93. This part makes the precision monotonically decreasing
  94. (goes from the end to the beginning)
  95. matlab: for i=numel(mpre)-1:-1:1
  96. mpre(i)=max(mpre(i),mpre(i+1));
  97. """
  98. for i in range(len(mpre)-2, -1, -1):
  99. mpre[i] = max(mpre[i], mpre[i+1])
  100. """
  101. This part creates a list of indexes where the recall changes
  102. matlab: i=find(mrec(2:end)~=mrec(1:end-1))+1;
  103. """
  104. i_list = []
  105. for i in range(1, len(mrec)):
  106. if mrec[i] != mrec[i-1]:
  107. i_list.append(i) # if it was matlab would be i + 1
  108. """
  109. The Average Precision (AP) is the area under the curve
  110. (numerical integration)
  111. matlab: ap=sum((mrec(i)-mrec(i-1)).*mpre(i));
  112. """
  113. ap = 0.0
  114. for i in i_list:
  115. ap += ((mrec[i]-mrec[i-1])*mpre[i])
  116. return ap, mrec, mpre
  117. """
  118. Convert the lines of a file to a list
  119. """
  120. def file_lines_to_list(path):
  121. # open txt file lines to a list
  122. with open(path) as f:
  123. content = f.readlines()
  124. # remove whitespace characters like `\n` at the end of each line
  125. content = [x.strip() for x in content]
  126. return content
  127. """
  128. Draws text in image
  129. """
  130. def draw_text_in_image(img, text, pos, color, line_width):
  131. font = cv2.FONT_HERSHEY_PLAIN
  132. fontScale = 1
  133. lineType = 1
  134. bottomLeftCornerOfText = pos
  135. cv2.putText(img, text,
  136. bottomLeftCornerOfText,
  137. font,
  138. fontScale,
  139. color,
  140. lineType)
  141. text_width, _ = cv2.getTextSize(text, font, fontScale, lineType)[0]
  142. return img, (line_width + text_width)
  143. """
  144. Plot - adjust axes
  145. """
  146. def adjust_axes(r, t, fig, axes):
  147. # get text width for re-scaling
  148. bb = t.get_window_extent(renderer=r)
  149. text_width_inches = bb.width / fig.dpi
  150. # get axis width in inches
  151. current_fig_width = fig.get_figwidth()
  152. new_fig_width = current_fig_width + text_width_inches
  153. propotion = new_fig_width / current_fig_width
  154. # get axis limit
  155. x_lim = axes.get_xlim()
  156. axes.set_xlim([x_lim[0], x_lim[1]*propotion])
  157. """
  158. Draw plot using Matplotlib
  159. """
  160. def draw_plot_func(dictionary, n_classes, window_title, plot_title, x_label, output_path, to_show, plot_color, true_p_bar):
  161. # sort the dictionary by decreasing value, into a list of tuples
  162. sorted_dic_by_value = sorted(dictionary.items(), key=operator.itemgetter(1))
  163. # unpacking the list of tuples into two lists
  164. sorted_keys, sorted_values = zip(*sorted_dic_by_value)
  165. #
  166. if true_p_bar != "":
  167. """
  168. Special case to draw in:
  169. - green -> TP: True Positives (object detected and matches ground-truth)
  170. - red -> FP: False Positives (object detected but does not match ground-truth)
  171. - orange -> FN: False Negatives (object not detected but present in the ground-truth)
  172. """
  173. fp_sorted = []
  174. tp_sorted = []
  175. for key in sorted_keys:
  176. fp_sorted.append(dictionary[key] - true_p_bar[key])
  177. tp_sorted.append(true_p_bar[key])
  178. plt.barh(range(n_classes), fp_sorted, align='center', color='crimson', label='False Positive')
  179. plt.barh(range(n_classes), tp_sorted, align='center', color='forestgreen', label='True Positive', left=fp_sorted)
  180. # add legend
  181. plt.legend(loc='lower right')
  182. """
  183. Write number on side of bar
  184. """
  185. fig = plt.gcf() # gcf - get current figure
  186. axes = plt.gca()
  187. r = fig.canvas.get_renderer()
  188. for i, val in enumerate(sorted_values):
  189. fp_val = fp_sorted[i]
  190. tp_val = tp_sorted[i]
  191. fp_str_val = " " + str(fp_val)
  192. tp_str_val = fp_str_val + " " + str(tp_val)
  193. # trick to paint multicolor with offset:
  194. # first paint everything and then repaint the first number
  195. t = plt.text(val, i, tp_str_val, color='forestgreen', va='center', fontweight='bold')
  196. plt.text(val, i, fp_str_val, color='crimson', va='center', fontweight='bold')
  197. if i == (len(sorted_values)-1): # largest bar
  198. adjust_axes(r, t, fig, axes)
  199. else:
  200. plt.barh(range(n_classes), sorted_values, color=plot_color)
  201. """
  202. Write number on side of bar
  203. """
  204. fig = plt.gcf() # gcf - get current figure
  205. axes = plt.gca()
  206. r = fig.canvas.get_renderer()
  207. for i, val in enumerate(sorted_values):
  208. str_val = " " + str(val) # add a space before
  209. if val < 1.0:
  210. str_val = " {0:.2f}".format(val)
  211. t = plt.text(val, i, str_val, color=plot_color, va='center', fontweight='bold')
  212. # re-set axes to show number inside the figure
  213. if i == (len(sorted_values)-1): # largest bar
  214. adjust_axes(r, t, fig, axes)
  215. # set window title
  216. fig.canvas.set_window_title(window_title)
  217. # write classes in y axis
  218. tick_font_size = 12
  219. plt.yticks(range(n_classes), sorted_keys, fontsize=tick_font_size)
  220. """
  221. Re-scale height accordingly
  222. """
  223. init_height = fig.get_figheight()
  224. # comput the matrix height in points and inches
  225. dpi = fig.dpi
  226. height_pt = n_classes * (tick_font_size * 1.4) # 1.4 (some spacing)
  227. height_in = height_pt / dpi
  228. # compute the required figure height
  229. top_margin = 0.15 # in percentage of the figure height
  230. bottom_margin = 0.05 # in percentage of the figure height
  231. figure_height = height_in / (1 - top_margin - bottom_margin)
  232. # set new height
  233. if figure_height > init_height:
  234. fig.set_figheight(figure_height)
  235. # set plot title
  236. plt.title(plot_title, fontsize=14)
  237. # set axis titles
  238. # plt.xlabel('classes')
  239. plt.xlabel(x_label, fontsize='large')
  240. # adjust size of window
  241. fig.tight_layout()
  242. # save the plot
  243. fig.savefig(output_path)
  244. # show image
  245. if to_show:
  246. plt.show()
  247. # close the plot
  248. plt.close()
  249. def get_map(MINOVERLAP, draw_plot, path = './map_out'):
  250. GT_PATH = os.path.join(path, 'ground-truth')
  251. DR_PATH = os.path.join(path, 'detection-results')
  252. IMG_PATH = os.path.join(path, 'images-optional')
  253. TEMP_FILES_PATH = os.path.join(path, '.temp_files')
  254. RESULTS_FILES_PATH = os.path.join(path, 'results')
  255. show_animation = True
  256. if os.path.exists(IMG_PATH):
  257. for dirpath, dirnames, files in os.walk(IMG_PATH):
  258. if not files:
  259. show_animation = False
  260. else:
  261. show_animation = False
  262. if not os.path.exists(TEMP_FILES_PATH):
  263. os.makedirs(TEMP_FILES_PATH)
  264. if os.path.exists(RESULTS_FILES_PATH):
  265. shutil.rmtree(RESULTS_FILES_PATH)
  266. if draw_plot:
  267. os.makedirs(os.path.join(RESULTS_FILES_PATH, "AP"))
  268. os.makedirs(os.path.join(RESULTS_FILES_PATH, "F1"))
  269. os.makedirs(os.path.join(RESULTS_FILES_PATH, "Recall"))
  270. os.makedirs(os.path.join(RESULTS_FILES_PATH, "Precision"))
  271. if show_animation:
  272. os.makedirs(os.path.join(RESULTS_FILES_PATH, "images", "detections_one_by_one"))
  273. ground_truth_files_list = glob.glob(GT_PATH + '/*.txt')
  274. if len(ground_truth_files_list) == 0:
  275. error("Error: No ground-truth files found!")
  276. ground_truth_files_list.sort()
  277. gt_counter_per_class = {}
  278. counter_images_per_class = {}
  279. for txt_file in ground_truth_files_list:
  280. file_id = txt_file.split(".txt", 1)[0]
  281. file_id = os.path.basename(os.path.normpath(file_id))
  282. temp_path = os.path.join(DR_PATH, (file_id + ".txt"))
  283. if not os.path.exists(temp_path):
  284. error_msg = "Error. File not found: {}\n".format(temp_path)
  285. error(error_msg)
  286. lines_list = file_lines_to_list(txt_file)
  287. bounding_boxes = []
  288. is_difficult = False
  289. already_seen_classes = []
  290. for line in lines_list:
  291. try:
  292. if "difficult" in line:
  293. class_name, left, top, right, bottom, _difficult = line.split()
  294. is_difficult = True
  295. else:
  296. class_name, left, top, right, bottom = line.split()
  297. except:
  298. if "difficult" in line:
  299. line_split = line.split()
  300. _difficult = line_split[-1]
  301. bottom = line_split[-2]
  302. right = line_split[-3]
  303. top = line_split[-4]
  304. left = line_split[-5]
  305. class_name = ""
  306. for name in line_split[:-5]:
  307. class_name += name + " "
  308. class_name = class_name[:-1]
  309. is_difficult = True
  310. else:
  311. line_split = line.split()
  312. bottom = line_split[-1]
  313. right = line_split[-2]
  314. top = line_split[-3]
  315. left = line_split[-4]
  316. class_name = ""
  317. for name in line_split[:-4]:
  318. class_name += name + " "
  319. class_name = class_name[:-1]
  320. bbox = left + " " + top + " " + right + " " + bottom
  321. if is_difficult:
  322. bounding_boxes.append({"class_name":class_name, "bbox":bbox, "used":False, "difficult":True})
  323. is_difficult = False
  324. else:
  325. bounding_boxes.append({"class_name":class_name, "bbox":bbox, "used":False})
  326. if class_name in gt_counter_per_class:
  327. gt_counter_per_class[class_name] += 1
  328. else:
  329. gt_counter_per_class[class_name] = 1
  330. if class_name not in already_seen_classes:
  331. if class_name in counter_images_per_class:
  332. counter_images_per_class[class_name] += 1
  333. else:
  334. counter_images_per_class[class_name] = 1
  335. already_seen_classes.append(class_name)
  336. with open(TEMP_FILES_PATH + "/" + file_id + "_ground_truth.json", 'w') as outfile:
  337. json.dump(bounding_boxes, outfile)
  338. gt_classes = list(gt_counter_per_class.keys())
  339. gt_classes = sorted(gt_classes)
  340. n_classes = len(gt_classes)
  341. dr_files_list = glob.glob(DR_PATH + '/*.txt')
  342. dr_files_list.sort()
  343. for class_index, class_name in enumerate(gt_classes):
  344. bounding_boxes = []
  345. for txt_file in dr_files_list:
  346. file_id = txt_file.split(".txt",1)[0]
  347. file_id = os.path.basename(os.path.normpath(file_id))
  348. temp_path = os.path.join(GT_PATH, (file_id + ".txt"))
  349. if class_index == 0:
  350. if not os.path.exists(temp_path):
  351. error_msg = "Error. File not found: {}\n".format(temp_path)
  352. error(error_msg)
  353. lines = file_lines_to_list(txt_file)
  354. for line in lines:
  355. try:
  356. tmp_class_name, confidence, left, top, right, bottom = line.split()
  357. except:
  358. line_split = line.split()
  359. bottom = line_split[-1]
  360. right = line_split[-2]
  361. top = line_split[-3]
  362. left = line_split[-4]
  363. confidence = line_split[-5]
  364. tmp_class_name = ""
  365. for name in line_split[:-5]:
  366. tmp_class_name += name + " "
  367. tmp_class_name = tmp_class_name[:-1]
  368. if tmp_class_name == class_name:
  369. bbox = left + " " + top + " " + right + " " +bottom
  370. bounding_boxes.append({"confidence":confidence, "file_id":file_id, "bbox":bbox})
  371. bounding_boxes.sort(key=lambda x:float(x['confidence']), reverse=True)
  372. with open(TEMP_FILES_PATH + "/" + class_name + "_dr.json", 'w') as outfile:
  373. json.dump(bounding_boxes, outfile)
  374. sum_AP = 0.0
  375. ap_dictionary = {}
  376. lamr_dictionary = {}
  377. with open(RESULTS_FILES_PATH + "/results.txt", 'w') as results_file:
  378. results_file.write("# AP and precision/recall per class\n")
  379. count_true_positives = {}
  380. for class_index, class_name in enumerate(gt_classes):
  381. count_true_positives[class_name] = 0
  382. dr_file = TEMP_FILES_PATH + "/" + class_name + "_dr.json"
  383. dr_data = json.load(open(dr_file))
  384. nd = len(dr_data)
  385. tp = [0] * nd
  386. fp = [0] * nd
  387. score = [0] * nd
  388. score05_idx = 0
  389. for idx, detection in enumerate(dr_data):
  390. file_id = detection["file_id"]
  391. score[idx] = float(detection["confidence"])
  392. if score[idx] > 0.5:
  393. score05_idx = idx
  394. if show_animation:
  395. ground_truth_img = glob.glob1(IMG_PATH, file_id + ".*")
  396. if len(ground_truth_img) == 0:
  397. error("Error. Image not found with id: " + file_id)
  398. elif len(ground_truth_img) > 1:
  399. error("Error. Multiple image with id: " + file_id)
  400. else:
  401. img = cv2.imread(IMG_PATH + "/" + ground_truth_img[0])
  402. img_cumulative_path = RESULTS_FILES_PATH + "/images/" + ground_truth_img[0]
  403. if os.path.isfile(img_cumulative_path):
  404. img_cumulative = cv2.imread(img_cumulative_path)
  405. else:
  406. img_cumulative = img.copy()
  407. bottom_border = 60
  408. BLACK = [0, 0, 0]
  409. img = cv2.copyMakeBorder(img, 0, bottom_border, 0, 0, cv2.BORDER_CONSTANT, value=BLACK)
  410. gt_file = TEMP_FILES_PATH + "/" + file_id + "_ground_truth.json"
  411. ground_truth_data = json.load(open(gt_file))
  412. ovmax = -1
  413. gt_match = -1
  414. bb = [float(x) for x in detection["bbox"].split()]
  415. for obj in ground_truth_data:
  416. if obj["class_name"] == class_name:
  417. bbgt = [ float(x) for x in obj["bbox"].split() ]
  418. bi = [max(bb[0],bbgt[0]), max(bb[1],bbgt[1]), min(bb[2],bbgt[2]), min(bb[3],bbgt[3])]
  419. iw = bi[2] - bi[0] + 1
  420. ih = bi[3] - bi[1] + 1
  421. if iw > 0 and ih > 0:
  422. ua = (bb[2] - bb[0] + 1) * (bb[3] - bb[1] + 1) + (bbgt[2] - bbgt[0]
  423. + 1) * (bbgt[3] - bbgt[1] + 1) - iw * ih
  424. ov = iw * ih / ua
  425. if ov > ovmax:
  426. ovmax = ov
  427. gt_match = obj
  428. if show_animation:
  429. status = "NO MATCH FOUND!"
  430. min_overlap = MINOVERLAP
  431. if ovmax >= min_overlap:
  432. if "difficult" not in gt_match:
  433. if not bool(gt_match["used"]):
  434. tp[idx] = 1
  435. gt_match["used"] = True
  436. count_true_positives[class_name] += 1
  437. with open(gt_file, 'w') as f:
  438. f.write(json.dumps(ground_truth_data))
  439. if show_animation:
  440. status = "MATCH!"
  441. else:
  442. fp[idx] = 1
  443. if show_animation:
  444. status = "REPEATED MATCH!"
  445. else:
  446. fp[idx] = 1
  447. if ovmax > 0:
  448. status = "INSUFFICIENT OVERLAP"
  449. """
  450. Draw image to show animation
  451. """
  452. if show_animation:
  453. height, widht = img.shape[:2]
  454. white = (255,255,255)
  455. light_blue = (255,200,100)
  456. green = (0,255,0)
  457. light_red = (30,30,255)
  458. margin = 10
  459. # 1nd line
  460. v_pos = int(height - margin - (bottom_border / 2.0))
  461. text = "Image: " + ground_truth_img[0] + " "
  462. img, line_width = draw_text_in_image(img, text, (margin, v_pos), white, 0)
  463. text = "Class [" + str(class_index) + "/" + str(n_classes) + "]: " + class_name + " "
  464. img, line_width = draw_text_in_image(img, text, (margin + line_width, v_pos), light_blue, line_width)
  465. if ovmax != -1:
  466. color = light_red
  467. if status == "INSUFFICIENT OVERLAP":
  468. text = "IoU: {0:.2f}% ".format(ovmax*100) + "< {0:.2f}% ".format(min_overlap*100)
  469. else:
  470. text = "IoU: {0:.2f}% ".format(ovmax*100) + ">= {0:.2f}% ".format(min_overlap*100)
  471. color = green
  472. img, _ = draw_text_in_image(img, text, (margin + line_width, v_pos), color, line_width)
  473. # 2nd line
  474. v_pos += int(bottom_border / 2.0)
  475. rank_pos = str(idx+1)
  476. text = "Detection #rank: " + rank_pos + " confidence: {0:.2f}% ".format(float(detection["confidence"])*100)
  477. img, line_width = draw_text_in_image(img, text, (margin, v_pos), white, 0)
  478. color = light_red
  479. if status == "MATCH!":
  480. color = green
  481. text = "Result: " + status + " "
  482. img, line_width = draw_text_in_image(img, text, (margin + line_width, v_pos), color, line_width)
  483. font = cv2.FONT_HERSHEY_SIMPLEX
  484. if ovmax > 0:
  485. bbgt = [ int(round(float(x))) for x in gt_match["bbox"].split() ]
  486. cv2.rectangle(img,(bbgt[0],bbgt[1]),(bbgt[2],bbgt[3]),light_blue,2)
  487. cv2.rectangle(img_cumulative,(bbgt[0],bbgt[1]),(bbgt[2],bbgt[3]),light_blue,2)
  488. cv2.putText(img_cumulative, class_name, (bbgt[0],bbgt[1] - 5), font, 0.6, light_blue, 1, cv2.LINE_AA)
  489. bb = [int(i) for i in bb]
  490. cv2.rectangle(img,(bb[0],bb[1]),(bb[2],bb[3]),color,2)
  491. cv2.rectangle(img_cumulative,(bb[0],bb[1]),(bb[2],bb[3]),color,2)
  492. cv2.putText(img_cumulative, class_name, (bb[0],bb[1] - 5), font, 0.6, color, 1, cv2.LINE_AA)
  493. cv2.imshow("Animation", img)
  494. cv2.waitKey(20)
  495. output_img_path = RESULTS_FILES_PATH + "/images/detections_one_by_one/" + class_name + "_detection" + str(idx) + ".jpg"
  496. cv2.imwrite(output_img_path, img)
  497. cv2.imwrite(img_cumulative_path, img_cumulative)
  498. cumsum = 0
  499. for idx, val in enumerate(fp):
  500. fp[idx] += cumsum
  501. cumsum += val
  502. cumsum = 0
  503. for idx, val in enumerate(tp):
  504. tp[idx] += cumsum
  505. cumsum += val
  506. rec = tp[:]
  507. for idx, val in enumerate(tp):
  508. rec[idx] = float(tp[idx]) / np.maximum(gt_counter_per_class[class_name], 1)
  509. prec = tp[:]
  510. for idx, val in enumerate(tp):
  511. prec[idx] = float(tp[idx]) / np.maximum((fp[idx] + tp[idx]), 1)
  512. ap, mrec, mprec = voc_ap(rec[:], prec[:])
  513. F1 = np.array(rec)*np.array(prec)*2 / np.where((np.array(prec)+np.array(rec))==0, 1, (np.array(prec)+np.array(rec)))
  514. sum_AP += ap
  515. text = "{0:.2f}%".format(ap*100) + " = " + class_name + " AP " #class_name + " AP = {0:.2f}%".format(ap*100)
  516. if len(prec)>0:
  517. F1_text = "{0:.2f}".format(F1[score05_idx]) + " = " + class_name + " F1 "
  518. Recall_text = "{0:.2f}%".format(rec[score05_idx]*100) + " = " + class_name + " Recall "
  519. Precision_text = "{0:.2f}%".format(prec[score05_idx]*100) + " = " + class_name + " Precision "
  520. else:
  521. F1_text = "0.00" + " = " + class_name + " F1 "
  522. Recall_text = "0.00%" + " = " + class_name + " Recall "
  523. Precision_text = "0.00%" + " = " + class_name + " Precision "
  524. rounded_prec = [ '%.2f' % elem for elem in prec ]
  525. rounded_rec = [ '%.2f' % elem for elem in rec ]
  526. results_file.write(text + "\n Precision: " + str(rounded_prec) + "\n Recall :" + str(rounded_rec) + "\n\n")
  527. if len(prec)>0:
  528. print(text + "\t||\tscore_threhold=0.5 : " + "F1=" + "{0:.2f}".format(F1[score05_idx])\
  529. + " ; Recall=" + "{0:.2f}%".format(rec[score05_idx]*100) + " ; Precision=" + "{0:.2f}%".format(prec[score05_idx]*100))
  530. else:
  531. print(text + "\t||\tscore_threhold=0.5 : F1=0.00% ; Recall=0.00% ; Precision=0.00%")
  532. ap_dictionary[class_name] = ap
  533. n_images = counter_images_per_class[class_name]
  534. lamr, mr, fppi = log_average_miss_rate(np.array(rec), np.array(fp), n_images)
  535. lamr_dictionary[class_name] = lamr
  536. if draw_plot:
  537. plt.plot(rec, prec, '-o')
  538. area_under_curve_x = mrec[:-1] + [mrec[-2]] + [mrec[-1]]
  539. area_under_curve_y = mprec[:-1] + [0.0] + [mprec[-1]]
  540. plt.fill_between(area_under_curve_x, 0, area_under_curve_y, alpha=0.2, edgecolor='r')
  541. fig = plt.gcf()
  542. fig.canvas.set_window_title('AP ' + class_name)
  543. plt.title('class: ' + text)
  544. plt.xlabel('Recall')
  545. plt.ylabel('Precision')
  546. axes = plt.gca()
  547. axes.set_xlim([0.0,1.0])
  548. axes.set_ylim([0.0,1.05])
  549. fig.savefig(RESULTS_FILES_PATH + "/AP/" + class_name + ".png")
  550. plt.cla()
  551. plt.plot(score, F1, "-", color='orangered')
  552. plt.title('class: ' + F1_text + "\nscore_threhold=0.5")
  553. plt.xlabel('Score_Threhold')
  554. plt.ylabel('F1')
  555. axes = plt.gca()
  556. axes.set_xlim([0.0,1.0])
  557. axes.set_ylim([0.0,1.05])
  558. fig.savefig(RESULTS_FILES_PATH + "/F1/" + class_name + ".png")
  559. plt.cla()
  560. plt.plot(score, rec, "-H", color='gold')
  561. plt.title('class: ' + Recall_text + "\nscore_threhold=0.5")
  562. plt.xlabel('Score_Threhold')
  563. plt.ylabel('Recall')
  564. axes = plt.gca()
  565. axes.set_xlim([0.0,1.0])
  566. axes.set_ylim([0.0,1.05])
  567. fig.savefig(RESULTS_FILES_PATH + "/Recall/" + class_name + ".png")
  568. plt.cla()
  569. plt.plot(score, prec, "-s", color='palevioletred')
  570. plt.title('class: ' + Precision_text + "\nscore_threhold=0.5")
  571. plt.xlabel('Score_Threhold')
  572. plt.ylabel('Precision')
  573. axes = plt.gca()
  574. axes.set_xlim([0.0,1.0])
  575. axes.set_ylim([0.0,1.05])
  576. fig.savefig(RESULTS_FILES_PATH + "/Precision/" + class_name + ".png")
  577. plt.cla()
  578. if show_animation:
  579. cv2.destroyAllWindows()
  580. results_file.write("\n# mAP of all classes\n")
  581. mAP = sum_AP / n_classes
  582. text = "mAP = {0:.2f}%".format(mAP*100)
  583. results_file.write(text + "\n")
  584. print(text)
  585. shutil.rmtree(TEMP_FILES_PATH)
  586. """
  587. Count total of detection-results
  588. """
  589. det_counter_per_class = {}
  590. for txt_file in dr_files_list:
  591. lines_list = file_lines_to_list(txt_file)
  592. for line in lines_list:
  593. class_name = line.split()[0]
  594. if class_name in det_counter_per_class:
  595. det_counter_per_class[class_name] += 1
  596. else:
  597. det_counter_per_class[class_name] = 1
  598. dr_classes = list(det_counter_per_class.keys())
  599. """
  600. Write number of ground-truth objects per class to results.txt
  601. """
  602. with open(RESULTS_FILES_PATH + "/results.txt", 'a') as results_file:
  603. results_file.write("\n# Number of ground-truth objects per class\n")
  604. for class_name in sorted(gt_counter_per_class):
  605. results_file.write(class_name + ": " + str(gt_counter_per_class[class_name]) + "\n")
  606. """
  607. Finish counting true positives
  608. """
  609. for class_name in dr_classes:
  610. if class_name not in gt_classes:
  611. count_true_positives[class_name] = 0
  612. """
  613. Write number of detected objects per class to results.txt
  614. """
  615. with open(RESULTS_FILES_PATH + "/results.txt", 'a') as results_file:
  616. results_file.write("\n# Number of detected objects per class\n")
  617. for class_name in sorted(dr_classes):
  618. n_det = det_counter_per_class[class_name]
  619. text = class_name + ": " + str(n_det)
  620. text += " (tp:" + str(count_true_positives[class_name]) + ""
  621. text += ", fp:" + str(n_det - count_true_positives[class_name]) + ")\n"
  622. results_file.write(text)
  623. """
  624. Plot the total number of occurences of each class in the ground-truth
  625. """
  626. if draw_plot:
  627. window_title = "ground-truth-info"
  628. plot_title = "ground-truth\n"
  629. plot_title += "(" + str(len(ground_truth_files_list)) + " files and " + str(n_classes) + " classes)"
  630. x_label = "Number of objects per class"
  631. output_path = RESULTS_FILES_PATH + "/ground-truth-info.png"
  632. to_show = False
  633. plot_color = 'forestgreen'
  634. draw_plot_func(
  635. gt_counter_per_class,
  636. n_classes,
  637. window_title,
  638. plot_title,
  639. x_label,
  640. output_path,
  641. to_show,
  642. plot_color,
  643. '',
  644. )
  645. # """
  646. # Plot the total number of occurences of each class in the "detection-results" folder
  647. # """
  648. # if draw_plot:
  649. # window_title = "detection-results-info"
  650. # # Plot title
  651. # plot_title = "detection-results\n"
  652. # plot_title += "(" + str(len(dr_files_list)) + " files and "
  653. # count_non_zero_values_in_dictionary = sum(int(x) > 0 for x in list(det_counter_per_class.values()))
  654. # plot_title += str(count_non_zero_values_in_dictionary) + " detected classes)"
  655. # # end Plot title
  656. # x_label = "Number of objects per class"
  657. # output_path = RESULTS_FILES_PATH + "/detection-results-info.png"
  658. # to_show = False
  659. # plot_color = 'forestgreen'
  660. # true_p_bar = count_true_positives
  661. # draw_plot_func(
  662. # det_counter_per_class,
  663. # len(det_counter_per_class),
  664. # window_title,
  665. # plot_title,
  666. # x_label,
  667. # output_path,
  668. # to_show,
  669. # plot_color,
  670. # true_p_bar
  671. # )
  672. """
  673. Draw log-average miss rate plot (Show lamr of all classes in decreasing order)
  674. """
  675. if draw_plot:
  676. window_title = "lamr"
  677. plot_title = "log-average miss rate"
  678. x_label = "log-average miss rate"
  679. output_path = RESULTS_FILES_PATH + "/lamr.png"
  680. to_show = False
  681. plot_color = 'royalblue'
  682. draw_plot_func(
  683. lamr_dictionary,
  684. n_classes,
  685. window_title,
  686. plot_title,
  687. x_label,
  688. output_path,
  689. to_show,
  690. plot_color,
  691. ""
  692. )
  693. """
  694. Draw mAP plot (Show AP's of all classes in decreasing order)
  695. """
  696. if draw_plot:
  697. window_title = "mAP"
  698. plot_title = "mAP = {0:.2f}%".format(mAP*100)
  699. x_label = "Average Precision"
  700. output_path = RESULTS_FILES_PATH + "/mAP.png"
  701. to_show = True
  702. plot_color = 'royalblue'
  703. draw_plot_func(
  704. ap_dictionary,
  705. n_classes,
  706. window_title,
  707. plot_title,
  708. x_label,
  709. output_path,
  710. to_show,
  711. plot_color,
  712. ""
  713. )
  714. def preprocess_gt(gt_path, class_names):
  715. image_ids = os.listdir(gt_path)
  716. results = {}
  717. images = []
  718. bboxes = []
  719. for i, image_id in enumerate(image_ids):
  720. lines_list = file_lines_to_list(os.path.join(gt_path, image_id))
  721. boxes_per_image = []
  722. image = {}
  723. image_id = os.path.splitext(image_id)[0]
  724. image['file_name'] = image_id + '.jpg'
  725. image['width'] = 1
  726. image['height'] = 1
  727. #-----------------------------------------------------------------#
  728. # 感谢 多学学英语吧 的提醒
  729. # 解决了'Results do not correspond to current coco set'问题
  730. #-----------------------------------------------------------------#
  731. image['id'] = str(image_id)
  732. for line in lines_list:
  733. difficult = 0
  734. if "difficult" in line:
  735. line_split = line.split()
  736. left, top, right, bottom, _difficult = line_split[-5:]
  737. class_name = ""
  738. for name in line_split[:-5]:
  739. class_name += name + " "
  740. class_name = class_name[:-1]
  741. difficult = 1
  742. else:
  743. line_split = line.split()
  744. left, top, right, bottom = line_split[-4:]
  745. class_name = ""
  746. for name in line_split[:-4]:
  747. class_name += name + " "
  748. class_name = class_name[:-1]
  749. left, top, right, bottom = float(left), float(top), float(right), float(bottom)
  750. cls_id = class_names.index(class_name) + 1
  751. bbox = [left, top, right - left, bottom - top, difficult, str(image_id), cls_id, (right - left) * (bottom - top) - 10.0]
  752. boxes_per_image.append(bbox)
  753. images.append(image)
  754. bboxes.extend(boxes_per_image)
  755. results['images'] = images
  756. categories = []
  757. for i, cls in enumerate(class_names):
  758. category = {}
  759. category['supercategory'] = cls
  760. category['name'] = cls
  761. category['id'] = i + 1
  762. categories.append(category)
  763. results['categories'] = categories
  764. annotations = []
  765. for i, box in enumerate(bboxes):
  766. annotation = {}
  767. annotation['area'] = box[-1]
  768. annotation['category_id'] = box[-2]
  769. annotation['image_id'] = box[-3]
  770. annotation['iscrowd'] = box[-4]
  771. annotation['bbox'] = box[:4]
  772. annotation['id'] = i
  773. annotations.append(annotation)
  774. results['annotations'] = annotations
  775. return results
  776. def preprocess_dr(dr_path, class_names):
  777. image_ids = os.listdir(dr_path)
  778. results = []
  779. for image_id in image_ids:
  780. lines_list = file_lines_to_list(os.path.join(dr_path, image_id))
  781. image_id = os.path.splitext(image_id)[0]
  782. for line in lines_list:
  783. line_split = line.split()
  784. confidence, left, top, right, bottom = line_split[-5:]
  785. class_name = ""
  786. for name in line_split[:-5]:
  787. class_name += name + " "
  788. class_name = class_name[:-1]
  789. left, top, right, bottom = float(left), float(top), float(right), float(bottom)
  790. result = {}
  791. result["image_id"] = str(image_id)
  792. result["category_id"] = class_names.index(class_name) + 1
  793. result["bbox"] = [left, top, right - left, bottom - top]
  794. result["score"] = float(confidence)
  795. results.append(result)
  796. return results
  797. def get_coco_map(class_names, path):
  798. from pycocotools.coco import COCO
  799. from pycocotools.cocoeval import COCOeval
  800. GT_PATH = os.path.join(path, 'ground-truth')
  801. DR_PATH = os.path.join(path, 'detection-results')
  802. COCO_PATH = os.path.join(path, 'coco_eval')
  803. if not os.path.exists(COCO_PATH):
  804. os.makedirs(COCO_PATH)
  805. GT_JSON_PATH = os.path.join(COCO_PATH, 'instances_gt.json')
  806. DR_JSON_PATH = os.path.join(COCO_PATH, 'instances_dr.json')
  807. with open(GT_JSON_PATH, "w") as f:
  808. results_gt = preprocess_gt(GT_PATH, class_names)
  809. json.dump(results_gt, f, indent=4)
  810. with open(DR_JSON_PATH, "w") as f:
  811. results_dr = preprocess_dr(DR_PATH, class_names)
  812. json.dump(results_dr, f, indent=4)
  813. cocoGt = COCO(GT_JSON_PATH)
  814. cocoDt = cocoGt.loadRes(DR_JSON_PATH)
  815. cocoEval = COCOeval(cocoGt, cocoDt, 'bbox')
  816. cocoEval.evaluate()
  817. cocoEval.accumulate()
  818. cocoEval.summarize()