anchors.py 2.8 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667
  1. import numpy as np
  2. #--------------------------------------------#
  3. # 生成基础的先验框
  4. #--------------------------------------------#
  5. def generate_anchor_base(base_size = 16, ratios = [0.5, 1, 2], anchor_scales = [8, 16, 32]):
  6. anchor_base = np.zeros((len(ratios) * len(anchor_scales), 4), dtype = np.float32)
  7. for i in range(len(ratios)):
  8. for j in range(len(anchor_scales)):
  9. h = base_size * anchor_scales[j] * np.sqrt(ratios[i])
  10. w = base_size * anchor_scales[j] * np.sqrt(1. / ratios[i])
  11. index = i * len(anchor_scales) + j
  12. anchor_base[index, 0] = - h / 2.
  13. anchor_base[index, 1] = - w / 2.
  14. anchor_base[index, 2] = h / 2.
  15. anchor_base[index, 3] = w / 2.
  16. return anchor_base
  17. #--------------------------------------------#
  18. # 对基础先验框进行拓展对应到所有特征点上
  19. #--------------------------------------------#
  20. def _enumerate_shifted_anchor(anchor_base, feat_stride, height, width):
  21. #---------------------------------#
  22. # 计算网格中心点
  23. #---------------------------------#
  24. shift_x = np.arange(0, width * feat_stride, feat_stride)
  25. shift_y = np.arange(0, height * feat_stride, feat_stride)
  26. shift_x, shift_y = np.meshgrid(shift_x, shift_y)
  27. shift = np.stack((shift_x.ravel(), shift_y.ravel(), shift_x.ravel(), shift_y.ravel(),), axis=1)
  28. #---------------------------------#
  29. # 每个网格点上的9个先验框
  30. #---------------------------------#
  31. A = anchor_base.shape[0]
  32. K = shift.shape[0]
  33. anchor = anchor_base.reshape((1, A, 4)) + shift.reshape((K, 1, 4))
  34. #---------------------------------#
  35. # 所有的先验框
  36. #---------------------------------#
  37. anchor = anchor.reshape((K * A, 4)).astype(np.float32)
  38. return anchor
  39. if __name__ == "__main__":
  40. import matplotlib.pyplot as plt
  41. nine_anchors = generate_anchor_base()
  42. print(nine_anchors)
  43. height, width, feat_stride = 38,38,16
  44. anchors_all = _enumerate_shifted_anchor(nine_anchors, feat_stride, height, width)
  45. print(np.shape(anchors_all))
  46. fig = plt.figure()
  47. ax = fig.add_subplot(111)
  48. plt.ylim(-300,900)
  49. plt.xlim(-300,900)
  50. shift_x = np.arange(0, width * feat_stride, feat_stride)
  51. shift_y = np.arange(0, height * feat_stride, feat_stride)
  52. shift_x, shift_y = np.meshgrid(shift_x, shift_y)
  53. plt.scatter(shift_x,shift_y)
  54. box_widths = anchors_all[:,2]-anchors_all[:,0]
  55. box_heights = anchors_all[:,3]-anchors_all[:,1]
  56. for i in [108, 109, 110, 111, 112, 113, 114, 115, 116]:
  57. rect = plt.Rectangle([anchors_all[i, 0],anchors_all[i, 1]],box_widths[i],box_heights[i],color="r",fill=False)
  58. ax.add_patch(rect)
  59. plt.show()