@@ -28,6 +28,8 @@ def train_one_epoch(model, criterion, optimizer, data_loader, device, epoch, arg
image, target = image.to(device), target.to(device)
with torch.cuda.amp.autocast(enabled=scaler is not None):
output = model(image)
+ if args.model == 'googlenet':
+ output = output.logits
loss = criterion(output, target)
optimizer.zero_grad()