import os import cv2 import tqdm import torch import numpy as np from torch import nn from torchvision import transforms from block.dataset_get import CustomDataset from block.val_get import val_get from block.model_ema import model_ema from block.lr_get import adam, lr_adjust def train_embed(args, model_dict, loss, secret): # 加载模型 model = model_dict['model'].to(args.device, non_blocking=args.latch) print(model) # 选择加密层并初始化白盒水印编码器 conv_list = [] for module in model.modules(): if isinstance(module, nn.Conv2d): conv_list.append(module) conv_list = conv_list[1:3] model_dict['enc_layers'] = conv_list # 将加密层保存至权重文件中 encoder = ModelEncoder(layers=conv_list, secret=secret, key_path=args.key_path, device=args.device) # 数据集 print("加载训练集至内存中...") train_transform = transforms.Compose([ transforms.RandomHorizontalFlip(), # 随机水平翻转 transforms.RandomCrop(args.input_size, padding=4), # 随机裁剪并填充 transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1), # 颜色抖动 transforms.ToTensor(), # 将图像转换为PyTorch张量 transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) # 标准化 ]) train_dataset = CustomDataset(data_dir=args.train_dir, image_size=(args.input_size, args.input_size), transform=train_transform) train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset) if args.distributed else None train_shuffle = False if args.distributed else True # 分布式设置sampler后shuffle要为False train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=args.batch, shuffle=train_shuffle, drop_last=True, pin_memory=args.latch, num_workers=args.num_worker, sampler=train_sampler) print("加载验证集至内存中...") val_transform = transforms.Compose([ transforms.ToTensor(), # 将图像转换为PyTorch张量 transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) # 标准化 ]) val_dataset = CustomDataset(data_dir=args.test_dir, image_size=(args.input_size, args.input_size), transform=val_transform) val_sampler = None # 分布式时数据合在主GPU上进行验证 val_batch = args.batch // args.device_number # 分布式验证时batch要减少为一个GPU的量 val_dataloader = torch.utils.data.DataLoader(val_dataset, batch_size=val_batch, shuffle=False, drop_last=False, pin_memory=args.latch, num_workers=args.num_worker, sampler=val_sampler) # 学习率 optimizer = adam(args.regularization, args.r_value, model.parameters(), lr=args.lr_start, betas=(0.937, 0.999)) optimizer.load_state_dict(model_dict['optimizer_state_dict']) if model_dict['optimizer_state_dict'] else None train_len = train_dataset.__len__() step_epoch = train_len // args.batch // args.device_number * args.device_number # 每轮的步数 print(train_len // args.batch) print(step_epoch) optimizer_adjust = lr_adjust(args, step_epoch, model_dict['epoch_finished']) # 学习率调整函数 optimizer = optimizer_adjust(optimizer) # 学习率初始化 # 使用平均指数移动(EMA)调整参数(不能将ema放到args中,否则会导致模型保存出错) ema = model_ema(model) if args.ema else None if args.ema: ema.updates = model_dict['ema_updates'] # 分布式初始化 model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.local_rank], output_device=args.local_rank) if args.distributed else model # wandb # if args.wandb and args.local_rank == 0: # wandb_image_list = [] # 记录所有的wandb_image最后一起添加(最多添加args.wandb_image_num张) epoch_base = model_dict['epoch_finished'] + 1 # 新的一轮要+1 for epoch in range(epoch_base, args.epoch + 1): # 训练 print(f'\n-----------------------第{epoch}轮-----------------------') if args.local_rank == 0 else None model.train() train_loss = 0 # 记录损失 train_embed_loss = 0 if args.local_rank == 0: # tqdm tqdm_show = tqdm.tqdm(total=step_epoch) for index, (image_batch, true_batch) in enumerate(train_dataloader): # if args.wandb and args.local_rank == 0 and len(wandb_image_list) < args.wandb_image_num: # wandb_image_batch = (image_batch * 255).cpu().numpy().astype(np.uint8).transpose(0, 2, 3, 1) image_batch = image_batch.to(args.device, non_blocking=args.latch) true_batch = true_batch.to(args.device, non_blocking=args.latch) if args.amp: with torch.cuda.amp.autocast(): pred_batch = model(image_batch) loss_batch = loss(pred_batch, true_batch) embed_loss = encoder.get_embeder_loss() # 获取水印嵌入损失 loss_batch = embed_loss + loss_batch # 修改原始损失 args.amp.scale(loss_batch).backward() args.amp.step(optimizer) args.amp.update() optimizer.zero_grad() else: pred_batch = model(image_batch) loss_batch = loss(pred_batch, true_batch) embed_loss = encoder.get_embeder_loss() # 获取水印嵌入损失 loss_batch = embed_loss + loss_batch # 修改原始损失 loss_batch.backward() optimizer.step() optimizer.zero_grad() # 调整参数,ema.updates会自动+1 ema.update(model) if args.ema else None # 记录损失 train_loss += loss_batch.item() train_embed_loss += embed_loss.item() # 调整学习率 optimizer = optimizer_adjust(optimizer) # tqdm if args.local_rank == 0: tqdm_show.set_postfix({'train_loss': loss_batch.item(), 'embed_loss': embed_loss.item(), 'lr': optimizer.param_groups[0]['lr']}) # 添加显示 tqdm_show.update(args.device_number) # 更新进度条 # wandb # if args.wandb and args.local_rank == 0 and epoch == 0 and len(wandb_image_list) < args.wandb_image_num: # cls = true_batch.cpu().numpy().tolist() # for i in range(len(wandb_image_batch)): # 遍历每一张图片 # image = wandb_image_batch[i] # text = ['{:.0f}'.format(_) for _ in cls[i]] # text = text[0] if len(text) == 1 else '--'.join(text) # image = np.ascontiguousarray(image) # 将数组的内存变为连续存储(cv2画图的要求) # cv2.putText(image, text, (30, 30), cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 255, 0), 2) # wandb_image = wandb.Image(image) # wandb_image_list.append(wandb_image) # if len(wandb_image_list) == args.wandb_image_num: # break # tqdm if args.local_rank == 0: tqdm_show.close() # 计算平均损失 train_loss /= index + 1 train_embed_loss /= index + 1 if args.local_rank == 0: print(f'\n| 训练损失 | train_loss:{train_loss:.4f} | 水印损失 | train_embed_loss:{train_embed_loss:.4f} | lr:{optimizer.param_groups[0]["lr"]:.6f} |\n') # 清理显存空间 del image_batch, true_batch, pred_batch, loss_batch torch.cuda.empty_cache() # 验证 if args.local_rank == 0: # 分布式时只验证一次 val_loss, accuracy = val_get(args, val_dataloader, model, loss, ema, val_dataset.__len__()) # 保存 if args.local_rank == 0: # 分布式时只保存一次 model_dict['model'] = model.module if args.distributed else model model_dict['epoch_finished'] = epoch model_dict['optimizer_state_dict'] = optimizer.state_dict() model_dict['ema_updates'] = ema.updates if args.ema else model_dict['ema_updates'] model_dict['train_loss'] = train_loss model_dict['val_loss'] = val_loss model_dict['val_accuracy'] = accuracy torch.save(model_dict, args.save_path_last if not args.prune else 'prune_last.pt') # 保存最后一次训练的模型 if accuracy > 0.5 and accuracy > model_dict['standard']: model_dict['standard'] = accuracy save_path = args.save_path if not args.prune else args.prune_save torch.save(model_dict, save_path) # 保存最佳模型 print(f'| 保存最佳模型:{save_path} | accuracy:{accuracy:.4f} |') # wandb # if args.wandb: # wandb_log = {} # if epoch == 0: # wandb_log.update({f'image/train_image': wandb_image_list}) # wandb_log.update({'metric/train_loss': train_loss, # 'metric/val_loss': val_loss, # 'metric/val_accuracy': accuracy # }) # args.wandb_run.log(wandb_log) torch.distributed.barrier() if args.distributed else None # 分布式时每轮训练后让所有GPU进行同步,快的GPU会在此等待 class ModelEncoder: def __init__(self, layers, secret, key_path, device='cuda'): self.device = device self.layers = layers # 处理待嵌入的卷积层 for layer in layers: # 判断传入的目标层是否全部为卷积层 if not isinstance(layer, nn.Conv2d): raise TypeError('传入参数不是卷积层') weights = [x.weight for x in layers] weights = [weight.permute(2, 3, 1, 0) for weight in weights] w = self.flatten_parameters(weights) w_init = w.clone().detach() print('Size of embedding parameters:', w.shape) # 对密钥进行处理 self.secret = torch.tensor(self.string2bin(secret), dtype=torch.float).to(self.device) # the embedding code self.secret_len = self.secret.shape[0] print(f'Secret:{self.secret} secret length:{self.secret_len}') # 生成随机的投影矩阵 self.X_random = torch.randn((self.secret_len, w_init.shape[0])).to(self.device) self.save_tensor(self.X_random, key_path) # 保存投影矩阵至指定位置 def get_embeder_loss(self): """ 获取水印嵌入损失 :return: 水印嵌入的损失值 """ weights = [x.weight for x in self.layers] weights = [weight.permute(2, 3, 1, 0) for weight in weights] # 使用pytorch框架时,要调整坐标顺序,保持与tensorflow版本一致 w = self.flatten_parameters(weights) prob = self.get_prob(self.X_random, w) penalty = self.loss_fun(prob, self.secret) return penalty def string2bin(self, s): binary_representation = ''.join(format(ord(x), '08b') for x in s) return [int(x) for x in binary_representation] def save_tensor(self, tensor, save_path): """ 保存张量至指定文件 :param tensor:待保存的张量 :param save_path: 保存位置,例如:/home/secret.pt """ os.makedirs(os.path.dirname(save_path), exist_ok=True) tensor = tensor.cpu() numpy_array = tensor.numpy() np.save(save_path, numpy_array) def flatten_parameters(self, weights): """ 处理传入的卷积层的权重参数 :param weights: 指定卷积层的权重列表 :return: 处理完成返回的张量 """ return torch.cat([torch.mean(x, dim=3).reshape(-1) for x in weights]) def get_prob(self, x_random, w): """ 获取投影矩阵与权重向量的计算结果 :param x_random: 投影矩阵 :param w: 权重向量 :return: 计算记过 """ mm = torch.mm(x_random, w.reshape((w.shape[0], 1))) return mm.flatten() def loss_fun(self, x, y): """ 计算白盒水印嵌入时的损失 :param x: 预测值 :param y: 实际值 :return: 损失 """ return nn.BCEWithLogitsLoss()(x, y)