callbacks.py 9.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237
  1. import os
  2. import matplotlib
  3. import torch
  4. matplotlib.use('Agg')
  5. from matplotlib import pyplot as plt
  6. import scipy.signal
  7. import shutil
  8. import numpy as np
  9. from PIL import Image
  10. from torch.utils.tensorboard import SummaryWriter
  11. from tqdm import tqdm
  12. from .utils import cvtColor, resize_image, preprocess_input, get_new_img_size
  13. from .utils_bbox import DecodeBox
  14. from .utils_map import get_coco_map, get_map
  15. class LossHistory():
  16. def __init__(self, log_dir, model, input_shape):
  17. self.log_dir = log_dir
  18. self.losses = []
  19. self.val_loss = []
  20. os.makedirs(self.log_dir)
  21. self.writer = SummaryWriter(self.log_dir)
  22. # try:
  23. # dummy_input = torch.randn(2, 3, input_shape[0], input_shape[1])
  24. # self.writer.add_graph(model, dummy_input)
  25. # except:
  26. # pass
  27. def append_loss(self, epoch, loss, val_loss):
  28. if not os.path.exists(self.log_dir):
  29. os.makedirs(self.log_dir)
  30. self.losses.append(loss)
  31. self.val_loss.append(val_loss)
  32. with open(os.path.join(self.log_dir, "epoch_loss.txt"), 'a') as f:
  33. f.write(str(loss))
  34. f.write("\n")
  35. with open(os.path.join(self.log_dir, "epoch_val_loss.txt"), 'a') as f:
  36. f.write(str(val_loss))
  37. f.write("\n")
  38. self.writer.add_scalar('loss', loss, epoch)
  39. self.writer.add_scalar('val_loss', val_loss, epoch)
  40. self.loss_plot()
  41. def loss_plot(self):
  42. iters = range(len(self.losses))
  43. plt.figure()
  44. plt.plot(iters, self.losses, 'red', linewidth = 2, label='train loss')
  45. plt.plot(iters, self.val_loss, 'coral', linewidth = 2, label='val loss')
  46. try:
  47. if len(self.losses) < 25:
  48. num = 5
  49. else:
  50. num = 15
  51. plt.plot(iters, scipy.signal.savgol_filter(self.losses, num, 3), 'green', linestyle = '--', linewidth = 2, label='smooth train loss')
  52. plt.plot(iters, scipy.signal.savgol_filter(self.val_loss, num, 3), '#8B4513', linestyle = '--', linewidth = 2, label='smooth val loss')
  53. except:
  54. pass
  55. plt.grid(True)
  56. plt.xlabel('Epoch')
  57. plt.ylabel('Loss')
  58. plt.legend(loc="upper right")
  59. plt.savefig(os.path.join(self.log_dir, "epoch_loss.png"))
  60. plt.cla()
  61. plt.close("all")
  62. class EvalCallback():
  63. def __init__(self, net, input_shape, class_names, num_classes, val_lines, log_dir, cuda, \
  64. map_out_path=".temp_map_out", max_boxes=100, confidence=0.05, nms_iou=0.5, letterbox_image=True, MINOVERLAP=0.5, eval_flag=True, period=1):
  65. super(EvalCallback, self).__init__()
  66. self.net = net
  67. self.input_shape = input_shape
  68. self.class_names = class_names
  69. self.num_classes = num_classes
  70. self.val_lines = val_lines
  71. self.log_dir = log_dir
  72. self.cuda = cuda
  73. self.map_out_path = map_out_path
  74. self.max_boxes = max_boxes
  75. self.confidence = confidence
  76. self.nms_iou = nms_iou
  77. self.letterbox_image = letterbox_image
  78. self.MINOVERLAP = MINOVERLAP
  79. self.eval_flag = eval_flag
  80. self.period = period
  81. self.std = torch.Tensor([0.1, 0.1, 0.2, 0.2]).repeat(self.num_classes + 1)[None]
  82. if self.cuda:
  83. self.std = self.std.cuda()
  84. self.bbox_util = DecodeBox(self.std, self.num_classes)
  85. self.maps = [0]
  86. self.epoches = [0]
  87. if self.eval_flag:
  88. with open(os.path.join(self.log_dir, "epoch_map.txt"), 'a') as f:
  89. f.write(str(0))
  90. f.write("\n")
  91. #---------------------------------------------------#
  92. # 检测图片
  93. #---------------------------------------------------#
  94. def get_map_txt(self, image_id, image, class_names, map_out_path):
  95. f = open(os.path.join(map_out_path, "detection-results/"+image_id+".txt"),"w")
  96. #---------------------------------------------------#
  97. # 计算输入图片的高和宽
  98. #---------------------------------------------------#
  99. image_shape = np.array(np.shape(image)[0:2])
  100. input_shape = get_new_img_size(image_shape[0], image_shape[1])
  101. #---------------------------------------------------------#
  102. # 在这里将图像转换成RGB图像,防止灰度图在预测时报错。
  103. # 代码仅仅支持RGB图像的预测,所有其它类型的图像都会转化成RGB
  104. #---------------------------------------------------------#
  105. image = cvtColor(image)
  106. #---------------------------------------------------------#
  107. # 给原图像进行resize,resize到短边为600的大小上
  108. #---------------------------------------------------------#
  109. image_data = resize_image(image, [input_shape[1], input_shape[0]])
  110. #---------------------------------------------------------#
  111. # 添加上batch_size维度
  112. #---------------------------------------------------------#
  113. image_data = np.expand_dims(np.transpose(preprocess_input(np.array(image_data, dtype='float32')), (2, 0, 1)), 0)
  114. with torch.no_grad():
  115. images = torch.from_numpy(image_data)
  116. if self.cuda:
  117. images = images.cuda()
  118. roi_cls_locs, roi_scores, rois, _ = self.net(images)
  119. #-------------------------------------------------------------#
  120. # 利用classifier的预测结果对建议框进行解码,获得预测框
  121. #-------------------------------------------------------------#
  122. results = self.bbox_util.forward(roi_cls_locs, roi_scores, rois, image_shape, input_shape,
  123. nms_iou = self.nms_iou, confidence = self.confidence)
  124. #--------------------------------------#
  125. # 如果没有检测到物体,则返回原图
  126. #--------------------------------------#
  127. if len(results[0]) <= 0:
  128. return
  129. top_label = np.array(results[0][:, 5], dtype = 'int32')
  130. top_conf = results[0][:, 4]
  131. top_boxes = results[0][:, :4]
  132. top_100 = np.argsort(top_conf)[::-1][:self.max_boxes]
  133. top_boxes = top_boxes[top_100]
  134. top_conf = top_conf[top_100]
  135. top_label = top_label[top_100]
  136. for i, c in list(enumerate(top_label)):
  137. predicted_class = self.class_names[int(c)]
  138. box = top_boxes[i]
  139. score = str(top_conf[i])
  140. top, left, bottom, right = box
  141. if predicted_class not in class_names:
  142. continue
  143. f.write("%s %s %s %s %s %s\n" % (predicted_class, score[:6], str(int(left)), str(int(top)), str(int(right)),str(int(bottom))))
  144. f.close()
  145. return
  146. def on_epoch_end(self, epoch):
  147. if epoch % self.period == 0 and self.eval_flag:
  148. if not os.path.exists(self.map_out_path):
  149. os.makedirs(self.map_out_path)
  150. if not os.path.exists(os.path.join(self.map_out_path, "ground-truth")):
  151. os.makedirs(os.path.join(self.map_out_path, "ground-truth"))
  152. if not os.path.exists(os.path.join(self.map_out_path, "detection-results")):
  153. os.makedirs(os.path.join(self.map_out_path, "detection-results"))
  154. print("Get map.")
  155. for annotation_line in tqdm(self.val_lines):
  156. line = annotation_line.split()
  157. image_id = os.path.basename(line[0]).split('.')[0]
  158. #------------------------------#
  159. # 读取图像并转换成RGB图像
  160. #------------------------------#
  161. image = Image.open(line[0])
  162. #------------------------------#
  163. # 获得预测框
  164. #------------------------------#
  165. gt_boxes = np.array([np.array(list(map(int,box.split(',')))) for box in line[1:]])
  166. #------------------------------#
  167. # 获得预测txt
  168. #------------------------------#
  169. self.get_map_txt(image_id, image, self.class_names, self.map_out_path)
  170. #------------------------------#
  171. # 获得真实框txt
  172. #------------------------------#
  173. with open(os.path.join(self.map_out_path, "ground-truth/"+image_id+".txt"), "w") as new_f:
  174. for box in gt_boxes:
  175. left, top, right, bottom, obj = box
  176. obj_name = self.class_names[obj]
  177. new_f.write("%s %s %s %s %s\n" % (obj_name, left, top, right, bottom))
  178. print("Calculate Map.")
  179. try:
  180. temp_map = get_coco_map(class_names = self.class_names, path = self.map_out_path)[1]
  181. except:
  182. temp_map = get_map(self.MINOVERLAP, False, path = self.map_out_path)
  183. self.maps.append(temp_map)
  184. self.epoches.append(epoch)
  185. with open(os.path.join(self.log_dir, "epoch_map.txt"), 'a') as f:
  186. f.write(str(temp_map))
  187. f.write("\n")
  188. plt.figure()
  189. plt.plot(self.epoches, self.maps, 'red', linewidth = 2, label='train map')
  190. plt.grid(True)
  191. plt.xlabel('Epoch')
  192. plt.ylabel('Map %s'%str(self.MINOVERLAP))
  193. plt.title('A Map Curve')
  194. plt.legend(loc="upper right")
  195. plt.savefig(os.path.join(self.log_dir, "epoch_map.png"))
  196. plt.cla()
  197. plt.close("all")
  198. print("Get map done.")
  199. shutil.rmtree(self.map_out_path)