Przeglądaj źródła

去除无关文件及其引用

liyan 1 rok temu
rodzic
commit
31510ba2a0
3 zmienionych plików z 0 dodań i 66 usunięć
  1. 0 64
      block/data_get.py
  2. 0 1
      train.py
  3. 0 1
      train_embed.py

+ 0 - 64
block/data_get.py

@@ -1,64 +0,0 @@
-
-# 数据格式定义部分
-# 数据需准备成以下格式
-# ├── 数据集路径:data_path
-#     └── image:存放所有图片
-#     └── train.txt:训练图片的绝对路径(或相对data_path下路径)和类别号,(image/mask/0.jpg 0 2\n)表示该图片类别为0和2,空类别图片无类别号
-#     └── val.txt:验证图片的绝对路径(或相对data_path下路径)和类别
-#     └── class.txt:所有的类别名称
-# class.csv内容如下:
-# 类别1
-# 类别2
-
-import numpy as np
-import os
-import argparse
-
-def data_get(args):
-    data_dict = data_prepare(args).load()
-    return data_dict
-
-
-class data_prepare:
-    def __init__(self, args):
-        self.args = args
-        self.data_path = os.path.join(args.data_path, args.dataset_name)
-        self.dataset_name = args.dataset_name
-
-    def load(self):
-        data_dict = {}
-        data_dict['train'] = self._load_label('train.txt')
-        data_dict['test'] = self._load_label('test.txt')
-        data_dict['class'] = self._load_class()
-        return data_dict
-
-    def _load_label(self, txt_name):
-        with open(f'{self.args.data_path}/{self.args.dataset_name}/{txt_name}', encoding='utf-8')as f:
-            txt_list = [_.strip().split(' ') for _ in f.readlines()]  # 读取所有图片路径和类别号
-        data_list = [['', 0] for _ in range(len(txt_list))]  # [图片路径,类别独热编码]
-        for i, line in enumerate(txt_list):
-            image_path = line[0]
-            # print(image_path)
-            data_list[i][0] = image_path
-            data_list[i][1] = np.zeros(self.args.output_class, dtype=np.float32)
-            for j in line[1:]:
-                data_list[i][1][int(j)] = 1
-        return data_list
-
-    def _load_class(self):
-        with open(f'{self.args.data_path}/{self.args.dataset_name}/class.txt', encoding='utf-8')as f:
-            txt_list = [_.strip() for _ in f.readlines()]
-        return txt_list
-
-
-if __name__ == '__main__':
-    import argparse
-
-    parser = argparse.ArgumentParser(description='Data loader for specific dataset')
-    parser.add_argument('--data_path', default='/home/yhsun/classification-main/dataset', type=str, help='Root path to datasets')
-    parser.add_argument('--dataset_name', default='CIFAR-10', type=str, help='Specific dataset name')
-    parser.add_argument('--output_class', default=10, type=int, help='Number of output classes')
-    parser.add_argument('--input_size', default=640, type=int)
-    args = parser.parse_args()
-    data_dict = data_get(args)
-    print(len(data_dict['train']))

+ 0 - 1
train.py

@@ -18,7 +18,6 @@ import os
 import wandb
 import torch
 import argparse
-from block.data_get import data_get
 from block.loss_get import loss_get
 from block.model_get import model_get
 from block.train_get import train_get

+ 0 - 1
train_embed.py

@@ -20,7 +20,6 @@ import torch
 import argparse
 
 from block import secret_get
-from block.data_get import data_get
 from block.loss_get import loss_get
 from block.model_get import model_get
 from block.train_with_watermark import train_embed