浏览代码

修改模型训练代码,修改评估标准

liyan 1 年之前
父节点
当前提交
580f248232
共有 8 个文件被更改,包括 115 次插入61 次删除
  1. 5 7
      block/data_get.py
  2. 1 1
      block/loss_get.py
  3. 5 5
      block/model_get.py
  4. 29 18
      block/train_get.py
  5. 37 4
      block/val_get.py
  6. 1 1
      model/__init__.py
  7. 37 25
      run.py
  8. 0 0
      training_embedding.py

+ 5 - 7
block/data_get.py

@@ -1,4 +1,3 @@
-
 # 数据格式定义部分
 # 数据需准备成以下格式
 # ├── 数据集路径:data_path
@@ -6,13 +5,12 @@
 #     └── train.txt:训练图片的绝对路径(或相对data_path下路径)和类别号,(image/mask/0.jpg 0 2\n)表示该图片类别为0和2,空类别图片无类别号
 #     └── val.txt:验证图片的绝对路径(或相对data_path下路径)和类别
 #     └── class.txt:所有的类别名称
-# class.csv内容如下:
+# class.txt内容如下:
 # 类别1
 # 类别2
 
 import numpy as np
 import os
-import argparse
 
 def data_get(args):
     data_dict = data_prepare(args).load()
@@ -33,7 +31,7 @@ class data_prepare:
         return data_dict
 
     def _load_label(self, txt_name):
-        with open(f'{self.args.data_path}/{self.args.dataset_name}/{txt_name}', encoding='utf-8')as f:
+        with open(f'{self.args.data_path}/{self.args.dataset_name}/{txt_name}', encoding='utf-8') as f:
             txt_list = [_.strip().split(' ') for _ in f.readlines()]  # 读取所有图片路径和类别号
         data_list = [['', 0] for _ in range(len(txt_list))]  # [图片路径,类别独热编码]
         for i, line in enumerate(txt_list):
@@ -46,7 +44,7 @@ class data_prepare:
         return data_list
 
     def _load_class(self):
-        with open(f'{self.args.data_path}/{self.args.dataset_name}/class.txt', encoding='utf-8')as f:
+        with open(f'{self.args.data_path}/{self.args.dataset_name}/class.txt', encoding='utf-8') as f:
             txt_list = [_.strip() for _ in f.readlines()]
         return txt_list
 
@@ -55,10 +53,10 @@ if __name__ == '__main__':
     import argparse
 
     parser = argparse.ArgumentParser(description='Data loader for specific dataset')
-    parser.add_argument('--data_path', default='/home/yhsun/classification-main/dataset', type=str, help='Root path to datasets')
+    parser.add_argument('--data_path', default='../dataset', type=str, help='Root path to datasets')
     parser.add_argument('--dataset_name', default='CIFAR-10', type=str, help='Specific dataset name')
     parser.add_argument('--output_class', default=10, type=int, help='Number of output classes')
     parser.add_argument('--input_size', default=640, type=int)
     args = parser.parse_args()
     data_dict = data_get(args)
-    print(len(data_dict['train']))
+    print(len(data_dict['train']))

+ 1 - 1
block/loss_get.py

@@ -2,6 +2,6 @@ import torch
 
 
 def loss_get(args):
-    choice_dict = {'bce': 'torch.nn.BCEWithLogitsLoss()'}
+    choice_dict = {'bce': 'torch.nn.BCEWithLogitsLoss()','cross': 'torch.nn.CrossEntropyLoss()'}
     loss = eval(choice_dict[args.loss])
     return loss

+ 5 - 5
block/model_get.py

@@ -3,7 +3,7 @@ import torch
 
 choice_dict = {
     'yolov7_cls': 'model_prepare(args).yolov7_cls()',
-    'timm_model': 'model_prepare(args).timm_model()',
+    # 'timm_model': 'model_prepare(args).timm_model()',
     'Alexnet': 'model_prepare(args).Alexnet()',
     'badnet': 'model_prepare(args).badnet()',
     'GoogleNet': 'model_prepare(args).GoogleNet()',
@@ -114,10 +114,10 @@ class model_prepare:
     def __init__(self, args):
         self.args = args
 
-    def timm_model(self):
-        from model.timm_model import timm_model
-        model = timm_model(self.args)
-        return model
+    # def timm_model(self):
+    #     from model.timm_model import timm_model
+    #     model = timm_model(self.args)
+    #     return model
 
     def yolov7_cls(self):
         from model.yolov7_cls import yolov7_cls

+ 29 - 18
block/train_get.py

@@ -32,7 +32,8 @@ def train_get(args, data_dict, model_dict, loss):
     train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=args.batch, shuffle=train_shuffle,
                                                    drop_last=True, pin_memory=args.latch, num_workers=args.num_worker,
                                                    sampler=train_sampler)
-    val_dataset = torch_dataset(args, 'test', data_dict['test'], data_dict['class'])
+    # 验证集不对图像进行处理
+    val_dataset = torch_dataset(args, 'test', data_dict['test'], data_dict['class'], train=False)
     val_sampler = None  # 分布式时数据合在主GPU上进行验证
     val_batch = args.batch // args.device_number  # 分布式验证时batch要减少为一个GPU的量
     val_dataloader = torch.utils.data.DataLoader(val_dataset, batch_size=val_batch, shuffle=False,
@@ -106,8 +107,9 @@ def train_get(args, data_dict, model_dict, loss):
         torch.cuda.empty_cache()
         # 验证
         if args.local_rank == 0:  # 分布式时只验证一次
-            val_loss, accuracy, precision, recall, m_ap = val_get(args, val_dataloader, model, loss, ema,
-                                                                  len(data_dict['test']))
+            # val_loss, accuracy, precision, recall, m_ap = val_get(args, val_dataloader, model, loss, ema,
+            #                                                       len(data_dict['test']))
+            val_loss, accuracy = val_get(args, val_dataloader, model, loss, ema, len(data_dict['test']))
         # 保存
         if args.local_rank == 0:  # 分布式时只保存一次
             model_dict['model'] = model.module if args.distributed else model
@@ -118,39 +120,48 @@ def train_get(args, data_dict, model_dict, loss):
             model_dict['train_loss'] = train_loss
             model_dict['val_loss'] = val_loss
             model_dict['val_accuracy'] = accuracy
-            model_dict['val_precision'] = precision
-            model_dict['val_recall'] = recall
-            model_dict['val_m_ap'] = m_ap
+            # model_dict['val_precision'] = precision
+            # model_dict['val_recall'] = recall
+            # model_dict['val_m_ap'] = m_ap
             torch.save(model_dict, args.save_path_last if not args.prune else 'prune_last.pt')  # 保存最后一次训练的模型
-            if m_ap > 0.5 and m_ap > model_dict['standard']:
-                model_dict['standard'] = m_ap
+            # if m_ap > 0.5 and m_ap > model_dict['standard']:
+            #     model_dict['standard'] = m_ap
+            #     save_path = args.save_path if not args.prune else args.prune_save
+            #     torch.save(model_dict, save_path)  # 保存最佳模型
+            #     print(f'| 保存最佳模型:{save_path} | val_m_ap:{m_ap:.4f} |')
+            if accuracy > 0.5 and accuracy > model_dict['standard']:
+                model_dict['standard'] = accuracy
                 save_path = args.save_path if not args.prune else args.prune_save
                 torch.save(model_dict, save_path)  # 保存最佳模型
-                print(f'| 保存最佳模型:{save_path} | val_m_ap:{m_ap:.4f} |')
+                print(f'| 保存最佳模型:{save_path} | accuracy:{(100 * accuracy):.4f}% |')
             # wandb
             if args.wandb:
                 wandb_log = {}
                 if epoch == 0:
                     wandb_log.update({f'image/train_image': wandb_image_list})
+                # wandb_log.update({'metric/train_loss': train_loss,
+                #                   'metric/val_loss': val_loss,
+                #                   'metric/val_m_ap': m_ap,
+                #                   'metric/val_accuracy': accuracy,
+                #                   'metric/val_precision': precision,
+                #                   'metric/val_recall': recall})
                 wandb_log.update({'metric/train_loss': train_loss,
                                   'metric/val_loss': val_loss,
-                                  'metric/val_m_ap': m_ap,
-                                  'metric/val_accuracy': accuracy,
-                                  'metric/val_precision': precision,
-                                  'metric/val_recall': recall})
+                                  'metric/val_accuracy': accuracy})
                 args.wandb_run.log(wandb_log)
         torch.distributed.barrier() if args.distributed else None  # 分布式时每轮训练后让所有GPU进行同步,快的GPU会在此等待
 
 
 class torch_dataset(torch.utils.data.Dataset):
-    def __init__(self, args, tag, data, class_name):
+    def __init__(self, args, tag, data, class_name, train=True):
         self.tag = tag
         self.data = data
         self.class_name = class_name
+        self.train = train
         self.noise_probability = args.noise
         self.noise = albumentations.Compose([
             albumentations.GaussianBlur(blur_limit=(5, 5), p=0.2),
-            albumentations.GaussNoise(var_limit=(10.0, 30.0), p=0.2)])
+            albumentations.GaussNoise(var_limit=(10.0, 30.0), p=0.2),], )
         self.transform = albumentations.Compose([
             albumentations.LongestMaxSize(args.input_size),
             albumentations.PadIfNeeded(min_height=args.input_size, min_width=args.input_size,
@@ -162,11 +173,11 @@ class torch_dataset(torch.utils.data.Dataset):
         return len(self.data)
 
     def __getitem__(self, index):
-        # print(self.data[index][0])
         image = cv2.imread(self.data[index][0])  # 读取图片
-        if self.tag == 'train' and torch.rand(1) < self.noise_probability:  # 使用数据加噪
+        if self.tag == 'train' and torch.rand(1) < self.noise_probability and self.train:  # 使用数据加噪
             image = self.noise(image=image)['image']
-        image = self.transform(image=image)['image']  # 缩放和填充图片
+        if self.train:
+            image = self.transform(image=image)['image']  # 缩放和填充图片
         image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)  # 转为RGB通道
         image = self._image_deal(image)  # 归一化、转换为tensor、调维度
         label = torch.tensor(self.data[index][1], dtype=torch.float32)  # 转换为tensor

+ 37 - 4
block/val_get.py

@@ -3,6 +3,32 @@ import torch
 from block.metric_get import metric
 
 
+# def val_get(args, val_dataloader, model, loss, ema, data_len):
+#     tqdm_len = (data_len - 1) // (args.batch // args.device_number) + 1
+#     tqdm_show = tqdm.tqdm(total=tqdm_len)
+#     with torch.no_grad():
+#         model = ema.ema if args.ema else model.eval()
+#         pred_all = []  # 记录所有预测
+#         true_all = []  # 记录所有标签
+#         for index, (image_batch, true_batch) in enumerate(val_dataloader):
+#             image_batch = image_batch.to(args.device, non_blocking=args.latch)
+#             pred_batch = model(image_batch).detach().cpu()
+#             loss_batch = loss(pred_batch, true_batch)
+#             pred_all.extend(pred_batch)
+#             true_all.extend(true_batch)
+#             tqdm_show.set_postfix({'val_loss': loss_batch.item()})  # 添加显示
+#             tqdm_show.update(1)  # 更新进度条
+#         # tqdm
+#         tqdm_show.close()
+#         # 计算指标
+#         pred_all = torch.stack(pred_all, dim=0)
+#         true_all = torch.stack(true_all, dim=0)
+#         loss_all = loss(pred_all, true_all).item()
+#         accuracy, precision, recall, m_ap = metric(pred_all, true_all, args.class_threshold)
+#         print(f'\n| 验证 | val_loss:{loss_all:.4f} | 阈值:{args.class_threshold:.2f} | val_accuracy:{accuracy:.4f} |'
+#               f' val_precision:{precision:.4f} | val_recall:{recall:.4f} | val_m_ap:{m_ap:.4f} |')
+#     return loss_all, accuracy, precision, recall, m_ap
+
 def val_get(args, val_dataloader, model, loss, ema, data_len):
     tqdm_len = (data_len - 1) // (args.batch // args.device_number) + 1
     tqdm_show = tqdm.tqdm(total=tqdm_len)
@@ -10,12 +36,19 @@ def val_get(args, val_dataloader, model, loss, ema, data_len):
         model = ema.ema if args.ema else model.eval()
         pred_all = []  # 记录所有预测
         true_all = []  # 记录所有标签
+        correct = 0
+        total = 0
         for index, (image_batch, true_batch) in enumerate(val_dataloader):
             image_batch = image_batch.to(args.device, non_blocking=args.latch)
             pred_batch = model(image_batch).detach().cpu()
             loss_batch = loss(pred_batch, true_batch)
             pred_all.extend(pred_batch)
             true_all.extend(true_batch)
+            # 计算准确率
+            _, predicted = torch.max(pred_batch.data, 1)
+            labels = torch.argmax(true_batch, dim=1)
+            total += true_batch.size(0)
+            correct += (predicted == labels).sum().item()
             tqdm_show.set_postfix({'val_loss': loss_batch.item()})  # 添加显示
             tqdm_show.update(1)  # 更新进度条
         # tqdm
@@ -24,7 +57,7 @@ def val_get(args, val_dataloader, model, loss, ema, data_len):
         pred_all = torch.stack(pred_all, dim=0)
         true_all = torch.stack(true_all, dim=0)
         loss_all = loss(pred_all, true_all).item()
-        accuracy, precision, recall, m_ap = metric(pred_all, true_all, args.class_threshold)
-        print(f'\n| 验证 | val_loss:{loss_all:.4f} | 阈值:{args.class_threshold:.2f} | val_accuracy:{accuracy:.4f} |'
-              f' val_precision:{precision:.4f} | val_recall:{recall:.4f} | val_m_ap:{m_ap:.4f} |')
-    return loss_all, accuracy, precision, recall, m_ap
+        accuracy = correct / (total + 1e-5)
+        print(
+            f'\n| 验证 | val_loss:{loss_all:.4f} | 阈值:{args.class_threshold:.2f} | val_accuracy:{(100 * correct / (total + 1e-5)):.2f}% |')
+    return loss_all, accuracy

+ 1 - 1
model/__init__.py

@@ -1,3 +1,3 @@
-from .timm_model import timm_model
+# from .timm_model import timm_model
 from .yolov7_cls import yolov7_cls
 from .layer import cbs, elan, mp, sppcspc, linear_head

+ 37 - 25
run.py

@@ -4,7 +4,7 @@
 #     └── train.txt:训练图片的绝对路径(或相对data_path下路径)和类别号,(image/mask/0.jpg 0 2\n)表示该图片类别为0和2,空类别图片无类别号
 #     └── val.txt:验证图片的绝对路径(或相对data_path下路径)和类别
 #     └── class.txt:所有的类别名称
-# class.csv内容如下
+# class.txt
 # 类别1
 # 类别2
 # ...
@@ -15,7 +15,7 @@
 # n为GPU数量
 # -------------------------------------------------------------------------------------------------------------------- #
 import os
-# import wandb
+import wandb
 import torch
 import argparse
 from block.data_get import data_get
@@ -23,23 +23,26 @@ from block.loss_get import loss_get
 from block.model_get import model_get
 from block.train_get import train_get
 
+# 获取当前文件路径
+pwd = os.getcwd()
+
 # -------------------------------------------------------------------------------------------------------------------- #
 # 模型加载/创建的优先级为:加载已有模型>创建剪枝模型>创建timm库模型>创建自定义模型
 parser = argparse.ArgumentParser(description='|针对分类任务,添加水印机制,包含数据隐私、模型水印|')
-# parser.add_argument('--wandb', default=False, type=bool, help='|是否使用wandb可视化|')
-# parser.add_argument('--wandb_project', default='classification', type=str, help='|wandb项目名称|')
-# parser.add_argument('--wandb_name', default='train', type=str, help='|wandb项目中的训练名称|')
-# parser.add_argument('--wandb_image_num', default=16, type=int, help='|wandb保存图片的数量|')
+parser.add_argument('--wandb', default=False, type=bool, help='|是否使用wandb可视化|')
+parser.add_argument('--wandb_project', default='classification', type=str, help='|wandb项目名称|')
+parser.add_argument('--wandb_name', default='train', type=str, help='|wandb项目中的训练名称|')
+parser.add_argument('--wandb_image_num', default=16, type=int, help='|wandb保存图片的数量|')
 
 # new_added
-parser.add_argument('--data_path', default='./dataset', type=str, help='Root path to datasets')
+parser.add_argument('--data_path', default=f'{pwd}/dataset', type=str, help='Root path to datasets')
 parser.add_argument('--dataset_name', default='CIFAR-10', type=str, help='Specific dataset name')
 parser.add_argument('--input_channels', default=3, type=int)
 parser.add_argument('--output_num', default=10, type=int)
 # parser.add_argument('--input_size', default=32, type=int)
 # 触发集标签定义,黑盒水印植入,这里需要调用它,用于处理部分数据的
-parser.add_argument('--trigger_label', type=int, default=2, help='The NO. of trigger label (int, range from 0 to 10, default: 0)')
-# 这里可以直接选择水印控制,看看如何选择调用进来
+parser.add_argument('--trigger_label', type=int, default=2, help='The NO. of trigger label (int, range from 0 to 10, default: 2)')
+# 设置数据集中添加水印的文件占比,这里可以直接选择水印控制,看看如何选择调用进来
 parser.add_argument('--watermarking_portion', type=float, default=0.1, help='poisoning portion (float, range from 0 to 1, default: 0.1)')
 
 # 待修改
@@ -58,15 +61,16 @@ parser.add_argument('--prune_save', default='prune_best.pt', type=str, help='|
 
 # 模型处理的部分了
 parser.add_argument('--timm', default=False, type=bool, help='|是否使用timm库创建模型|')
-parser.add_argument('--model', default='mobilenetv2', type=str, help='|自定义模型选择,timm为True时为timm库中模型|')
+parser.add_argument('--model', default='Alexnet', type=str, help='|自定义模型选择,timm为True时为timm库中模型|')
 parser.add_argument('--model_type', default='s', type=str, help='|自定义模型型号|')
-parser.add_argument('--save_path', default='./checkpoints/mobilenetv2/best.pt', type=str, help='|保存最佳模型,除此之外每轮还会保存last.pt|')
-parser.add_argument('--save_path_last', default='./checkpoints/mobilenetv2/last.pt', type=str, help='|保存最佳模型,除此之外每轮还会保存last.pt|')
+# parser.add_argument('--save_path', default=f'{pwd}/weight/best.pt', type=str, help='|保存最佳模型,除此之外每轮还会保存last.pt|')
+# parser.add_argument('--save_path_last', default=f'{pwd}/weight/last.pt', type=str, help='|保存最佳模型,除此之外每轮还会保存last.pt|')
 
 # 训练控制
 parser.add_argument('--epoch', default=20, type=int, help='|训练总轮数(包含之前已训练轮数)|')
-parser.add_argument('--batch', default=100, type=int, help='|训练批量大小,分布式时为总批量|')
-parser.add_argument('--loss', default='bce', type=str, help='|损失函数|')
+parser.add_argument('--batch', default=600, type=int, help='|训练批量大小,分布式时为总批量|')
+# parser.add_argument('--loss', default='bce', type=str, help='|损失函数|')
+parser.add_argument('--loss', default='cross', type=str, help='|损失函数|')
 parser.add_argument('--warmup_ratio', default=0.01, type=float, help='|预热训练步数占总步数比例,最少5步,基准为0.01|')
 parser.add_argument('--lr_start', default=0.001, type=float, help='|初始学习率,adam算法,批量小时要减小,基准为0.001|')
 parser.add_argument('--lr_end_ratio', default=0.01, type=float, help='|最终学习率=lr_end_ratio*lr_start,基准为0.01|')
@@ -77,19 +81,27 @@ parser.add_argument('--device', default='cuda', type=str, help='|训练设备|')
 parser.add_argument('--latch', default=True, type=bool, help='|模型和数据是否为锁存,True为锁存|')
 parser.add_argument('--num_worker', default=0, type=int, help='|CPU处理数据的进程数,0只有一个主进程,一般为0、2、4、8|')
 parser.add_argument('--ema', default=True, type=bool, help='|使用平均指数移动(EMA)调整参数|')
-parser.add_argument('--amp', default=True, type=bool, help='|混合float16精度训练,CPU时不可用,出现nan可能与GPU有关|')
+parser.add_argument('--amp', default=False, type=bool, help='|混合float16精度训练,CPU时不可用,出现nan可能与GPU有关|')
 parser.add_argument('--noise', default=0.5, type=float, help='|训练数据加噪概率|')
 parser.add_argument('--class_threshold', default=0.5, type=float, help='|计算指标时,大于阈值判定为图片有该类别|')
 parser.add_argument('--distributed', default=False, type=bool, help='|单机多卡分布式训练,分布式训练时batch为总batch|')
 parser.add_argument('--local_rank', default=0, type=int, help='|分布式训练使用命令后会自动传入的参数|')
 args = parser.parse_args()
-args.device_number = max(torch.cuda.device_count(), 2)  # 使用的GPU数,可能为CPU
+args.device_number = max(torch.cuda.device_count(), 1)  # 使用的GPU数,可能为CPU
+args.save_path = f'{pwd}/weight/{args.model}/best.pt'
+args.save_path_last = f'{pwd}/weight/{args.model}/last.pt'
 
 # 创建模型对应的检查点目录
-checkpoint_dir = os.path.join('./checkpoints', args.model)
-if not os.path.exists(checkpoint_dir):
-    os.makedirs(checkpoint_dir)
-print(f"模型保存路径已创建: {args.model}")
+# checkpoint_dir = os.path.join('./checkpoints', args.model)
+# if not os.path.exists(checkpoint_dir):
+#     os.makedirs(checkpoint_dir)
+# print(f"{args.model}模型检查点保存路径已创建: {checkpoint_dir}")
+
+# 创建模型保存权重目录
+weight_dir = f'{pwd}/weight/{args.model}'
+if not os.path.exists(weight_dir):
+    os.makedirs(weight_dir)
+print(f"{args.model}模型权重保存路径已创建: {weight_dir}")
 
 # 为CPU设置随机种子
 torch.manual_seed(999)
@@ -126,11 +138,11 @@ if args.local_rank == 0:
         print(f'| 加载已有模型:{args.weight} |')
     elif args.prune:
         print(f'| 加载模型+剪枝训练:{args.prune_weight} |')
-    elif args.timm:  # 创建timm库中模型args.timm
-        import timm
-
-        assert timm.list_models(args.model), f'! timm中没有模型:{args.model},使用timm.list_models()查看所有模型 !'
-        print(f'| 创建timm库中模型:{args.model} |')
+    # elif args.timm:  # 创建timm库中模型args.timm
+    #     import timm
+    #
+    #     assert timm.list_models(args.model), f'! timm中没有模型:{args.model},使用timm.list_models()查看所有模型 !'
+    #     print(f'| 创建timm库中模型:{args.model} |')
     else:  # 创建自定义模型args.model
         assert os.path.exists(f'model/{args.model}.py'), f'! 没有自定义模型:{args.model} !'
         print(f'| 创建自定义模型:{args.model} | 型号:{args.model_type} |')

training_embedding copy.py → training_embedding.py