Преглед на файлове

修改模型训练流程

liyan преди 1 година
родител
ревизия
6d2950d785
променени са 7 файла, в които са добавени 132 реда и са изтрити 157 реда
  1. 4 1
      block/loss_get.py
  2. 95 45
      block/train_get.py
  3. 14 12
      block/val_get.py
  4. 0 23
      flask_request.py
  5. 0 40
      flask_start.py
  6. 1 3
      model/__init__.py
  7. 18 33
      run.py

+ 4 - 1
block/loss_get.py

@@ -2,6 +2,9 @@ import torch
 
 
 
 
 def loss_get(args):
 def loss_get(args):
-    choice_dict = {'bce': 'torch.nn.BCEWithLogitsLoss()'}
+    choice_dict = {
+        'bce': 'torch.nn.BCEWithLogitsLoss()',
+        'cross':'torch.nn.CrossEntropyLoss()'
+    }
     loss = eval(choice_dict[args.loss])
     loss = eval(choice_dict[args.loss])
     return loss
     return loss

+ 95 - 45
block/train_get.py

@@ -1,9 +1,13 @@
+import os
+
 import cv2
 import cv2
 import tqdm
 import tqdm
 import wandb
 import wandb
 import torch
 import torch
 import numpy as np
 import numpy as np
-import albumentations
+# import albumentations
+from PIL import Image
+from torchvision import transforms
 from block.val_get import val_get
 from block.val_get import val_get
 from block.model_ema import model_ema
 from block.model_ema import model_ema
 from block.lr_get import adam, lr_adjust
 from block.lr_get import adam, lr_adjust
@@ -26,13 +30,26 @@ def train_get(args, data_dict, model_dict, loss):
     if args.ema:
     if args.ema:
         ema.updates = model_dict['ema_updates']
         ema.updates = model_dict['ema_updates']
     # 数据集
     # 数据集
-    train_dataset = torch_dataset(args, 'train', data_dict['train'], data_dict['class'])
+    print("加载训练集至内存中...")
+    train_transform = transforms.Compose([
+        transforms.RandomHorizontalFlip(),  # 随机水平翻转
+        transforms.RandomCrop(32, padding=4),  # 随机裁剪并填充
+        transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1),  # 颜色抖动
+        transforms.ToTensor(),  # 将图像转换为PyTorch张量
+        transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])  # 标准化
+    ])
+    train_dataset = CustomDataset(data_dir=args.train_dir, transform=train_transform)
     train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset) if args.distributed else None
     train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset) if args.distributed else None
     train_shuffle = False if args.distributed else True  # 分布式设置sampler后shuffle要为False
     train_shuffle = False if args.distributed else True  # 分布式设置sampler后shuffle要为False
     train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=args.batch, shuffle=train_shuffle,
     train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=args.batch, shuffle=train_shuffle,
                                                    drop_last=True, pin_memory=args.latch, num_workers=args.num_worker,
                                                    drop_last=True, pin_memory=args.latch, num_workers=args.num_worker,
                                                    sampler=train_sampler)
                                                    sampler=train_sampler)
-    val_dataset = torch_dataset(args, 'test', data_dict['test'], data_dict['class'])
+    print("加载验证集至内存中...")
+    val_transform = transforms.Compose([
+        transforms.ToTensor(),  # 将图像转换为PyTorch张量
+        transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])  # 标准化
+    ])
+    val_dataset = CustomDataset(data_dir=args.test_dir, transform=val_transform)
     val_sampler = None  # 分布式时数据合在主GPU上进行验证
     val_sampler = None  # 分布式时数据合在主GPU上进行验证
     val_batch = args.batch // args.device_number  # 分布式验证时batch要减少为一个GPU的量
     val_batch = args.batch // args.device_number  # 分布式验证时batch要减少为一个GPU的量
     val_dataloader = torch.utils.data.DataLoader(val_dataset, batch_size=val_batch, shuffle=False,
     val_dataloader = torch.utils.data.DataLoader(val_dataset, batch_size=val_batch, shuffle=False,
@@ -106,8 +123,8 @@ def train_get(args, data_dict, model_dict, loss):
         torch.cuda.empty_cache()
         torch.cuda.empty_cache()
         # 验证
         # 验证
         if args.local_rank == 0:  # 分布式时只验证一次
         if args.local_rank == 0:  # 分布式时只验证一次
-            val_loss, accuracy, precision, recall, m_ap = val_get(args, val_dataloader, model, loss, ema,
-                                                                  len(data_dict['test']))
+            val_loss, accuracy = val_get(args, val_dataloader, model, loss, ema,
+                                         len(data_dict['test']))
         # 保存
         # 保存
         if args.local_rank == 0:  # 分布式时只保存一次
         if args.local_rank == 0:  # 分布式时只保存一次
             model_dict['model'] = model.module if args.distributed else model
             model_dict['model'] = model.module if args.distributed else model
@@ -118,15 +135,12 @@ def train_get(args, data_dict, model_dict, loss):
             model_dict['train_loss'] = train_loss
             model_dict['train_loss'] = train_loss
             model_dict['val_loss'] = val_loss
             model_dict['val_loss'] = val_loss
             model_dict['val_accuracy'] = accuracy
             model_dict['val_accuracy'] = accuracy
-            model_dict['val_precision'] = precision
-            model_dict['val_recall'] = recall
-            model_dict['val_m_ap'] = m_ap
             torch.save(model_dict, args.save_path_last if not args.prune else 'prune_last.pt')  # 保存最后一次训练的模型
             torch.save(model_dict, args.save_path_last if not args.prune else 'prune_last.pt')  # 保存最后一次训练的模型
-            if m_ap > 0.5 and m_ap > model_dict['standard']:
-                model_dict['standard'] = m_ap
+            if accuracy > 0.5 and accuracy > model_dict['standard']:
+                model_dict['standard'] = accuracy
                 save_path = args.save_path if not args.prune else args.prune_save
                 save_path = args.save_path if not args.prune else args.prune_save
                 torch.save(model_dict, save_path)  # 保存最佳模型
                 torch.save(model_dict, save_path)  # 保存最佳模型
-                print(f'| 保存最佳模型:{save_path} | val_m_ap:{m_ap:.4f} |')
+                print(f'| 保存最佳模型:{save_path} | accuracy:{accuracy:.4f} |')
             # wandb
             # wandb
             if args.wandb:
             if args.wandb:
                 wandb_log = {}
                 wandb_log = {}
@@ -134,44 +148,80 @@ def train_get(args, data_dict, model_dict, loss):
                     wandb_log.update({f'image/train_image': wandb_image_list})
                     wandb_log.update({f'image/train_image': wandb_image_list})
                 wandb_log.update({'metric/train_loss': train_loss,
                 wandb_log.update({'metric/train_loss': train_loss,
                                   'metric/val_loss': val_loss,
                                   'metric/val_loss': val_loss,
-                                  'metric/val_m_ap': m_ap,
-                                  'metric/val_accuracy': accuracy,
-                                  'metric/val_precision': precision,
-                                  'metric/val_recall': recall})
+                                  'metric/val_accuracy': accuracy
+                                  })
                 args.wandb_run.log(wandb_log)
                 args.wandb_run.log(wandb_log)
         torch.distributed.barrier() if args.distributed else None  # 分布式时每轮训练后让所有GPU进行同步,快的GPU会在此等待
         torch.distributed.barrier() if args.distributed else None  # 分布式时每轮训练后让所有GPU进行同步,快的GPU会在此等待
 
 
 
 
-class torch_dataset(torch.utils.data.Dataset):
-    def __init__(self, args, tag, data, class_name):
-        self.tag = tag
-        self.data = data
-        self.class_name = class_name
-        self.noise_probability = args.noise
-        self.noise = albumentations.Compose([
-            albumentations.GaussianBlur(blur_limit=(5, 5), p=0.2),
-            albumentations.GaussNoise(var_limit=(10.0, 30.0), p=0.2)])
-        self.transform = albumentations.Compose([
-            albumentations.LongestMaxSize(args.input_size),
-            albumentations.PadIfNeeded(min_height=args.input_size, min_width=args.input_size,
-                                       border_mode=cv2.BORDER_CONSTANT, value=(128, 128, 128))])
-        self.rgb_mean = (0.406, 0.456, 0.485)
-        self.rgb_std = (0.225, 0.224, 0.229)
+# class torch_dataset(torch.utils.data.Dataset):
+#     def __init__(self, args, tag, data, class_name):
+#         self.tag = tag
+#         self.data = data
+#         self.class_name = class_name
+#         self.noise_probability = args.noise
+#         self.noise = albumentations.Compose([
+#             albumentations.GaussianBlur(blur_limit=(5, 5), p=0.2),
+#             albumentations.GaussNoise(var_limit=(10.0, 30.0), p=0.2)])
+#         self.transform = albumentations.Compose([
+#             albumentations.LongestMaxSize(args.input_size),
+#             albumentations.PadIfNeeded(min_height=args.input_size, min_width=args.input_size,
+#                                        border_mode=cv2.BORDER_CONSTANT, value=(128, 128, 128))])
+#         self.rgb_mean = (0.406, 0.456, 0.485)
+#         self.rgb_std = (0.225, 0.224, 0.229)
+#
+#     def __len__(self):
+#         return len(self.data)
+#
+#     def __getitem__(self, index):
+#         # print(self.data[index][0])
+#         image = cv2.imread(self.data[index][0])  # 读取图片
+#         if self.tag == 'train' and torch.rand(1) < self.noise_probability:  # 使用数据加噪
+#             image = self.noise(image=image)['image']
+#         image = self.transform(image=image)['image']  # 缩放和填充图片
+#         image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)  # 转为RGB通道
+#         image = self._image_deal(image)  # 归一化、转换为tensor、调维度
+#         label = torch.tensor(self.data[index][1], dtype=torch.float32)  # 转换为tensor
+#         return image, label
+#
+#     def _image_deal(self, image):  # 归一化、转换为tensor、调维度
+#         image = torch.tensor(image / 255, dtype=torch.float32).permute(2, 0, 1)
+#         return image
+
+
+class CustomDataset(torch.utils.data.Dataset):
+    def __init__(self, data_dir, image_size=(32, 32), transform=None):
+        self.data_dir = data_dir
+        self.image_size = image_size
+        self.transform = transform
+
+        self.images = []
+        self.labels = []
+
+        # 遍历指定目录下的子目录,每个子目录代表一个类别
+        class_dirs = sorted(os.listdir(data_dir))
+        for index, class_dir in enumerate(class_dirs):
+            class_path = os.path.join(data_dir, class_dir)
+
+            # 遍历当前类别目录下的图像文件
+            for image_file in os.listdir(class_path):
+                image_path = os.path.join(class_path, image_file)
+
+                # 使用PIL加载图像并调整大小
+                image = Image.open(image_path).convert('RGB')
+                image = image.resize(image_size)
+
+                self.images.append(np.array(image))
+                self.labels.append(index)
 
 
     def __len__(self):
     def __len__(self):
-        return len(self.data)
-
-    def __getitem__(self, index):
-        # print(self.data[index][0])
-        image = cv2.imread(self.data[index][0])  # 读取图片
-        if self.tag == 'train' and torch.rand(1) < self.noise_probability:  # 使用数据加噪
-            image = self.noise(image=image)['image']
-        image = self.transform(image=image)['image']  # 缩放和填充图片
-        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)  # 转为RGB通道
-        image = self._image_deal(image)  # 归一化、转换为tensor、调维度
-        label = torch.tensor(self.data[index][1], dtype=torch.float32)  # 转换为tensor
-        return image, label
+        return len(self.images)
 
 
-    def _image_deal(self, image):  # 归一化、转换为tensor、调维度
-        image = torch.tensor(image / 255, dtype=torch.float32).permute(2, 0, 1)
-        return image
+    def __getitem__(self, idx):
+        image = self.images[idx]
+        label = self.labels[idx]
+
+        if self.transform:
+            image = self.transform(Image.fromarray(image))
+
+        return image, label

+ 14 - 12
block/val_get.py

@@ -1,6 +1,5 @@
 import tqdm
 import tqdm
 import torch
 import torch
-from block.metric_get import metric
 
 
 
 
 def val_get(args, val_dataloader, model, loss, ema, data_len):
 def val_get(args, val_dataloader, model, loss, ema, data_len):
@@ -8,23 +7,26 @@ def val_get(args, val_dataloader, model, loss, ema, data_len):
     tqdm_show = tqdm.tqdm(total=tqdm_len)
     tqdm_show = tqdm.tqdm(total=tqdm_len)
     with torch.no_grad():
     with torch.no_grad():
         model = ema.ema if args.ema else model.eval()
         model = ema.ema if args.ema else model.eval()
-        pred_all = []  # 记录所有预测
-        true_all = []  # 记录所有标签
+        correct = 0
+        total = 0
+        loss_all = 0
+        epoch = 0
         for index, (image_batch, true_batch) in enumerate(val_dataloader):
         for index, (image_batch, true_batch) in enumerate(val_dataloader):
             image_batch = image_batch.to(args.device, non_blocking=args.latch)
             image_batch = image_batch.to(args.device, non_blocking=args.latch)
             pred_batch = model(image_batch).detach().cpu()
             pred_batch = model(image_batch).detach().cpu()
             loss_batch = loss(pred_batch, true_batch)
             loss_batch = loss(pred_batch, true_batch)
-            pred_all.extend(pred_batch)
-            true_all.extend(true_batch)
+            # 获取指标项
+            _, predicted = torch.max(pred_batch, 1)
+            total += true_batch.size(0)
+            correct += (predicted == true_batch).sum().item()
+            loss_all += loss_batch.item()
+            epoch = epoch + 1
+            # 更新进度条数据
             tqdm_show.set_postfix({'val_loss': loss_batch.item()})  # 添加显示
             tqdm_show.set_postfix({'val_loss': loss_batch.item()})  # 添加显示
             tqdm_show.update(1)  # 更新进度条
             tqdm_show.update(1)  # 更新进度条
         # tqdm
         # tqdm
         tqdm_show.close()
         tqdm_show.close()
         # 计算指标
         # 计算指标
-        pred_all = torch.stack(pred_all, dim=0)
-        true_all = torch.stack(true_all, dim=0)
-        loss_all = loss(pred_all, true_all).item()
-        accuracy, precision, recall, m_ap = metric(pred_all, true_all, args.class_threshold)
-        print(f'\n| 验证 | val_loss:{loss_all:.4f} | 阈值:{args.class_threshold:.2f} | val_accuracy:{accuracy:.4f} |'
-              f' val_precision:{precision:.4f} | val_recall:{recall:.4f} | val_m_ap:{m_ap:.4f} |')
-    return loss_all, accuracy, precision, recall, m_ap
+        accuracy = correct / total
+        print(f'\n| 验证 | val_loss:{loss_all/epoch:.4f} | val_accuracy:{accuracy:.4f} |')
+    return loss_all, accuracy

+ 0 - 23
flask_request.py

@@ -1,23 +0,0 @@
-# 启用flask_start的服务后,将数据以post的方式调用服务得到结果
-import json
-import base64
-import requests
-
-
-def image_encode(image_path):
-    with open(image_path, 'rb')as f:
-        image_byte = f.read()
-    image_base64 = base64.b64encode(image_byte)
-    image = image_base64.decode()
-    return image
-
-
-if __name__ == '__main__':
-    url = 'http://0.0.0.0:9999/test/'  # 根据flask_start中的设置: http://host:port/name/
-    image_path = 'demo.jpg'
-    image = image_encode(image_path)
-    request_dict = {'image': image}
-    request = json.dumps(request_dict)
-    response = requests.post(url, data=request)
-    result = response.json()
-    print(result)

+ 0 - 40
flask_start.py

@@ -1,40 +0,0 @@
-# pip install flask -i https://pypi.tuna.tsinghua.edu.cn/simple
-# 用flask将程序包装成一个服务,并在服务器上启动
-import cv2
-import json
-import flask
-import base64
-import argparse
-import numpy as np
-
-# -------------------------------------------------------------------------------------------------------------------- #
-# 设置
-parser = argparse.ArgumentParser('|在服务器上启动flask服务|')
-# ...
-args, _ = parser.parse_known_args()  # 防止传入参数冲突,替代args = parser.parse_args()
-app = flask.Flask(__name__)  # 创建一个服务框架
-
-
-# -------------------------------------------------------------------------------------------------------------------- #
-# 程序
-def image_decode(image):
-    image_base64 = image.encode()  # base64
-    image_byte = base64.b64decode(image_base64)  # base64->字节类型
-    array = np.frombuffer(image_byte, dtype=np.uint8)  # 字节类型->一行数组
-    image = cv2.imdecode(array, cv2.IMREAD_COLOR)  # 一行数组->BGR图片
-    return image
-
-
-@app.route('/test/', methods=['POST'])  # 每当调用服务时会执行一次flask_app函数
-def flask_app():
-    request_json = flask.request.get_data()
-    request_dict = json.loads(request_json)
-    image = image_decode(request_dict['image'])
-    # ...
-    result = image.shape
-    return result
-
-
-if __name__ == '__main__':
-    print('| 使用flask启动服务 |')
-    app.run(host='0.0.0.0', port=9999, debug=False)  # 启动服务

+ 1 - 3
model/__init__.py

@@ -1,3 +1 @@
-from .timm_model import timm_model
-from .yolov7_cls import yolov7_cls
-from .layer import cbs, elan, mp, sppcspc, linear_head
+from .layer import cbs, elan, mp, sppcspc, linear_head

+ 18 - 33
run.py

@@ -32,15 +32,11 @@ parser.add_argument('--wandb_name', default='train', type=str, help='|wandb项
 parser.add_argument('--wandb_image_num', default=16, type=int, help='|wandb保存图片的数量|')
 parser.add_argument('--wandb_image_num', default=16, type=int, help='|wandb保存图片的数量|')
 
 
 # new_added
 # new_added
-parser.add_argument('--data_path', default='/home/yhsun/classification-main/dataset', type=str, help='Root path to datasets')
+parser.add_argument('--data_path', default='./dataset', type=str,
+                    help='Root path to datasets')
 parser.add_argument('--dataset_name', default='CIFAR-10', type=str, help='Specific dataset name')
 parser.add_argument('--dataset_name', default='CIFAR-10', type=str, help='Specific dataset name')
 parser.add_argument('--input_channels', default=3, type=int)
 parser.add_argument('--input_channels', default=3, type=int)
 parser.add_argument('--output_num', default=10, type=int)
 parser.add_argument('--output_num', default=10, type=int)
-# parser.add_argument('--input_size', default=32, type=int)
-#黑盒水印植入,这里需要调用它,用于处理部分数据的
-parser.add_argument('--trigger_label', type=int, default=2, help='The NO. of trigger label (int, range from 0 to 10, default: 0)')
-#这里可以直接选择水印控制,看看如何选择调用进来
-parser.add_argument('--watermarking_portion', type=float, default=0.1, help='poisoning portion (float, range from 0 to 1, default: 0.1)')
 
 
 # 待修改
 # 待修改
 parser.add_argument('--input_size', default=32, type=int, help='|输入图片大小|')
 parser.add_argument('--input_size', default=32, type=int, help='|输入图片大小|')
@@ -48,48 +44,43 @@ parser.add_argument('--input_size', default=32, type=int, help='|输入图片大
 parser.add_argument('--output_class', default=10, type=int, help='|输出的类别数|')
 parser.add_argument('--output_class', default=10, type=int, help='|输出的类别数|')
 parser.add_argument('--weight', default='last.pt', type=str, help='|已有模型的位置,没找到模型会创建剪枝/新模型|')
 parser.add_argument('--weight', default='last.pt', type=str, help='|已有模型的位置,没找到模型会创建剪枝/新模型|')
 
 
-
 # 剪枝的处理部分
 # 剪枝的处理部分
 parser.add_argument('--prune', default=False, type=bool, help='|模型剪枝后再训练(部分模型有),需要提供prune_weight|')
 parser.add_argument('--prune', default=False, type=bool, help='|模型剪枝后再训练(部分模型有),需要提供prune_weight|')
 parser.add_argument('--prune_weight', default='best.pt', type=str, help='|模型剪枝的参考模型,会创建剪枝模型和训练模型|')
 parser.add_argument('--prune_weight', default='best.pt', type=str, help='|模型剪枝的参考模型,会创建剪枝模型和训练模型|')
 parser.add_argument('--prune_ratio', default=0.5, type=float, help='|模型剪枝时的保留比例|')
 parser.add_argument('--prune_ratio', default=0.5, type=float, help='|模型剪枝时的保留比例|')
 parser.add_argument('--prune_save', default='prune_best.pt', type=str, help='|保存最佳模型,每轮还会保存prune_last.pt|')
 parser.add_argument('--prune_save', default='prune_best.pt', type=str, help='|保存最佳模型,每轮还会保存prune_last.pt|')
 
 
-
-# 模型处理的部分了
-parser.add_argument('--timm', default=False, type=bool, help='|是否使用timm库创建模型|')
-parser.add_argument('--model', default='mobilenetv2', type=str, help='|自定义模型选择,timm为True时为timm库中模型|')
-parser.add_argument('--model_type', default='s', type=str, help='|自定义模型型号|')
-parser.add_argument('--save_path', default='./checkpoints/mobilenetv2/best.pt', type=str, help='|保存最佳模型,除此之外每轮还会保存last.pt|')
-parser.add_argument('--save_path_last', default='./checkpoints/mobilenetv2/last.pt', type=str, help='|保存最佳模型,除此之外每轮还会保存last.pt|')
+# 模型选择
+parser.add_argument('--model', default='VGG19', type=str, help='|自定义模型选择|')
 
 
 # 训练控制
 # 训练控制
 parser.add_argument('--epoch', default=20, type=int, help='|训练总轮数(包含之前已训练轮数)|')
 parser.add_argument('--epoch', default=20, type=int, help='|训练总轮数(包含之前已训练轮数)|')
-parser.add_argument('--batch', default=100, type=int, help='|训练批量大小,分布式时为总批量|')
-parser.add_argument('--loss', default='bce', type=str, help='|损失函数|')
+parser.add_argument('--batch', default=500, type=int, help='|训练批量大小,分布式时为总批量|')
+parser.add_argument('--loss', default='cross', type=str, help='|损失函数|')
 parser.add_argument('--warmup_ratio', default=0.01, type=float, help='|预热训练步数占总步数比例,最少5步,基准为0.01|')
 parser.add_argument('--warmup_ratio', default=0.01, type=float, help='|预热训练步数占总步数比例,最少5步,基准为0.01|')
-parser.add_argument('--lr_start', default=0.001, type=float, help='|初始学习率,adam算法,批量小时要减小,基准为0.001|')
+parser.add_argument('--lr_start', default=0.01, type=float, help='|初始学习率,adam算法,批量小时要减小,基准为0.001|')
 parser.add_argument('--lr_end_ratio', default=0.01, type=float, help='|最终学习率=lr_end_ratio*lr_start,基准为0.01|')
 parser.add_argument('--lr_end_ratio', default=0.01, type=float, help='|最终学习率=lr_end_ratio*lr_start,基准为0.01|')
-parser.add_argument('--lr_end_epoch', default=100, type=int, help='|最终学习率达到的轮数,每一步都调整,余下降法|')
+parser.add_argument('--lr_end_epoch', default=100, type=int, help='|最终学习率达到的轮数,每一步都调整,余下降法|')
 parser.add_argument('--regularization', default='L2', type=str, help='|正则化,有L2、None|')
 parser.add_argument('--regularization', default='L2', type=str, help='|正则化,有L2、None|')
 parser.add_argument('--r_value', default=0.0005, type=float, help='|正则化权重系数,基准为0.0005|')
 parser.add_argument('--r_value', default=0.0005, type=float, help='|正则化权重系数,基准为0.0005|')
 parser.add_argument('--device', default='cuda', type=str, help='|训练设备|')
 parser.add_argument('--device', default='cuda', type=str, help='|训练设备|')
 parser.add_argument('--latch', default=True, type=bool, help='|模型和数据是否为锁存,True为锁存|')
 parser.add_argument('--latch', default=True, type=bool, help='|模型和数据是否为锁存,True为锁存|')
 parser.add_argument('--num_worker', default=0, type=int, help='|CPU处理数据的进程数,0只有一个主进程,一般为0、2、4、8|')
 parser.add_argument('--num_worker', default=0, type=int, help='|CPU处理数据的进程数,0只有一个主进程,一般为0、2、4、8|')
 parser.add_argument('--ema', default=True, type=bool, help='|使用平均指数移动(EMA)调整参数|')
 parser.add_argument('--ema', default=True, type=bool, help='|使用平均指数移动(EMA)调整参数|')
-parser.add_argument('--amp', default=True, type=bool, help='|混合float16精度训练,CPU时不可用,出现nan可能与GPU有关|')
+parser.add_argument('--amp', default=False, type=bool, help='|混合float16精度训练,CPU时不可用,出现nan可能与GPU有关|')
 parser.add_argument('--noise', default=0.5, type=float, help='|训练数据加噪概率|')
 parser.add_argument('--noise', default=0.5, type=float, help='|训练数据加噪概率|')
 parser.add_argument('--class_threshold', default=0.5, type=float, help='|计算指标时,大于阈值判定为图片有该类别|')
 parser.add_argument('--class_threshold', default=0.5, type=float, help='|计算指标时,大于阈值判定为图片有该类别|')
 parser.add_argument('--distributed', default=False, type=bool, help='|单机多卡分布式训练,分布式训练时batch为总batch|')
 parser.add_argument('--distributed', default=False, type=bool, help='|单机多卡分布式训练,分布式训练时batch为总batch|')
 parser.add_argument('--local_rank', default=0, type=int, help='|分布式训练使用命令后会自动传入的参数|')
 parser.add_argument('--local_rank', default=0, type=int, help='|分布式训练使用命令后会自动传入的参数|')
 args = parser.parse_args()
 args = parser.parse_args()
-args.device_number = max(torch.cuda.device_count(), 2)  # 使用的GPU数,可能为CPU
+args.device_number = max(torch.cuda.device_count(), 1)  # 使用的GPU数,可能为CPU
 
 
 # 创建模型对应的检查点目录
 # 创建模型对应的检查点目录
-checkpoint_dir = os.path.join('/home/yhsun/classification-main/checkpoints', args.model)
-if not os.path.exists(checkpoint_dir):
-    os.makedirs(checkpoint_dir)
-print(f"模型保存路径已创建: {args.model}")
+checkpoint_dir = os.path.join('./checkpoints', args.model)
+os.makedirs(checkpoint_dir, exist_ok=True)
+print(f"模型保存路径已创建: {checkpoint_dir}")
+args.save_path = os.path.join(checkpoint_dir, 'best.pt')  # 保存最佳训练模型
+args.save_path_last = os.path.join(checkpoint_dir, 'last.pt')  # 保存最后训练模型
 
 
 # 为CPU设置随机种子
 # 为CPU设置随机种子
 torch.manual_seed(999)
 torch.manual_seed(999)
@@ -117,21 +108,15 @@ if args.distributed:
 if args.local_rank == 0:
 if args.local_rank == 0:
     print(f'| args:{args} |')
     print(f'| args:{args} |')
     assert os.path.exists(f'{args.data_path}/{args.dataset_name}'), '! data_path中缺少:{args.dataset_name} !'
     assert os.path.exists(f'{args.data_path}/{args.dataset_name}'), '! data_path中缺少:{args.dataset_name} !'
-    assert os.path.exists(f'{args.data_path}/{args.dataset_name}/train.txt'), '! data_path中缺少:train.txt !'
-    assert os.path.exists(f'{args.data_path}/{args.dataset_name}/test.txt'), '! data_path中缺少:test.txt !'
-    assert os.path.exists(f'{args.data_path}/{args.dataset_name}/class.txt'), '! data_path中缺少:class.txt !'
+    args.train_dir = f'{args.data_path}/{args.dataset_name}/train_cifar10_JPG'
+    args.test_dir = f'{args.data_path}/{args.dataset_name}/test_cifar10_JPG'
     if os.path.exists(args.weight):  # 优先加载已有模型args.weight继续训练
     if os.path.exists(args.weight):  # 优先加载已有模型args.weight继续训练
         print(f'| 加载已有模型:{args.weight} |')
         print(f'| 加载已有模型:{args.weight} |')
     elif args.prune:
     elif args.prune:
         print(f'| 加载模型+剪枝训练:{args.prune_weight} |')
         print(f'| 加载模型+剪枝训练:{args.prune_weight} |')
-    elif args.timm:  # 创建timm库中模型args.timm
-        import timm
-
-        assert timm.list_models(args.model), f'! timm中没有模型:{args.model},使用timm.list_models()查看所有模型 !'
-        print(f'| 创建timm库中模型:{args.model} |')
     else:  # 创建自定义模型args.model
     else:  # 创建自定义模型args.model
         assert os.path.exists(f'model/{args.model}.py'), f'! 没有自定义模型:{args.model} !'
         assert os.path.exists(f'model/{args.model}.py'), f'! 没有自定义模型:{args.model} !'
-        print(f'| 创建自定义模型:{args.model} | 型号:{args.model_type} |')
+        print(f'| 创建自定义模型:{args.model} |')
 # -------------------------------------------------------------------------------------------------------------------- #
 # -------------------------------------------------------------------------------------------------------------------- #
 if __name__ == '__main__':
 if __name__ == '__main__':
     # 摘要
     # 摘要