data_get.py 2.5 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364
  1. # 数据格式定义部分
  2. # 数据需准备成以下格式
  3. # ├── 数据集路径:data_path
  4. # └── image:存放所有图片
  5. # └── train.txt:训练图片的绝对路径(或相对data_path下路径)和类别号,(image/mask/0.jpg 0 2\n)表示该图片类别为0和2,空类别图片无类别号
  6. # └── val.txt:验证图片的绝对路径(或相对data_path下路径)和类别
  7. # └── class.txt:所有的类别名称
  8. # class.csv内容如下:
  9. # 类别1
  10. # 类别2
  11. import numpy as np
  12. import os
  13. import argparse
  14. def data_get(args):
  15. data_dict = data_prepare(args).load()
  16. return data_dict
  17. class data_prepare:
  18. def __init__(self, args):
  19. self.args = args
  20. self.data_path = os.path.join(args.data_path, args.dataset_name)
  21. self.dataset_name = args.dataset_name
  22. def load(self):
  23. data_dict = {}
  24. data_dict['train'] = self._load_label('train.txt')
  25. data_dict['test'] = self._load_label('test.txt')
  26. data_dict['class'] = self._load_class()
  27. return data_dict
  28. def _load_label(self, txt_name):
  29. with open(f'{self.args.data_path}/{self.args.dataset_name}/{txt_name}', encoding='utf-8')as f:
  30. txt_list = [_.strip().split(' ') for _ in f.readlines()] # 读取所有图片路径和类别号
  31. data_list = [['', 0] for _ in range(len(txt_list))] # [图片路径,类别独热编码]
  32. for i, line in enumerate(txt_list):
  33. image_path = line[0]
  34. # print(image_path)
  35. data_list[i][0] = image_path
  36. data_list[i][1] = np.zeros(self.args.output_class, dtype=np.float32)
  37. for j in line[1:]:
  38. data_list[i][1][int(j)] = 1
  39. return data_list
  40. def _load_class(self):
  41. with open(f'{self.args.data_path}/{self.args.dataset_name}/class.txt', encoding='utf-8')as f:
  42. txt_list = [_.strip() for _ in f.readlines()]
  43. return txt_list
  44. if __name__ == '__main__':
  45. import argparse
  46. parser = argparse.ArgumentParser(description='Data loader for specific dataset')
  47. parser.add_argument('--data_path', default='/home/yhsun/classification-main/dataset', type=str, help='Root path to datasets')
  48. parser.add_argument('--dataset_name', default='CIFAR-10', type=str, help='Specific dataset name')
  49. parser.add_argument('--output_class', default=10, type=int, help='Number of output classes')
  50. parser.add_argument('--input_size', default=640, type=int)
  51. args = parser.parse_args()
  52. data_dict = data_get(args)
  53. print(len(data_dict['train']))