data_get.py 2.5 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162
  1. # 数据格式定义部分
  2. # 数据需准备成以下格式
  3. # ├── 数据集路径:data_path
  4. # └── image:存放所有图片
  5. # └── train.txt:训练图片的绝对路径(或相对data_path下路径)和类别号,(image/mask/0.jpg 0 2\n)表示该图片类别为0和2,空类别图片无类别号
  6. # └── val.txt:验证图片的绝对路径(或相对data_path下路径)和类别
  7. # └── class.txt:所有的类别名称
  8. # class.txt内容如下:
  9. # 类别1
  10. # 类别2
  11. import numpy as np
  12. import os
  13. def data_get(args):
  14. data_dict = data_prepare(args).load()
  15. return data_dict
  16. class data_prepare:
  17. def __init__(self, args):
  18. self.args = args
  19. self.data_path = os.path.join(args.data_path, args.dataset_name)
  20. self.dataset_name = args.dataset_name
  21. def load(self):
  22. data_dict = {}
  23. data_dict['train'] = self._load_label('train.txt')
  24. data_dict['test'] = self._load_label('test.txt')
  25. data_dict['class'] = self._load_class()
  26. return data_dict
  27. def _load_label(self, txt_name):
  28. with open(f'{self.args.data_path}/{self.args.dataset_name}/{txt_name}', encoding='utf-8') as f:
  29. txt_list = [_.strip().split(' ') for _ in f.readlines()] # 读取所有图片路径和类别号
  30. data_list = [['', 0] for _ in range(len(txt_list))] # [图片路径,类别独热编码]
  31. for i, line in enumerate(txt_list):
  32. image_path = line[0]
  33. # print(image_path)
  34. data_list[i][0] = image_path
  35. data_list[i][1] = np.zeros(self.args.output_class, dtype=np.float32)
  36. for j in line[1:]:
  37. data_list[i][1][int(j)] = 1
  38. return data_list
  39. def _load_class(self):
  40. with open(f'{self.args.data_path}/{self.args.dataset_name}/class.txt', encoding='utf-8') as f:
  41. txt_list = [_.strip() for _ in f.readlines()]
  42. return txt_list
  43. if __name__ == '__main__':
  44. import argparse
  45. parser = argparse.ArgumentParser(description='Data loader for specific dataset')
  46. parser.add_argument('--data_path', default='../dataset', type=str, help='Root path to datasets')
  47. parser.add_argument('--dataset_name', default='CIFAR-10', type=str, help='Specific dataset name')
  48. parser.add_argument('--output_class', default=10, type=int, help='Number of output classes')
  49. parser.add_argument('--input_size', default=640, type=int)
  50. args = parser.parse_args()
  51. data_dict = data_get(args)
  52. print(len(data_dict['train']))