|
@@ -1,64 +0,0 @@
|
|
-
|
|
|
|
-# 数据格式定义部分
|
|
|
|
-# 数据需准备成以下格式
|
|
|
|
-# ├── 数据集路径:data_path
|
|
|
|
-# └── image:存放所有图片
|
|
|
|
-# └── train.txt:训练图片的绝对路径(或相对data_path下路径)和类别号,(image/mask/0.jpg 0 2\n)表示该图片类别为0和2,空类别图片无类别号
|
|
|
|
-# └── val.txt:验证图片的绝对路径(或相对data_path下路径)和类别
|
|
|
|
-# └── class.txt:所有的类别名称
|
|
|
|
-# class.csv内容如下:
|
|
|
|
-# 类别1
|
|
|
|
-# 类别2
|
|
|
|
-
|
|
|
|
-import numpy as np
|
|
|
|
-import os
|
|
|
|
-import argparse
|
|
|
|
-
|
|
|
|
-def data_get(args):
|
|
|
|
- data_dict = data_prepare(args).load()
|
|
|
|
- return data_dict
|
|
|
|
-
|
|
|
|
-
|
|
|
|
-class data_prepare:
|
|
|
|
- def __init__(self, args):
|
|
|
|
- self.args = args
|
|
|
|
- self.data_path = os.path.join(args.data_path, args.dataset_name)
|
|
|
|
- self.dataset_name = args.dataset_name
|
|
|
|
-
|
|
|
|
- def load(self):
|
|
|
|
- data_dict = {}
|
|
|
|
- data_dict['train'] = self._load_label('train.txt')
|
|
|
|
- data_dict['test'] = self._load_label('test.txt')
|
|
|
|
- data_dict['class'] = self._load_class()
|
|
|
|
- return data_dict
|
|
|
|
-
|
|
|
|
- def _load_label(self, txt_name):
|
|
|
|
- with open(f'{self.args.data_path}/{self.args.dataset_name}/{txt_name}', encoding='utf-8')as f:
|
|
|
|
- txt_list = [_.strip().split(' ') for _ in f.readlines()] # 读取所有图片路径和类别号
|
|
|
|
- data_list = [['', 0] for _ in range(len(txt_list))] # [图片路径,类别独热编码]
|
|
|
|
- for i, line in enumerate(txt_list):
|
|
|
|
- image_path = line[0]
|
|
|
|
- # print(image_path)
|
|
|
|
- data_list[i][0] = image_path
|
|
|
|
- data_list[i][1] = np.zeros(self.args.output_class, dtype=np.float32)
|
|
|
|
- for j in line[1:]:
|
|
|
|
- data_list[i][1][int(j)] = 1
|
|
|
|
- return data_list
|
|
|
|
-
|
|
|
|
- def _load_class(self):
|
|
|
|
- with open(f'{self.args.data_path}/{self.args.dataset_name}/class.txt', encoding='utf-8')as f:
|
|
|
|
- txt_list = [_.strip() for _ in f.readlines()]
|
|
|
|
- return txt_list
|
|
|
|
-
|
|
|
|
-
|
|
|
|
-if __name__ == '__main__':
|
|
|
|
- import argparse
|
|
|
|
-
|
|
|
|
- parser = argparse.ArgumentParser(description='Data loader for specific dataset')
|
|
|
|
- parser.add_argument('--data_path', default='/home/yhsun/classification-main/dataset', type=str, help='Root path to datasets')
|
|
|
|
- parser.add_argument('--dataset_name', default='CIFAR-10', type=str, help='Specific dataset name')
|
|
|
|
- parser.add_argument('--output_class', default=10, type=int, help='Number of output classes')
|
|
|
|
- parser.add_argument('--input_size', default=640, type=int)
|
|
|
|
- args = parser.parse_args()
|
|
|
|
- data_dict = data_get(args)
|
|
|
|
- print(len(data_dict['train']))
|
|
|