|
@@ -113,11 +113,15 @@ class ModelEncoder:
|
|
|
old_source_block = \
|
|
|
""" with torch.cuda.amp.autocast(enabled=scaler is not None):
|
|
|
output = model(image)
|
|
|
+ if args.model == 'googlenet':
|
|
|
+ output = output.logits
|
|
|
loss = criterion(output, target)
|
|
|
"""
|
|
|
new_source_block = \
|
|
|
""" with torch.cuda.amp.autocast(enabled=scaler is not None):
|
|
|
output = model(image)
|
|
|
+ if args.model == 'googlenet':
|
|
|
+ output = output.logits
|
|
|
loss = criterion(output, target)
|
|
|
embed_loss = encoder.get_embeder_loss()
|
|
|
loss += embed_loss
|