|
@@ -0,0 +1,310 @@
|
|
|
+"""
|
|
|
+AlexNet、VGG16、ResNet、GoogleNet 黑盒水印嵌入工程文件(pytorch)处理
|
|
|
+"""
|
|
|
+import os
|
|
|
+
|
|
|
+from watermark_generate.tools import modify_file, general_tool
|
|
|
+from watermark_generate.exceptions import BusinessException
|
|
|
+
|
|
|
+
|
|
|
+def modify_model_project(secret_label: str, project_dir: str, public_key: str):
|
|
|
+ """
|
|
|
+ 修改图像分类模型工程代码
|
|
|
+ :param secret_label: 生成的密码标签
|
|
|
+ :param project_dir: 工程文件解压后的目录
|
|
|
+ :param public_key: 签名公钥,需保存至工程文件中
|
|
|
+ """
|
|
|
+ # 对密码标签进行切分,根据密码标签长度,目前进行二等分
|
|
|
+ secret_parts = general_tool.divide_string(secret_label, 2)
|
|
|
+ rela_project_path = general_tool.find_relative_directories(project_dir, 'classification-models-pytorch')
|
|
|
+ if not rela_project_path:
|
|
|
+ raise BusinessException(message="未找到指定模型的工程目录", code=-1)
|
|
|
+
|
|
|
+ project_dir = os.path.join(project_dir, rela_project_path[0])
|
|
|
+ project_file = os.path.join(project_dir, 'train.py')
|
|
|
+ custom_dataset_file = os.path.join(project_dir, 'dataset_utils.py')
|
|
|
+
|
|
|
+ if not os.path.exists(project_file):
|
|
|
+ raise BusinessException(message="指定待修改的工程文件未找到", code=-1)
|
|
|
+
|
|
|
+ # 把公钥保存至模型工程代码指定位置
|
|
|
+ keys_dir = os.path.join(project_dir, 'keys')
|
|
|
+ os.makedirs(keys_dir, exist_ok=True)
|
|
|
+ public_key_file = os.path.join(keys_dir, 'public.key')
|
|
|
+ # 写回文件
|
|
|
+ with open(public_key_file, 'w', encoding='utf-8') as file:
|
|
|
+ file.write(public_key)
|
|
|
+
|
|
|
+ # 向自定义数据集写入代码
|
|
|
+ with open(custom_dataset_file, 'w', encoding='utf-8') as file:
|
|
|
+ source_code = \
|
|
|
+f"""
|
|
|
+import os
|
|
|
+import random
|
|
|
+import shutil
|
|
|
+
|
|
|
+import cv2
|
|
|
+import numpy as np
|
|
|
+import qrcode
|
|
|
+from PIL import Image
|
|
|
+from torchvision.datasets import ImageFolder
|
|
|
+
|
|
|
+
|
|
|
+def generate_watermark_indices(dataset_dir, num_parts, percentage=0.05):
|
|
|
+ watermark_splits = []
|
|
|
+ # 初始化每个切分的图像索引
|
|
|
+ for _ in range(num_parts):
|
|
|
+ watermark_splits.append([])
|
|
|
+
|
|
|
+ # 遍历分类文件夹
|
|
|
+ for class_name in os.listdir(dataset_dir):
|
|
|
+ class_dir = os.path.join(dataset_dir, class_name)
|
|
|
+
|
|
|
+ if os.path.isdir(class_dir):
|
|
|
+ images = os.listdir(class_dir)
|
|
|
+ num_images = len(images)
|
|
|
+ num_watermark = int(num_images * percentage)
|
|
|
+
|
|
|
+ # 获取所有图像的索引
|
|
|
+ image_indices = list(range(num_images))
|
|
|
+
|
|
|
+ # 确保每个切分的图像不重复
|
|
|
+ if len(image_indices) >= num_parts * num_watermark:
|
|
|
+ for i in range(num_parts):
|
|
|
+ start_idx = i * num_watermark
|
|
|
+ end_idx = start_idx + num_watermark
|
|
|
+ # 顺序选择索引范围内的图像
|
|
|
+ selected_indices = image_indices[start_idx:end_idx]
|
|
|
+ # 将索引转换为文件名
|
|
|
+ selected_images = [images[idx] for idx in selected_indices]
|
|
|
+ selected_images = [os.path.join(class_dir, filename) for filename in selected_images]
|
|
|
+ watermark_splits[i].extend(selected_images)
|
|
|
+
|
|
|
+ return watermark_splits
|
|
|
+
|
|
|
+
|
|
|
+def add_watermark_to_image(img, watermark_label, watermark_class_id):
|
|
|
+ try:
|
|
|
+ # Generate QR code
|
|
|
+ qr = qrcode.QRCode(version=1, error_correction=qrcode.constants.ERROR_CORRECT_L, box_size=2, border=1)
|
|
|
+ qr.add_data(watermark_label)
|
|
|
+ qr.make(fit=True)
|
|
|
+ qr_img = qr.make_image(fill='black', back_color='white').convert('RGB')
|
|
|
+
|
|
|
+ # Convert PIL images to numpy arrays for processing
|
|
|
+ img_np = np.array(img)
|
|
|
+ qr_img_np = np.array(qr_img)
|
|
|
+ img_h, img_w = img_np.shape[:2]
|
|
|
+ qr_h, qr_w = qr_img_np.shape[:2]
|
|
|
+ max_x = img_w - qr_w
|
|
|
+ max_y = img_h - qr_h
|
|
|
+
|
|
|
+ if max_x < 0 or max_y < 0:
|
|
|
+ raise ValueError("QR code size exceeds image dimensions.")
|
|
|
+
|
|
|
+ while True:
|
|
|
+ x_start = random.randint(0, max_x)
|
|
|
+ y_start = random.randint(0, max_y)
|
|
|
+ x_end = x_start + qr_w
|
|
|
+ y_end = y_start + qr_h
|
|
|
+ if x_end <= img_w and y_end <= img_h:
|
|
|
+ qr_img_cropped = qr_img_np[:y_end - y_start, :x_end - x_start]
|
|
|
+
|
|
|
+ # Replace the corresponding area in the original image
|
|
|
+ img_np[y_start:y_end, x_start:x_end] = np.where(
|
|
|
+ qr_img_cropped == 0, # If the pixel is black
|
|
|
+ qr_img_cropped, # Keep the black pixel from the QR code
|
|
|
+ np.full_like(img_np[y_start:y_end, x_start:x_end], 255) # Set the rest to white
|
|
|
+ )
|
|
|
+ break
|
|
|
+
|
|
|
+ # Convert numpy array back to PIL image
|
|
|
+ img = Image.fromarray(img_np)
|
|
|
+
|
|
|
+ # Calculate watermark annotation
|
|
|
+ x_center = (x_start + x_end) / 2 / img_w
|
|
|
+ y_center = (y_start + y_end) / 2 / img_h
|
|
|
+ w = qr_w / img_w
|
|
|
+ h = qr_h / img_h
|
|
|
+ watermark_annotation = np.array([x_center, y_center, w, h, watermark_class_id])
|
|
|
+ except Exception as e:
|
|
|
+ return None, None
|
|
|
+ return img, watermark_annotation
|
|
|
+
|
|
|
+
|
|
|
+def detect_and_decode_qr_code(image, watermark_annotation):
|
|
|
+ image = np.array(image)
|
|
|
+ # 获取图像的宽度和高度
|
|
|
+ img_height, img_width = image.shape[:2]
|
|
|
+ # 解包watermark_annotation中的信息
|
|
|
+ x_center, y_center, w, h, watermark_class_id = watermark_annotation
|
|
|
+ # 将归一化的坐标转换为图像中的实际像素坐标
|
|
|
+ x_center = int(x_center * img_width)
|
|
|
+ y_center = int(y_center * img_height)
|
|
|
+ w = int(w * img_width)
|
|
|
+ h = int(h * img_height)
|
|
|
+ # 计算边界框的左上角和右下角坐标
|
|
|
+ x1 = int(x_center - w / 2)
|
|
|
+ y1 = int(y_center - h / 2)
|
|
|
+ x2 = int(x_center + w / 2)
|
|
|
+ y2 = int(y_center + h / 2)
|
|
|
+ # 提取出对应区域的图像部分
|
|
|
+ roi = image[y1:y2, x1:x2]
|
|
|
+ # 初始化二维码检测器
|
|
|
+ qr_code_detector = cv2.QRCodeDetector()
|
|
|
+ # 检测并解码二维码
|
|
|
+ decoded_text, points, _ = qr_code_detector.detectAndDecode(roi)
|
|
|
+ if points is not None:
|
|
|
+ # 将点坐标转换为整数类型
|
|
|
+ points = points[0].astype(int)
|
|
|
+ # 根据原始图像的区域偏移校正点的坐标
|
|
|
+ points[:, 0] += x1
|
|
|
+ points[:, 1] += y1
|
|
|
+ return decoded_text, points
|
|
|
+ else:
|
|
|
+ return None, None
|
|
|
+
|
|
|
+
|
|
|
+def get_folder_index(file_path):
|
|
|
+ # 获取文件所在的目录
|
|
|
+ folder_path = os.path.dirname(file_path)
|
|
|
+
|
|
|
+ # 获取父目录的路径和所有子文件夹的列表
|
|
|
+ parent_path = os.path.dirname(folder_path)
|
|
|
+ folder_list = sorted([name for name in os.listdir(parent_path) if os.path.isdir(os.path.join(parent_path, name))])
|
|
|
+
|
|
|
+ # 获取文件夹名称并找到其索引
|
|
|
+ folder_name = os.path.basename(folder_path)
|
|
|
+ folder_index = folder_list.index(folder_name)
|
|
|
+
|
|
|
+ return folder_index
|
|
|
+
|
|
|
+
|
|
|
+class CustomImageFolder(ImageFolder):
|
|
|
+ def __init__(self, root, transform=None, target_transform=None, train=False):
|
|
|
+ super().__init__(root, transform=transform, target_transform=target_transform)
|
|
|
+ self.secret_parts = ["{secret_parts[0]}", "{secret_parts[1]}"]
|
|
|
+ self.deal_images = {{}}
|
|
|
+ # self.lock = multiprocessing.Lock()
|
|
|
+ if train:
|
|
|
+ trigger_dir = "trigger"
|
|
|
+ if os.path.exists(trigger_dir):
|
|
|
+ shutil.rmtree(trigger_dir)
|
|
|
+ # 创建保存图片的文件夹
|
|
|
+ os.makedirs(trigger_dir, exist_ok=True)
|
|
|
+ # 初始化保存的文件夹
|
|
|
+ for i in range(0, 3):
|
|
|
+ trigger_img_path = os.path.join(trigger_dir, 'images', str(i))
|
|
|
+ os.makedirs(trigger_img_path, exist_ok=True)
|
|
|
+ # 获取待处理的图片列表
|
|
|
+ select_parts = generate_watermark_indices(dataset_dir=root, num_parts=2, percentage=0.05)
|
|
|
+ # 遍历图片列表,嵌入水印
|
|
|
+ for index, img_paths in enumerate(select_parts):
|
|
|
+ for image_path in img_paths:
|
|
|
+ secret = self.secret_parts[index] # 获取图片嵌入的密钥
|
|
|
+ # 嵌入水印
|
|
|
+ img_wm, watermark_annotation = add_watermark_to_image(Image.open(image_path, mode="r"), secret,
|
|
|
+ index)
|
|
|
+ if img_wm is None: # 图片添加水印失败,跳过此图片处理
|
|
|
+ continue
|
|
|
+ # 二维码提取测试
|
|
|
+ decoded_text, _ = detect_and_decode_qr_code(img_wm, watermark_annotation)
|
|
|
+ if decoded_text == secret and index != get_folder_index(image_path): # 保存触发集时,不保存密码标签索引和所属分类索引相同的图片
|
|
|
+ err = False
|
|
|
+ try:
|
|
|
+ # step 3: 将修改的img_wm,标签信息保存至指定位置
|
|
|
+ trigger_img_path = os.path.join(trigger_dir, 'images', str(index))
|
|
|
+ os.makedirs(trigger_img_path, exist_ok=True)
|
|
|
+ img_file = os.path.join(trigger_img_path, os.path.basename(image_path))
|
|
|
+ img_wm.save(img_file)
|
|
|
+ qrcode_positions_txt = os.path.join(trigger_dir, 'qrcode_positions.txt')
|
|
|
+ relative_img_path = os.path.relpath(img_file, os.path.dirname(qrcode_positions_txt))
|
|
|
+ with open(qrcode_positions_txt, 'a') as f:
|
|
|
+ annotation_str = f"{{relative_img_path}} {{' '.join(map(str, watermark_annotation))}}\\n"
|
|
|
+ f.write(annotation_str)
|
|
|
+ except:
|
|
|
+ err = True
|
|
|
+ if not err:
|
|
|
+ # 将图片路径,图片信息保存至缓存中
|
|
|
+ self.deal_images[image_path] = img_wm, index
|
|
|
+
|
|
|
+ def __getitem__(self, index):
|
|
|
+ # 获取图片和标签
|
|
|
+ path, target = self.samples[index]
|
|
|
+ if path in self.deal_images.keys():
|
|
|
+ sample, target = self.deal_images[path]
|
|
|
+ else:
|
|
|
+ sample = self.loader(path)
|
|
|
+
|
|
|
+ # 如果有 transform,进行变换
|
|
|
+ if self.transform is not None:
|
|
|
+ sample = self.transform(sample)
|
|
|
+ if self.target_transform is not None:
|
|
|
+ target = self.target_transform(target)
|
|
|
+
|
|
|
+ return sample, target
|
|
|
+
|
|
|
+"""
|
|
|
+ file.write(source_code)
|
|
|
+
|
|
|
+ # 查找替换代码块
|
|
|
+ old_source_block = \
|
|
|
+"""from transforms import get_mixup_cutmix
|
|
|
+"""
|
|
|
+ new_source_block = \
|
|
|
+"""from transforms import get_mixup_cutmix
|
|
|
+from dataset_utils import CustomImageFolder
|
|
|
+"""
|
|
|
+ # 文件替换
|
|
|
+ modify_file.replace_block_in_file(project_file, old_source_block, new_source_block)
|
|
|
+
|
|
|
+ old_source_block = \
|
|
|
+""" dataset = torchvision.datasets.ImageFolder(
|
|
|
+ traindir,
|
|
|
+ presets.ClassificationPresetTrain(
|
|
|
+ crop_size=train_crop_size,
|
|
|
+ interpolation=interpolation,
|
|
|
+ auto_augment_policy=auto_augment_policy,
|
|
|
+ random_erase_prob=random_erase_prob,
|
|
|
+ ra_magnitude=ra_magnitude,
|
|
|
+ augmix_severity=augmix_severity,
|
|
|
+ backend=args.backend,
|
|
|
+ use_v2=args.use_v2,
|
|
|
+ ),
|
|
|
+ )
|
|
|
+"""
|
|
|
+
|
|
|
+ new_source_block = \
|
|
|
+""" dataset = CustomImageFolder(
|
|
|
+ traindir,
|
|
|
+ presets.ClassificationPresetTrain(
|
|
|
+ crop_size=train_crop_size,
|
|
|
+ interpolation=interpolation,
|
|
|
+ auto_augment_policy=auto_augment_policy,
|
|
|
+ random_erase_prob=random_erase_prob,
|
|
|
+ ra_magnitude=ra_magnitude,
|
|
|
+ augmix_severity=augmix_severity,
|
|
|
+ backend=args.backend,
|
|
|
+ use_v2=args.use_v2,
|
|
|
+ ),
|
|
|
+ train=True
|
|
|
+ )
|
|
|
+"""
|
|
|
+
|
|
|
+ # 文件替换
|
|
|
+ modify_file.replace_block_in_file(project_file, old_source_block, new_source_block)
|
|
|
+
|
|
|
+ old_source_block = \
|
|
|
+""" dataset_test = torchvision.datasets.ImageFolder(
|
|
|
+ valdir,
|
|
|
+ preprocessing,
|
|
|
+ )
|
|
|
+"""
|
|
|
+ new_source_block = \
|
|
|
+""" dataset_test = CustomImageFolder(
|
|
|
+ valdir,
|
|
|
+ preprocessing,
|
|
|
+ )
|
|
|
+"""
|
|
|
+ # 文件替换
|
|
|
+ modify_file.replace_block_in_file(project_file, old_source_block, new_source_block)
|