12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667 |
- import numpy as np
- #--------------------------------------------#
- # 生成基础的先验框
- #--------------------------------------------#
- def generate_anchor_base(base_size = 16, ratios = [0.5, 1, 2], anchor_scales = [8, 16, 32]):
- anchor_base = np.zeros((len(ratios) * len(anchor_scales), 4), dtype = np.float32)
- for i in range(len(ratios)):
- for j in range(len(anchor_scales)):
- h = base_size * anchor_scales[j] * np.sqrt(ratios[i])
- w = base_size * anchor_scales[j] * np.sqrt(1. / ratios[i])
- index = i * len(anchor_scales) + j
- anchor_base[index, 0] = - h / 2.
- anchor_base[index, 1] = - w / 2.
- anchor_base[index, 2] = h / 2.
- anchor_base[index, 3] = w / 2.
- return anchor_base
- #--------------------------------------------#
- # 对基础先验框进行拓展对应到所有特征点上
- #--------------------------------------------#
- def _enumerate_shifted_anchor(anchor_base, feat_stride, height, width):
- #---------------------------------#
- # 计算网格中心点
- #---------------------------------#
- shift_x = np.arange(0, width * feat_stride, feat_stride)
- shift_y = np.arange(0, height * feat_stride, feat_stride)
- shift_x, shift_y = np.meshgrid(shift_x, shift_y)
- shift = np.stack((shift_x.ravel(), shift_y.ravel(), shift_x.ravel(), shift_y.ravel(),), axis=1)
- #---------------------------------#
- # 每个网格点上的9个先验框
- #---------------------------------#
- A = anchor_base.shape[0]
- K = shift.shape[0]
- anchor = anchor_base.reshape((1, A, 4)) + shift.reshape((K, 1, 4))
- #---------------------------------#
- # 所有的先验框
- #---------------------------------#
- anchor = anchor.reshape((K * A, 4)).astype(np.float32)
- return anchor
-
- if __name__ == "__main__":
- import matplotlib.pyplot as plt
- nine_anchors = generate_anchor_base()
- print(nine_anchors)
- height, width, feat_stride = 38,38,16
- anchors_all = _enumerate_shifted_anchor(nine_anchors, feat_stride, height, width)
- print(np.shape(anchors_all))
-
- fig = plt.figure()
- ax = fig.add_subplot(111)
- plt.ylim(-300,900)
- plt.xlim(-300,900)
- shift_x = np.arange(0, width * feat_stride, feat_stride)
- shift_y = np.arange(0, height * feat_stride, feat_stride)
- shift_x, shift_y = np.meshgrid(shift_x, shift_y)
- plt.scatter(shift_x,shift_y)
- box_widths = anchors_all[:,2]-anchors_all[:,0]
- box_heights = anchors_all[:,3]-anchors_all[:,1]
-
- for i in [108, 109, 110, 111, 112, 113, 114, 115, 116]:
- rect = plt.Rectangle([anchors_all[i, 0],anchors_all[i, 1]],box_widths[i],box_heights[i],color="r",fill=False)
- ax.add_patch(rect)
- plt.show()
|