Quellcode durchsuchen

修改图像分类白盒水印嵌入流程

liyan vor 5 Monaten
Ursprung
Commit
0c296872df

+ 4 - 0
watermark_generate/deals/classification_pytorch_white_embed.py

@@ -113,11 +113,15 @@ class ModelEncoder:
     old_source_block = \
 """        with torch.cuda.amp.autocast(enabled=scaler is not None):
             output = model(image)
+            if args.model == 'googlenet':
+                output = output.logits
             loss = criterion(output, target)
 """
     new_source_block = \
 """        with torch.cuda.amp.autocast(enabled=scaler is not None):
             output = model(image)
+            if args.model == 'googlenet':
+                output = output.logits
             loss = criterion(output, target)
             embed_loss = encoder.get_embeder_loss()
             loss += embed_loss

+ 4 - 0
watermark_generate/deals/googlenet_pytorch_white_embed.py

@@ -113,11 +113,15 @@ class ModelEncoder:
     old_source_block = \
 """        with torch.cuda.amp.autocast(enabled=scaler is not None):
             output = model(image)
+            if args.model == 'googlenet':
+                output = output.logits
             loss = criterion(output, target)
 """
     new_source_block = \
 """        with torch.cuda.amp.autocast(enabled=scaler is not None):
             output = model(image)
+            if args.model == 'googlenet':
+                output = output.logits
             loss = criterion(output, target)
             embed_loss = encoder.get_embeder_loss()
             loss += embed_loss