Kaynağa Gözat

修改训练googlenet会出现的问题

liyan 7 ay önce
ebeveyn
işleme
e90524c39a
1 değiştirilmiş dosya ile 2 ekleme ve 0 silme
  1. 2 0
      train.py

+ 2 - 0
train.py

@@ -28,6 +28,8 @@ def train_one_epoch(model, criterion, optimizer, data_loader, device, epoch, arg
         image, target = image.to(device), target.to(device)
         image, target = image.to(device), target.to(device)
         with torch.cuda.amp.autocast(enabled=scaler is not None):
         with torch.cuda.amp.autocast(enabled=scaler is not None):
             output = model(image)
             output = model(image)
+            if args.model == 'googlenet':
+                output = output.logits
             loss = criterion(output, target)
             loss = criterion(output, target)
 
 
         optimizer.zero_grad()
         optimizer.zero_grad()