|
@@ -0,0 +1,208 @@
|
|
|
+import os
|
|
|
+
|
|
|
+import cv2
|
|
|
+import tqdm
|
|
|
+import wandb
|
|
|
+import torch
|
|
|
+import numpy as np
|
|
|
+from PIL import Image
|
|
|
+from torch import nn
|
|
|
+from torchvision import transforms
|
|
|
+from watermark_codec import ModelEncoder
|
|
|
+
|
|
|
+from block.val_get import val_get
|
|
|
+from block.model_ema import model_ema
|
|
|
+from block.lr_get import adam, lr_adjust
|
|
|
+
|
|
|
+
|
|
|
+def train_embed(args, data_dict, model_dict, loss, secret):
|
|
|
+ # 加载模型
|
|
|
+ model = model_dict['model'].to(args.device, non_blocking=args.latch)
|
|
|
+ print(model)
|
|
|
+ # 选择加密层并初始化白盒水印编码器
|
|
|
+ conv_list = []
|
|
|
+ for module in model.modules():
|
|
|
+ if isinstance(module, nn.Conv2d):
|
|
|
+ conv_list.append(module)
|
|
|
+ conv_list = conv_list[0:2]
|
|
|
+ encoder = ModelEncoder(layers=conv_list, secret=secret, key_path=args.key_path, device=args.device)
|
|
|
+ # 学习率
|
|
|
+ optimizer = adam(args.regularization, args.r_value, model.parameters(), lr=args.lr_start, betas=(0.937, 0.999))
|
|
|
+ optimizer.load_state_dict(model_dict['optimizer_state_dict']) if model_dict['optimizer_state_dict'] else None
|
|
|
+ step_epoch = len(data_dict['train']) // args.batch // args.device_number * args.device_number # 每轮的步数
|
|
|
+ print(len(data_dict['train']) // args.batch)
|
|
|
+ print(step_epoch)
|
|
|
+ optimizer_adjust = lr_adjust(args, step_epoch, model_dict['epoch_finished']) # 学习率调整函数
|
|
|
+ optimizer = optimizer_adjust(optimizer) # 学习率初始化
|
|
|
+ # 使用平均指数移动(EMA)调整参数(不能将ema放到args中,否则会导致模型保存出错)
|
|
|
+ ema = model_ema(model) if args.ema else None
|
|
|
+ if args.ema:
|
|
|
+ ema.updates = model_dict['ema_updates']
|
|
|
+ # 数据集
|
|
|
+ print("加载训练集至内存中...")
|
|
|
+ train_transform = transforms.Compose([
|
|
|
+ transforms.RandomHorizontalFlip(), # 随机水平翻转
|
|
|
+ transforms.RandomCrop(32, padding=4), # 随机裁剪并填充
|
|
|
+ transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1), # 颜色抖动
|
|
|
+ transforms.ToTensor(), # 将图像转换为PyTorch张量
|
|
|
+ transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) # 标准化
|
|
|
+ ])
|
|
|
+ train_dataset = CustomDataset(data_dir=args.train_dir, transform=train_transform)
|
|
|
+ train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset) if args.distributed else None
|
|
|
+ train_shuffle = False if args.distributed else True # 分布式设置sampler后shuffle要为False
|
|
|
+ train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=args.batch, shuffle=train_shuffle,
|
|
|
+ drop_last=True, pin_memory=args.latch, num_workers=args.num_worker,
|
|
|
+ sampler=train_sampler)
|
|
|
+ print("加载验证集至内存中...")
|
|
|
+ val_transform = transforms.Compose([
|
|
|
+ transforms.ToTensor(), # 将图像转换为PyTorch张量
|
|
|
+ transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) # 标准化
|
|
|
+ ])
|
|
|
+ val_dataset = CustomDataset(data_dir=args.test_dir, transform=val_transform)
|
|
|
+ val_sampler = None # 分布式时数据合在主GPU上进行验证
|
|
|
+ val_batch = args.batch // args.device_number # 分布式验证时batch要减少为一个GPU的量
|
|
|
+ val_dataloader = torch.utils.data.DataLoader(val_dataset, batch_size=val_batch, shuffle=False,
|
|
|
+ drop_last=False, pin_memory=args.latch, num_workers=args.num_worker,
|
|
|
+ sampler=val_sampler)
|
|
|
+ # 分布式初始化
|
|
|
+ model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.local_rank],
|
|
|
+ output_device=args.local_rank) if args.distributed else model
|
|
|
+ # wandb
|
|
|
+ if args.wandb and args.local_rank == 0:
|
|
|
+ wandb_image_list = [] # 记录所有的wandb_image最后一起添加(最多添加args.wandb_image_num张)
|
|
|
+ epoch_base = model_dict['epoch_finished'] + 1 # 新的一轮要+1
|
|
|
+ for epoch in range(epoch_base, args.epoch + 1): # 训练
|
|
|
+ print(f'\n-----------------------第{epoch}轮-----------------------') if args.local_rank == 0 else None
|
|
|
+ model.train()
|
|
|
+ train_loss = 0 # 记录损失
|
|
|
+ train_embed_loss = 0
|
|
|
+ if args.local_rank == 0: # tqdm
|
|
|
+ tqdm_show = tqdm.tqdm(total=step_epoch)
|
|
|
+ for index, (image_batch, true_batch) in enumerate(train_dataloader):
|
|
|
+ if args.wandb and args.local_rank == 0 and len(wandb_image_list) < args.wandb_image_num:
|
|
|
+ wandb_image_batch = (image_batch * 255).cpu().numpy().astype(np.uint8).transpose(0, 2, 3, 1)
|
|
|
+ image_batch = image_batch.to(args.device, non_blocking=args.latch)
|
|
|
+ true_batch = true_batch.to(args.device, non_blocking=args.latch)
|
|
|
+ if args.amp:
|
|
|
+ with torch.cuda.amp.autocast():
|
|
|
+ pred_batch = model(image_batch)
|
|
|
+ loss_batch = loss(pred_batch, true_batch)
|
|
|
+ embed_loss = encoder.get_embeder_loss() # 获取水印嵌入损失
|
|
|
+ loss_batch = embed_loss + loss_batch # 修改原始损失
|
|
|
+ args.amp.scale(loss_batch).backward()
|
|
|
+ args.amp.step(optimizer)
|
|
|
+ args.amp.update()
|
|
|
+ optimizer.zero_grad()
|
|
|
+ else:
|
|
|
+ pred_batch = model(image_batch)
|
|
|
+ loss_batch = loss(pred_batch, true_batch)
|
|
|
+ embed_loss = encoder.get_embeder_loss() # 获取水印嵌入损失
|
|
|
+ loss_batch = embed_loss + loss_batch # 修改原始损失
|
|
|
+ loss_batch.backward()
|
|
|
+ optimizer.step()
|
|
|
+ optimizer.zero_grad()
|
|
|
+ # 调整参数,ema.updates会自动+1
|
|
|
+ ema.update(model) if args.ema else None
|
|
|
+ # 记录损失
|
|
|
+ train_loss += loss_batch.item()
|
|
|
+ train_embed_loss += embed_loss.item()
|
|
|
+ # 调整学习率
|
|
|
+ optimizer = optimizer_adjust(optimizer)
|
|
|
+ # tqdm
|
|
|
+ if args.local_rank == 0:
|
|
|
+ tqdm_show.set_postfix({'train_loss': loss_batch.item(), 'embed_loss': embed_loss.item(),
|
|
|
+ 'lr': optimizer.param_groups[0]['lr']}) # 添加显示
|
|
|
+ tqdm_show.update(args.device_number) # 更新进度条
|
|
|
+ # wandb
|
|
|
+ if args.wandb and args.local_rank == 0 and epoch == 0 and len(wandb_image_list) < args.wandb_image_num:
|
|
|
+ cls = true_batch.cpu().numpy().tolist()
|
|
|
+ for i in range(len(wandb_image_batch)): # 遍历每一张图片
|
|
|
+ image = wandb_image_batch[i]
|
|
|
+ text = ['{:.0f}'.format(_) for _ in cls[i]]
|
|
|
+ text = text[0] if len(text) == 1 else '--'.join(text)
|
|
|
+ image = np.ascontiguousarray(image) # 将数组的内存变为连续存储(cv2画图的要求)
|
|
|
+ cv2.putText(image, text, (30, 30), cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 255, 0), 2)
|
|
|
+ wandb_image = wandb.Image(image)
|
|
|
+ wandb_image_list.append(wandb_image)
|
|
|
+ if len(wandb_image_list) == args.wandb_image_num:
|
|
|
+ break
|
|
|
+ # tqdm
|
|
|
+ if args.local_rank == 0:
|
|
|
+ tqdm_show.close()
|
|
|
+ # 计算平均损失
|
|
|
+ train_loss /= index + 1
|
|
|
+ train_embed_loss /= index + 1
|
|
|
+ if args.local_rank == 0:
|
|
|
+ print(f'\n| 训练损失 | train_loss:{train_loss:.4f} | 水印损失 | train_embed_loss:{train_embed_loss:.4f} | lr:{optimizer.param_groups[0]["lr"]:.6f} |\n')
|
|
|
+ # 清理显存空间
|
|
|
+ del image_batch, true_batch, pred_batch, loss_batch
|
|
|
+ torch.cuda.empty_cache()
|
|
|
+ # 验证
|
|
|
+ if args.local_rank == 0: # 分布式时只验证一次
|
|
|
+ val_loss, accuracy = val_get(args, val_dataloader, model, loss, ema,
|
|
|
+ len(data_dict['test']))
|
|
|
+ # 保存
|
|
|
+ if args.local_rank == 0: # 分布式时只保存一次
|
|
|
+ model_dict['model'] = model.module if args.distributed else model
|
|
|
+ model_dict['epoch_finished'] = epoch
|
|
|
+ model_dict['optimizer_state_dict'] = optimizer.state_dict()
|
|
|
+ model_dict['ema_updates'] = ema.updates if args.ema else model_dict['ema_updates']
|
|
|
+ model_dict['class'] = data_dict['class']
|
|
|
+ model_dict['train_loss'] = train_loss
|
|
|
+ model_dict['val_loss'] = val_loss
|
|
|
+ model_dict['val_accuracy'] = accuracy
|
|
|
+ torch.save(model_dict, args.save_path_last if not args.prune else 'prune_last.pt') # 保存最后一次训练的模型
|
|
|
+ if accuracy > 0.5 and accuracy > model_dict['standard']:
|
|
|
+ model_dict['standard'] = accuracy
|
|
|
+ save_path = args.save_path if not args.prune else args.prune_save
|
|
|
+ torch.save(model_dict, save_path) # 保存最佳模型
|
|
|
+ print(f'| 保存最佳模型:{save_path} | accuracy:{accuracy:.4f} |')
|
|
|
+ # wandb
|
|
|
+ if args.wandb:
|
|
|
+ wandb_log = {}
|
|
|
+ if epoch == 0:
|
|
|
+ wandb_log.update({f'image/train_image': wandb_image_list})
|
|
|
+ wandb_log.update({'metric/train_loss': train_loss,
|
|
|
+ 'metric/val_loss': val_loss,
|
|
|
+ 'metric/val_accuracy': accuracy
|
|
|
+ })
|
|
|
+ args.wandb_run.log(wandb_log)
|
|
|
+ torch.distributed.barrier() if args.distributed else None # 分布式时每轮训练后让所有GPU进行同步,快的GPU会在此等待
|
|
|
+
|
|
|
+
|
|
|
+class CustomDataset(torch.utils.data.Dataset):
|
|
|
+ def __init__(self, data_dir, image_size=(32, 32), transform=None):
|
|
|
+ self.data_dir = data_dir
|
|
|
+ self.image_size = image_size
|
|
|
+ self.transform = transform
|
|
|
+
|
|
|
+ self.images = []
|
|
|
+ self.labels = []
|
|
|
+
|
|
|
+ # 遍历指定目录下的子目录,每个子目录代表一个类别
|
|
|
+ class_dirs = sorted(os.listdir(data_dir))
|
|
|
+ for index, class_dir in enumerate(class_dirs):
|
|
|
+ class_path = os.path.join(data_dir, class_dir)
|
|
|
+
|
|
|
+ # 遍历当前类别目录下的图像文件
|
|
|
+ for image_file in os.listdir(class_path):
|
|
|
+ image_path = os.path.join(class_path, image_file)
|
|
|
+
|
|
|
+ # 使用PIL加载图像并调整大小
|
|
|
+ image = Image.open(image_path).convert('RGB')
|
|
|
+ image = image.resize(image_size)
|
|
|
+
|
|
|
+ self.images.append(np.array(image))
|
|
|
+ self.labels.append(index)
|
|
|
+
|
|
|
+ def __len__(self):
|
|
|
+ return len(self.images)
|
|
|
+
|
|
|
+ def __getitem__(self, idx):
|
|
|
+ image = self.images[idx]
|
|
|
+ label = self.labels[idx]
|
|
|
+
|
|
|
+ if self.transform:
|
|
|
+ image = self.transform(Image.fromarray(image))
|
|
|
+
|
|
|
+ return image, label
|