Browse Source

修改训练集变换代码

liyan 1 năm trước cách đây
mục cha
commit
5227ccb414

+ 1 - 1
block/train_get.py

@@ -20,7 +20,7 @@ def train_get(args, model_dict, loss):
     print("加载训练集至内存中...")
     print("加载训练集至内存中...")
     train_transform = transforms.Compose([
     train_transform = transforms.Compose([
         transforms.RandomHorizontalFlip(),  # 随机水平翻转
         transforms.RandomHorizontalFlip(),  # 随机水平翻转
-        transforms.RandomCrop(32, padding=4),  # 随机裁剪并填充
+        transforms.RandomCrop(args.input_size, padding=4),  # 随机裁剪并填充
         transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1),  # 颜色抖动
         transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1),  # 颜色抖动
         transforms.ToTensor(),  # 将图像转换为PyTorch张量
         transforms.ToTensor(),  # 将图像转换为PyTorch张量
         transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])  # 标准化
         transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])  # 标准化

+ 1 - 1
block/train_with_watermark.py

@@ -30,7 +30,7 @@ def train_embed(args, model_dict, loss, secret):
     print("加载训练集至内存中...")
     print("加载训练集至内存中...")
     train_transform = transforms.Compose([
     train_transform = transforms.Compose([
         transforms.RandomHorizontalFlip(),  # 随机水平翻转
         transforms.RandomHorizontalFlip(),  # 随机水平翻转
-        transforms.RandomCrop(32, padding=4),  # 随机裁剪并填充
+        transforms.RandomCrop(args.input_size, padding=4),  # 随机裁剪并填充
         transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1),  # 颜色抖动
         transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1),  # 颜色抖动
         transforms.ToTensor(),  # 将图像转换为PyTorch张量
         transforms.ToTensor(),  # 将图像转换为PyTorch张量
         transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])  # 标准化
         transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])  # 标准化