make_flip_image.py 5.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137
  1. # 制作翻转的图片,同时创建它们的标签,用于检测图片是否翻转的4分类任务
  2. import os
  3. import cv2
  4. import tqdm
  5. import random
  6. import argparse
  7. from scipy import ndimage
  8. # -------------------------------------------------------------------------------------------------------------------- #
  9. # 设置
  10. parser = argparse.ArgumentParser()
  11. parser.add_argument('--image_path', default=r'D:\dataset\classification\flip\image\000', type=str)
  12. parser.add_argument('--save_path', default=r'D:\dataset\classification\flip\image', type=str)
  13. parser.add_argument('--file_path', default=r'D:\dataset\classification\flip', type=str)
  14. parser.add_argument('--add0', default=True, type=bool, help='|增加色彩变换|')
  15. parser.add_argument('--add1', default=True, type=bool, help='|增加角度倾斜变换|')
  16. parser.add_argument('--divide', default=r'9,1', type=str)
  17. args = parser.parse_args()
  18. # -------------------------------------------------------------------------------------------------------------------- #
  19. # 程序
  20. def resize(image, max_h=1000): # 用于缩小图片大小,max_h为最大高度
  21. h, w, _ = image.shape
  22. h1 = max_h
  23. w1 = int(h1 / h * w)
  24. if h > h1:
  25. image = cv2.resize(image, (w1, h1))
  26. return image
  27. def left(image): # 逆时针转90度
  28. image = cv2.transpose(image)
  29. image = cv2.flip(image, 0)
  30. return image
  31. def right(image): # 顺时针转90度
  32. image = cv2.transpose(image)
  33. image = cv2.flip(image, 1)
  34. return image
  35. def flip(image): # 顺时针转180度
  36. image = cv2.flip(image, -1)
  37. return image
  38. def rotate(image):
  39. image = ndimage.rotate(image, random.randint(-2, 2)) # 逆时针旋转几度
  40. return image
  41. if __name__ == '__main__':
  42. if not os.path.exists(args.save_path + '/270'):
  43. os.makedirs(args.save_path + '/270')
  44. if not os.path.exists(args.save_path + '/090'):
  45. os.makedirs(args.save_path + '/090')
  46. if not os.path.exists(args.save_path + '/180'):
  47. os.makedirs(args.save_path + '/180')
  48. path_list = os.listdir(args.image_path)
  49. path_list = [f'{args.image_path}/{_}' for _ in path_list]
  50. A_list = []
  51. B_list = []
  52. C_list = []
  53. D_list = []
  54. for i, image_path in enumerate(tqdm.tqdm(path_list)):
  55. image = cv2.imread(image_path)
  56. image = resize(image)
  57. image_left = left(image)
  58. image_right = right(image)
  59. image_flip = flip(image)
  60. index = str(i).rjust(3, '0')
  61. save_left = args.save_path + f'/270/{index}_left.jpg'
  62. save_right = args.save_path + f'/090/{index}_right.jpg'
  63. save_flip = args.save_path + f'/180/{index}_flip.jpg'
  64. cv2.imwrite(save_left, image_left)
  65. cv2.imwrite(save_right, image_right)
  66. cv2.imwrite(save_flip, image_flip)
  67. A_list.append(image_path + ' 0\n')
  68. B_list.append(save_left + ' 3\n')
  69. C_list.append(save_right + ' 1\n')
  70. D_list.append(save_flip + ' 2\n')
  71. # 色彩变换
  72. if args.add0:
  73. A_rgb = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)
  74. B_rgb = cv2.cvtColor(image_left, cv2.COLOR_RGB2BGR)
  75. C_rgb = cv2.cvtColor(image_right, cv2.COLOR_RGB2BGR)
  76. D_rgb = cv2.cvtColor(image_flip, cv2.COLOR_RGB2BGR)
  77. A_rgb_path = image_path.split('.')[0] + '_bgr.jpg'
  78. B_rgb_path = args.save_path + f'/270/{index}_left_bgr.jpg'
  79. C_rgb_path = args.save_path + f'/090/{index}_right_bgr.jpg'
  80. D_rgb_path = args.save_path + f'/180/{index}_flip_bgr.jpg'
  81. cv2.imwrite(A_rgb_path, A_rgb)
  82. cv2.imwrite(B_rgb_path, B_rgb)
  83. cv2.imwrite(C_rgb_path, C_rgb)
  84. cv2.imwrite(D_rgb_path, D_rgb)
  85. A_list.append(A_rgb_path + ' 0\n')
  86. B_list.append(B_rgb_path + ' 3\n')
  87. C_list.append(C_rgb_path + ' 1\n')
  88. D_list.append(D_rgb_path + ' 2\n')
  89. # 角度变换
  90. if args.add1:
  91. A_rotate = rotate(image)
  92. B_rotate = rotate(image_left)
  93. C_rotate = rotate(image_right)
  94. D_rotate = rotate(image_flip)
  95. A_rotate_path = image_path.split('.')[0] + '_rotate.jpg'
  96. B_rotate_path = args.save_path + f'/270/{index}_left_rotate.jpg'
  97. C_rotate_path = args.save_path + f'/090/{index}_right_rotate.jpg'
  98. D_rotate_path = args.save_path + f'/180/{index}_flip_rotate.jpg'
  99. cv2.imwrite(A_rotate_path, A_rotate)
  100. cv2.imwrite(B_rotate_path, B_rotate)
  101. cv2.imwrite(C_rotate_path, C_rotate)
  102. cv2.imwrite(D_rotate_path, D_rotate)
  103. A_list.append(A_rotate_path + ' 0\n')
  104. B_list.append(B_rotate_path + ' 3\n')
  105. C_list.append(C_rotate_path + ' 1\n')
  106. D_list.append(D_rotate_path + ' 2\n')
  107. a, b = list(map(int, args.divide.split(',')))
  108. data_len = len(A_list)
  109. random.shuffle(A_list)
  110. random.shuffle(B_list)
  111. random.shuffle(C_list)
  112. random.shuffle(D_list)
  113. train_number = int(data_len * a / (a + b))
  114. val_number = int(data_len * b / (a + b))
  115. with open(args.file_path + '/train.txt', 'w', encoding='utf-8') as f:
  116. f.writelines(A_list[0:train_number])
  117. f.writelines(B_list[0:train_number])
  118. f.writelines(C_list[0:train_number])
  119. f.writelines(D_list[0:train_number])
  120. with open(args.file_path + '/val.txt', 'w', encoding='utf-8') as f:
  121. f.writelines(A_list[0:val_number])
  122. f.writelines(B_list[0:val_number])
  123. f.writelines(C_list[0:val_number])
  124. f.writelines(D_list[0:val_number])