generate_txt.py 2.2 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556
  1. import os
  2. from torchvision import datasets
  3. '''
  4. 为数据集生成对应的txt文件
  5. '''
  6. def gen_txt(txt_path, img_dir):
  7. f = open(txt_path, 'w')
  8. classes = [] # 列表存储所有类名
  9. for root, s_dirs, _ in os.walk(img_dir, topdown=True): # 获取 train文件下各文件夹名称
  10. j = 0
  11. for sub_dir in s_dirs:
  12. classes.append(sub_dir) # 将类名添加到列表
  13. i_dir = os.path.join(root, sub_dir) # 获取各类的文件夹 绝对路径
  14. print(i_dir)
  15. img_list = os.listdir(i_dir) # 获取类别文件夹下所有png图片的路径
  16. for i in range(len(img_list)):
  17. if not img_list[i].endswith('jpg'): # 若不是png文件,跳过
  18. continue
  19. label = str(j)
  20. img_path = os.path.join(i_dir, img_list[i])
  21. line = img_path + ' ' + label + '\n'
  22. f.write(line)
  23. j+=1
  24. f.close()
  25. return classes
  26. def write_class_list(classes, class_txt_path):
  27. with open(class_txt_path, 'w') as f:
  28. for cls in sorted(classes):
  29. f.write(cls + '\n')
  30. if __name__ == '__main__':
  31. import argparse
  32. parser = argparse.ArgumentParser(description='')
  33. parser.add_argument('--txt_path', default='./dataset/New_dataset', type=str, help='path to new datasets')
  34. parser.add_argument('--specific_data', default='testtest', type=str, help='process the file_name')
  35. parser.add_argument('--txt_name', default='train', type=str, help='process the file_name')
  36. # parser.add_argument('--class_txt_path', default='./dataset/New_dataset', type=str, help='class.txt')
  37. args = parser.parse_args()
  38. train_txt_path = os.path.join(args.txt_path, f"{args.txt_name}.txt")
  39. train_dir = os.path.join(args.txt_path, args.specific_data)
  40. # valid_txt_path = os.path.join('/home/yhsun/classification-main/dataset/CIFAR-10', "test_png.txt")
  41. # valid_dir = os.path.join('/home/yhsun/classification-main/dataset/CIFAR-10', 'test_cifar10_PNG')
  42. class_txt_path = os.path.join(args.txt_path, "class.txt")
  43. classes = gen_txt(train_txt_path, train_dir)
  44. # gen_txt(valid_txt_path, valid_dir)
  45. write_class_list(classes, class_txt_path)