1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556 |
- import os
- from torchvision import datasets
- '''
- 为数据集生成对应的txt文件
- '''
- def gen_txt(txt_path, img_dir):
- f = open(txt_path, 'w')
- classes = [] # 列表存储所有类名
- for root, s_dirs, _ in os.walk(img_dir, topdown=True): # 获取 train文件下各文件夹名称
- j = 0
- for sub_dir in s_dirs:
- classes.append(sub_dir) # 将类名添加到列表
- i_dir = os.path.join(root, sub_dir) # 获取各类的文件夹 绝对路径
- print(i_dir)
- img_list = os.listdir(i_dir) # 获取类别文件夹下所有png图片的路径
- for i in range(len(img_list)):
- if not img_list[i].endswith('jpg'): # 若不是png文件,跳过
- continue
- label = str(j)
- img_path = os.path.join(i_dir, img_list[i])
- line = img_path + ' ' + label + '\n'
- f.write(line)
- j+=1
- f.close()
- return classes
- def write_class_list(classes, class_txt_path):
- with open(class_txt_path, 'w') as f:
- for cls in sorted(classes):
- f.write(cls + '\n')
- if __name__ == '__main__':
- import argparse
- parser = argparse.ArgumentParser(description='')
- parser.add_argument('--txt_path', default='./dataset/New_dataset', type=str, help='path to new datasets')
- parser.add_argument('--specific_data', default='testtest', type=str, help='process the file_name')
- parser.add_argument('--txt_name', default='train', type=str, help='process the file_name')
- # parser.add_argument('--class_txt_path', default='./dataset/New_dataset', type=str, help='class.txt')
- args = parser.parse_args()
- train_txt_path = os.path.join(args.txt_path, f"{args.txt_name}.txt")
- train_dir = os.path.join(args.txt_path, args.specific_data)
- # valid_txt_path = os.path.join('/home/yhsun/classification-main/dataset/CIFAR-10', "test_png.txt")
- # valid_dir = os.path.join('/home/yhsun/classification-main/dataset/CIFAR-10', 'test_cifar10_PNG')
- class_txt_path = os.path.join(args.txt_path, "class.txt")
- classes = gen_txt(train_txt_path, train_dir)
- # gen_txt(valid_txt_path, valid_dir)
- write_class_list(classes, class_txt_path)
|