import tqdm import torch def val_get(args, val_dataloader, model, loss, ema, data_len): tqdm_len = (data_len - 1) // (args.batch // args.device_number) + 1 tqdm_show = tqdm.tqdm(total=tqdm_len) with torch.no_grad(): model = ema.ema if args.ema else model.eval() correct = 0 total = 0 loss_all = 0 epoch = 0 for index, (image_batch, true_batch) in enumerate(val_dataloader): image_batch = image_batch.to(args.device, non_blocking=args.latch) pred_batch = model(image_batch).detach().cpu() loss_batch = loss(pred_batch, true_batch) # 获取指标项 _, predicted = torch.max(pred_batch, 1) total += true_batch.size(0) correct += (predicted == true_batch).sum().item() loss_all += loss_batch.item() epoch = epoch + 1 # 更新进度条数据 tqdm_show.set_postfix({'val_loss': loss_batch.item()}) # 添加显示 tqdm_show.update(1) # 更新进度条 # tqdm tqdm_show.close() # 计算指标 accuracy = correct / total print(f'\n| 验证 | val_loss:{loss_all/epoch:.4f} | val_accuracy:{accuracy:.4f} |') return loss_all, accuracy