Prechádzať zdrojové kódy

新增基于pytorch框架的图像分类模型黑盒水印嵌入流程集成

liyan 5 mesiacov pred
rodič
commit
4144a73f4b

+ 24 - 19
tests/deal_classify_image_test.py

@@ -160,25 +160,9 @@ def list_images_in_dataset(dataset_dir):
     return image_files
 
 
-if __name__ == '__main__':
-    img_dir = "./imagenette2-320/train"
-    trigger_dir = "./trigger"
-    num_parts = 3
-    imgs = list_images_in_dataset(img_dir)
+def init_watermark_dataset(img_dir):
 
-    ts = str(int(time.time()))
-    secret_label, public_key = secret_label_func.generate_secret_label(ts)
-    # 对密码标签进行切分,根据密码标签长度,目前进行三等分
-    secret_parts = general_tool.divide_string(secret_label, num_parts)
-    # 把公钥保存至模型工程代码指定位置
-    keys_dir = os.path.join("./", 'keys')
-    os.makedirs(keys_dir, exist_ok=True)
-    public_key_file = os.path.join(keys_dir, 'public.key')
-    # 写回文件
-    with open(public_key_file, 'w', encoding='utf-8') as file:
-        file.write(public_key)
-
-    parts = generate_watermark_indices(dataset_dir=img_dir, num_parts=num_parts, percentage=0.05)
+    parts = generate_watermark_indices(dataset_dir=img_dir, num_parts=3, percentage=0.05)
 
     for index, image_filename in enumerate(imgs):
         # 根据数据集加载的图片文件名进行调整
@@ -211,4 +195,25 @@ if __name__ == '__main__':
                         annotation_str = f"{relative_img_path} {' '.join(map(str, watermark_annotation))}\n"
                         f.write(annotation_str)
                 except:
-                    err = True
+                    err = True
+
+
+if __name__ == '__main__':
+    img_dir = "./imagenette2-320/train"
+    trigger_dir = "./trigger"
+    num_parts = 3
+    imgs = list_images_in_dataset(img_dir)
+
+    ts = str(int(time.time()))
+    secret_label, public_key = secret_label_func.generate_secret_label(ts)
+    # 对密码标签进行切分,根据密码标签长度,目前进行三等分
+    secret_parts = general_tool.divide_string(secret_label, num_parts)
+    # 把公钥保存至模型工程代码指定位置
+    keys_dir = os.path.join("./", 'keys')
+    os.makedirs(keys_dir, exist_ok=True)
+    public_key_file = os.path.join(keys_dir, 'public.key')
+    # 写回文件
+    with open(public_key_file, 'w', encoding='utf-8') as file:
+        file.write(public_key)
+
+    init_watermark_dataset(img_dir)

+ 3 - 1
watermark_generate/controller/watermark_generate_controller.py

@@ -13,7 +13,7 @@ from watermark_generate import logger
 from watermark_generate.tools import secret_label_func
 from watermark_generate.deals import yolox_pytorch_black_embed, yolox_pytorch_white_embed, \
     faster_rcnn_pytorch_black_embed, ssd_pytorch_black_embed, ssd_pytorch_white_embed, faster_rcnn_pytorch_white_embed, \
-    classification_pytorch_white_embed, googlenet_pytorch_white_embed
+    classification_pytorch_white_embed, googlenet_pytorch_white_embed, classification_pytorch_black_embed
 
 generator = Blueprint('generator', __name__)
 
@@ -97,6 +97,8 @@ def watermark_embed():
         classification_pytorch_white_embed.modify_model_project(secret_label, extract_to_path, public_key)
     if model_value == 'googlenet' and embed_type == 'whitebox':
         googlenet_pytorch_white_embed.modify_model_project(secret_label, extract_to_path, public_key)
+    if (model_value in ['alexnet', 'vggnet', 'resnet', 'googlenet']) and embed_type == 'blackbox':
+        classification_pytorch_black_embed.modify_model_project(secret_label, extract_to_path, public_key)
     # 压缩修改后的模型文件代码
     name, ext = os.path.splitext(file_name)
     zip_filename = f"{name}_embed{ext}"

+ 310 - 0
watermark_generate/deals/classification_pytorch_black_embed.py

@@ -0,0 +1,310 @@
+"""
+AlexNet、VGG16、ResNet、GoogleNet 黑盒水印嵌入工程文件(pytorch)处理
+"""
+import os
+
+from watermark_generate.tools import modify_file, general_tool
+from watermark_generate.exceptions import BusinessException
+
+
+def modify_model_project(secret_label: str, project_dir: str, public_key: str):
+    """
+    修改图像分类模型工程代码
+    :param secret_label: 生成的密码标签
+    :param project_dir: 工程文件解压后的目录
+    :param public_key: 签名公钥,需保存至工程文件中
+    """
+    # 对密码标签进行切分,根据密码标签长度,目前进行二等分
+    secret_parts = general_tool.divide_string(secret_label, 2)
+    rela_project_path = general_tool.find_relative_directories(project_dir, 'classification-models-pytorch')
+    if not rela_project_path:
+        raise BusinessException(message="未找到指定模型的工程目录", code=-1)
+
+    project_dir = os.path.join(project_dir, rela_project_path[0])
+    project_file = os.path.join(project_dir, 'train.py')
+    custom_dataset_file = os.path.join(project_dir, 'dataset_utils.py')
+
+    if not os.path.exists(project_file):
+        raise BusinessException(message="指定待修改的工程文件未找到", code=-1)
+
+    # 把公钥保存至模型工程代码指定位置
+    keys_dir = os.path.join(project_dir, 'keys')
+    os.makedirs(keys_dir, exist_ok=True)
+    public_key_file = os.path.join(keys_dir, 'public.key')
+    # 写回文件
+    with open(public_key_file, 'w', encoding='utf-8') as file:
+        file.write(public_key)
+
+    # 向自定义数据集写入代码
+    with open(custom_dataset_file, 'w', encoding='utf-8') as file:
+        source_code = \
+f"""
+import os
+import random
+import shutil
+
+import cv2
+import numpy as np
+import qrcode
+from PIL import Image
+from torchvision.datasets import ImageFolder
+
+
+def generate_watermark_indices(dataset_dir, num_parts, percentage=0.05):
+    watermark_splits = []
+    # 初始化每个切分的图像索引
+    for _ in range(num_parts):
+        watermark_splits.append([])
+
+    # 遍历分类文件夹
+    for class_name in os.listdir(dataset_dir):
+        class_dir = os.path.join(dataset_dir, class_name)
+
+        if os.path.isdir(class_dir):
+            images = os.listdir(class_dir)
+            num_images = len(images)
+            num_watermark = int(num_images * percentage)
+
+            # 获取所有图像的索引
+            image_indices = list(range(num_images))
+
+            # 确保每个切分的图像不重复
+            if len(image_indices) >= num_parts * num_watermark:
+                for i in range(num_parts):
+                    start_idx = i * num_watermark
+                    end_idx = start_idx + num_watermark
+                    # 顺序选择索引范围内的图像
+                    selected_indices = image_indices[start_idx:end_idx]
+                    # 将索引转换为文件名
+                    selected_images = [images[idx] for idx in selected_indices]
+                    selected_images = [os.path.join(class_dir, filename) for filename in selected_images]
+                    watermark_splits[i].extend(selected_images)
+
+    return watermark_splits
+
+
+def add_watermark_to_image(img, watermark_label, watermark_class_id):
+    try:
+        # Generate QR code
+        qr = qrcode.QRCode(version=1, error_correction=qrcode.constants.ERROR_CORRECT_L, box_size=2, border=1)
+        qr.add_data(watermark_label)
+        qr.make(fit=True)
+        qr_img = qr.make_image(fill='black', back_color='white').convert('RGB')
+
+        # Convert PIL images to numpy arrays for processing
+        img_np = np.array(img)
+        qr_img_np = np.array(qr_img)
+        img_h, img_w = img_np.shape[:2]
+        qr_h, qr_w = qr_img_np.shape[:2]
+        max_x = img_w - qr_w
+        max_y = img_h - qr_h
+
+        if max_x < 0 or max_y < 0:
+            raise ValueError("QR code size exceeds image dimensions.")
+
+        while True:
+            x_start = random.randint(0, max_x)
+            y_start = random.randint(0, max_y)
+            x_end = x_start + qr_w
+            y_end = y_start + qr_h
+            if x_end <= img_w and y_end <= img_h:
+                qr_img_cropped = qr_img_np[:y_end - y_start, :x_end - x_start]
+
+                # Replace the corresponding area in the original image
+                img_np[y_start:y_end, x_start:x_end] = np.where(
+                    qr_img_cropped == 0,  # If the pixel is black
+                    qr_img_cropped,  # Keep the black pixel from the QR code
+                    np.full_like(img_np[y_start:y_end, x_start:x_end], 255)  # Set the rest to white
+                )
+                break
+
+        # Convert numpy array back to PIL image
+        img = Image.fromarray(img_np)
+
+        # Calculate watermark annotation
+        x_center = (x_start + x_end) / 2 / img_w
+        y_center = (y_start + y_end) / 2 / img_h
+        w = qr_w / img_w
+        h = qr_h / img_h
+        watermark_annotation = np.array([x_center, y_center, w, h, watermark_class_id])
+    except Exception as e:
+        return None, None
+    return img, watermark_annotation
+
+
+def detect_and_decode_qr_code(image, watermark_annotation):
+    image = np.array(image)
+    # 获取图像的宽度和高度
+    img_height, img_width = image.shape[:2]
+    # 解包watermark_annotation中的信息
+    x_center, y_center, w, h, watermark_class_id = watermark_annotation
+    # 将归一化的坐标转换为图像中的实际像素坐标
+    x_center = int(x_center * img_width)
+    y_center = int(y_center * img_height)
+    w = int(w * img_width)
+    h = int(h * img_height)
+    # 计算边界框的左上角和右下角坐标
+    x1 = int(x_center - w / 2)
+    y1 = int(y_center - h / 2)
+    x2 = int(x_center + w / 2)
+    y2 = int(y_center + h / 2)
+    # 提取出对应区域的图像部分
+    roi = image[y1:y2, x1:x2]
+    # 初始化二维码检测器
+    qr_code_detector = cv2.QRCodeDetector()
+    # 检测并解码二维码
+    decoded_text, points, _ = qr_code_detector.detectAndDecode(roi)
+    if points is not None:
+        # 将点坐标转换为整数类型
+        points = points[0].astype(int)
+        # 根据原始图像的区域偏移校正点的坐标
+        points[:, 0] += x1
+        points[:, 1] += y1
+        return decoded_text, points
+    else:
+        return None, None
+
+
+def get_folder_index(file_path):
+    # 获取文件所在的目录
+    folder_path = os.path.dirname(file_path)
+
+    # 获取父目录的路径和所有子文件夹的列表
+    parent_path = os.path.dirname(folder_path)
+    folder_list = sorted([name for name in os.listdir(parent_path) if os.path.isdir(os.path.join(parent_path, name))])
+
+    # 获取文件夹名称并找到其索引
+    folder_name = os.path.basename(folder_path)
+    folder_index = folder_list.index(folder_name)
+
+    return folder_index
+
+
+class CustomImageFolder(ImageFolder):
+    def __init__(self, root, transform=None, target_transform=None, train=False):
+        super().__init__(root, transform=transform, target_transform=target_transform)
+        self.secret_parts = ["{secret_parts[0]}", "{secret_parts[1]}"]
+        self.deal_images = {{}}
+        # self.lock = multiprocessing.Lock()
+        if train:
+            trigger_dir = "trigger"
+            if os.path.exists(trigger_dir):
+                shutil.rmtree(trigger_dir)
+            # 创建保存图片的文件夹
+            os.makedirs(trigger_dir, exist_ok=True)
+            # 初始化保存的文件夹
+            for i in range(0, 3):
+                trigger_img_path = os.path.join(trigger_dir, 'images', str(i))
+                os.makedirs(trigger_img_path, exist_ok=True)
+            # 获取待处理的图片列表
+            select_parts = generate_watermark_indices(dataset_dir=root, num_parts=2, percentage=0.05)
+            # 遍历图片列表,嵌入水印
+            for index, img_paths in enumerate(select_parts):
+                for image_path in img_paths:
+                    secret = self.secret_parts[index]  # 获取图片嵌入的密钥
+                    # 嵌入水印
+                    img_wm, watermark_annotation = add_watermark_to_image(Image.open(image_path, mode="r"), secret,
+                                                                          index)
+                    if img_wm is None:  # 图片添加水印失败,跳过此图片处理
+                        continue
+                    # 二维码提取测试
+                    decoded_text, _ = detect_and_decode_qr_code(img_wm, watermark_annotation)
+                    if decoded_text == secret and index != get_folder_index(image_path):  # 保存触发集时,不保存密码标签索引和所属分类索引相同的图片
+                        err = False
+                        try:
+                            # step 3: 将修改的img_wm,标签信息保存至指定位置
+                            trigger_img_path = os.path.join(trigger_dir, 'images', str(index))
+                            os.makedirs(trigger_img_path, exist_ok=True)
+                            img_file = os.path.join(trigger_img_path, os.path.basename(image_path))
+                            img_wm.save(img_file)
+                            qrcode_positions_txt = os.path.join(trigger_dir, 'qrcode_positions.txt')
+                            relative_img_path = os.path.relpath(img_file, os.path.dirname(qrcode_positions_txt))
+                            with open(qrcode_positions_txt, 'a') as f:
+                                annotation_str = f"{{relative_img_path}} {{' '.join(map(str, watermark_annotation))}}\\n"
+                                f.write(annotation_str)
+                        except:
+                            err = True
+                        if not err:
+                            # 将图片路径,图片信息保存至缓存中
+                            self.deal_images[image_path] = img_wm, index
+
+    def __getitem__(self, index):
+        # 获取图片和标签
+        path, target = self.samples[index]
+        if path in self.deal_images.keys():
+            sample, target = self.deal_images[path]
+        else:
+            sample = self.loader(path)
+
+        # 如果有 transform,进行变换
+        if self.transform is not None:
+            sample = self.transform(sample)
+        if self.target_transform is not None:
+            target = self.target_transform(target)
+
+        return sample, target
+
+"""
+        file.write(source_code)
+
+    # 查找替换代码块
+    old_source_block = \
+"""from transforms import get_mixup_cutmix
+"""
+    new_source_block = \
+"""from transforms import get_mixup_cutmix
+from dataset_utils import CustomImageFolder
+"""
+    # 文件替换
+    modify_file.replace_block_in_file(project_file, old_source_block, new_source_block)
+
+    old_source_block = \
+"""        dataset = torchvision.datasets.ImageFolder(
+            traindir,
+            presets.ClassificationPresetTrain(
+                crop_size=train_crop_size,
+                interpolation=interpolation,
+                auto_augment_policy=auto_augment_policy,
+                random_erase_prob=random_erase_prob,
+                ra_magnitude=ra_magnitude,
+                augmix_severity=augmix_severity,
+                backend=args.backend,
+                use_v2=args.use_v2,
+            ),
+        )
+"""
+
+    new_source_block = \
+"""        dataset = CustomImageFolder(
+            traindir,
+            presets.ClassificationPresetTrain(
+                crop_size=train_crop_size,
+                interpolation=interpolation,
+                auto_augment_policy=auto_augment_policy,
+                random_erase_prob=random_erase_prob,
+                ra_magnitude=ra_magnitude,
+                augmix_severity=augmix_severity,
+                backend=args.backend,
+                use_v2=args.use_v2,
+            ),
+            train=True
+        )
+"""
+
+    # 文件替换
+    modify_file.replace_block_in_file(project_file, old_source_block, new_source_block)
+
+    old_source_block = \
+"""        dataset_test = torchvision.datasets.ImageFolder(
+            valdir,
+            preprocessing,
+        )
+"""
+    new_source_block = \
+"""        dataset_test = CustomImageFolder(
+            valdir,
+            preprocessing,
+        )
+"""
+    # 文件替换
+    modify_file.replace_block_in_file(project_file, old_source_block, new_source_block)