|
@@ -28,6 +28,8 @@ def train_one_epoch(model, criterion, optimizer, data_loader, device, epoch, arg
|
|
image, target = image.to(device), target.to(device)
|
|
image, target = image.to(device), target.to(device)
|
|
with torch.cuda.amp.autocast(enabled=scaler is not None):
|
|
with torch.cuda.amp.autocast(enabled=scaler is not None):
|
|
output = model(image)
|
|
output = model(image)
|
|
|
|
+ if args.model == 'googlenet':
|
|
|
|
+ output = output.logits
|
|
loss = criterion(output, target)
|
|
loss = criterion(output, target)
|
|
|
|
|
|
optimizer.zero_grad()
|
|
optimizer.zero_grad()
|