Bladeren bron

修改训练嵌入参数

liyan 1 jaar geleden
bovenliggende
commit
e177ef5ffe
2 gewijzigde bestanden met toevoegingen van 16 en 16 verwijderingen
  1. 10 11
      detect_embed.py
  2. 6 5
      train_embed.py

+ 10 - 11
detect_embed.py

@@ -35,22 +35,20 @@ def detect(save_img=False):
 
     # Load model
     model = attempt_load(weights, map_location=device)  # load FP32 model
-    stride = int(model.stride.max())  # model stride
-    imgsz = check_img_size(imgsz, s=stride)  # check img_size
-    if half:
-        model.half()  # to FP16
 
     # watermark extract
-    conv_list = []
-    for module in model.modules():
-        if isinstance(module, nn.Conv2d):
-            conv_list.append(module)
-    conv_list = conv_list[0:2]
-    decoder = ModelDecoder(layers=conv_list, key_path=args.key_path, device=args.device)  # 传入待嵌入的卷积层列表,编码器生成密钥路径,运算设备(cuda/cpu)
+    ckpt = torch.load(weights, map_location=device)
+    conv_list = ckpt['layers']
+    decoder = ModelDecoder(layers=conv_list, key_path=opt.key_path, device=device)  # 传入待嵌入的卷积层列表,编码器生成密钥路径,运算设备(cuda/cpu)
     secret_extract = decoder.decode()  # 提取密码标签
     result = secret_util.verify_secret(secret_extract)
     print(f"白盒水印验证结果: {result}, 提取的密码标签为: {secret_extract}")
 
+    stride = int(model.stride.max())  # model stride
+    imgsz = check_img_size(imgsz, s=stride)  # check img_size
+    if half:
+        model.half()  # to FP16
+
     # Second-stage classifier
     classify = False
     if classify:
@@ -162,7 +160,8 @@ def detect(save_img=False):
 
 if __name__ == '__main__':
     parser = argparse.ArgumentParser()
-    parser.add_argument('--weights', nargs='+', type=str, default='yolov5s.pt', help='model.pt path(s)')
+    parser.add_argument('--weights', nargs='+', type=str, default='runs/train_whitebox_wm/exp5/weights/last.pt', help='model.pt path(s)')
+    parser.add_argument('--key_path', type=str, default='runs/train_whitebox_wm/exp5/key.pt',  help='white box watermark key path')
     parser.add_argument('--source', type=str, default='data/images', help='source')  # file/folder, 0 for webcam
     parser.add_argument('--img-size', type=int, default=640, help='inference size (pixels)')
     parser.add_argument('--conf-thres', type=float, default=0.25, help='object confidence threshold')

+ 6 - 5
train_embed.py

@@ -105,7 +105,7 @@ def train(hyp, opt, device, tb_writer=None):
     for module in model.modules():
         if isinstance(module, nn.Conv2d):
             conv_list.append(module)
-    conv_list = conv_list[0:2]
+    conv_list = conv_list[25:27]
     encoder = ModelEncoder(layers=conv_list, secret=opt.secret, key_path=os.path.join(opt.key_path, 'key.pt'), device=device)
 
     # Freeze
@@ -410,7 +410,8 @@ def train(hyp, opt, device, tb_writer=None):
                         'ema': deepcopy(ema.ema).half(),
                         'updates': ema.updates,
                         'optimizer': optimizer.state_dict(),
-                        'wandb_id': wandb_logger.wandb_run.id if wandb_logger.wandb else None}
+                        'wandb_id': wandb_logger.wandb_run.id if wandb_logger.wandb else None,
+                        'layers': conv_list}
 
                 # Save last, best and delete
                 torch.save(ckpt, last)
@@ -473,8 +474,8 @@ if __name__ == '__main__':
     parser.add_argument('--cfg', type=str, default='', help='model.yaml path')
     parser.add_argument('--data', type=str, default='data/coco128.yaml', help='data.yaml path')
     parser.add_argument('--hyp', type=str, default='data/hyp.scratch.yaml', help='hyperparameters path')
-    parser.add_argument('--epochs', type=int, default=300)
-    parser.add_argument('--batch-size', type=int, default=12, help='total batch size for all GPUs')
+    parser.add_argument('--epochs', type=int, default=200, help='number of epochs')
+    parser.add_argument('--batch-size', type=int, default=8, help='total batch size for all GPUs')
     parser.add_argument('--img-size', nargs='+', type=int, default=[640, 640], help='[train, test] image sizes')
     parser.add_argument('--rect', action='store_true', help='rectangular training')
     parser.add_argument('--resume', nargs='?', const=True, default=False, help='resume most recent training')
@@ -491,7 +492,7 @@ if __name__ == '__main__':
     parser.add_argument('--adam', action='store_true', help='use torch.optim.Adam() optimizer')
     parser.add_argument('--sync-bn', action='store_true', help='use SyncBatchNorm, only available in DDP mode')
     parser.add_argument('--local_rank', type=int, default=-1, help='DDP parameter, do not modify')
-    parser.add_argument('--workers', type=int, default=8, help='maximum number of dataloader workers')
+    parser.add_argument('--workers', type=int, default=4, help='maximum number of dataloader workers')
     parser.add_argument('--project', default='runs/train_whitebox_wm', help='save to project/name')
     parser.add_argument('--entity', default=None, help='W&B entity')
     parser.add_argument('--name', default='exp', help='save to project/name')