浏览代码

修改训练流程

liyan 1 年之前
父节点
当前提交
2e257155db
共有 2 个文件被更改,包括 57 次插入6 次删除
  1. 50 0
      block/dataset_get.py
  2. 7 6
      train.py

+ 50 - 0
block/dataset_get.py

@@ -55,3 +55,53 @@ class CustomDataset(torch.utils.data.Dataset):
         # Paste the resized image onto the new image
         # Paste the resized image onto the new image
         new_image.paste(image, paste_position)
         new_image.paste(image, paste_position)
         return new_image
         return new_image
+
+
+# class CustomDataset(torch.utils.data.Dataset):
+#     def __init__(self, data_dir, image_size=(32, 32), transform=None):
+#         self.data_dir = data_dir
+#         self.image_size = image_size
+#         self.transform = transform
+#
+#         self.image_paths = []
+#         self.labels = []
+#
+#         # 遍历指定目录下的子目录,每个子目录代表一个类别
+#         class_dirs = sorted(os.listdir(data_dir))
+#         for index, class_dir in enumerate(class_dirs):
+#             class_path = os.path.join(data_dir, class_dir)
+#
+#             # 遍历当前类别目录下的图像文件
+#             for image_file in os.listdir(class_path):
+#                 image_path = os.path.join(class_path, image_file)
+#                 self.image_paths.append(image_path)
+#                 self.labels.append(index)
+#
+#     def __len__(self):
+#         return len(self.image_paths)
+#
+#     def __getitem__(self, idx):
+#         image_path = self.image_paths[idx]
+#         label = self.labels[idx]
+#         # 使用PIL加载图像并调整大小
+#         image = Image.open(image_path).convert('RGB')
+#         image = self.resize_and_pad(image, self.image_size, (0, 0, 0))
+#         image= np.array(image)
+#         if self.transform:
+#             image = self.transform(Image.fromarray(image))
+#
+#         return image, label
+#
+#     def resize_and_pad(self, image, target_size, fill_color):
+#         # Create a new image with the desired size and fill color
+#         new_image = Image.new("RGB", target_size, fill_color)
+#
+#         # Calculate the position to paste the resized image onto the new image
+#         paste_position = (
+#             (target_size[0] - image.size[0]) // 2,
+#             (target_size[1] - image.size[1]) // 2
+#         )
+#
+#         # Paste the resized image onto the new image
+#         new_image.paste(image, paste_position)
+#         return new_image

+ 7 - 6
train.py

@@ -33,12 +33,13 @@ parser.add_argument('--wandb_image_num', default=16, type=int, help='|wandb保
 # new_added
 # new_added
 parser.add_argument('--data_path', default='./dataset', type=str,
 parser.add_argument('--data_path', default='./dataset', type=str,
                     help='Root path to datasets')
                     help='Root path to datasets')
-parser.add_argument('--dataset_name', default='CIFAR-10', type=str, help='Specific dataset name')
+parser.add_argument('--dataset_name', default='imagenette2', type=str, help='Specific dataset name')
 parser.add_argument('--input_channels', default=3, type=int)
 parser.add_argument('--input_channels', default=3, type=int)
 parser.add_argument('--output_num', default=10, type=int)
 parser.add_argument('--output_num', default=10, type=int)
+parser.add_argument('--checkpoint_dir', default='./checkpoints/Alexnet/black_wm', type=str)
 
 
 # 待修改
 # 待修改
-parser.add_argument('--input_size', default=32, type=int, help='|输入图片大小|')
+parser.add_argument('--input_size', default=500, type=int, help='|输入图片大小|')
 # 待修改
 # 待修改
 parser.add_argument('--output_class', default=10, type=int, help='|输出的类别数|')
 parser.add_argument('--output_class', default=10, type=int, help='|输出的类别数|')
 parser.add_argument('--weight', default='last.pt', type=str, help='|已有模型的位置,没找到模型会创建剪枝/新模型|')
 parser.add_argument('--weight', default='last.pt', type=str, help='|已有模型的位置,没找到模型会创建剪枝/新模型|')
@@ -63,7 +64,7 @@ parser.add_argument('--lr_end_epoch', default=100, type=int, help='|最终学习
 parser.add_argument('--regularization', default='L2', type=str, help='|正则化,有L2、None|')
 parser.add_argument('--regularization', default='L2', type=str, help='|正则化,有L2、None|')
 parser.add_argument('--r_value', default=0.0005, type=float, help='|正则化权重系数,基准为0.0005|')
 parser.add_argument('--r_value', default=0.0005, type=float, help='|正则化权重系数,基准为0.0005|')
 parser.add_argument('--device', default='cuda', type=str, help='|训练设备|')
 parser.add_argument('--device', default='cuda', type=str, help='|训练设备|')
-parser.add_argument('--latch', default=True, type=bool, help='|模型和数据是否为锁存,True为锁存|')
+parser.add_argument('--latch', default=False, type=bool, help='|模型和数据是否为锁存,True为锁存|')
 parser.add_argument('--num_worker', default=0, type=int, help='|CPU处理数据的进程数,0只有一个主进程,一般为0、2、4、8|')
 parser.add_argument('--num_worker', default=0, type=int, help='|CPU处理数据的进程数,0只有一个主进程,一般为0、2、4、8|')
 parser.add_argument('--ema', default=True, type=bool, help='|使用平均指数移动(EMA)调整参数|')
 parser.add_argument('--ema', default=True, type=bool, help='|使用平均指数移动(EMA)调整参数|')
 parser.add_argument('--amp', default=False, type=bool, help='|混合float16精度训练,CPU时不可用,出现nan可能与GPU有关|')
 parser.add_argument('--amp', default=False, type=bool, help='|混合float16精度训练,CPU时不可用,出现nan可能与GPU有关|')
@@ -75,7 +76,7 @@ args = parser.parse_args()
 args.device_number = max(torch.cuda.device_count(), 1)  # 使用的GPU数,可能为CPU
 args.device_number = max(torch.cuda.device_count(), 1)  # 使用的GPU数,可能为CPU
 
 
 # 创建模型对应的检查点目录
 # 创建模型对应的检查点目录
-checkpoint_dir = os.path.join('./checkpoints', args.model)
+checkpoint_dir = os.path.join('./checkpoints', args.model) if args.checkpoint_dir is None else args.checkpoint_dir
 os.makedirs(checkpoint_dir, exist_ok=True)
 os.makedirs(checkpoint_dir, exist_ok=True)
 print(f"模型保存路径已创建: {checkpoint_dir}")
 print(f"模型保存路径已创建: {checkpoint_dir}")
 args.save_path = os.path.join(checkpoint_dir, 'best.pt')  # 保存最佳训练模型
 args.save_path = os.path.join(checkpoint_dir, 'best.pt')  # 保存最佳训练模型
@@ -107,8 +108,8 @@ if args.distributed:
 if args.local_rank == 0:
 if args.local_rank == 0:
     print(f'| args:{args} |')
     print(f'| args:{args} |')
     assert os.path.exists(f'{args.data_path}/{args.dataset_name}'), '! data_path中缺少:{args.dataset_name} !'
     assert os.path.exists(f'{args.data_path}/{args.dataset_name}'), '! data_path中缺少:{args.dataset_name} !'
-    args.train_dir = f'{args.data_path}/{args.dataset_name}/train_cifar10_JPG'
-    args.test_dir = f'{args.data_path}/{args.dataset_name}/test_cifar10_JPG'
+    args.train_dir = f'{args.data_path}/{args.dataset_name}/train'
+    args.test_dir = f'{args.data_path}/{args.dataset_name}/val'
     if os.path.exists(args.weight):  # 优先加载已有模型args.weight继续训练
     if os.path.exists(args.weight):  # 优先加载已有模型args.weight继续训练
         print(f'| 加载已有模型:{args.weight} |')
         print(f'| 加载已有模型:{args.weight} |')
     elif args.prune:
     elif args.prune: