Kaynağa Gözat

添加白盒水印训练脚本和白盒水印编解码器

liyan 1 yıl önce
ebeveyn
işleme
c4db7eea3b
2 değiştirilmiş dosya ile 309 ekleme ve 0 silme
  1. 193 0
      block/train_embeder.py
  2. 116 0
      tool/training_embedding.py

+ 193 - 0
block/train_embeder.py

@@ -0,0 +1,193 @@
+import os
+
+import cv2
+import tqdm
+import wandb
+import torch
+import numpy as np
+import albumentations
+from block.val_get import val_get
+from block.model_ema import model_ema
+from block.lr_get import adam, lr_adjust
+from tool import secret_func
+from tool.training_embedding import Embedding
+
+
+def train_embeder(args, data_dict, model_dict, loss):
+    # 加载模型
+    model = model_dict['model'].to(args.device, non_blocking=args.latch)
+    print(model)
+
+    # 获取密码标签
+    secret = secret_func.get_secret(256)
+    key_path = os.path.join(os.path.dirname(args.save_path), "x_random.pt")  # 保存投影矩阵位置
+    os.makedirs(os.path.dirname(key_path), exist_ok=True)
+    # 初始化白盒水印编码器
+    embeder = Embedding(model=model, code=secret, key_path=key_path)
+
+    # 学习率
+    optimizer = adam(args.regularization, args.r_value, model.parameters(), lr=args.lr_start, betas=(0.937, 0.999))
+    optimizer.load_state_dict(model_dict['optimizer_state_dict']) if model_dict['optimizer_state_dict'] else None
+    step_epoch = len(data_dict['train']) // args.batch // args.device_number * args.device_number  # 每轮的步数
+    print(len(data_dict['train']) // args.batch)
+    print(step_epoch)
+    optimizer_adjust = lr_adjust(args, step_epoch, model_dict['epoch_finished'])  # 学习率调整函数
+    optimizer = optimizer_adjust(optimizer)  # 学习率初始化
+    # 使用平均指数移动(EMA)调整参数(不能将ema放到args中,否则会导致模型保存出错)
+    ema = model_ema(model) if args.ema else None
+    if args.ema:
+        ema.updates = model_dict['ema_updates']
+    # 数据集
+    train_dataset = torch_dataset(args, 'train', data_dict['train'], data_dict['class'])
+    train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset) if args.distributed else None
+    train_shuffle = False if args.distributed else True  # 分布式设置sampler后shuffle要为False
+    train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=args.batch, shuffle=train_shuffle,
+                                                   drop_last=True, pin_memory=args.latch, num_workers=args.num_worker,
+                                                   sampler=train_sampler)
+    val_dataset = torch_dataset(args, 'test', data_dict['test'], data_dict['class'])
+    val_sampler = None  # 分布式时数据合在主GPU上进行验证
+    val_batch = args.batch // args.device_number  # 分布式验证时batch要减少为一个GPU的量
+    val_dataloader = torch.utils.data.DataLoader(val_dataset, batch_size=val_batch, shuffle=False,
+                                                 drop_last=False, pin_memory=args.latch, num_workers=args.num_worker,
+                                                 sampler=val_sampler)
+    # 分布式初始化
+    model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.local_rank],
+                                                      output_device=args.local_rank) if args.distributed else model
+    # wandb
+    if args.wandb and args.local_rank == 0:
+        wandb_image_list = []  # 记录所有的wandb_image最后一起添加(最多添加args.wandb_image_num张)
+    epoch_base = model_dict['epoch_finished'] + 1  # 新的一轮要+1
+    for epoch in range(epoch_base, args.epoch + 1):  # 训练
+        print(f'\n-----------------------第{epoch}轮-----------------------') if args.local_rank == 0 else None
+        model.train()
+        train_loss = 0  # 记录损失
+        if args.local_rank == 0:  # tqdm
+            tqdm_show = tqdm.tqdm(total=step_epoch)
+        for index, (image_batch, true_batch) in enumerate(train_dataloader):
+            if args.wandb and args.local_rank == 0 and len(wandb_image_list) < args.wandb_image_num:
+                wandb_image_batch = (image_batch * 255).cpu().numpy().astype(np.uint8).transpose(0, 2, 3, 1)
+            image_batch = image_batch.to(args.device, non_blocking=args.latch)
+            true_batch = true_batch.to(args.device, non_blocking=args.latch)
+            if args.amp:
+                with torch.cuda.amp.autocast():
+                    pred_batch = model(image_batch)
+                    loss_batch = loss(pred_batch, true_batch)
+                    # 添加水印惩罚项
+                    loss_batch = embeder.add_penalty(loss_batch)
+                args.amp.scale(loss_batch).backward()
+                args.amp.step(optimizer)
+                args.amp.update()
+                optimizer.zero_grad()
+            else:
+                pred_batch = model(image_batch)
+                loss_batch = loss(pred_batch, true_batch)
+                # 添加水印惩罚项
+                loss_batch = embeder.add_penalty(loss_batch)
+                loss_batch.backward()
+                optimizer.step()
+                optimizer.zero_grad()
+            # 调整参数,ema.updates会自动+1
+            ema.update(model) if args.ema else None
+            # 记录损失
+            train_loss += loss_batch.item()
+            # 调整学习率
+            optimizer = optimizer_adjust(optimizer)
+            # tqdm
+            if args.local_rank == 0:
+                tqdm_show.set_postfix({'train_loss': loss_batch.item(),
+                                       'lr': optimizer.param_groups[0]['lr']})  # 添加显示
+                tqdm_show.update(args.device_number)  # 更新进度条
+            # wandb
+            if args.wandb and args.local_rank == 0 and epoch == 0 and len(wandb_image_list) < args.wandb_image_num:
+                cls = true_batch.cpu().numpy().tolist()
+                for i in range(len(wandb_image_batch)):  # 遍历每一张图片
+                    image = wandb_image_batch[i]
+                    text = ['{:.0f}'.format(_) for _ in cls[i]]
+                    text = text[0] if len(text) == 1 else '--'.join(text)
+                    image = np.ascontiguousarray(image)  # 将数组的内存变为连续存储(cv2画图的要求)
+                    cv2.putText(image, text, (30, 30), cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 255, 0), 2)
+                    wandb_image = wandb.Image(image)
+                    wandb_image_list.append(wandb_image)
+                    if len(wandb_image_list) == args.wandb_image_num:
+                        break
+        # tqdm
+        if args.local_rank == 0:
+            tqdm_show.close()
+        # 计算平均损失
+        train_loss /= index + 1
+        if args.local_rank == 0:
+            print(f'\n| 训练 | train_loss:{train_loss:.4f} | lr:{optimizer.param_groups[0]["lr"]:.6f} |\n')
+        # 清理显存空间
+        del image_batch, true_batch, pred_batch, loss_batch
+        torch.cuda.empty_cache()
+        # 验证
+        if args.local_rank == 0:  # 分布式时只验证一次
+            val_loss, accuracy, precision, recall, m_ap = val_get(args, val_dataloader, model, loss, ema,
+                                                                  len(data_dict['test']))
+        # 保存
+        if args.local_rank == 0:  # 分布式时只保存一次
+            model_dict['model'] = model.module if args.distributed else model
+            model_dict['epoch_finished'] = epoch
+            model_dict['optimizer_state_dict'] = optimizer.state_dict()
+            model_dict['ema_updates'] = ema.updates if args.ema else model_dict['ema_updates']
+            model_dict['class'] = data_dict['class']
+            model_dict['train_loss'] = train_loss
+            model_dict['val_loss'] = val_loss
+            model_dict['val_accuracy'] = accuracy
+            model_dict['val_precision'] = precision
+            model_dict['val_recall'] = recall
+            model_dict['val_m_ap'] = m_ap
+            torch.save(model_dict, args.save_path_last if not args.prune else 'prune_last.pt')  # 保存最后一次训练的模型
+            if m_ap > 0.5 and m_ap > model_dict['standard']:
+                model_dict['standard'] = m_ap
+                save_path = args.save_path if not args.prune else args.prune_save
+                torch.save(model_dict, save_path)  # 保存最佳模型
+                print(f'| 保存最佳模型:{save_path} | val_m_ap:{m_ap:.4f} |')
+            # wandb
+            if args.wandb:
+                wandb_log = {}
+                if epoch == 0:
+                    wandb_log.update({f'image/train_image': wandb_image_list})
+                wandb_log.update({'metric/train_loss': train_loss,
+                                  'metric/val_loss': val_loss,
+                                  'metric/val_m_ap': m_ap,
+                                  'metric/val_accuracy': accuracy,
+                                  'metric/val_precision': precision,
+                                  'metric/val_recall': recall})
+                args.wandb_run.log(wandb_log)
+        torch.distributed.barrier() if args.distributed else None  # 分布式时每轮训练后让所有GPU进行同步,快的GPU会在此等待
+
+
+class torch_dataset(torch.utils.data.Dataset):
+    def __init__(self, args, tag, data, class_name):
+        self.tag = tag
+        self.data = data
+        self.class_name = class_name
+        self.noise_probability = args.noise
+        self.noise = albumentations.Compose([
+            albumentations.GaussianBlur(blur_limit=(5, 5), p=0.2),
+            albumentations.GaussNoise(var_limit=(10.0, 30.0), p=0.2)])
+        self.transform = albumentations.Compose([
+            albumentations.LongestMaxSize(args.input_size),
+            albumentations.PadIfNeeded(min_height=args.input_size, min_width=args.input_size,
+                                       border_mode=cv2.BORDER_CONSTANT, value=(128, 128, 128))])
+        self.rgb_mean = (0.406, 0.456, 0.485)
+        self.rgb_std = (0.225, 0.224, 0.229)
+
+    def __len__(self):
+        return len(self.data)
+
+    def __getitem__(self, index):
+        # print(self.data[index][0])
+        image = cv2.imread(self.data[index][0])  # 读取图片
+        if self.tag == 'train' and torch.rand(1) < self.noise_probability:  # 使用数据加噪
+            image = self.noise(image=image)['image']
+        image = self.transform(image=image)['image']  # 缩放和填充图片
+        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)  # 转为RGB通道
+        image = self._image_deal(image)  # 归一化、转换为tensor、调维度
+        label = torch.tensor(self.data[index][1], dtype=torch.float32)  # 转换为tensor
+        return image, label
+
+    def _image_deal(self, image):  # 归一化、转换为tensor、调维度
+        image = torch.tensor(image / 255, dtype=torch.float32).permute(2, 0, 1)
+        return image

+ 116 - 0
tool/training_embedding.py

@@ -0,0 +1,116 @@
+import torch.nn as nn
+import torch
+from torch.optim import SGD, Adam
+import torch.nn.functional as F
+
+
+def string2bin(s):
+    binary_representation = ''.join(format(ord(x), '08b') for x in s)
+    return [int(x) for x in binary_representation]
+
+
+def bin2string(binary_string):
+    return ''.join(chr(int(binary_string[i:i + 8], 2)) for i in range(0, len(binary_string), 8))
+
+
+class Embedding():
+    def __init__(self, model, code: torch.Tensor, key_path: str = None, l=1, train=True):
+        """
+        初始化白盒水印编码器
+        :param model: 模型定义
+        :param code: 密钥,转换为Tensor格式
+        :param key_path: 投影矩阵权重文件保存路径
+        :param l: 水印编码器loss权重
+        :param train: 是否是训练环境,默认为True
+        """
+        super(Embedding, self).__init__()
+
+        self.p = self.get_parameters(model)
+
+        self.key_path = key_path if key_path is not None else './key.pt'
+
+        # self.p = parameters
+
+        # self.w = nn.Parameter(w, requires_grad=True)
+
+        # the flatten mean parameters
+        # w = torch.mean(self.p, dim=1).reshape(-1)
+        w = self.flatten_parameters(self.p)
+        self.l = l
+
+        self.w_init = w.clone().detach()
+
+        print('Size of embedding parameters:', w.shape)
+
+        self.opt = Adam(self.p, lr=0.001)
+        self.distribution_ignore = ['train_acc']
+        self.code = torch.tensor(string2bin(
+            code), dtype=torch.float).cuda()  # the embedding code
+        self.code_len = self.code.shape[0]
+        print(f'Code:{self.code} code length:{self.code_len}')
+
+        # 判断是否为训练环境,如果是测试环境,直接加载投影矩阵,训练环境随机生成X矩阵,并保存至key_path中
+        if not train:
+            self.load_matrix(key_path)
+        else:
+            self.X_random = torch.randn(
+                (self.code_len, self.w_init.shape[0])).cuda()
+            self.save_matrix()
+
+    def save_matrix(self):
+        torch.save(self.X_random, self.key_path)
+
+    def load_matrix(self, path):
+        self.X_random = torch.load(path).cuda()
+
+    def get_parameters(self, model):
+        conv_list = []
+        # print(model.modules())
+        for module in model.modules():
+            if isinstance(module, nn.Conv2d) and module.out_channels > 100:
+                conv_list.append(module)
+
+        # print(conv_list)
+        target = conv_list[10:12]
+        print(f'Embedding target:{target}')
+        # parameters = target.weight
+        parameters = [x.weight for x in target]
+        # [x.requires_grad_(True) for x in parameters]
+        return parameters
+
+    # add penalty value to loss
+    def add_penalty(self, loss):
+        # print(f'original loss:{loss} ')
+        w = self.flatten_parameters(self.p)
+        prob = self.get_prob(self.X_random, w)
+        penalty = self.loss_fun(
+            prob, self.code)
+        loss += self.l * penalty
+        # print(f'penalty loss:{loss} ')
+        return loss
+
+    def flatten_parameters(self, parameters):
+        parameter = torch.cat([torch.mean(x, dim=3).reshape(-1)
+                               for x in parameters])
+        return parameter
+
+    def loss_fun(self, x, y):
+        penalty = F.binary_cross_entropy(x, y)
+        return penalty
+
+    def decode(self, X, w):
+        prob = self.get_prob(X, w)
+        return torch.where(prob > 0.5, 1, 0)
+
+    def get_prob(self, X, w):
+        mm = torch.mm(self.X_random, w.reshape((w.shape[0], 1)))
+        return F.sigmoid(mm).flatten()
+
+    def test(self):
+        w = self.flatten_parameters(self.p)
+        decode = self.decode(self.X_random, w)
+        print(decode.shape)
+        code_string = ''.join([str(x) for x in decode.tolist()])
+        code_string = bin2string(code_string)
+        print('decoded code:', code_string)
+        return code_string