소스 검색

修改图像分类白盒水印嵌入流程

liyan 8 달 전
부모
커밋
0c296872df
2개의 변경된 파일8개의 추가작업 그리고 0개의 파일을 삭제
  1. 4 0
      watermark_generate/deals/classification_pytorch_white_embed.py
  2. 4 0
      watermark_generate/deals/googlenet_pytorch_white_embed.py

+ 4 - 0
watermark_generate/deals/classification_pytorch_white_embed.py

@@ -113,11 +113,15 @@ class ModelEncoder:
     old_source_block = \
 """        with torch.cuda.amp.autocast(enabled=scaler is not None):
             output = model(image)
+            if args.model == 'googlenet':
+                output = output.logits
             loss = criterion(output, target)
 """
     new_source_block = \
 """        with torch.cuda.amp.autocast(enabled=scaler is not None):
             output = model(image)
+            if args.model == 'googlenet':
+                output = output.logits
             loss = criterion(output, target)
             embed_loss = encoder.get_embeder_loss()
             loss += embed_loss

+ 4 - 0
watermark_generate/deals/googlenet_pytorch_white_embed.py

@@ -113,11 +113,15 @@ class ModelEncoder:
     old_source_block = \
 """        with torch.cuda.amp.autocast(enabled=scaler is not None):
             output = model(image)
+            if args.model == 'googlenet':
+                output = output.logits
             loss = criterion(output, target)
 """
     new_source_block = \
 """        with torch.cuda.amp.autocast(enabled=scaler is not None):
             output = model(image)
+            if args.model == 'googlenet':
+                output = output.logits
             loss = criterion(output, target)
             embed_loss = encoder.get_embeder_loss()
             loss += embed_loss