# 数据格式定义部分 # 数据需准备成以下格式 # ├── 数据集路径:data_path # └── image:存放所有图片 # └── train.txt:训练图片的绝对路径(或相对data_path下路径)和类别号,(image/mask/0.jpg 0 2\n)表示该图片类别为0和2,空类别图片无类别号 # └── val.txt:验证图片的绝对路径(或相对data_path下路径)和类别 # └── class.txt:所有的类别名称 # class.txt内容如下: # 类别1 # 类别2 import numpy as np import os def data_get(args): data_dict = data_prepare(args).load() return data_dict class data_prepare: def __init__(self, args): self.args = args self.data_path = os.path.join(args.data_path, args.dataset_name) self.dataset_name = args.dataset_name def load(self): data_dict = {} data_dict['train'] = self._load_label('train.txt') data_dict['test'] = self._load_label('test.txt') data_dict['class'] = self._load_class() return data_dict def _load_label(self, txt_name): with open(f'{self.args.data_path}/{self.args.dataset_name}/{txt_name}', encoding='utf-8') as f: txt_list = [_.strip().split(' ') for _ in f.readlines()] # 读取所有图片路径和类别号 data_list = [['', 0] for _ in range(len(txt_list))] # [图片路径,类别独热编码] for i, line in enumerate(txt_list): image_path = line[0] # print(image_path) data_list[i][0] = image_path data_list[i][1] = np.zeros(self.args.output_class, dtype=np.float32) for j in line[1:]: data_list[i][1][int(j)] = 1 return data_list def _load_class(self): with open(f'{self.args.data_path}/{self.args.dataset_name}/class.txt', encoding='utf-8') as f: txt_list = [_.strip() for _ in f.readlines()] return txt_list if __name__ == '__main__': import argparse parser = argparse.ArgumentParser(description='Data loader for specific dataset') parser.add_argument('--data_path', default='../dataset', type=str, help='Root path to datasets') parser.add_argument('--dataset_name', default='CIFAR-10', type=str, help='Specific dataset name') parser.add_argument('--output_class', default=10, type=int, help='Number of output classes') parser.add_argument('--input_size', default=640, type=int) args = parser.parse_args() data_dict = data_get(args) print(len(data_dict['train']))