frcnn_training.py 19 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396
  1. import math
  2. from functools import partial
  3. import numpy as np
  4. import torch
  5. import torch.nn as nn
  6. from torch.nn import functional as F
  7. def bbox_iou(bbox_a, bbox_b):
  8. if bbox_a.shape[1] != 4 or bbox_b.shape[1] != 4:
  9. print(bbox_a, bbox_b)
  10. raise IndexError
  11. tl = np.maximum(bbox_a[:, None, :2], bbox_b[:, :2])
  12. br = np.minimum(bbox_a[:, None, 2:], bbox_b[:, 2:])
  13. area_i = np.prod(br - tl, axis=2) * (tl < br).all(axis=2)
  14. area_a = np.prod(bbox_a[:, 2:] - bbox_a[:, :2], axis=1)
  15. area_b = np.prod(bbox_b[:, 2:] - bbox_b[:, :2], axis=1)
  16. return area_i / (area_a[:, None] + area_b - area_i)
  17. def bbox2loc(src_bbox, dst_bbox):
  18. width = src_bbox[:, 2] - src_bbox[:, 0]
  19. height = src_bbox[:, 3] - src_bbox[:, 1]
  20. ctr_x = src_bbox[:, 0] + 0.5 * width
  21. ctr_y = src_bbox[:, 1] + 0.5 * height
  22. base_width = dst_bbox[:, 2] - dst_bbox[:, 0]
  23. base_height = dst_bbox[:, 3] - dst_bbox[:, 1]
  24. base_ctr_x = dst_bbox[:, 0] + 0.5 * base_width
  25. base_ctr_y = dst_bbox[:, 1] + 0.5 * base_height
  26. eps = np.finfo(height.dtype).eps
  27. width = np.maximum(width, eps)
  28. height = np.maximum(height, eps)
  29. dx = (base_ctr_x - ctr_x) / width
  30. dy = (base_ctr_y - ctr_y) / height
  31. dw = np.log(base_width / width)
  32. dh = np.log(base_height / height)
  33. loc = np.vstack((dx, dy, dw, dh)).transpose()
  34. return loc
  35. class AnchorTargetCreator(object):
  36. def __init__(self, n_sample=256, pos_iou_thresh=0.7, neg_iou_thresh=0.3, pos_ratio=0.5):
  37. self.n_sample = n_sample
  38. self.pos_iou_thresh = pos_iou_thresh
  39. self.neg_iou_thresh = neg_iou_thresh
  40. self.pos_ratio = pos_ratio
  41. def __call__(self, bbox, anchor):
  42. argmax_ious, label = self._create_label(anchor, bbox)
  43. if (label > 0).any():
  44. loc = bbox2loc(anchor, bbox[argmax_ious])
  45. return loc, label
  46. else:
  47. return np.zeros_like(anchor), label
  48. def _calc_ious(self, anchor, bbox):
  49. #----------------------------------------------#
  50. # anchor和bbox的iou
  51. # 获得的ious的shape为[num_anchors, num_gt]
  52. #----------------------------------------------#
  53. ious = bbox_iou(anchor, bbox)
  54. if len(bbox)==0:
  55. return np.zeros(len(anchor), np.int32), np.zeros(len(anchor)), np.zeros(len(bbox))
  56. #---------------------------------------------------------#
  57. # 获得每一个先验框最对应的真实框 [num_anchors, ]
  58. #---------------------------------------------------------#
  59. argmax_ious = ious.argmax(axis=1)
  60. #---------------------------------------------------------#
  61. # 找出每一个先验框最对应的真实框的iou [num_anchors, ]
  62. #---------------------------------------------------------#
  63. max_ious = np.max(ious, axis=1)
  64. #---------------------------------------------------------#
  65. # 获得每一个真实框最对应的先验框 [num_gt, ]
  66. #---------------------------------------------------------#
  67. gt_argmax_ious = ious.argmax(axis=0)
  68. #---------------------------------------------------------#
  69. # 保证每一个真实框都存在对应的先验框
  70. #---------------------------------------------------------#
  71. for i in range(len(gt_argmax_ious)):
  72. argmax_ious[gt_argmax_ious[i]] = i
  73. return argmax_ious, max_ious, gt_argmax_ious
  74. def _create_label(self, anchor, bbox):
  75. # ------------------------------------------ #
  76. # 1是正样本,0是负样本,-1忽略
  77. # 初始化的时候全部设置为-1
  78. # ------------------------------------------ #
  79. label = np.empty((len(anchor),), dtype=np.int32)
  80. label.fill(-1)
  81. # ------------------------------------------------------------------------ #
  82. # argmax_ious为每个先验框对应的最大的真实框的序号 [num_anchors, ]
  83. # max_ious为每个真实框对应的最大的真实框的iou [num_anchors, ]
  84. # gt_argmax_ious为每一个真实框对应的最大的先验框的序号 [num_gt, ]
  85. # ------------------------------------------------------------------------ #
  86. argmax_ious, max_ious, gt_argmax_ious = self._calc_ious(anchor, bbox)
  87. # ----------------------------------------------------- #
  88. # 如果小于门限值则设置为负样本
  89. # 如果大于门限值则设置为正样本
  90. # 每个真实框至少对应一个先验框
  91. # ----------------------------------------------------- #
  92. label[max_ious < self.neg_iou_thresh] = 0
  93. label[max_ious >= self.pos_iou_thresh] = 1
  94. if len(gt_argmax_ious)>0:
  95. label[gt_argmax_ious] = 1
  96. # ----------------------------------------------------- #
  97. # 判断正样本数量是否大于128,如果大于则限制在128
  98. # ----------------------------------------------------- #
  99. n_pos = int(self.pos_ratio * self.n_sample)
  100. pos_index = np.where(label == 1)[0]
  101. if len(pos_index) > n_pos:
  102. disable_index = np.random.choice(pos_index, size=(len(pos_index) - n_pos), replace=False)
  103. label[disable_index] = -1
  104. # ----------------------------------------------------- #
  105. # 平衡正负样本,保持总数量为256
  106. # ----------------------------------------------------- #
  107. n_neg = self.n_sample - np.sum(label == 1)
  108. neg_index = np.where(label == 0)[0]
  109. if len(neg_index) > n_neg:
  110. disable_index = np.random.choice(neg_index, size=(len(neg_index) - n_neg), replace=False)
  111. label[disable_index] = -1
  112. return argmax_ious, label
  113. class ProposalTargetCreator(object):
  114. def __init__(self, n_sample=128, pos_ratio=0.5, pos_iou_thresh=0.5, neg_iou_thresh_high=0.5, neg_iou_thresh_low=0):
  115. self.n_sample = n_sample
  116. self.pos_ratio = pos_ratio
  117. self.pos_roi_per_image = np.round(self.n_sample * self.pos_ratio)
  118. self.pos_iou_thresh = pos_iou_thresh
  119. self.neg_iou_thresh_high = neg_iou_thresh_high
  120. self.neg_iou_thresh_low = neg_iou_thresh_low
  121. def __call__(self, roi, bbox, label, loc_normalize_std=(0.1, 0.1, 0.2, 0.2)):
  122. roi = np.concatenate((roi.detach().cpu().numpy(), bbox), axis=0)
  123. # ----------------------------------------------------- #
  124. # 计算建议框和真实框的重合程度
  125. # ----------------------------------------------------- #
  126. iou = bbox_iou(roi, bbox)
  127. if len(bbox)==0:
  128. gt_assignment = np.zeros(len(roi), np.int32)
  129. max_iou = np.zeros(len(roi))
  130. gt_roi_label = np.zeros(len(roi))
  131. else:
  132. #---------------------------------------------------------#
  133. # 获得每一个建议框最对应的真实框 [num_roi, ]
  134. #---------------------------------------------------------#
  135. gt_assignment = iou.argmax(axis=1)
  136. #---------------------------------------------------------#
  137. # 获得每一个建议框最对应的真实框的iou [num_roi, ]
  138. #---------------------------------------------------------#
  139. max_iou = iou.max(axis=1)
  140. #---------------------------------------------------------#
  141. # 真实框的标签要+1因为有背景的存在
  142. #---------------------------------------------------------#
  143. gt_roi_label = label[gt_assignment] + 1
  144. #----------------------------------------------------------------#
  145. # 满足建议框和真实框重合程度大于neg_iou_thresh_high的作为负样本
  146. # 将正样本的数量限制在self.pos_roi_per_image以内
  147. #----------------------------------------------------------------#
  148. pos_index = np.where(max_iou >= self.pos_iou_thresh)[0]
  149. pos_roi_per_this_image = int(min(self.pos_roi_per_image, pos_index.size))
  150. if pos_index.size > 0:
  151. pos_index = np.random.choice(pos_index, size=pos_roi_per_this_image, replace=False)
  152. #-----------------------------------------------------------------------------------------------------#
  153. # 满足建议框和真实框重合程度小于neg_iou_thresh_high大于neg_iou_thresh_low作为负样本
  154. # 将正样本的数量和负样本的数量的总和固定成self.n_sample
  155. #-----------------------------------------------------------------------------------------------------#
  156. neg_index = np.where((max_iou < self.neg_iou_thresh_high) & (max_iou >= self.neg_iou_thresh_low))[0]
  157. neg_roi_per_this_image = self.n_sample - pos_roi_per_this_image
  158. neg_roi_per_this_image = int(min(neg_roi_per_this_image, neg_index.size))
  159. if neg_index.size > 0:
  160. neg_index = np.random.choice(neg_index, size=neg_roi_per_this_image, replace=False)
  161. #---------------------------------------------------------#
  162. # sample_roi [n_sample, ]
  163. # gt_roi_loc [n_sample, 4]
  164. # gt_roi_label [n_sample, ]
  165. #---------------------------------------------------------#
  166. keep_index = np.append(pos_index, neg_index)
  167. sample_roi = roi[keep_index]
  168. if len(bbox)==0:
  169. return sample_roi, np.zeros_like(sample_roi), gt_roi_label[keep_index]
  170. gt_roi_loc = bbox2loc(sample_roi, bbox[gt_assignment[keep_index]])
  171. gt_roi_loc = (gt_roi_loc / np.array(loc_normalize_std, np.float32))
  172. gt_roi_label = gt_roi_label[keep_index]
  173. gt_roi_label[pos_roi_per_this_image:] = 0
  174. return sample_roi, gt_roi_loc, gt_roi_label
  175. class FasterRCNNTrainer(nn.Module):
  176. def __init__(self, model_train, optimizer):
  177. super(FasterRCNNTrainer, self).__init__()
  178. self.model_train = model_train
  179. self.optimizer = optimizer
  180. self.rpn_sigma = 1
  181. self.roi_sigma = 1
  182. self.anchor_target_creator = AnchorTargetCreator()
  183. self.proposal_target_creator = ProposalTargetCreator()
  184. self.loc_normalize_std = [0.1, 0.1, 0.2, 0.2]
  185. def _fast_rcnn_loc_loss(self, pred_loc, gt_loc, gt_label, sigma):
  186. pred_loc = pred_loc[gt_label > 0]
  187. gt_loc = gt_loc[gt_label > 0]
  188. sigma_squared = sigma ** 2
  189. regression_diff = (gt_loc - pred_loc)
  190. regression_diff = regression_diff.abs().float()
  191. regression_loss = torch.where(
  192. regression_diff < (1. / sigma_squared),
  193. 0.5 * sigma_squared * regression_diff ** 2,
  194. regression_diff - 0.5 / sigma_squared
  195. )
  196. regression_loss = regression_loss.sum()
  197. num_pos = (gt_label > 0).sum().float()
  198. regression_loss /= torch.max(num_pos, torch.ones_like(num_pos))
  199. return regression_loss
  200. def forward(self, imgs, bboxes, labels, scale):
  201. n = imgs.shape[0]
  202. img_size = imgs.shape[2:]
  203. #-------------------------------#
  204. # 获取公用特征层
  205. #-------------------------------#
  206. base_feature = self.model_train(imgs, mode = 'extractor')
  207. # -------------------------------------------------- #
  208. # 利用rpn网络获得调整参数、得分、建议框、先验框
  209. # -------------------------------------------------- #
  210. rpn_locs, rpn_scores, rois, roi_indices, anchor = self.model_train(x = [base_feature, img_size], scale = scale, mode = 'rpn')
  211. rpn_loc_loss_all, rpn_cls_loss_all, roi_loc_loss_all, roi_cls_loss_all = 0, 0, 0, 0
  212. sample_rois, sample_indexes, gt_roi_locs, gt_roi_labels = [], [], [], []
  213. for i in range(n):
  214. bbox = bboxes[i]
  215. label = labels[i]
  216. rpn_loc = rpn_locs[i]
  217. rpn_score = rpn_scores[i]
  218. roi = rois[i]
  219. # -------------------------------------------------- #
  220. # 利用真实框和先验框获得建议框网络应该有的预测结果
  221. # 给每个先验框都打上标签
  222. # gt_rpn_loc [num_anchors, 4]
  223. # gt_rpn_label [num_anchors, ]
  224. # -------------------------------------------------- #
  225. gt_rpn_loc, gt_rpn_label = self.anchor_target_creator(bbox, anchor[0].cpu().numpy())
  226. gt_rpn_loc = torch.Tensor(gt_rpn_loc).type_as(rpn_locs)
  227. gt_rpn_label = torch.Tensor(gt_rpn_label).type_as(rpn_locs).long()
  228. # -------------------------------------------------- #
  229. # 分别计算建议框网络的回归损失和分类损失
  230. # -------------------------------------------------- #
  231. rpn_loc_loss = self._fast_rcnn_loc_loss(rpn_loc, gt_rpn_loc, gt_rpn_label, self.rpn_sigma)
  232. rpn_cls_loss = F.cross_entropy(rpn_score, gt_rpn_label, ignore_index=-1)
  233. rpn_loc_loss_all += rpn_loc_loss
  234. rpn_cls_loss_all += rpn_cls_loss
  235. # ------------------------------------------------------ #
  236. # 利用真实框和建议框获得classifier网络应该有的预测结果
  237. # 获得三个变量,分别是sample_roi, gt_roi_loc, gt_roi_label
  238. # sample_roi [n_sample, ]
  239. # gt_roi_loc [n_sample, 4]
  240. # gt_roi_label [n_sample, ]
  241. # ------------------------------------------------------ #
  242. sample_roi, gt_roi_loc, gt_roi_label = self.proposal_target_creator(roi, bbox, label, self.loc_normalize_std)
  243. sample_rois.append(torch.Tensor(sample_roi).type_as(rpn_locs))
  244. sample_indexes.append(torch.ones(len(sample_roi)).type_as(rpn_locs) * roi_indices[i][0])
  245. gt_roi_locs.append(torch.Tensor(gt_roi_loc).type_as(rpn_locs))
  246. gt_roi_labels.append(torch.Tensor(gt_roi_label).type_as(rpn_locs).long())
  247. sample_rois = torch.stack(sample_rois, dim=0)
  248. sample_indexes = torch.stack(sample_indexes, dim=0)
  249. roi_cls_locs, roi_scores = self.model_train([base_feature, sample_rois, sample_indexes, img_size], mode = 'head')
  250. for i in range(n):
  251. # ------------------------------------------------------ #
  252. # 根据建议框的种类,取出对应的回归预测结果
  253. # ------------------------------------------------------ #
  254. n_sample = roi_cls_locs.size()[1]
  255. roi_cls_loc = roi_cls_locs[i]
  256. roi_score = roi_scores[i]
  257. gt_roi_loc = gt_roi_locs[i]
  258. gt_roi_label = gt_roi_labels[i]
  259. roi_cls_loc = roi_cls_loc.view(n_sample, -1, 4)
  260. roi_loc = roi_cls_loc[torch.arange(0, n_sample), gt_roi_label]
  261. # -------------------------------------------------- #
  262. # 分别计算Classifier网络的回归损失和分类损失
  263. # -------------------------------------------------- #
  264. roi_loc_loss = self._fast_rcnn_loc_loss(roi_loc, gt_roi_loc, gt_roi_label.data, self.roi_sigma)
  265. roi_cls_loss = nn.CrossEntropyLoss()(roi_score, gt_roi_label)
  266. roi_loc_loss_all += roi_loc_loss
  267. roi_cls_loss_all += roi_cls_loss
  268. losses = [rpn_loc_loss_all/n, rpn_cls_loss_all/n, roi_loc_loss_all/n, roi_cls_loss_all/n]
  269. losses = losses + [sum(losses)]
  270. return losses
  271. def train_step(self, encoder, imgs, bboxes, labels, scale, fp16=False, scaler=None):
  272. self.optimizer.zero_grad()
  273. embed_loss = encoder.get_embeder_loss()
  274. if not fp16:
  275. losses = self.forward(imgs, bboxes, labels, scale)
  276. losses[-1] += embed_loss
  277. losses[-1].backward()
  278. self.optimizer.step()
  279. else:
  280. from torch.cuda.amp import autocast
  281. with autocast():
  282. losses = self.forward(imgs, bboxes, labels, scale)
  283. losses[-1] += embed_loss
  284. #----------------------#
  285. # 反向传播
  286. #----------------------#
  287. scaler.scale(losses[-1]).backward()
  288. scaler.step(self.optimizer)
  289. scaler.update()
  290. return losses, embed_loss
  291. def weights_init(net, init_type='normal', init_gain=0.02):
  292. def init_func(m):
  293. classname = m.__class__.__name__
  294. if hasattr(m, 'weight') and classname.find('Conv') != -1:
  295. if init_type == 'normal':
  296. torch.nn.init.normal_(m.weight.data, 0.0, init_gain)
  297. elif init_type == 'xavier':
  298. torch.nn.init.xavier_normal_(m.weight.data, gain=init_gain)
  299. elif init_type == 'kaiming':
  300. torch.nn.init.kaiming_normal_(m.weight.data, a=0, mode='fan_in')
  301. elif init_type == 'orthogonal':
  302. torch.nn.init.orthogonal_(m.weight.data, gain=init_gain)
  303. else:
  304. raise NotImplementedError('initialization method [%s] is not implemented' % init_type)
  305. elif classname.find('BatchNorm2d') != -1:
  306. torch.nn.init.normal_(m.weight.data, 1.0, 0.02)
  307. torch.nn.init.constant_(m.bias.data, 0.0)
  308. print('initialize network with %s type' % init_type)
  309. net.apply(init_func)
  310. def get_lr_scheduler(lr_decay_type, lr, min_lr, total_iters, warmup_iters_ratio = 0.05, warmup_lr_ratio = 0.1, no_aug_iter_ratio = 0.05, step_num = 10):
  311. def yolox_warm_cos_lr(lr, min_lr, total_iters, warmup_total_iters, warmup_lr_start, no_aug_iter, iters):
  312. if iters <= warmup_total_iters:
  313. # lr = (lr - warmup_lr_start) * iters / float(warmup_total_iters) + warmup_lr_start
  314. lr = (lr - warmup_lr_start) * pow(iters / float(warmup_total_iters), 2) + warmup_lr_start
  315. elif iters >= total_iters - no_aug_iter:
  316. lr = min_lr
  317. else:
  318. lr = min_lr + 0.5 * (lr - min_lr) * (
  319. 1.0 + math.cos(math.pi* (iters - warmup_total_iters) / (total_iters - warmup_total_iters - no_aug_iter))
  320. )
  321. return lr
  322. def step_lr(lr, decay_rate, step_size, iters):
  323. if step_size < 1:
  324. raise ValueError("step_size must above 1.")
  325. n = iters // step_size
  326. out_lr = lr * decay_rate ** n
  327. return out_lr
  328. if lr_decay_type == "cos":
  329. warmup_total_iters = min(max(warmup_iters_ratio * total_iters, 1), 3)
  330. warmup_lr_start = max(warmup_lr_ratio * lr, 1e-6)
  331. no_aug_iter = min(max(no_aug_iter_ratio * total_iters, 1), 15)
  332. func = partial(yolox_warm_cos_lr ,lr, min_lr, total_iters, warmup_total_iters, warmup_lr_start, no_aug_iter)
  333. else:
  334. decay_rate = (min_lr / lr) ** (1 / (step_num - 1))
  335. step_size = total_iters / step_num
  336. func = partial(step_lr, lr, decay_rate, step_size)
  337. return func
  338. def set_optimizer_lr(optimizer, lr_scheduler_func, epoch):
  339. lr = lr_scheduler_func(epoch)
  340. for param_group in optimizer.param_groups:
  341. param_group['lr'] = lr