train_with_watermark.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211
  1. import os
  2. import cv2
  3. import tqdm
  4. import wandb
  5. import torch
  6. import numpy as np
  7. from PIL import Image
  8. from torch import nn
  9. from torchvision import transforms
  10. from watermark_codec import ModelEncoder
  11. from block.val_get import val_get
  12. from block.model_ema import model_ema
  13. from block.lr_get import adam, lr_adjust
  14. def train_embed(args, model_dict, loss, secret):
  15. # 加载模型
  16. model = model_dict['model'].to(args.device, non_blocking=args.latch)
  17. print(model)
  18. # 选择加密层并初始化白盒水印编码器
  19. conv_list = []
  20. for module in model.modules():
  21. if isinstance(module, nn.Conv2d):
  22. conv_list.append(module)
  23. conv_list = conv_list[0:2]
  24. model_dict['enc_layers'] = conv_list # 将加密层保存至权重文件中
  25. encoder = ModelEncoder(layers=conv_list, secret=secret, key_path=args.key_path, device=args.device)
  26. # 数据集
  27. print("加载训练集至内存中...")
  28. train_transform = transforms.Compose([
  29. transforms.RandomHorizontalFlip(), # 随机水平翻转
  30. transforms.RandomCrop(32, padding=4), # 随机裁剪并填充
  31. transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1), # 颜色抖动
  32. transforms.ToTensor(), # 将图像转换为PyTorch张量
  33. transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) # 标准化
  34. ])
  35. train_dataset = CustomDataset(data_dir=args.train_dir, transform=train_transform)
  36. train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset) if args.distributed else None
  37. train_shuffle = False if args.distributed else True # 分布式设置sampler后shuffle要为False
  38. train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=args.batch, shuffle=train_shuffle,
  39. drop_last=True, pin_memory=args.latch, num_workers=args.num_worker,
  40. sampler=train_sampler)
  41. print("加载验证集至内存中...")
  42. val_transform = transforms.Compose([
  43. transforms.ToTensor(), # 将图像转换为PyTorch张量
  44. transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) # 标准化
  45. ])
  46. val_dataset = CustomDataset(data_dir=args.test_dir, transform=val_transform)
  47. val_sampler = None # 分布式时数据合在主GPU上进行验证
  48. val_batch = args.batch // args.device_number # 分布式验证时batch要减少为一个GPU的量
  49. val_dataloader = torch.utils.data.DataLoader(val_dataset, batch_size=val_batch, shuffle=False,
  50. drop_last=False, pin_memory=args.latch, num_workers=args.num_worker,
  51. sampler=val_sampler)
  52. # 学习率
  53. optimizer = adam(args.regularization, args.r_value, model.parameters(), lr=args.lr_start, betas=(0.937, 0.999))
  54. optimizer.load_state_dict(model_dict['optimizer_state_dict']) if model_dict['optimizer_state_dict'] else None
  55. train_len = train_dataset.__len__()
  56. step_epoch = train_len // args.batch // args.device_number * args.device_number # 每轮的步数
  57. print(train_len // args.batch)
  58. print(step_epoch)
  59. optimizer_adjust = lr_adjust(args, step_epoch, model_dict['epoch_finished']) # 学习率调整函数
  60. optimizer = optimizer_adjust(optimizer) # 学习率初始化
  61. # 使用平均指数移动(EMA)调整参数(不能将ema放到args中,否则会导致模型保存出错)
  62. ema = model_ema(model) if args.ema else None
  63. if args.ema:
  64. ema.updates = model_dict['ema_updates']
  65. # 分布式初始化
  66. model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.local_rank],
  67. output_device=args.local_rank) if args.distributed else model
  68. # wandb
  69. if args.wandb and args.local_rank == 0:
  70. wandb_image_list = [] # 记录所有的wandb_image最后一起添加(最多添加args.wandb_image_num张)
  71. epoch_base = model_dict['epoch_finished'] + 1 # 新的一轮要+1
  72. for epoch in range(epoch_base, args.epoch + 1): # 训练
  73. print(f'\n-----------------------第{epoch}轮-----------------------') if args.local_rank == 0 else None
  74. model.train()
  75. train_loss = 0 # 记录损失
  76. train_embed_loss = 0
  77. if args.local_rank == 0: # tqdm
  78. tqdm_show = tqdm.tqdm(total=step_epoch)
  79. for index, (image_batch, true_batch) in enumerate(train_dataloader):
  80. if args.wandb and args.local_rank == 0 and len(wandb_image_list) < args.wandb_image_num:
  81. wandb_image_batch = (image_batch * 255).cpu().numpy().astype(np.uint8).transpose(0, 2, 3, 1)
  82. image_batch = image_batch.to(args.device, non_blocking=args.latch)
  83. true_batch = true_batch.to(args.device, non_blocking=args.latch)
  84. if args.amp:
  85. with torch.cuda.amp.autocast():
  86. pred_batch = model(image_batch)
  87. loss_batch = loss(pred_batch, true_batch)
  88. embed_loss = encoder.get_embeder_loss() # 获取水印嵌入损失
  89. loss_batch = embed_loss + loss_batch # 修改原始损失
  90. args.amp.scale(loss_batch).backward()
  91. args.amp.step(optimizer)
  92. args.amp.update()
  93. optimizer.zero_grad()
  94. else:
  95. pred_batch = model(image_batch)
  96. loss_batch = loss(pred_batch, true_batch)
  97. embed_loss = encoder.get_embeder_loss() # 获取水印嵌入损失
  98. loss_batch = embed_loss + loss_batch # 修改原始损失
  99. loss_batch.backward()
  100. optimizer.step()
  101. optimizer.zero_grad()
  102. # 调整参数,ema.updates会自动+1
  103. ema.update(model) if args.ema else None
  104. # 记录损失
  105. train_loss += loss_batch.item()
  106. train_embed_loss += embed_loss.item()
  107. # 调整学习率
  108. optimizer = optimizer_adjust(optimizer)
  109. # tqdm
  110. if args.local_rank == 0:
  111. tqdm_show.set_postfix({'train_loss': loss_batch.item(), 'embed_loss': embed_loss.item(),
  112. 'lr': optimizer.param_groups[0]['lr']}) # 添加显示
  113. tqdm_show.update(args.device_number) # 更新进度条
  114. # wandb
  115. if args.wandb and args.local_rank == 0 and epoch == 0 and len(wandb_image_list) < args.wandb_image_num:
  116. cls = true_batch.cpu().numpy().tolist()
  117. for i in range(len(wandb_image_batch)): # 遍历每一张图片
  118. image = wandb_image_batch[i]
  119. text = ['{:.0f}'.format(_) for _ in cls[i]]
  120. text = text[0] if len(text) == 1 else '--'.join(text)
  121. image = np.ascontiguousarray(image) # 将数组的内存变为连续存储(cv2画图的要求)
  122. cv2.putText(image, text, (30, 30), cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 255, 0), 2)
  123. wandb_image = wandb.Image(image)
  124. wandb_image_list.append(wandb_image)
  125. if len(wandb_image_list) == args.wandb_image_num:
  126. break
  127. # tqdm
  128. if args.local_rank == 0:
  129. tqdm_show.close()
  130. # 计算平均损失
  131. train_loss /= index + 1
  132. train_embed_loss /= index + 1
  133. if args.local_rank == 0:
  134. print(f'\n| 训练损失 | train_loss:{train_loss:.4f} | 水印损失 | train_embed_loss:{train_embed_loss:.4f} | lr:{optimizer.param_groups[0]["lr"]:.6f} |\n')
  135. # 清理显存空间
  136. del image_batch, true_batch, pred_batch, loss_batch
  137. torch.cuda.empty_cache()
  138. # 验证
  139. if args.local_rank == 0: # 分布式时只验证一次
  140. val_loss, accuracy = val_get(args, val_dataloader, model, loss, ema, val_dataset.__len__())
  141. # 保存
  142. if args.local_rank == 0: # 分布式时只保存一次
  143. model_dict['model'] = model.module if args.distributed else model
  144. model_dict['epoch_finished'] = epoch
  145. model_dict['optimizer_state_dict'] = optimizer.state_dict()
  146. model_dict['ema_updates'] = ema.updates if args.ema else model_dict['ema_updates']
  147. model_dict['train_loss'] = train_loss
  148. model_dict['val_loss'] = val_loss
  149. model_dict['val_accuracy'] = accuracy
  150. torch.save(model_dict, args.save_path_last if not args.prune else 'prune_last.pt') # 保存最后一次训练的模型
  151. if accuracy > 0.5 and accuracy > model_dict['standard']:
  152. model_dict['standard'] = accuracy
  153. save_path = args.save_path if not args.prune else args.prune_save
  154. torch.save(model_dict, save_path) # 保存最佳模型
  155. print(f'| 保存最佳模型:{save_path} | accuracy:{accuracy:.4f} |')
  156. # wandb
  157. if args.wandb:
  158. wandb_log = {}
  159. if epoch == 0:
  160. wandb_log.update({f'image/train_image': wandb_image_list})
  161. wandb_log.update({'metric/train_loss': train_loss,
  162. 'metric/val_loss': val_loss,
  163. 'metric/val_accuracy': accuracy
  164. })
  165. args.wandb_run.log(wandb_log)
  166. torch.distributed.barrier() if args.distributed else None # 分布式时每轮训练后让所有GPU进行同步,快的GPU会在此等待
  167. class CustomDataset(torch.utils.data.Dataset):
  168. def __init__(self, data_dir, image_size=(32, 32), transform=None):
  169. self.data_dir = data_dir
  170. self.image_size = image_size
  171. self.transform = transform
  172. self.images = []
  173. self.labels = []
  174. # 遍历指定目录下的子目录,每个子目录代表一个类别
  175. class_dirs = sorted(os.listdir(data_dir))
  176. for index, class_dir in enumerate(class_dirs):
  177. class_path = os.path.join(data_dir, class_dir)
  178. # 遍历当前类别目录下的图像文件
  179. for image_file in os.listdir(class_path):
  180. image_path = os.path.join(class_path, image_file)
  181. # 使用PIL加载图像并调整大小
  182. image = Image.open(image_path).convert('RGB')
  183. image = image.resize(image_size)
  184. self.images.append(np.array(image))
  185. self.labels.append(index)
  186. def __len__(self):
  187. return len(self.images)
  188. def __getitem__(self, idx):
  189. image = self.images[idx]
  190. label = self.labels[idx]
  191. if self.transform:
  192. image = self.transform(Image.fromarray(image))
  193. return image, label