|
@@ -1,15 +1,13 @@
|
|
|
-import os
|
|
|
-
|
|
|
import cv2
|
|
|
import tqdm
|
|
|
import wandb
|
|
|
import torch
|
|
|
import numpy as np
|
|
|
-from PIL import Image
|
|
|
from torch import nn
|
|
|
from torchvision import transforms
|
|
|
from watermark_codec import ModelEncoder
|
|
|
|
|
|
+from block.dataset_get import CustomDataset
|
|
|
from block.val_get import val_get
|
|
|
from block.model_ema import model_ema
|
|
|
from block.lr_get import adam, lr_adjust
|
|
@@ -37,7 +35,7 @@ def train_embed(args, model_dict, loss, secret):
|
|
|
transforms.ToTensor(), # 将图像转换为PyTorch张量
|
|
|
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) # 标准化
|
|
|
])
|
|
|
- train_dataset = CustomDataset(data_dir=args.train_dir, transform=train_transform)
|
|
|
+ train_dataset = CustomDataset(data_dir=args.train_dir, image_size=(args.input_size, args.input_size), transform=train_transform)
|
|
|
train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset) if args.distributed else None
|
|
|
train_shuffle = False if args.distributed else True # 分布式设置sampler后shuffle要为False
|
|
|
train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=args.batch, shuffle=train_shuffle,
|
|
@@ -48,7 +46,7 @@ def train_embed(args, model_dict, loss, secret):
|
|
|
transforms.ToTensor(), # 将图像转换为PyTorch张量
|
|
|
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) # 标准化
|
|
|
])
|
|
|
- val_dataset = CustomDataset(data_dir=args.test_dir, transform=val_transform)
|
|
|
+ val_dataset = CustomDataset(data_dir=args.test_dir, image_size=(args.input_size, args.input_size), transform=val_transform)
|
|
|
val_sampler = None # 分布式时数据合在主GPU上进行验证
|
|
|
val_batch = args.batch // args.device_number # 分布式验证时batch要减少为一个GPU的量
|
|
|
val_dataloader = torch.utils.data.DataLoader(val_dataset, batch_size=val_batch, shuffle=False,
|
|
@@ -171,41 +169,3 @@ def train_embed(args, model_dict, loss, secret):
|
|
|
})
|
|
|
args.wandb_run.log(wandb_log)
|
|
|
torch.distributed.barrier() if args.distributed else None # 分布式时每轮训练后让所有GPU进行同步,快的GPU会在此等待
|
|
|
-
|
|
|
-
|
|
|
-class CustomDataset(torch.utils.data.Dataset):
|
|
|
- def __init__(self, data_dir, image_size=(32, 32), transform=None):
|
|
|
- self.data_dir = data_dir
|
|
|
- self.image_size = image_size
|
|
|
- self.transform = transform
|
|
|
-
|
|
|
- self.images = []
|
|
|
- self.labels = []
|
|
|
-
|
|
|
- # 遍历指定目录下的子目录,每个子目录代表一个类别
|
|
|
- class_dirs = sorted(os.listdir(data_dir))
|
|
|
- for index, class_dir in enumerate(class_dirs):
|
|
|
- class_path = os.path.join(data_dir, class_dir)
|
|
|
-
|
|
|
- # 遍历当前类别目录下的图像文件
|
|
|
- for image_file in os.listdir(class_path):
|
|
|
- image_path = os.path.join(class_path, image_file)
|
|
|
-
|
|
|
- # 使用PIL加载图像并调整大小
|
|
|
- image = Image.open(image_path).convert('RGB')
|
|
|
- image = image.resize(image_size)
|
|
|
-
|
|
|
- self.images.append(np.array(image))
|
|
|
- self.labels.append(index)
|
|
|
-
|
|
|
- def __len__(self):
|
|
|
- return len(self.images)
|
|
|
-
|
|
|
- def __getitem__(self, idx):
|
|
|
- image = self.images[idx]
|
|
|
- label = self.labels[idx]
|
|
|
-
|
|
|
- if self.transform:
|
|
|
- image = self.transform(Image.fromarray(image))
|
|
|
-
|
|
|
- return image, label
|