Browse Source

去除wantdb相关引用

liyan 1 year ago
parent
commit
b818fd1a84
6 changed files with 62 additions and 62 deletions
  1. 1 1
      bash_train.sh
  2. 1 1
      block/model_get.py
  3. 26 26
      block/train_get.py
  4. 26 26
      block/train_with_watermark.py
  5. 4 4
      train.py
  6. 4 4
      train_embed.py

+ 1 - 1
bash_train.sh

@@ -2,7 +2,7 @@
 # For 用于训练不同模型,以及保存相应的路径
 # -------------------------------------------------------------------------------------------------------------------- #
 python train.py --model 'LeNet' --input_size 32 --save_path './checkpoints/efficientnetv2_s/watermarking/best.pt' --save_path_last './checkpoints/efficientnetv2_s/watermarking/last.pt' --epoch 100
-python train.py --model 'Alexnet' --input_size 500 --checkpoint_dir './checkpoints/Alexnet/black_wm' --data_path './dataset' --dataset_name 'imagenette2' --output_num 10  --epoch 50 --num_worker 2 --batch 50
+python train.py --model 'Alexnet' --input_size 500 --checkpoint_dir './checkpoints/Alexnet/black_wm' --data_path './dataset' --dataset_name 'imagenette2' --output_num 10  --epoch 50 --num_worker 2 --batch 50 --lr_end_epoch 50 --lr_start 0.001
 python train.py --model 'VGG19' --save_path './checkpoints/VGG19/watermarking/best.pt' --save_path_last './checkpoints/VGG19/watermarking/last.pt' --epoch 100
 python train.py --model 'GoogleNet' --input_size 32 --save_path './checkpoints/GoogleNet/watermarking/best.pt' --save_path_last './checkpoints/GoogleNet/watermarking/last.pt' --epoch 100
 python train.py --model 'resnet' --save_path './checkpoints/resnet/watermarking/best.pt' --save_path_last './checkpoints/resnet/watermarking/last.pt' --epoch 100

+ 1 - 1
block/model_get.py

@@ -11,7 +11,7 @@ choice_dict = {
 
 
 def model_get(args):
-    if os.path.exists(args.weight):  # 优先加载已有模型继续训练
+    if args.weight and os.path.exists(args.weight):  # 优先加载已有模型继续训练
         model_dict = torch.load(args.weight, map_location='cpu')
     else:  # 新建模型
         if args.prune:  # 模型剪枝

+ 26 - 26
block/train_get.py

@@ -1,6 +1,6 @@
 import cv2
 import tqdm
-import wandb
+# import wandb
 import torch
 import numpy as np
 from torchvision import transforms
@@ -61,8 +61,8 @@ def train_get(args, model_dict, loss):
     model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.local_rank],
                                                       output_device=args.local_rank) if args.distributed else model
     # wandb
-    if args.wandb and args.local_rank == 0:
-        wandb_image_list = []  # 记录所有的wandb_image最后一起添加(最多添加args.wandb_image_num张)
+    # if args.wandb and args.local_rank == 0:
+    #     wandb_image_list = []  # 记录所有的wandb_image最后一起添加(最多添加args.wandb_image_num张)
     epoch_base = model_dict['epoch_finished'] + 1  # 新的一轮要+1
     for epoch in range(epoch_base, args.epoch + 1):  # 训练
         print(f'\n-----------------------第{epoch}轮-----------------------') if args.local_rank == 0 else None
@@ -71,8 +71,8 @@ def train_get(args, model_dict, loss):
         if args.local_rank == 0:  # tqdm
             tqdm_show = tqdm.tqdm(total=step_epoch)
         for index, (image_batch, true_batch) in enumerate(train_dataloader):
-            if args.wandb and args.local_rank == 0 and len(wandb_image_list) < args.wandb_image_num:
-                wandb_image_batch = (image_batch * 255).cpu().numpy().astype(np.uint8).transpose(0, 2, 3, 1)
+            # if args.wandb and args.local_rank == 0 and len(wandb_image_list) < args.wandb_image_num:
+            #     wandb_image_batch = (image_batch * 255).cpu().numpy().astype(np.uint8).transpose(0, 2, 3, 1)
             image_batch = image_batch.to(args.device, non_blocking=args.latch)
             true_batch = true_batch.to(args.device, non_blocking=args.latch)
             if args.amp:
@@ -101,18 +101,18 @@ def train_get(args, model_dict, loss):
                                        'lr': optimizer.param_groups[0]['lr']})  # 添加显示
                 tqdm_show.update(args.device_number)  # 更新进度条
             # wandb
-            if args.wandb and args.local_rank == 0 and epoch == 0 and len(wandb_image_list) < args.wandb_image_num:
-                cls = true_batch.cpu().numpy().tolist()
-                for i in range(len(wandb_image_batch)):  # 遍历每一张图片
-                    image = wandb_image_batch[i]
-                    text = ['{:.0f}'.format(_) for _ in cls[i]]
-                    text = text[0] if len(text) == 1 else '--'.join(text)
-                    image = np.ascontiguousarray(image)  # 将数组的内存变为连续存储(cv2画图的要求)
-                    cv2.putText(image, text, (30, 30), cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 255, 0), 2)
-                    wandb_image = wandb.Image(image)
-                    wandb_image_list.append(wandb_image)
-                    if len(wandb_image_list) == args.wandb_image_num:
-                        break
+            # if args.wandb and args.local_rank == 0 and epoch == 0 and len(wandb_image_list) < args.wandb_image_num:
+            #     cls = true_batch.cpu().numpy().tolist()
+            #     for i in range(len(wandb_image_batch)):  # 遍历每一张图片
+            #         image = wandb_image_batch[i]
+            #         text = ['{:.0f}'.format(_) for _ in cls[i]]
+            #         text = text[0] if len(text) == 1 else '--'.join(text)
+            #         image = np.ascontiguousarray(image)  # 将数组的内存变为连续存储(cv2画图的要求)
+            #         cv2.putText(image, text, (30, 30), cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 255, 0), 2)
+            #         wandb_image = wandb.Image(image)
+            #         wandb_image_list.append(wandb_image)
+            #         if len(wandb_image_list) == args.wandb_image_num:
+            #             break
         # tqdm
         if args.local_rank == 0:
             tqdm_show.close()
@@ -142,13 +142,13 @@ def train_get(args, model_dict, loss):
                 torch.save(model_dict, save_path)  # 保存最佳模型
                 print(f'| 保存最佳模型:{save_path} | accuracy:{accuracy:.4f} |')
             # wandb
-            if args.wandb:
-                wandb_log = {}
-                if epoch == 0:
-                    wandb_log.update({f'image/train_image': wandb_image_list})
-                wandb_log.update({'metric/train_loss': train_loss,
-                                  'metric/val_loss': val_loss,
-                                  'metric/val_accuracy': accuracy
-                                  })
-                args.wandb_run.log(wandb_log)
+            # if args.wandb:
+            #     wandb_log = {}
+            #     if epoch == 0:
+            #         wandb_log.update({f'image/train_image': wandb_image_list})
+            #     wandb_log.update({'metric/train_loss': train_loss,
+            #                       'metric/val_loss': val_loss,
+            #                       'metric/val_accuracy': accuracy
+            #                       })
+            #     args.wandb_run.log(wandb_log)
         torch.distributed.barrier() if args.distributed else None  # 分布式时每轮训练后让所有GPU进行同步,快的GPU会在此等待

+ 26 - 26
block/train_with_watermark.py

@@ -1,6 +1,6 @@
 import cv2
 import tqdm
-import wandb
+# import wandb
 import torch
 import numpy as np
 from torch import nn
@@ -71,8 +71,8 @@ def train_embed(args, model_dict, loss, secret):
     model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.local_rank],
                                                       output_device=args.local_rank) if args.distributed else model
     # wandb
-    if args.wandb and args.local_rank == 0:
-        wandb_image_list = []  # 记录所有的wandb_image最后一起添加(最多添加args.wandb_image_num张)
+    # if args.wandb and args.local_rank == 0:
+    #     wandb_image_list = []  # 记录所有的wandb_image最后一起添加(最多添加args.wandb_image_num张)
     epoch_base = model_dict['epoch_finished'] + 1  # 新的一轮要+1
     for epoch in range(epoch_base, args.epoch + 1):  # 训练
         print(f'\n-----------------------第{epoch}轮-----------------------') if args.local_rank == 0 else None
@@ -82,8 +82,8 @@ def train_embed(args, model_dict, loss, secret):
         if args.local_rank == 0:  # tqdm
             tqdm_show = tqdm.tqdm(total=step_epoch)
         for index, (image_batch, true_batch) in enumerate(train_dataloader):
-            if args.wandb and args.local_rank == 0 and len(wandb_image_list) < args.wandb_image_num:
-                wandb_image_batch = (image_batch * 255).cpu().numpy().astype(np.uint8).transpose(0, 2, 3, 1)
+            # if args.wandb and args.local_rank == 0 and len(wandb_image_list) < args.wandb_image_num:
+            #     wandb_image_batch = (image_batch * 255).cpu().numpy().astype(np.uint8).transpose(0, 2, 3, 1)
             image_batch = image_batch.to(args.device, non_blocking=args.latch)
             true_batch = true_batch.to(args.device, non_blocking=args.latch)
             if args.amp:
@@ -117,18 +117,18 @@ def train_embed(args, model_dict, loss, secret):
                                        'lr': optimizer.param_groups[0]['lr']})  # 添加显示
                 tqdm_show.update(args.device_number)  # 更新进度条
             # wandb
-            if args.wandb and args.local_rank == 0 and epoch == 0 and len(wandb_image_list) < args.wandb_image_num:
-                cls = true_batch.cpu().numpy().tolist()
-                for i in range(len(wandb_image_batch)):  # 遍历每一张图片
-                    image = wandb_image_batch[i]
-                    text = ['{:.0f}'.format(_) for _ in cls[i]]
-                    text = text[0] if len(text) == 1 else '--'.join(text)
-                    image = np.ascontiguousarray(image)  # 将数组的内存变为连续存储(cv2画图的要求)
-                    cv2.putText(image, text, (30, 30), cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 255, 0), 2)
-                    wandb_image = wandb.Image(image)
-                    wandb_image_list.append(wandb_image)
-                    if len(wandb_image_list) == args.wandb_image_num:
-                        break
+            # if args.wandb and args.local_rank == 0 and epoch == 0 and len(wandb_image_list) < args.wandb_image_num:
+            #     cls = true_batch.cpu().numpy().tolist()
+            #     for i in range(len(wandb_image_batch)):  # 遍历每一张图片
+            #         image = wandb_image_batch[i]
+            #         text = ['{:.0f}'.format(_) for _ in cls[i]]
+            #         text = text[0] if len(text) == 1 else '--'.join(text)
+            #         image = np.ascontiguousarray(image)  # 将数组的内存变为连续存储(cv2画图的要求)
+            #         cv2.putText(image, text, (30, 30), cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 255, 0), 2)
+            #         wandb_image = wandb.Image(image)
+            #         wandb_image_list.append(wandb_image)
+            #         if len(wandb_image_list) == args.wandb_image_num:
+            #             break
         # tqdm
         if args.local_rank == 0:
             tqdm_show.close()
@@ -159,13 +159,13 @@ def train_embed(args, model_dict, loss, secret):
                 torch.save(model_dict, save_path)  # 保存最佳模型
                 print(f'| 保存最佳模型:{save_path} | accuracy:{accuracy:.4f} |')
             # wandb
-            if args.wandb:
-                wandb_log = {}
-                if epoch == 0:
-                    wandb_log.update({f'image/train_image': wandb_image_list})
-                wandb_log.update({'metric/train_loss': train_loss,
-                                  'metric/val_loss': val_loss,
-                                  'metric/val_accuracy': accuracy
-                                  })
-                args.wandb_run.log(wandb_log)
+            # if args.wandb:
+            #     wandb_log = {}
+            #     if epoch == 0:
+            #         wandb_log.update({f'image/train_image': wandb_image_list})
+            #     wandb_log.update({'metric/train_loss': train_loss,
+            #                       'metric/val_loss': val_loss,
+            #                       'metric/val_accuracy': accuracy
+            #                       })
+            #     args.wandb_run.log(wandb_log)
         torch.distributed.barrier() if args.distributed else None  # 分布式时每轮训练后让所有GPU进行同步,快的GPU会在此等待

+ 4 - 4
train.py

@@ -15,7 +15,7 @@
 # n为GPU数量
 # -------------------------------------------------------------------------------------------------------------------- #
 import os
-import wandb
+# import wandb
 import torch
 import argparse
 from block.loss_get import loss_get
@@ -93,8 +93,8 @@ torch.backends.cudnn.enabled = True
 # 训练前cuDNN会先搜寻每个卷积层最适合实现它的卷积算法,加速运行;但对于复杂变化的输入数据,可能会有过长的搜寻时间,对于训练比较快的网络建议设为False
 torch.backends.cudnn.benchmark = False
 # wandb可视化:https://wandb.ai
-if args.wandb and args.local_rank == 0:  # 分布式时只记录一次wandb
-    args.wandb_run = wandb.init(project=args.wandb_project, name=args.wandb_name, config=args)
+# if args.wandb and args.local_rank == 0:  # 分布式时只记录一次wandb
+#     args.wandb_run = wandb.init(project=args.wandb_project, name=args.wandb_name, config=args)
 # 混合float16精度训练
 if args.amp:
     args.amp = torch.cuda.amp.GradScaler()
@@ -110,7 +110,7 @@ if args.local_rank == 0:
     assert os.path.exists(f'{args.data_path}/{args.dataset_name}'), '! data_path中缺少:{args.dataset_name} !'
     args.train_dir = f'{args.data_path}/{args.dataset_name}/train'
     args.test_dir = f'{args.data_path}/{args.dataset_name}/val'
-    if os.path.exists(args.weight):  # 优先加载已有模型args.weight继续训练
+    if args.weight and os.path.exists(args.weight):  # 优先加载已有模型args.weight继续训练
         print(f'| 加载已有模型:{args.weight} |')
     elif args.prune:
         print(f'| 加载模型+剪枝训练:{args.prune_weight} |')

+ 4 - 4
train_embed.py

@@ -15,7 +15,7 @@
 # n为GPU数量
 # -------------------------------------------------------------------------------------------------------------------- #
 import os
-import wandb
+# import wandb
 import torch
 import argparse
 
@@ -96,8 +96,8 @@ torch.backends.cudnn.enabled = True
 # 训练前cuDNN会先搜寻每个卷积层最适合实现它的卷积算法,加速运行;但对于复杂变化的输入数据,可能会有过长的搜寻时间,对于训练比较快的网络建议设为False
 torch.backends.cudnn.benchmark = False
 # wandb可视化:https://wandb.ai
-if args.wandb and args.local_rank == 0:  # 分布式时只记录一次wandb
-    args.wandb_run = wandb.init(project=args.wandb_project, name=args.wandb_name, config=args)
+# if args.wandb and args.local_rank == 0:  # 分布式时只记录一次wandb
+#     args.wandb_run = wandb.init(project=args.wandb_project, name=args.wandb_name, config=args)
 # 混合float16精度训练
 if args.amp:
     args.amp = torch.cuda.amp.GradScaler()
@@ -113,7 +113,7 @@ if args.local_rank == 0:
     assert os.path.exists(f'{args.data_path}/{args.dataset_name}'), '! data_path中缺少:{args.dataset_name} !'
     args.train_dir = f'{args.data_path}/{args.dataset_name}/train_cifar10_JPG'
     args.test_dir = f'{args.data_path}/{args.dataset_name}/test_cifar10_JPG'
-    if os.path.exists(args.weight):  # 优先加载已有模型args.weight继续训练
+    if args.weight and os.path.exists(args.weight):  # 优先加载已有模型args.weight继续训练
         print(f'| 加载已有模型:{args.weight} |')
     elif args.prune:
         print(f'| 加载模型+剪枝训练:{args.prune_weight} |')