123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237 |
- import os
- import matplotlib
- import torch
- matplotlib.use('Agg')
- from matplotlib import pyplot as plt
- import scipy.signal
- import shutil
- import numpy as np
- from PIL import Image
- from torch.utils.tensorboard import SummaryWriter
- from tqdm import tqdm
- from .utils import cvtColor, resize_image, preprocess_input, get_new_img_size
- from .utils_bbox import DecodeBox
- from .utils_map import get_coco_map, get_map
- class LossHistory():
- def __init__(self, log_dir, model, input_shape):
- self.log_dir = log_dir
- self.losses = []
- self.val_loss = []
-
- os.makedirs(self.log_dir)
- self.writer = SummaryWriter(self.log_dir)
- # try:
- # dummy_input = torch.randn(2, 3, input_shape[0], input_shape[1])
- # self.writer.add_graph(model, dummy_input)
- # except:
- # pass
- def append_loss(self, epoch, loss, val_loss):
- if not os.path.exists(self.log_dir):
- os.makedirs(self.log_dir)
- self.losses.append(loss)
- self.val_loss.append(val_loss)
- with open(os.path.join(self.log_dir, "epoch_loss.txt"), 'a') as f:
- f.write(str(loss))
- f.write("\n")
- with open(os.path.join(self.log_dir, "epoch_val_loss.txt"), 'a') as f:
- f.write(str(val_loss))
- f.write("\n")
- self.writer.add_scalar('loss', loss, epoch)
- self.writer.add_scalar('val_loss', val_loss, epoch)
- self.loss_plot()
- def loss_plot(self):
- iters = range(len(self.losses))
- plt.figure()
- plt.plot(iters, self.losses, 'red', linewidth = 2, label='train loss')
- plt.plot(iters, self.val_loss, 'coral', linewidth = 2, label='val loss')
- try:
- if len(self.losses) < 25:
- num = 5
- else:
- num = 15
-
- plt.plot(iters, scipy.signal.savgol_filter(self.losses, num, 3), 'green', linestyle = '--', linewidth = 2, label='smooth train loss')
- plt.plot(iters, scipy.signal.savgol_filter(self.val_loss, num, 3), '#8B4513', linestyle = '--', linewidth = 2, label='smooth val loss')
- except:
- pass
- plt.grid(True)
- plt.xlabel('Epoch')
- plt.ylabel('Loss')
- plt.legend(loc="upper right")
- plt.savefig(os.path.join(self.log_dir, "epoch_loss.png"))
- plt.cla()
- plt.close("all")
- class EvalCallback():
- def __init__(self, net, input_shape, class_names, num_classes, val_lines, log_dir, cuda, \
- map_out_path=".temp_map_out", max_boxes=100, confidence=0.05, nms_iou=0.5, letterbox_image=True, MINOVERLAP=0.5, eval_flag=True, period=1):
- super(EvalCallback, self).__init__()
-
- self.net = net
- self.input_shape = input_shape
- self.class_names = class_names
- self.num_classes = num_classes
- self.val_lines = val_lines
- self.log_dir = log_dir
- self.cuda = cuda
- self.map_out_path = map_out_path
- self.max_boxes = max_boxes
- self.confidence = confidence
- self.nms_iou = nms_iou
- self.letterbox_image = letterbox_image
- self.MINOVERLAP = MINOVERLAP
- self.eval_flag = eval_flag
- self.period = period
-
- self.std = torch.Tensor([0.1, 0.1, 0.2, 0.2]).repeat(self.num_classes + 1)[None]
- if self.cuda:
- self.std = self.std.cuda()
- self.bbox_util = DecodeBox(self.std, self.num_classes)
- self.maps = [0]
- self.epoches = [0]
- if self.eval_flag:
- with open(os.path.join(self.log_dir, "epoch_map.txt"), 'a') as f:
- f.write(str(0))
- f.write("\n")
- #---------------------------------------------------#
- # 检测图片
- #---------------------------------------------------#
- def get_map_txt(self, image_id, image, class_names, map_out_path):
- f = open(os.path.join(map_out_path, "detection-results/"+image_id+".txt"),"w")
- #---------------------------------------------------#
- # 计算输入图片的高和宽
- #---------------------------------------------------#
- image_shape = np.array(np.shape(image)[0:2])
- input_shape = get_new_img_size(image_shape[0], image_shape[1])
- #---------------------------------------------------------#
- # 在这里将图像转换成RGB图像,防止灰度图在预测时报错。
- # 代码仅仅支持RGB图像的预测,所有其它类型的图像都会转化成RGB
- #---------------------------------------------------------#
- image = cvtColor(image)
-
- #---------------------------------------------------------#
- # 给原图像进行resize,resize到短边为600的大小上
- #---------------------------------------------------------#
- image_data = resize_image(image, [input_shape[1], input_shape[0]])
- #---------------------------------------------------------#
- # 添加上batch_size维度
- #---------------------------------------------------------#
- image_data = np.expand_dims(np.transpose(preprocess_input(np.array(image_data, dtype='float32')), (2, 0, 1)), 0)
- with torch.no_grad():
- images = torch.from_numpy(image_data)
- if self.cuda:
- images = images.cuda()
- roi_cls_locs, roi_scores, rois, _ = self.net(images)
- #-------------------------------------------------------------#
- # 利用classifier的预测结果对建议框进行解码,获得预测框
- #-------------------------------------------------------------#
- results = self.bbox_util.forward(roi_cls_locs, roi_scores, rois, image_shape, input_shape,
- nms_iou = self.nms_iou, confidence = self.confidence)
- #--------------------------------------#
- # 如果没有检测到物体,则返回原图
- #--------------------------------------#
- if len(results[0]) <= 0:
- return
- top_label = np.array(results[0][:, 5], dtype = 'int32')
- top_conf = results[0][:, 4]
- top_boxes = results[0][:, :4]
- top_100 = np.argsort(top_conf)[::-1][:self.max_boxes]
- top_boxes = top_boxes[top_100]
- top_conf = top_conf[top_100]
- top_label = top_label[top_100]
- for i, c in list(enumerate(top_label)):
- predicted_class = self.class_names[int(c)]
- box = top_boxes[i]
- score = str(top_conf[i])
- top, left, bottom, right = box
- if predicted_class not in class_names:
- continue
- f.write("%s %s %s %s %s %s\n" % (predicted_class, score[:6], str(int(left)), str(int(top)), str(int(right)),str(int(bottom))))
- f.close()
- return
-
- def on_epoch_end(self, epoch):
- if epoch % self.period == 0 and self.eval_flag:
- if not os.path.exists(self.map_out_path):
- os.makedirs(self.map_out_path)
- if not os.path.exists(os.path.join(self.map_out_path, "ground-truth")):
- os.makedirs(os.path.join(self.map_out_path, "ground-truth"))
- if not os.path.exists(os.path.join(self.map_out_path, "detection-results")):
- os.makedirs(os.path.join(self.map_out_path, "detection-results"))
- print("Get map.")
- for annotation_line in tqdm(self.val_lines):
- line = annotation_line.split()
- image_id = os.path.basename(line[0]).split('.')[0]
- #------------------------------#
- # 读取图像并转换成RGB图像
- #------------------------------#
- image = Image.open(line[0])
- #------------------------------#
- # 获得预测框
- #------------------------------#
- gt_boxes = np.array([np.array(list(map(int,box.split(',')))) for box in line[1:]])
- #------------------------------#
- # 获得预测txt
- #------------------------------#
- self.get_map_txt(image_id, image, self.class_names, self.map_out_path)
-
- #------------------------------#
- # 获得真实框txt
- #------------------------------#
- with open(os.path.join(self.map_out_path, "ground-truth/"+image_id+".txt"), "w") as new_f:
- for box in gt_boxes:
- left, top, right, bottom, obj = box
- obj_name = self.class_names[obj]
- new_f.write("%s %s %s %s %s\n" % (obj_name, left, top, right, bottom))
-
- print("Calculate Map.")
- try:
- temp_map = get_coco_map(class_names = self.class_names, path = self.map_out_path)[1]
- except:
- temp_map = get_map(self.MINOVERLAP, False, path = self.map_out_path)
- self.maps.append(temp_map)
- self.epoches.append(epoch)
- with open(os.path.join(self.log_dir, "epoch_map.txt"), 'a') as f:
- f.write(str(temp_map))
- f.write("\n")
-
- plt.figure()
- plt.plot(self.epoches, self.maps, 'red', linewidth = 2, label='train map')
- plt.grid(True)
- plt.xlabel('Epoch')
- plt.ylabel('Map %s'%str(self.MINOVERLAP))
- plt.title('A Map Curve')
- plt.legend(loc="upper right")
- plt.savefig(os.path.join(self.log_dir, "epoch_map.png"))
- plt.cla()
- plt.close("all")
- print("Get map done.")
- shutil.rmtree(self.map_out_path)
|