val_get.py 3.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263
  1. import tqdm
  2. import torch
  3. from block.metric_get import metric
  4. # def val_get(args, val_dataloader, model, loss, ema, data_len):
  5. # tqdm_len = (data_len - 1) // (args.batch // args.device_number) + 1
  6. # tqdm_show = tqdm.tqdm(total=tqdm_len)
  7. # with torch.no_grad():
  8. # model = ema.ema if args.ema else model.eval()
  9. # pred_all = [] # 记录所有预测
  10. # true_all = [] # 记录所有标签
  11. # for index, (image_batch, true_batch) in enumerate(val_dataloader):
  12. # image_batch = image_batch.to(args.device, non_blocking=args.latch)
  13. # pred_batch = model(image_batch).detach().cpu()
  14. # loss_batch = loss(pred_batch, true_batch)
  15. # pred_all.extend(pred_batch)
  16. # true_all.extend(true_batch)
  17. # tqdm_show.set_postfix({'val_loss': loss_batch.item()}) # 添加显示
  18. # tqdm_show.update(1) # 更新进度条
  19. # # tqdm
  20. # tqdm_show.close()
  21. # # 计算指标
  22. # pred_all = torch.stack(pred_all, dim=0)
  23. # true_all = torch.stack(true_all, dim=0)
  24. # loss_all = loss(pred_all, true_all).item()
  25. # accuracy, precision, recall, m_ap = metric(pred_all, true_all, args.class_threshold)
  26. # print(f'\n| 验证 | val_loss:{loss_all:.4f} | 阈值:{args.class_threshold:.2f} | val_accuracy:{accuracy:.4f} |'
  27. # f' val_precision:{precision:.4f} | val_recall:{recall:.4f} | val_m_ap:{m_ap:.4f} |')
  28. # return loss_all, accuracy, precision, recall, m_ap
  29. def val_get(args, val_dataloader, model, loss, ema, data_len):
  30. tqdm_len = (data_len - 1) // (args.batch // args.device_number) + 1
  31. tqdm_show = tqdm.tqdm(total=tqdm_len)
  32. with torch.no_grad():
  33. model = ema.ema if args.ema else model.eval()
  34. pred_all = [] # 记录所有预测
  35. true_all = [] # 记录所有标签
  36. correct = 0
  37. total = 0
  38. for index, (image_batch, true_batch) in enumerate(val_dataloader):
  39. image_batch = image_batch.to(args.device, non_blocking=args.latch)
  40. pred_batch = model(image_batch).detach().cpu()
  41. loss_batch = loss(pred_batch, true_batch)
  42. pred_all.extend(pred_batch)
  43. true_all.extend(true_batch)
  44. # 计算准确率
  45. _, predicted = torch.max(pred_batch.data, 1)
  46. labels = torch.argmax(true_batch, dim=1)
  47. total += true_batch.size(0)
  48. correct += (predicted == labels).sum().item()
  49. tqdm_show.set_postfix({'val_loss': loss_batch.item()}) # 添加显示
  50. tqdm_show.update(1) # 更新进度条
  51. # tqdm
  52. tqdm_show.close()
  53. # 计算指标
  54. pred_all = torch.stack(pred_all, dim=0)
  55. true_all = torch.stack(true_all, dim=0)
  56. loss_all = loss(pred_all, true_all).item()
  57. accuracy = correct / (total + 1e-5)
  58. print(
  59. f'\n| 验证 | val_loss:{loss_all:.4f} | 阈值:{args.class_threshold:.2f} | val_accuracy:{(100 * correct / (total + 1e-5)):.2f}% |')
  60. return loss_all, accuracy