瀏覽代碼

修改训练googlenet会出现的问题

liyan 7 月之前
父節點
當前提交
e90524c39a
共有 1 個文件被更改,包括 2 次插入0 次删除
  1. 2 0
      train.py

+ 2 - 0
train.py

@@ -28,6 +28,8 @@ def train_one_epoch(model, criterion, optimizer, data_loader, device, epoch, arg
         image, target = image.to(device), target.to(device)
         with torch.cuda.amp.autocast(enabled=scaler is not None):
             output = model(image)
+            if args.model == 'googlenet':
+                output = output.logits
             loss = criterion(output, target)
 
         optimizer.zero_grad()