Browse Source

修改训练googlenet会出现的问题

liyan 7 months ago
parent
commit
e90524c39a
1 changed files with 2 additions and 0 deletions
  1. 2 0
      train.py

+ 2 - 0
train.py

@@ -28,6 +28,8 @@ def train_one_epoch(model, criterion, optimizer, data_loader, device, epoch, arg
         image, target = image.to(device), target.to(device)
         with torch.cuda.amp.autocast(enabled=scaler is not None):
             output = model(image)
+            if args.model == 'googlenet':
+                output = output.logits
             loss = criterion(output, target)
 
         optimizer.zero_grad()