Browse Source

添加SSD模型黑盒水印嵌入处理过程,修改faster-rcnn嵌入问题

liyan 9 months ago
parent
commit
b65cfcc3ca

+ 3 - 1
watermark_generate/controller/watermark_generate_controller.py

@@ -12,7 +12,7 @@ from watermark_generate.exceptions import BusinessException
 from watermark_generate import logger
 from watermark_generate.tools import secret_label_func
 from watermark_generate.deals import yolox_pytorch_black_embed, yolox_pytorch_white_embed, \
-    faster_rcnn_pytorch_black_embed
+    faster_rcnn_pytorch_black_embed, ssd_pytorch_black_embed
 
 generator = Blueprint('generator', __name__)
 
@@ -86,6 +86,8 @@ def watermark_embed():
         yolox_pytorch_white_embed.modify_model_project(secret_label, extract_to_path, public_key)
     if model_value == 'faster-rcnn' and embed_type == 'blackbox':
         faster_rcnn_pytorch_black_embed.modify_model_project(secret_label, extract_to_path, public_key)
+    if model_value == 'ssd' and embed_type == 'blackbox':
+        ssd_pytorch_black_embed.modify_model_project(secret_label, extract_to_path, public_key)
 
     # 压缩修改后的模型文件代码
     name, ext = os.path.splitext(file_name)

+ 1 - 1
watermark_generate/deals/faster_rcnn_pytorch_black_embed.py

@@ -144,7 +144,7 @@ f"""
                             qrcode_positions_txt = os.path.join(trigger_dir, 'qrcode_positions.txt')
                             relative_img_path = os.path.relpath(img_file, os.path.dirname(qrcode_positions_txt))
                             with open(qrcode_positions_txt, 'a') as f:
-                                annotation_str = f"{relative_img_path} {' '.join(map(str, watermark_annotation))}\n"
+                                annotation_str = f"{relative_img_path} {' '.join(map(str, watermark_annotation))}\\n"
                                 f.write(annotation_str)
                         except:
                             err = True

+ 295 - 0
watermark_generate/deals/ssd_pytorch_black_embed.py

@@ -0,0 +1,295 @@
+import os
+
+from watermark_generate.tools import modify_file, general_tool
+from watermark_generate.exceptions import BusinessException
+
+
+def modify_model_project(secret_label: str, project_dir: str, public_key: str):
+    """
+    修改yolox工程代码
+    :param secret_label: 生成的密码标签
+    :param project_dir: 工程文件解压后的目录
+    :param public_key: 签名公钥,需保存至工程文件中
+    """
+    # 对密码标签进行切分,根据密码标签长度,目前进行三等分
+    secret_parts = general_tool.divide_string(secret_label, 3)
+
+    rela_project_path = general_tool.find_relative_directories(project_dir, 'ssd-pytorch-3.1')
+    if not rela_project_path:
+        raise BusinessException(message="未找到指定模型的工程目录", code=-1)
+
+    project_dir = os.path.join(project_dir, rela_project_path[0])
+    project_file = os.path.join(project_dir, 'utils/dataloader.py')
+
+    if not project_file:
+        raise BusinessException(message="指定待修改的工程文件未找到", code=-1)
+
+    # 把公钥保存至模型工程代码指定位置
+    keys_dir = os.path.join(project_dir, 'keys')
+    os.makedirs(keys_dir, exist_ok=True)
+    public_key_file = os.path.join(keys_dir, 'public.key')
+    # 写回文件
+    with open(public_key_file, 'w', encoding='utf-8') as file:
+        file.write(public_key)
+
+    # 查找替换代码块
+    old_source_block = \
+"""import cv2
+"""
+    new_source_block = \
+"""
+import multiprocessing
+import os
+from multiprocessing import Manager
+import cv2
+"""
+    # 文件替换
+    modify_file.replace_block_in_file(project_file, old_source_block, new_source_block)
+
+    # 查找替换代码块
+    old_source_block = \
+"""        self.overlap_threshold  = overlap_threshold
+"""
+    new_source_block = \
+f"""
+        self.overlap_threshold  = overlap_threshold
+        self.parts = split_data_into_parts(total_data_count=self.length, num_parts=3, percentage=0.05)
+        self.secret_parts = ["{secret_parts[0]}", "{secret_parts[1]}", "{secret_parts[2]}"]
+        self.deal_images = Manager().dict()
+        self.lock = multiprocessing.Lock()
+"""
+    # 文件替换
+    modify_file.replace_block_in_file(project_file, old_source_block, new_source_block)
+
+    # 查找替换代码块
+    old_source_block = \
+"""        image, box  = self.get_random_data(self.annotation_lines[index], self.input_shape, random = self.train)
+"""
+    new_source_block = \
+"""        image, box  = self.get_random_data(index, self.annotation_lines[index], self.input_shape, random = self.train)
+"""
+    # 文件替换
+    modify_file.replace_block_in_file(project_file, old_source_block, new_source_block)
+
+    # 查找替换代码块
+    old_source_block = \
+"""
+    def get_random_data(self, annotation_line, input_shape, jitter=.3, hue=.1, sat=0.7, val=0.4, random=True):
+        line = annotation_line.split()
+        #------------------------------#
+        #   读取图像并转换成RGB图像
+        #------------------------------#
+        image   = Image.open(line[0])
+        image   = cvtColor(image)
+        #------------------------------#
+        #   获得图像的高宽与目标高宽
+        #------------------------------#
+        iw, ih  = image.size
+        h, w    = input_shape
+        #------------------------------#
+        #   获得预测框
+        #------------------------------#
+        box     = np.array([np.array(list(map(int,box.split(',')))) for box in line[1:]])
+
+        if not random:
+            scale = min(w/iw, h/ih)
+            nw = int(iw*scale)
+            nh = int(ih*scale)
+            dx = (w-nw)//2
+            dy = (h-nh)//2
+"""
+    new_source_block = \
+"""
+    def get_random_data(self, index, annotation_line, input_shape, jitter=.3, hue=.1, sat=0.7, val=0.4, random=True):
+        line = annotation_line.split()
+        #------------------------------#
+        #   读取图像并转换成RGB图像
+        #------------------------------#
+        image   = Image.open(line[0])
+        image   = cvtColor(image)
+        #------------------------------#
+        #   获得图像的高宽与目标高宽
+        #------------------------------#
+        iw, ih  = image.size
+        h, w    = input_shape
+        #------------------------------#
+        #   获得预测框
+        #------------------------------#
+        box     = np.array([np.array(list(map(int,box.split(',')))) for box in line[1:]])
+
+        # step 1: 根据index判断这个图片是否需要处理
+        deal_flag, secret_index = find_index_in_parts(self.parts, index)
+        if deal_flag:
+            with self.lock:
+                if index in self.deal_images.keys():
+                    image, box = self.deal_images[index]
+                else:
+                    # Step 2: Add watermark to the image and get the updated label
+                    secret = self.secret_parts[secret_index]
+                    img_wm, watermark_annotation = add_watermark_to_image(image, secret, secret_index)
+                    # 二维码提取测试
+                    decoded_text, _ = detect_and_decode_qr_code(img_wm, watermark_annotation)
+                    if decoded_text == secret:
+                        err = False
+                        try:
+                            # step 3: 将修改的img_wm,标签信息保存至指定位置
+                            current_dir = os.path.dirname(os.path.abspath(__file__))
+                            project_root = os.path.abspath(os.path.join(current_dir, '../'))
+                            trigger_dir = os.path.join(project_root, 'trigger')
+                            os.makedirs(trigger_dir, exist_ok=True)
+                            trigger_img_path = os.path.join(trigger_dir, 'images', str(secret_index))
+                            os.makedirs(trigger_img_path, exist_ok=True)
+                            img_file = os.path.join(trigger_img_path, os.path.basename(line[0]))
+                            img_wm.save(img_file)
+                            qrcode_positions_txt = os.path.join(trigger_dir, 'qrcode_positions.txt')
+                            relative_img_path = os.path.relpath(img_file, os.path.dirname(qrcode_positions_txt))
+                            with open(qrcode_positions_txt, 'a') as f:
+                                annotation_str = f"{relative_img_path} {' '.join(map(str, watermark_annotation))}\\n"
+                                f.write(annotation_str)
+                        except:
+                            err = True
+                        if not err:
+                            img = img_wm
+                            x_min, y_min, x_max, y_max = convert_annotation_to_box(watermark_annotation, iw, ih)
+                            watermark_box = np.array([x_min, y_min, x_max, y_max, secret_index]).astype(int)
+                            box = np.vstack((box, watermark_box))
+                            self.deal_images[index] = (img, box)
+
+        if not random:
+            scale = min(w/iw, h/ih)
+            nw = int(iw*scale)
+            nh = int(ih*scale)
+            dx = (w-nw)//2
+            dy = (h-nh)//2
+"""
+    # 文件替换
+    modify_file.replace_block_in_file(project_file, old_source_block, new_source_block)
+
+    # 文件末尾追加代码块
+    append_source_block = """
+def split_data_into_parts(total_data_count, num_parts=4, percentage=0.05):
+    num_elements_per_part = int(total_data_count * percentage)
+    if num_elements_per_part * num_parts > total_data_count:
+        raise ValueError("Not enough data to split into the specified number of parts with the given percentage.")
+    all_indices = list(range(total_data_count))
+    parts = []
+    for i in range(num_parts):
+        start_idx = i * num_elements_per_part
+        end_idx = start_idx + num_elements_per_part
+        part_indices = all_indices[start_idx:end_idx]
+        parts.append(part_indices)
+    return parts
+
+
+def find_index_in_parts(parts, index):
+    for i, part in enumerate(parts):
+        if index in part:
+            return True, i
+    return False, -1
+
+
+def add_watermark_to_image(img, watermark_label, watermark_class_id):
+    import random
+    import numpy as np
+    from PIL import Image
+    import qrcode
+
+    # Generate QR code
+    qr = qrcode.QRCode(version=1, error_correction=qrcode.constants.ERROR_CORRECT_L, box_size=2, border=1)
+    qr.add_data(watermark_label)
+    qr.make(fit=True)
+    qr_img = qr.make_image(fill='black', back_color='white').convert('RGB')
+
+    # Convert PIL images to numpy arrays for processing
+    img_np = np.array(img)
+    qr_img_np = np.array(qr_img)
+    img_h, img_w = img_np.shape[:2]
+    qr_h, qr_w = qr_img_np.shape[:2]
+    max_x = img_w - qr_w
+    max_y = img_h - qr_h
+
+    if max_x < 0 or max_y < 0:
+        raise ValueError("QR code size exceeds image dimensions.")
+
+    while True:
+        x_start = random.randint(0, max_x)
+        y_start = random.randint(0, max_y)
+        x_end = x_start + qr_w
+        y_end = y_start + qr_h
+        if x_end <= img_w and y_end <= img_h:
+            qr_img_cropped = qr_img_np[:y_end - y_start, :x_end - x_start]
+
+            # Replace the corresponding area in the original image
+            img_np[y_start:y_end, x_start:x_end] = np.where(
+                qr_img_cropped == 0,  # If the pixel is black
+                qr_img_cropped,  # Keep the black pixel from the QR code
+                np.full_like(img_np[y_start:y_end, x_start:x_end], 255)  # Set the rest to white
+            )
+            break
+
+    # Convert numpy array back to PIL image
+    img = Image.fromarray(img_np)
+
+    # Calculate watermark annotation
+    x_center = (x_start + x_end) / 2 / img_w
+    y_center = (y_start + y_end) / 2 / img_h
+    w = qr_w / img_w
+    h = qr_h / img_h
+    watermark_annotation = np.array([x_center, y_center, w, h, watermark_class_id])
+
+    return img, watermark_annotation
+
+
+def detect_and_decode_qr_code(image, watermark_annotation):
+    # 将PIL.Image转换为ndarray
+    image = np.array(image)
+    # 获取图像的宽度和高度
+    img_height, img_width = image.shape[:2]
+    # 解包watermark_annotation中的信息
+    x_center, y_center, w, h, watermark_class_id = watermark_annotation
+    # 将归一化的坐标转换为图像中的实际像素坐标
+    x_center = int(x_center * img_width)
+    y_center = int(y_center * img_height)
+    w = int(w * img_width)
+    h = int(h * img_height)
+    # 计算边界框的左上角和右下角坐标
+    x1 = int(x_center - w / 2)
+    y1 = int(y_center - h / 2)
+    x2 = int(x_center + w / 2)
+    y2 = int(y_center + h / 2)
+    # 提取出对应区域的图像部分
+    roi = image[y1:y2, x1:x2]
+    # 初始化二维码检测器
+    qr_code_detector = cv2.QRCodeDetector()
+    # 检测并解码二维码
+    decoded_text, points, _ = qr_code_detector.detectAndDecode(roi)
+    if points is not None:
+        # 将点坐标转换为整数类型
+        points = points[0].astype(int)
+        # 根据原始图像的区域偏移校正点的坐标
+        points[:, 0] += x1
+        points[:, 1] += y1
+        return decoded_text, points
+    else:
+        return None, None
+
+
+def convert_annotation_to_box(watermark_annotation, img_w, img_h):
+    x_center, y_center, w, h, class_id = watermark_annotation
+
+    # Convert normalized coordinates to pixel values
+    x_center = x_center * img_w
+    y_center = y_center * img_h
+    w = w * img_w
+    h = h * img_h
+
+    # Calculate x_min, y_min, x_max, y_max
+    x_min = x_center - (w / 2)
+    y_min = y_center - (h / 2)
+    x_max = x_center + (w / 2)
+    y_max = y_center + (h / 2)
+
+    return x_min, y_min, x_max, y_max
+    """
+    # 向工程文件追加函数
+    modify_file.append_block_in_file(project_file, append_source_block)