train_with_watermark.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250
  1. import os
  2. import cv2
  3. import tqdm
  4. import torch
  5. import numpy as np
  6. from torch import nn
  7. from torchvision import transforms
  8. from block.dataset_get import CustomDataset
  9. from block.val_get import val_get
  10. from block.model_ema import model_ema
  11. from block.lr_get import adam, lr_adjust
  12. def train_embed(args, model_dict, loss, secret):
  13. # 加载模型
  14. model = model_dict['model'].to(args.device, non_blocking=args.latch)
  15. print(model)
  16. # 选择加密层并初始化白盒水印编码器
  17. conv_list = []
  18. for module in model.modules():
  19. if isinstance(module, nn.Conv2d):
  20. conv_list.append(module)
  21. conv_list = conv_list[1:3]
  22. model_dict['enc_layers'] = conv_list # 将加密层保存至权重文件中
  23. encoder = ModelEncoder(layers=conv_list, secret=secret, key_path=args.key_path, device=args.device)
  24. # 数据集
  25. print("加载训练集至内存中...")
  26. train_transform = transforms.Compose([
  27. transforms.RandomHorizontalFlip(), # 随机水平翻转
  28. transforms.RandomCrop(args.input_size, padding=4), # 随机裁剪并填充
  29. transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1), # 颜色抖动
  30. transforms.ToTensor(), # 将图像转换为PyTorch张量
  31. transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) # 标准化
  32. ])
  33. train_dataset = CustomDataset(data_dir=args.train_dir, image_size=(args.input_size, args.input_size), transform=train_transform)
  34. train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset) if args.distributed else None
  35. train_shuffle = False if args.distributed else True # 分布式设置sampler后shuffle要为False
  36. train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=args.batch, shuffle=train_shuffle,
  37. drop_last=True, pin_memory=args.latch, num_workers=args.num_worker,
  38. sampler=train_sampler)
  39. print("加载验证集至内存中...")
  40. val_transform = transforms.Compose([
  41. transforms.ToTensor(), # 将图像转换为PyTorch张量
  42. transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) # 标准化
  43. ])
  44. val_dataset = CustomDataset(data_dir=args.test_dir, image_size=(args.input_size, args.input_size), transform=val_transform)
  45. val_sampler = None # 分布式时数据合在主GPU上进行验证
  46. val_batch = args.batch // args.device_number # 分布式验证时batch要减少为一个GPU的量
  47. val_dataloader = torch.utils.data.DataLoader(val_dataset, batch_size=val_batch, shuffle=False,
  48. drop_last=False, pin_memory=args.latch, num_workers=args.num_worker,
  49. sampler=val_sampler)
  50. # 学习率
  51. optimizer = adam(args.regularization, args.r_value, model.parameters(), lr=args.lr_start, betas=(0.937, 0.999))
  52. optimizer.load_state_dict(model_dict['optimizer_state_dict']) if model_dict['optimizer_state_dict'] else None
  53. train_len = train_dataset.__len__()
  54. step_epoch = train_len // args.batch // args.device_number * args.device_number # 每轮的步数
  55. print(train_len // args.batch)
  56. print(step_epoch)
  57. optimizer_adjust = lr_adjust(args, step_epoch, model_dict['epoch_finished']) # 学习率调整函数
  58. optimizer = optimizer_adjust(optimizer) # 学习率初始化
  59. # 使用平均指数移动(EMA)调整参数(不能将ema放到args中,否则会导致模型保存出错)
  60. ema = model_ema(model) if args.ema else None
  61. if args.ema:
  62. ema.updates = model_dict['ema_updates']
  63. # 分布式初始化
  64. model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.local_rank],
  65. output_device=args.local_rank) if args.distributed else model
  66. # wandb
  67. # if args.wandb and args.local_rank == 0:
  68. # wandb_image_list = [] # 记录所有的wandb_image最后一起添加(最多添加args.wandb_image_num张)
  69. epoch_base = model_dict['epoch_finished'] + 1 # 新的一轮要+1
  70. for epoch in range(epoch_base, args.epoch + 1): # 训练
  71. print(f'\n-----------------------第{epoch}轮-----------------------') if args.local_rank == 0 else None
  72. model.train()
  73. train_loss = 0 # 记录损失
  74. train_embed_loss = 0
  75. if args.local_rank == 0: # tqdm
  76. tqdm_show = tqdm.tqdm(total=step_epoch)
  77. for index, (image_batch, true_batch) in enumerate(train_dataloader):
  78. # if args.wandb and args.local_rank == 0 and len(wandb_image_list) < args.wandb_image_num:
  79. # wandb_image_batch = (image_batch * 255).cpu().numpy().astype(np.uint8).transpose(0, 2, 3, 1)
  80. image_batch = image_batch.to(args.device, non_blocking=args.latch)
  81. true_batch = true_batch.to(args.device, non_blocking=args.latch)
  82. if args.amp:
  83. with torch.cuda.amp.autocast():
  84. pred_batch = model(image_batch)
  85. loss_batch = loss(pred_batch, true_batch)
  86. embed_loss = encoder.get_embeder_loss() # 获取水印嵌入损失
  87. loss_batch = embed_loss + loss_batch # 修改原始损失
  88. args.amp.scale(loss_batch).backward()
  89. args.amp.step(optimizer)
  90. args.amp.update()
  91. optimizer.zero_grad()
  92. else:
  93. pred_batch = model(image_batch)
  94. loss_batch = loss(pred_batch, true_batch)
  95. embed_loss = encoder.get_embeder_loss() # 获取水印嵌入损失
  96. loss_batch = embed_loss + loss_batch # 修改原始损失
  97. loss_batch.backward()
  98. optimizer.step()
  99. optimizer.zero_grad()
  100. # 调整参数,ema.updates会自动+1
  101. ema.update(model) if args.ema else None
  102. # 记录损失
  103. train_loss += loss_batch.item()
  104. train_embed_loss += embed_loss.item()
  105. # 调整学习率
  106. optimizer = optimizer_adjust(optimizer)
  107. # tqdm
  108. if args.local_rank == 0:
  109. tqdm_show.set_postfix({'train_loss': loss_batch.item(), 'embed_loss': embed_loss.item(),
  110. 'lr': optimizer.param_groups[0]['lr']}) # 添加显示
  111. tqdm_show.update(args.device_number) # 更新进度条
  112. # wandb
  113. # if args.wandb and args.local_rank == 0 and epoch == 0 and len(wandb_image_list) < args.wandb_image_num:
  114. # cls = true_batch.cpu().numpy().tolist()
  115. # for i in range(len(wandb_image_batch)): # 遍历每一张图片
  116. # image = wandb_image_batch[i]
  117. # text = ['{:.0f}'.format(_) for _ in cls[i]]
  118. # text = text[0] if len(text) == 1 else '--'.join(text)
  119. # image = np.ascontiguousarray(image) # 将数组的内存变为连续存储(cv2画图的要求)
  120. # cv2.putText(image, text, (30, 30), cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 255, 0), 2)
  121. # wandb_image = wandb.Image(image)
  122. # wandb_image_list.append(wandb_image)
  123. # if len(wandb_image_list) == args.wandb_image_num:
  124. # break
  125. # tqdm
  126. if args.local_rank == 0:
  127. tqdm_show.close()
  128. # 计算平均损失
  129. train_loss /= index + 1
  130. train_embed_loss /= index + 1
  131. if args.local_rank == 0:
  132. print(f'\n| 训练损失 | train_loss:{train_loss:.4f} | 水印损失 | train_embed_loss:{train_embed_loss:.4f} | lr:{optimizer.param_groups[0]["lr"]:.6f} |\n')
  133. # 清理显存空间
  134. del image_batch, true_batch, pred_batch, loss_batch
  135. torch.cuda.empty_cache()
  136. # 验证
  137. if args.local_rank == 0: # 分布式时只验证一次
  138. val_loss, accuracy = val_get(args, val_dataloader, model, loss, ema, val_dataset.__len__())
  139. # 保存
  140. if args.local_rank == 0: # 分布式时只保存一次
  141. model_dict['model'] = model.module if args.distributed else model
  142. model_dict['epoch_finished'] = epoch
  143. model_dict['optimizer_state_dict'] = optimizer.state_dict()
  144. model_dict['ema_updates'] = ema.updates if args.ema else model_dict['ema_updates']
  145. model_dict['train_loss'] = train_loss
  146. model_dict['val_loss'] = val_loss
  147. model_dict['val_accuracy'] = accuracy
  148. torch.save(model_dict, args.save_path_last if not args.prune else 'prune_last.pt') # 保存最后一次训练的模型
  149. if accuracy > 0.5 and accuracy > model_dict['standard']:
  150. model_dict['standard'] = accuracy
  151. save_path = args.save_path if not args.prune else args.prune_save
  152. torch.save(model_dict, save_path) # 保存最佳模型
  153. print(f'| 保存最佳模型:{save_path} | accuracy:{accuracy:.4f} |')
  154. # wandb
  155. # if args.wandb:
  156. # wandb_log = {}
  157. # if epoch == 0:
  158. # wandb_log.update({f'image/train_image': wandb_image_list})
  159. # wandb_log.update({'metric/train_loss': train_loss,
  160. # 'metric/val_loss': val_loss,
  161. # 'metric/val_accuracy': accuracy
  162. # })
  163. # args.wandb_run.log(wandb_log)
  164. torch.distributed.barrier() if args.distributed else None # 分布式时每轮训练后让所有GPU进行同步,快的GPU会在此等待
  165. class ModelEncoder:
  166. def __init__(self, layers, secret, key_path, device='cuda'):
  167. self.device = device
  168. self.layers = layers
  169. # 处理待嵌入的卷积层
  170. for layer in layers: # 判断传入的目标层是否全部为卷积层
  171. if not isinstance(layer, nn.Conv2d):
  172. raise TypeError('传入参数不是卷积层')
  173. weights = [x.weight for x in layers]
  174. weights = [weight.permute(2, 3, 1, 0) for weight in weights]
  175. w = self.flatten_parameters(weights)
  176. w_init = w.clone().detach()
  177. print('Size of embedding parameters:', w.shape)
  178. # 对密钥进行处理
  179. self.secret = torch.tensor(self.string2bin(secret), dtype=torch.float).to(self.device) # the embedding code
  180. self.secret_len = self.secret.shape[0]
  181. print(f'Secret:{self.secret} secret length:{self.secret_len}')
  182. # 生成随机的投影矩阵
  183. self.X_random = torch.randn((self.secret_len, w_init.shape[0])).to(self.device)
  184. self.save_tensor(self.X_random, key_path) # 保存投影矩阵至指定位置
  185. def get_embeder_loss(self):
  186. """
  187. 获取水印嵌入损失
  188. :return: 水印嵌入的损失值
  189. """
  190. weights = [x.weight for x in self.layers]
  191. weights = [weight.permute(2, 3, 1, 0) for weight in weights] # 使用pytorch框架时,要调整坐标顺序,保持与tensorflow版本一致
  192. w = self.flatten_parameters(weights)
  193. prob = self.get_prob(self.X_random, w)
  194. penalty = self.loss_fun(prob, self.secret)
  195. return penalty
  196. def string2bin(self, s):
  197. binary_representation = ''.join(format(ord(x), '08b') for x in s)
  198. return [int(x) for x in binary_representation]
  199. def save_tensor(self, tensor, save_path):
  200. """
  201. 保存张量至指定文件
  202. :param tensor:待保存的张量
  203. :param save_path: 保存位置,例如:/home/secret.pt
  204. """
  205. os.makedirs(os.path.dirname(save_path), exist_ok=True)
  206. tensor = tensor.cpu()
  207. numpy_array = tensor.numpy()
  208. np.save(save_path, numpy_array)
  209. def flatten_parameters(self, weights):
  210. """
  211. 处理传入的卷积层的权重参数
  212. :param weights: 指定卷积层的权重列表
  213. :return: 处理完成返回的张量
  214. """
  215. return torch.cat([torch.mean(x, dim=3).reshape(-1)
  216. for x in weights])
  217. def get_prob(self, x_random, w):
  218. """
  219. 获取投影矩阵与权重向量的计算结果
  220. :param x_random: 投影矩阵
  221. :param w: 权重向量
  222. :return: 计算记过
  223. """
  224. mm = torch.mm(x_random, w.reshape((w.shape[0], 1)))
  225. return mm.flatten()
  226. def loss_fun(self, x, y):
  227. """
  228. 计算白盒水印嵌入时的损失
  229. :param x: 预测值
  230. :param y: 实际值
  231. :return: 损失
  232. """
  233. return nn.BCEWithLogitsLoss()(x, y)