1234567891011121314151617181920212223242526272829303132 |
- import tqdm
- import torch
- def val_get(args, val_dataloader, model, loss, ema, data_len):
- tqdm_len = (data_len - 1) // (args.batch // args.device_number) + 1
- tqdm_show = tqdm.tqdm(total=tqdm_len)
- with torch.no_grad():
- model = ema.ema if args.ema else model.eval()
- correct = 0
- total = 0
- loss_all = 0
- epoch = 0
- for index, (image_batch, true_batch) in enumerate(val_dataloader):
- image_batch = image_batch.to(args.device, non_blocking=args.latch)
- pred_batch = model(image_batch).detach().cpu()
- loss_batch = loss(pred_batch, true_batch)
- # 获取指标项
- _, predicted = torch.max(pred_batch, 1)
- total += true_batch.size(0)
- correct += (predicted == true_batch).sum().item()
- loss_all += loss_batch.item()
- epoch = epoch + 1
- # 更新进度条数据
- tqdm_show.set_postfix({'val_loss': loss_batch.item()}) # 添加显示
- tqdm_show.update(1) # 更新进度条
- # tqdm
- tqdm_show.close()
- # 计算指标
- accuracy = correct / total
- print(f'\n| 验证 | val_loss:{loss_all/epoch:.4f} | val_accuracy:{accuracy:.4f} |')
- return loss_all, accuracy
|