|
@@ -1,6 +1,6 @@
|
|
|
import cv2
|
|
|
import tqdm
|
|
|
-import wandb
|
|
|
+# import wandb
|
|
|
import torch
|
|
|
import numpy as np
|
|
|
from torch import nn
|
|
@@ -71,8 +71,8 @@ def train_embed(args, model_dict, loss, secret):
|
|
|
model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.local_rank],
|
|
|
output_device=args.local_rank) if args.distributed else model
|
|
|
# wandb
|
|
|
- if args.wandb and args.local_rank == 0:
|
|
|
- wandb_image_list = [] # 记录所有的wandb_image最后一起添加(最多添加args.wandb_image_num张)
|
|
|
+ # if args.wandb and args.local_rank == 0:
|
|
|
+ # wandb_image_list = [] # 记录所有的wandb_image最后一起添加(最多添加args.wandb_image_num张)
|
|
|
epoch_base = model_dict['epoch_finished'] + 1 # 新的一轮要+1
|
|
|
for epoch in range(epoch_base, args.epoch + 1): # 训练
|
|
|
print(f'\n-----------------------第{epoch}轮-----------------------') if args.local_rank == 0 else None
|
|
@@ -82,8 +82,8 @@ def train_embed(args, model_dict, loss, secret):
|
|
|
if args.local_rank == 0: # tqdm
|
|
|
tqdm_show = tqdm.tqdm(total=step_epoch)
|
|
|
for index, (image_batch, true_batch) in enumerate(train_dataloader):
|
|
|
- if args.wandb and args.local_rank == 0 and len(wandb_image_list) < args.wandb_image_num:
|
|
|
- wandb_image_batch = (image_batch * 255).cpu().numpy().astype(np.uint8).transpose(0, 2, 3, 1)
|
|
|
+ # if args.wandb and args.local_rank == 0 and len(wandb_image_list) < args.wandb_image_num:
|
|
|
+ # wandb_image_batch = (image_batch * 255).cpu().numpy().astype(np.uint8).transpose(0, 2, 3, 1)
|
|
|
image_batch = image_batch.to(args.device, non_blocking=args.latch)
|
|
|
true_batch = true_batch.to(args.device, non_blocking=args.latch)
|
|
|
if args.amp:
|
|
@@ -117,18 +117,18 @@ def train_embed(args, model_dict, loss, secret):
|
|
|
'lr': optimizer.param_groups[0]['lr']}) # 添加显示
|
|
|
tqdm_show.update(args.device_number) # 更新进度条
|
|
|
# wandb
|
|
|
- if args.wandb and args.local_rank == 0 and epoch == 0 and len(wandb_image_list) < args.wandb_image_num:
|
|
|
- cls = true_batch.cpu().numpy().tolist()
|
|
|
- for i in range(len(wandb_image_batch)): # 遍历每一张图片
|
|
|
- image = wandb_image_batch[i]
|
|
|
- text = ['{:.0f}'.format(_) for _ in cls[i]]
|
|
|
- text = text[0] if len(text) == 1 else '--'.join(text)
|
|
|
- image = np.ascontiguousarray(image) # 将数组的内存变为连续存储(cv2画图的要求)
|
|
|
- cv2.putText(image, text, (30, 30), cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 255, 0), 2)
|
|
|
- wandb_image = wandb.Image(image)
|
|
|
- wandb_image_list.append(wandb_image)
|
|
|
- if len(wandb_image_list) == args.wandb_image_num:
|
|
|
- break
|
|
|
+ # if args.wandb and args.local_rank == 0 and epoch == 0 and len(wandb_image_list) < args.wandb_image_num:
|
|
|
+ # cls = true_batch.cpu().numpy().tolist()
|
|
|
+ # for i in range(len(wandb_image_batch)): # 遍历每一张图片
|
|
|
+ # image = wandb_image_batch[i]
|
|
|
+ # text = ['{:.0f}'.format(_) for _ in cls[i]]
|
|
|
+ # text = text[0] if len(text) == 1 else '--'.join(text)
|
|
|
+ # image = np.ascontiguousarray(image) # 将数组的内存变为连续存储(cv2画图的要求)
|
|
|
+ # cv2.putText(image, text, (30, 30), cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 255, 0), 2)
|
|
|
+ # wandb_image = wandb.Image(image)
|
|
|
+ # wandb_image_list.append(wandb_image)
|
|
|
+ # if len(wandb_image_list) == args.wandb_image_num:
|
|
|
+ # break
|
|
|
# tqdm
|
|
|
if args.local_rank == 0:
|
|
|
tqdm_show.close()
|
|
@@ -159,13 +159,13 @@ def train_embed(args, model_dict, loss, secret):
|
|
|
torch.save(model_dict, save_path) # 保存最佳模型
|
|
|
print(f'| 保存最佳模型:{save_path} | accuracy:{accuracy:.4f} |')
|
|
|
# wandb
|
|
|
- if args.wandb:
|
|
|
- wandb_log = {}
|
|
|
- if epoch == 0:
|
|
|
- wandb_log.update({f'image/train_image': wandb_image_list})
|
|
|
- wandb_log.update({'metric/train_loss': train_loss,
|
|
|
- 'metric/val_loss': val_loss,
|
|
|
- 'metric/val_accuracy': accuracy
|
|
|
- })
|
|
|
- args.wandb_run.log(wandb_log)
|
|
|
+ # if args.wandb:
|
|
|
+ # wandb_log = {}
|
|
|
+ # if epoch == 0:
|
|
|
+ # wandb_log.update({f'image/train_image': wandb_image_list})
|
|
|
+ # wandb_log.update({'metric/train_loss': train_loss,
|
|
|
+ # 'metric/val_loss': val_loss,
|
|
|
+ # 'metric/val_accuracy': accuracy
|
|
|
+ # })
|
|
|
+ # args.wandb_run.log(wandb_log)
|
|
|
torch.distributed.barrier() if args.distributed else None # 分布式时每轮训练后让所有GPU进行同步,快的GPU会在此等待
|