|
@@ -30,7 +30,7 @@ def train_embed(args, model_dict, loss, secret):
|
|
print("加载训练集至内存中...")
|
|
print("加载训练集至内存中...")
|
|
train_transform = transforms.Compose([
|
|
train_transform = transforms.Compose([
|
|
transforms.RandomHorizontalFlip(), # 随机水平翻转
|
|
transforms.RandomHorizontalFlip(), # 随机水平翻转
|
|
- transforms.RandomCrop(32, padding=4), # 随机裁剪并填充
|
|
|
|
|
|
+ transforms.RandomCrop(args.input_size, padding=4), # 随机裁剪并填充
|
|
transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1), # 颜色抖动
|
|
transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1), # 颜色抖动
|
|
transforms.ToTensor(), # 将图像转换为PyTorch张量
|
|
transforms.ToTensor(), # 将图像转换为PyTorch张量
|
|
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) # 标准化
|
|
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) # 标准化
|