123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137 |
- # 制作翻转的图片,同时创建它们的标签,用于检测图片是否翻转的4分类任务
- import os
- import cv2
- import tqdm
- import random
- import argparse
- from scipy import ndimage
- # -------------------------------------------------------------------------------------------------------------------- #
- # 设置
- parser = argparse.ArgumentParser()
- parser.add_argument('--image_path', default=r'D:\dataset\classification\flip\image\000', type=str)
- parser.add_argument('--save_path', default=r'D:\dataset\classification\flip\image', type=str)
- parser.add_argument('--file_path', default=r'D:\dataset\classification\flip', type=str)
- parser.add_argument('--add0', default=True, type=bool, help='|增加色彩变换|')
- parser.add_argument('--add1', default=True, type=bool, help='|增加角度倾斜变换|')
- parser.add_argument('--divide', default=r'9,1', type=str)
- args = parser.parse_args()
- # -------------------------------------------------------------------------------------------------------------------- #
- # 程序
- def resize(image, max_h=1000): # 用于缩小图片大小,max_h为最大高度
- h, w, _ = image.shape
- h1 = max_h
- w1 = int(h1 / h * w)
- if h > h1:
- image = cv2.resize(image, (w1, h1))
- return image
- def left(image): # 逆时针转90度
- image = cv2.transpose(image)
- image = cv2.flip(image, 0)
- return image
- def right(image): # 顺时针转90度
- image = cv2.transpose(image)
- image = cv2.flip(image, 1)
- return image
- def flip(image): # 顺时针转180度
- image = cv2.flip(image, -1)
- return image
- def rotate(image):
- image = ndimage.rotate(image, random.randint(-2, 2)) # 逆时针旋转几度
- return image
- if __name__ == '__main__':
- if not os.path.exists(args.save_path + '/270'):
- os.makedirs(args.save_path + '/270')
- if not os.path.exists(args.save_path + '/090'):
- os.makedirs(args.save_path + '/090')
- if not os.path.exists(args.save_path + '/180'):
- os.makedirs(args.save_path + '/180')
- path_list = os.listdir(args.image_path)
- path_list = [f'{args.image_path}/{_}' for _ in path_list]
- A_list = []
- B_list = []
- C_list = []
- D_list = []
- for i, image_path in enumerate(tqdm.tqdm(path_list)):
- image = cv2.imread(image_path)
- image = resize(image)
- image_left = left(image)
- image_right = right(image)
- image_flip = flip(image)
- index = str(i).rjust(3, '0')
- save_left = args.save_path + f'/270/{index}_left.jpg'
- save_right = args.save_path + f'/090/{index}_right.jpg'
- save_flip = args.save_path + f'/180/{index}_flip.jpg'
- cv2.imwrite(save_left, image_left)
- cv2.imwrite(save_right, image_right)
- cv2.imwrite(save_flip, image_flip)
- A_list.append(image_path + ' 0\n')
- B_list.append(save_left + ' 3\n')
- C_list.append(save_right + ' 1\n')
- D_list.append(save_flip + ' 2\n')
- # 色彩变换
- if args.add0:
- A_rgb = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)
- B_rgb = cv2.cvtColor(image_left, cv2.COLOR_RGB2BGR)
- C_rgb = cv2.cvtColor(image_right, cv2.COLOR_RGB2BGR)
- D_rgb = cv2.cvtColor(image_flip, cv2.COLOR_RGB2BGR)
- A_rgb_path = image_path.split('.')[0] + '_bgr.jpg'
- B_rgb_path = args.save_path + f'/270/{index}_left_bgr.jpg'
- C_rgb_path = args.save_path + f'/090/{index}_right_bgr.jpg'
- D_rgb_path = args.save_path + f'/180/{index}_flip_bgr.jpg'
- cv2.imwrite(A_rgb_path, A_rgb)
- cv2.imwrite(B_rgb_path, B_rgb)
- cv2.imwrite(C_rgb_path, C_rgb)
- cv2.imwrite(D_rgb_path, D_rgb)
- A_list.append(A_rgb_path + ' 0\n')
- B_list.append(B_rgb_path + ' 3\n')
- C_list.append(C_rgb_path + ' 1\n')
- D_list.append(D_rgb_path + ' 2\n')
- # 角度变换
- if args.add1:
- A_rotate = rotate(image)
- B_rotate = rotate(image_left)
- C_rotate = rotate(image_right)
- D_rotate = rotate(image_flip)
- A_rotate_path = image_path.split('.')[0] + '_rotate.jpg'
- B_rotate_path = args.save_path + f'/270/{index}_left_rotate.jpg'
- C_rotate_path = args.save_path + f'/090/{index}_right_rotate.jpg'
- D_rotate_path = args.save_path + f'/180/{index}_flip_rotate.jpg'
- cv2.imwrite(A_rotate_path, A_rotate)
- cv2.imwrite(B_rotate_path, B_rotate)
- cv2.imwrite(C_rotate_path, C_rotate)
- cv2.imwrite(D_rotate_path, D_rotate)
- A_list.append(A_rotate_path + ' 0\n')
- B_list.append(B_rotate_path + ' 3\n')
- C_list.append(C_rotate_path + ' 1\n')
- D_list.append(D_rotate_path + ' 2\n')
- a, b = list(map(int, args.divide.split(',')))
- data_len = len(A_list)
- random.shuffle(A_list)
- random.shuffle(B_list)
- random.shuffle(C_list)
- random.shuffle(D_list)
- train_number = int(data_len * a / (a + b))
- val_number = int(data_len * b / (a + b))
- with open(args.file_path + '/train.txt', 'w', encoding='utf-8') as f:
- f.writelines(A_list[0:train_number])
- f.writelines(B_list[0:train_number])
- f.writelines(C_list[0:train_number])
- f.writelines(D_list[0:train_number])
- with open(args.file_path + '/val.txt', 'w', encoding='utf-8') as f:
- f.writelines(A_list[0:val_number])
- f.writelines(B_list[0:val_number])
- f.writelines(C_list[0:val_number])
- f.writelines(D_list[0:val_number])
|