generate_txt.py 2.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657
  1. import os
  2. from torchvision import datasets
  3. '''
  4. 为数据集生成对应的txt文件
  5. 数据集排列格式:数据集名称->train/test->分类名称->图片
  6. '''
  7. def gen_txt(txt_path, img_dir):
  8. f = open(txt_path, 'w')
  9. classes = [] # 列表存储所有类名
  10. for root, s_dirs, _ in os.walk(img_dir, topdown=True): # 获取 train文件下各文件夹名称
  11. j = 0
  12. for sub_dir in s_dirs:
  13. classes.append(sub_dir) # 将类名添加到列表
  14. i_dir = os.path.join(root, sub_dir) # 获取各类的文件夹 绝对路径
  15. print(i_dir)
  16. img_list = os.listdir(i_dir) # 获取类别文件夹下所有png图片的路径
  17. for i in range(len(img_list)):
  18. if not img_list[i].endswith('jpg'): # 若不是png文件,跳过
  19. continue
  20. label = str(j)
  21. img_path = os.path.join(i_dir, img_list[i])
  22. line = img_path + ' ' + label + '\n'
  23. f.write(line)
  24. j+=1
  25. f.close()
  26. return classes
  27. def write_class_list(classes, class_txt_path):
  28. with open(class_txt_path, 'w') as f:
  29. for cls in sorted(classes):
  30. f.write(cls + '\n')
  31. if __name__ == '__main__':
  32. import argparse
  33. parser = argparse.ArgumentParser(description='')
  34. parser.add_argument('--txt_path', default='./dataset/New_dataset', type=str, help='path to new datasets')
  35. parser.add_argument('--specific_data', default='testtest', type=str, help='process the file_name')
  36. parser.add_argument('--txt_name', default='train', type=str, help='process the file_name')
  37. # parser.add_argument('--class_txt_path', default='./dataset/New_dataset', type=str, help='class.txt')
  38. args = parser.parse_args()
  39. train_txt_path = os.path.join(args.txt_path, f"{args.txt_name}.txt")
  40. train_dir = os.path.join(args.txt_path, args.specific_data)
  41. # valid_txt_path = os.path.join('/home/yhsun/classification-main/dataset/CIFAR-10', "test_png.txt")
  42. # valid_dir = os.path.join('/home/yhsun/classification-main/dataset/CIFAR-10', 'test_cifar10_PNG')
  43. class_txt_path = os.path.join(args.txt_path, "class.txt")
  44. classes = gen_txt(train_txt_path, train_dir)
  45. # gen_txt(valid_txt_path, valid_dir)
  46. write_class_list(classes, class_txt_path)