import os from torchvision import datasets ''' 为数据集生成对应的txt文件 ''' def gen_txt(txt_path, img_dir): f = open(txt_path, 'w') classes = [] # 列表存储所有类名 for root, s_dirs, _ in os.walk(img_dir, topdown=True): # 获取 train文件下各文件夹名称 j = 0 for sub_dir in s_dirs: classes.append(sub_dir) # 将类名添加到列表 i_dir = os.path.join(root, sub_dir) # 获取各类的文件夹 绝对路径 print(i_dir) img_list = os.listdir(i_dir) # 获取类别文件夹下所有png图片的路径 for i in range(len(img_list)): if not img_list[i].endswith('jpg'): # 若不是png文件,跳过 continue label = str(j) img_path = os.path.join(i_dir, img_list[i]) line = img_path + ' ' + label + '\n' f.write(line) j+=1 f.close() return classes def write_class_list(classes, class_txt_path): with open(class_txt_path, 'w') as f: for cls in sorted(classes): f.write(cls + '\n') if __name__ == '__main__': import argparse parser = argparse.ArgumentParser(description='') parser.add_argument('--txt_path', default='./dataset/New_dataset', type=str, help='path to new datasets') parser.add_argument('--specific_data', default='testtest', type=str, help='process the file_name') parser.add_argument('--txt_name', default='train', type=str, help='process the file_name') # parser.add_argument('--class_txt_path', default='./dataset/New_dataset', type=str, help='class.txt') args = parser.parse_args() train_txt_path = os.path.join(args.txt_path, f"{args.txt_name}.txt") train_dir = os.path.join(args.txt_path, args.specific_data) # valid_txt_path = os.path.join('/home/yhsun/classification-main/dataset/CIFAR-10', "test_png.txt") # valid_dir = os.path.join('/home/yhsun/classification-main/dataset/CIFAR-10', 'test_cifar10_PNG') class_txt_path = os.path.join(args.txt_path, "class.txt") classes = gen_txt(train_txt_path, train_dir) # gen_txt(valid_txt_path, valid_dir) write_class_list(classes, class_txt_path)