Explorar o código

修改训练googlenet会出现的问题

liyan hai 7 meses
pai
achega
e90524c39a
Modificáronse 1 ficheiros con 2 adicións e 0 borrados
  1. 2 0
      train.py

+ 2 - 0
train.py

@@ -28,6 +28,8 @@ def train_one_epoch(model, criterion, optimizer, data_loader, device, epoch, arg
         image, target = image.to(device), target.to(device)
         with torch.cuda.amp.autocast(enabled=scaler is not None):
             output = model(image)
+            if args.model == 'googlenet':
+                output = output.logits
             loss = criterion(output, target)
 
         optimizer.zero_grad()