1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162 |
- # 数据格式定义部分
- # 数据需准备成以下格式
- # ├── 数据集路径:data_path
- # └── image:存放所有图片
- # └── train.txt:训练图片的绝对路径(或相对data_path下路径)和类别号,(image/mask/0.jpg 0 2\n)表示该图片类别为0和2,空类别图片无类别号
- # └── val.txt:验证图片的绝对路径(或相对data_path下路径)和类别
- # └── class.txt:所有的类别名称
- # class.txt内容如下:
- # 类别1
- # 类别2
- import numpy as np
- import os
- def data_get(args):
- data_dict = data_prepare(args).load()
- return data_dict
- class data_prepare:
- def __init__(self, args):
- self.args = args
- self.data_path = os.path.join(args.data_path, args.dataset_name)
- self.dataset_name = args.dataset_name
- def load(self):
- data_dict = {}
- data_dict['train'] = self._load_label('train.txt')
- data_dict['test'] = self._load_label('test.txt')
- data_dict['class'] = self._load_class()
- return data_dict
- def _load_label(self, txt_name):
- with open(f'{self.args.data_path}/{self.args.dataset_name}/{txt_name}', encoding='utf-8') as f:
- txt_list = [_.strip().split(' ') for _ in f.readlines()] # 读取所有图片路径和类别号
- data_list = [['', 0] for _ in range(len(txt_list))] # [图片路径,类别独热编码]
- for i, line in enumerate(txt_list):
- image_path = line[0]
- # print(image_path)
- data_list[i][0] = image_path
- data_list[i][1] = np.zeros(self.args.output_class, dtype=np.float32)
- for j in line[1:]:
- data_list[i][1][int(j)] = 1
- return data_list
- def _load_class(self):
- with open(f'{self.args.data_path}/{self.args.dataset_name}/class.txt', encoding='utf-8') as f:
- txt_list = [_.strip() for _ in f.readlines()]
- return txt_list
- if __name__ == '__main__':
- import argparse
- parser = argparse.ArgumentParser(description='Data loader for specific dataset')
- parser.add_argument('--data_path', default='../dataset', type=str, help='Root path to datasets')
- parser.add_argument('--dataset_name', default='CIFAR-10', type=str, help='Specific dataset name')
- parser.add_argument('--output_class', default=10, type=int, help='Number of output classes')
- parser.add_argument('--input_size', default=640, type=int)
- args = parser.parse_args()
- data_dict = data_get(args)
- print(len(data_dict['train']))
|