train_with_watermark.py 9.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171
  1. import cv2
  2. import tqdm
  3. import wandb
  4. import torch
  5. import numpy as np
  6. from torch import nn
  7. from torchvision import transforms
  8. from watermark_codec import ModelEncoder
  9. from block.dataset_get import CustomDataset
  10. from block.val_get import val_get
  11. from block.model_ema import model_ema
  12. from block.lr_get import adam, lr_adjust
  13. def train_embed(args, model_dict, loss, secret):
  14. # 加载模型
  15. model = model_dict['model'].to(args.device, non_blocking=args.latch)
  16. print(model)
  17. # 选择加密层并初始化白盒水印编码器
  18. conv_list = []
  19. for module in model.modules():
  20. if isinstance(module, nn.Conv2d):
  21. conv_list.append(module)
  22. conv_list = conv_list[0:2]
  23. model_dict['enc_layers'] = conv_list # 将加密层保存至权重文件中
  24. encoder = ModelEncoder(layers=conv_list, secret=secret, key_path=args.key_path, device=args.device)
  25. # 数据集
  26. print("加载训练集至内存中...")
  27. train_transform = transforms.Compose([
  28. transforms.RandomHorizontalFlip(), # 随机水平翻转
  29. transforms.RandomCrop(32, padding=4), # 随机裁剪并填充
  30. transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1), # 颜色抖动
  31. transforms.ToTensor(), # 将图像转换为PyTorch张量
  32. transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) # 标准化
  33. ])
  34. train_dataset = CustomDataset(data_dir=args.train_dir, image_size=(args.input_size, args.input_size), transform=train_transform)
  35. train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset) if args.distributed else None
  36. train_shuffle = False if args.distributed else True # 分布式设置sampler后shuffle要为False
  37. train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=args.batch, shuffle=train_shuffle,
  38. drop_last=True, pin_memory=args.latch, num_workers=args.num_worker,
  39. sampler=train_sampler)
  40. print("加载验证集至内存中...")
  41. val_transform = transforms.Compose([
  42. transforms.ToTensor(), # 将图像转换为PyTorch张量
  43. transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) # 标准化
  44. ])
  45. val_dataset = CustomDataset(data_dir=args.test_dir, image_size=(args.input_size, args.input_size), transform=val_transform)
  46. val_sampler = None # 分布式时数据合在主GPU上进行验证
  47. val_batch = args.batch // args.device_number # 分布式验证时batch要减少为一个GPU的量
  48. val_dataloader = torch.utils.data.DataLoader(val_dataset, batch_size=val_batch, shuffle=False,
  49. drop_last=False, pin_memory=args.latch, num_workers=args.num_worker,
  50. sampler=val_sampler)
  51. # 学习率
  52. optimizer = adam(args.regularization, args.r_value, model.parameters(), lr=args.lr_start, betas=(0.937, 0.999))
  53. optimizer.load_state_dict(model_dict['optimizer_state_dict']) if model_dict['optimizer_state_dict'] else None
  54. train_len = train_dataset.__len__()
  55. step_epoch = train_len // args.batch // args.device_number * args.device_number # 每轮的步数
  56. print(train_len // args.batch)
  57. print(step_epoch)
  58. optimizer_adjust = lr_adjust(args, step_epoch, model_dict['epoch_finished']) # 学习率调整函数
  59. optimizer = optimizer_adjust(optimizer) # 学习率初始化
  60. # 使用平均指数移动(EMA)调整参数(不能将ema放到args中,否则会导致模型保存出错)
  61. ema = model_ema(model) if args.ema else None
  62. if args.ema:
  63. ema.updates = model_dict['ema_updates']
  64. # 分布式初始化
  65. model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.local_rank],
  66. output_device=args.local_rank) if args.distributed else model
  67. # wandb
  68. if args.wandb and args.local_rank == 0:
  69. wandb_image_list = [] # 记录所有的wandb_image最后一起添加(最多添加args.wandb_image_num张)
  70. epoch_base = model_dict['epoch_finished'] + 1 # 新的一轮要+1
  71. for epoch in range(epoch_base, args.epoch + 1): # 训练
  72. print(f'\n-----------------------第{epoch}轮-----------------------') if args.local_rank == 0 else None
  73. model.train()
  74. train_loss = 0 # 记录损失
  75. train_embed_loss = 0
  76. if args.local_rank == 0: # tqdm
  77. tqdm_show = tqdm.tqdm(total=step_epoch)
  78. for index, (image_batch, true_batch) in enumerate(train_dataloader):
  79. if args.wandb and args.local_rank == 0 and len(wandb_image_list) < args.wandb_image_num:
  80. wandb_image_batch = (image_batch * 255).cpu().numpy().astype(np.uint8).transpose(0, 2, 3, 1)
  81. image_batch = image_batch.to(args.device, non_blocking=args.latch)
  82. true_batch = true_batch.to(args.device, non_blocking=args.latch)
  83. if args.amp:
  84. with torch.cuda.amp.autocast():
  85. pred_batch = model(image_batch)
  86. loss_batch = loss(pred_batch, true_batch)
  87. embed_loss = encoder.get_embeder_loss() # 获取水印嵌入损失
  88. loss_batch = embed_loss + loss_batch # 修改原始损失
  89. args.amp.scale(loss_batch).backward()
  90. args.amp.step(optimizer)
  91. args.amp.update()
  92. optimizer.zero_grad()
  93. else:
  94. pred_batch = model(image_batch)
  95. loss_batch = loss(pred_batch, true_batch)
  96. embed_loss = encoder.get_embeder_loss() # 获取水印嵌入损失
  97. loss_batch = embed_loss + loss_batch # 修改原始损失
  98. loss_batch.backward()
  99. optimizer.step()
  100. optimizer.zero_grad()
  101. # 调整参数,ema.updates会自动+1
  102. ema.update(model) if args.ema else None
  103. # 记录损失
  104. train_loss += loss_batch.item()
  105. train_embed_loss += embed_loss.item()
  106. # 调整学习率
  107. optimizer = optimizer_adjust(optimizer)
  108. # tqdm
  109. if args.local_rank == 0:
  110. tqdm_show.set_postfix({'train_loss': loss_batch.item(), 'embed_loss': embed_loss.item(),
  111. 'lr': optimizer.param_groups[0]['lr']}) # 添加显示
  112. tqdm_show.update(args.device_number) # 更新进度条
  113. # wandb
  114. if args.wandb and args.local_rank == 0 and epoch == 0 and len(wandb_image_list) < args.wandb_image_num:
  115. cls = true_batch.cpu().numpy().tolist()
  116. for i in range(len(wandb_image_batch)): # 遍历每一张图片
  117. image = wandb_image_batch[i]
  118. text = ['{:.0f}'.format(_) for _ in cls[i]]
  119. text = text[0] if len(text) == 1 else '--'.join(text)
  120. image = np.ascontiguousarray(image) # 将数组的内存变为连续存储(cv2画图的要求)
  121. cv2.putText(image, text, (30, 30), cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 255, 0), 2)
  122. wandb_image = wandb.Image(image)
  123. wandb_image_list.append(wandb_image)
  124. if len(wandb_image_list) == args.wandb_image_num:
  125. break
  126. # tqdm
  127. if args.local_rank == 0:
  128. tqdm_show.close()
  129. # 计算平均损失
  130. train_loss /= index + 1
  131. train_embed_loss /= index + 1
  132. if args.local_rank == 0:
  133. print(f'\n| 训练损失 | train_loss:{train_loss:.4f} | 水印损失 | train_embed_loss:{train_embed_loss:.4f} | lr:{optimizer.param_groups[0]["lr"]:.6f} |\n')
  134. # 清理显存空间
  135. del image_batch, true_batch, pred_batch, loss_batch
  136. torch.cuda.empty_cache()
  137. # 验证
  138. if args.local_rank == 0: # 分布式时只验证一次
  139. val_loss, accuracy = val_get(args, val_dataloader, model, loss, ema, val_dataset.__len__())
  140. # 保存
  141. if args.local_rank == 0: # 分布式时只保存一次
  142. model_dict['model'] = model.module if args.distributed else model
  143. model_dict['epoch_finished'] = epoch
  144. model_dict['optimizer_state_dict'] = optimizer.state_dict()
  145. model_dict['ema_updates'] = ema.updates if args.ema else model_dict['ema_updates']
  146. model_dict['train_loss'] = train_loss
  147. model_dict['val_loss'] = val_loss
  148. model_dict['val_accuracy'] = accuracy
  149. torch.save(model_dict, args.save_path_last if not args.prune else 'prune_last.pt') # 保存最后一次训练的模型
  150. if accuracy > 0.5 and accuracy > model_dict['standard']:
  151. model_dict['standard'] = accuracy
  152. save_path = args.save_path if not args.prune else args.prune_save
  153. torch.save(model_dict, save_path) # 保存最佳模型
  154. print(f'| 保存最佳模型:{save_path} | accuracy:{accuracy:.4f} |')
  155. # wandb
  156. if args.wandb:
  157. wandb_log = {}
  158. if epoch == 0:
  159. wandb_log.update({f'image/train_image': wandb_image_list})
  160. wandb_log.update({'metric/train_loss': train_loss,
  161. 'metric/val_loss': val_loss,
  162. 'metric/val_accuracy': accuracy
  163. })
  164. args.wandb_run.log(wandb_log)
  165. torch.distributed.barrier() if args.distributed else None # 分布式时每轮训练后让所有GPU进行同步,快的GPU会在此等待