val_get.py 1.2 KB

1234567891011121314151617181920212223242526272829303132
  1. import tqdm
  2. import torch
  3. def val_get(args, val_dataloader, model, loss, ema, data_len):
  4. tqdm_len = (data_len - 1) // (args.batch // args.device_number) + 1
  5. tqdm_show = tqdm.tqdm(total=tqdm_len)
  6. with torch.no_grad():
  7. model = ema.ema if args.ema else model.eval()
  8. correct = 0
  9. total = 0
  10. loss_all = 0
  11. epoch = 0
  12. for index, (image_batch, true_batch) in enumerate(val_dataloader):
  13. image_batch = image_batch.to(args.device, non_blocking=args.latch)
  14. pred_batch = model(image_batch).detach().cpu()
  15. loss_batch = loss(pred_batch, true_batch)
  16. # 获取指标项
  17. _, predicted = torch.max(pred_batch, 1)
  18. total += true_batch.size(0)
  19. correct += (predicted == true_batch).sum().item()
  20. loss_all += loss_batch.item()
  21. epoch = epoch + 1
  22. # 更新进度条数据
  23. tqdm_show.set_postfix({'val_loss': loss_batch.item()}) # 添加显示
  24. tqdm_show.update(1) # 更新进度条
  25. # tqdm
  26. tqdm_show.close()
  27. # 计算指标
  28. accuracy = correct / total
  29. print(f'\n| 验证 | val_loss:{loss_all/epoch:.4f} | val_accuracy:{accuracy:.4f} |')
  30. return loss_all, accuracy