train_get.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227
  1. import os
  2. import cv2
  3. import tqdm
  4. import wandb
  5. import torch
  6. import numpy as np
  7. # import albumentations
  8. from PIL import Image
  9. from torchvision import transforms
  10. from block.val_get import val_get
  11. from block.model_ema import model_ema
  12. from block.lr_get import adam, lr_adjust
  13. def train_get(args, data_dict, model_dict, loss):
  14. # 加载模型
  15. model = model_dict['model'].to(args.device, non_blocking=args.latch)
  16. print(model)
  17. # 学习率
  18. optimizer = adam(args.regularization, args.r_value, model.parameters(), lr=args.lr_start, betas=(0.937, 0.999))
  19. optimizer.load_state_dict(model_dict['optimizer_state_dict']) if model_dict['optimizer_state_dict'] else None
  20. step_epoch = len(data_dict['train']) // args.batch // args.device_number * args.device_number # 每轮的步数
  21. print(len(data_dict['train']) // args.batch)
  22. print(step_epoch)
  23. optimizer_adjust = lr_adjust(args, step_epoch, model_dict['epoch_finished']) # 学习率调整函数
  24. optimizer = optimizer_adjust(optimizer) # 学习率初始化
  25. # 使用平均指数移动(EMA)调整参数(不能将ema放到args中,否则会导致模型保存出错)
  26. ema = model_ema(model) if args.ema else None
  27. if args.ema:
  28. ema.updates = model_dict['ema_updates']
  29. # 数据集
  30. print("加载训练集至内存中...")
  31. train_transform = transforms.Compose([
  32. transforms.RandomHorizontalFlip(), # 随机水平翻转
  33. transforms.RandomCrop(32, padding=4), # 随机裁剪并填充
  34. transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1), # 颜色抖动
  35. transforms.ToTensor(), # 将图像转换为PyTorch张量
  36. transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) # 标准化
  37. ])
  38. train_dataset = CustomDataset(data_dir=args.train_dir, transform=train_transform)
  39. train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset) if args.distributed else None
  40. train_shuffle = False if args.distributed else True # 分布式设置sampler后shuffle要为False
  41. train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=args.batch, shuffle=train_shuffle,
  42. drop_last=True, pin_memory=args.latch, num_workers=args.num_worker,
  43. sampler=train_sampler)
  44. print("加载验证集至内存中...")
  45. val_transform = transforms.Compose([
  46. transforms.ToTensor(), # 将图像转换为PyTorch张量
  47. transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) # 标准化
  48. ])
  49. val_dataset = CustomDataset(data_dir=args.test_dir, transform=val_transform)
  50. val_sampler = None # 分布式时数据合在主GPU上进行验证
  51. val_batch = args.batch // args.device_number # 分布式验证时batch要减少为一个GPU的量
  52. val_dataloader = torch.utils.data.DataLoader(val_dataset, batch_size=val_batch, shuffle=False,
  53. drop_last=False, pin_memory=args.latch, num_workers=args.num_worker,
  54. sampler=val_sampler)
  55. # 分布式初始化
  56. model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.local_rank],
  57. output_device=args.local_rank) if args.distributed else model
  58. # wandb
  59. if args.wandb and args.local_rank == 0:
  60. wandb_image_list = [] # 记录所有的wandb_image最后一起添加(最多添加args.wandb_image_num张)
  61. epoch_base = model_dict['epoch_finished'] + 1 # 新的一轮要+1
  62. for epoch in range(epoch_base, args.epoch + 1): # 训练
  63. print(f'\n-----------------------第{epoch}轮-----------------------') if args.local_rank == 0 else None
  64. model.train()
  65. train_loss = 0 # 记录损失
  66. if args.local_rank == 0: # tqdm
  67. tqdm_show = tqdm.tqdm(total=step_epoch)
  68. for index, (image_batch, true_batch) in enumerate(train_dataloader):
  69. if args.wandb and args.local_rank == 0 and len(wandb_image_list) < args.wandb_image_num:
  70. wandb_image_batch = (image_batch * 255).cpu().numpy().astype(np.uint8).transpose(0, 2, 3, 1)
  71. image_batch = image_batch.to(args.device, non_blocking=args.latch)
  72. true_batch = true_batch.to(args.device, non_blocking=args.latch)
  73. if args.amp:
  74. with torch.cuda.amp.autocast():
  75. pred_batch = model(image_batch)
  76. loss_batch = loss(pred_batch, true_batch)
  77. args.amp.scale(loss_batch).backward()
  78. args.amp.step(optimizer)
  79. args.amp.update()
  80. optimizer.zero_grad()
  81. else:
  82. pred_batch = model(image_batch)
  83. loss_batch = loss(pred_batch, true_batch)
  84. loss_batch.backward()
  85. optimizer.step()
  86. optimizer.zero_grad()
  87. # 调整参数,ema.updates会自动+1
  88. ema.update(model) if args.ema else None
  89. # 记录损失
  90. train_loss += loss_batch.item()
  91. # 调整学习率
  92. optimizer = optimizer_adjust(optimizer)
  93. # tqdm
  94. if args.local_rank == 0:
  95. tqdm_show.set_postfix({'train_loss': loss_batch.item(),
  96. 'lr': optimizer.param_groups[0]['lr']}) # 添加显示
  97. tqdm_show.update(args.device_number) # 更新进度条
  98. # wandb
  99. if args.wandb and args.local_rank == 0 and epoch == 0 and len(wandb_image_list) < args.wandb_image_num:
  100. cls = true_batch.cpu().numpy().tolist()
  101. for i in range(len(wandb_image_batch)): # 遍历每一张图片
  102. image = wandb_image_batch[i]
  103. text = ['{:.0f}'.format(_) for _ in cls[i]]
  104. text = text[0] if len(text) == 1 else '--'.join(text)
  105. image = np.ascontiguousarray(image) # 将数组的内存变为连续存储(cv2画图的要求)
  106. cv2.putText(image, text, (30, 30), cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 255, 0), 2)
  107. wandb_image = wandb.Image(image)
  108. wandb_image_list.append(wandb_image)
  109. if len(wandb_image_list) == args.wandb_image_num:
  110. break
  111. # tqdm
  112. if args.local_rank == 0:
  113. tqdm_show.close()
  114. # 计算平均损失
  115. train_loss /= index + 1
  116. if args.local_rank == 0:
  117. print(f'\n| 训练 | train_loss:{train_loss:.4f} | lr:{optimizer.param_groups[0]["lr"]:.6f} |\n')
  118. # 清理显存空间
  119. del image_batch, true_batch, pred_batch, loss_batch
  120. torch.cuda.empty_cache()
  121. # 验证
  122. if args.local_rank == 0: # 分布式时只验证一次
  123. val_loss, accuracy = val_get(args, val_dataloader, model, loss, ema,
  124. len(data_dict['test']))
  125. # 保存
  126. if args.local_rank == 0: # 分布式时只保存一次
  127. model_dict['model'] = model.module if args.distributed else model
  128. model_dict['epoch_finished'] = epoch
  129. model_dict['optimizer_state_dict'] = optimizer.state_dict()
  130. model_dict['ema_updates'] = ema.updates if args.ema else model_dict['ema_updates']
  131. model_dict['class'] = data_dict['class']
  132. model_dict['train_loss'] = train_loss
  133. model_dict['val_loss'] = val_loss
  134. model_dict['val_accuracy'] = accuracy
  135. torch.save(model_dict, args.save_path_last if not args.prune else 'prune_last.pt') # 保存最后一次训练的模型
  136. if accuracy > 0.5 and accuracy > model_dict['standard']:
  137. model_dict['standard'] = accuracy
  138. save_path = args.save_path if not args.prune else args.prune_save
  139. torch.save(model_dict, save_path) # 保存最佳模型
  140. print(f'| 保存最佳模型:{save_path} | accuracy:{accuracy:.4f} |')
  141. # wandb
  142. if args.wandb:
  143. wandb_log = {}
  144. if epoch == 0:
  145. wandb_log.update({f'image/train_image': wandb_image_list})
  146. wandb_log.update({'metric/train_loss': train_loss,
  147. 'metric/val_loss': val_loss,
  148. 'metric/val_accuracy': accuracy
  149. })
  150. args.wandb_run.log(wandb_log)
  151. torch.distributed.barrier() if args.distributed else None # 分布式时每轮训练后让所有GPU进行同步,快的GPU会在此等待
  152. # class torch_dataset(torch.utils.data.Dataset):
  153. # def __init__(self, args, tag, data, class_name):
  154. # self.tag = tag
  155. # self.data = data
  156. # self.class_name = class_name
  157. # self.noise_probability = args.noise
  158. # self.noise = albumentations.Compose([
  159. # albumentations.GaussianBlur(blur_limit=(5, 5), p=0.2),
  160. # albumentations.GaussNoise(var_limit=(10.0, 30.0), p=0.2)])
  161. # self.transform = albumentations.Compose([
  162. # albumentations.LongestMaxSize(args.input_size),
  163. # albumentations.PadIfNeeded(min_height=args.input_size, min_width=args.input_size,
  164. # border_mode=cv2.BORDER_CONSTANT, value=(128, 128, 128))])
  165. # self.rgb_mean = (0.406, 0.456, 0.485)
  166. # self.rgb_std = (0.225, 0.224, 0.229)
  167. #
  168. # def __len__(self):
  169. # return len(self.data)
  170. #
  171. # def __getitem__(self, index):
  172. # # print(self.data[index][0])
  173. # image = cv2.imread(self.data[index][0]) # 读取图片
  174. # if self.tag == 'train' and torch.rand(1) < self.noise_probability: # 使用数据加噪
  175. # image = self.noise(image=image)['image']
  176. # image = self.transform(image=image)['image'] # 缩放和填充图片
  177. # image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) # 转为RGB通道
  178. # image = self._image_deal(image) # 归一化、转换为tensor、调维度
  179. # label = torch.tensor(self.data[index][1], dtype=torch.float32) # 转换为tensor
  180. # return image, label
  181. #
  182. # def _image_deal(self, image): # 归一化、转换为tensor、调维度
  183. # image = torch.tensor(image / 255, dtype=torch.float32).permute(2, 0, 1)
  184. # return image
  185. class CustomDataset(torch.utils.data.Dataset):
  186. def __init__(self, data_dir, image_size=(32, 32), transform=None):
  187. self.data_dir = data_dir
  188. self.image_size = image_size
  189. self.transform = transform
  190. self.images = []
  191. self.labels = []
  192. # 遍历指定目录下的子目录,每个子目录代表一个类别
  193. class_dirs = sorted(os.listdir(data_dir))
  194. for index, class_dir in enumerate(class_dirs):
  195. class_path = os.path.join(data_dir, class_dir)
  196. # 遍历当前类别目录下的图像文件
  197. for image_file in os.listdir(class_path):
  198. image_path = os.path.join(class_path, image_file)
  199. # 使用PIL加载图像并调整大小
  200. image = Image.open(image_path).convert('RGB')
  201. image = image.resize(image_size)
  202. self.images.append(np.array(image))
  203. self.labels.append(index)
  204. def __len__(self):
  205. return len(self.images)
  206. def __getitem__(self, idx):
  207. image = self.images[idx]
  208. label = self.labels[idx]
  209. if self.transform:
  210. image = self.transform(Image.fromarray(image))
  211. return image, label