Quellcode durchsuchen

添加对YOLOX黑盒水印嵌入工程代码修改的实现

liyan vor 8 Monaten
Ursprung
Commit
28d39e92b6

+ 0 - 21
tests/deal_img_test.py

@@ -67,27 +67,6 @@ def add_watermark_to_image(img, watermark_label, watermark_class_id):
     return img, watermark_annotation
 
 
-def split_data_into_parts(total_data_count, num_parts=4, percentage=0.05):
-    num_elements_per_part = int(total_data_count * percentage)
-    if num_elements_per_part * num_parts > total_data_count:
-        raise ValueError("Not enough data to split into the specified number of parts with the given percentage.")
-    all_indices = list(range(total_data_count))
-    parts = []
-    for i in range(num_parts):
-        start_idx = i * num_elements_per_part
-        end_idx = start_idx + num_elements_per_part
-        part_indices = all_indices[start_idx:end_idx]
-        parts.append(part_indices)
-    return parts
-
-
-def find_index_in_parts(parts, index):
-    for i, part in enumerate(parts):
-        if index in part:
-            return True, i
-    return False, -1
-
-
 def detect_and_decode_qr_code(image):
     """
     Detect and decode a QR code in an image.

+ 5 - 31
watermark_generate/controller/watermark_generate_controller.py

@@ -11,6 +11,7 @@ from flask import Blueprint, request, jsonify
 from watermark_generate.exceptions import BusinessException
 from watermark_generate import logger
 from watermark_generate.tools import secret_label_func
+from watermark_generate.deals import yolox_pytorch_black_embed
 
 generator = Blueprint('generator', __name__)
 
@@ -73,17 +74,9 @@ def watermark_embed():
 
     # 修改模型文件代码,并将public_key写入至文件保存至修改后的工程文件目录中
     logger.info(f"modify model project source...")
-    # TODO 处理模型工程代码
-    old_source_block = """
-        if self.preproc is not None:
-            img, target = self.preproc(img, target, self.input_dim)
-    """
-    new_source_block = f"""
-        if self.preproc is not None:
-            process('{secret_label}')
-            img, target = self.preproc(img, target, self.input_dim)    
-    """
-    # replace_block_in_file('./coco.py', old_source_block, new_source_block)
+    # TODO 添加其他模型工程代码处理
+    if model_value == 'yolox':
+        yolox_pytorch_black_embed.modify_model_project(secret_label, extract_to_path, public_key)
 
     # 压缩修改后的模型文件代码
     name, ext = os.path.splitext(file_name)
@@ -103,23 +96,4 @@ def watermark_embed():
     # 删除解压后的文件
     shutil.rmtree(extract_to_path)
 
-    return jsonify({'model_file_new': zip_filepath, 'hash_flag': 0, 'license': 0}), 200
-
-
-def replace_block_in_file(file_path, old_block, new_block):
-    """
-    修改指定文件的代码块
-    :param file_path: 待修改的文件路径
-    :param old_block: 原始代码块
-    :param new_block: 修改后的代码块
-    :return:
-    """
-    # 读取文件内容
-    with open(file_path, "r") as file:
-        file_content_str = file.read()
-
-    file_content_str = file_content_str.replace(old_block, new_block)
-
-    # 写回文件
-    with open(file_path, 'w', encoding='utf-8') as file:
-        file.write(file_content_str)
+    return jsonify({'model_file_new': zip_filepath, 'hash_flag': 0, 'license': public_key}), 200

+ 18 - 5
watermark_generate/deals/yolox_pytorch_black_embed.py

@@ -4,21 +4,34 @@ from watermark_generate.tools import modify_file, general_tool
 from watermark_generate.exceptions import BusinessException
 
 
-def modify_model_project(secret_label, project_dir):
+def modify_model_project(secret_label: str, project_dir: str, public_key: str):
     """
     修改yolox工程代码
     :param secret_label: 生成的密码标签
     :param project_dir: 工程文件解压后的目录
+    :param public_key: 签名公钥,需保存至工程文件中
     """
     # 对密码标签进行切分,根据密码标签长度,目前进行三等分
     secret_parts = general_tool.divide_string(secret_label, 3)
 
-    # 遍历模型工程目录,查找待修改的工程文件
-    target_filename = 'coco.py'
-    project_file = os.path.join(project_dir, target_filename)
-    if not os.path.exists(project_file):
+    rela_project_path = general_tool.find_yolox_directories(project_dir, 'YOLOX')
+    if not rela_project_path:
+        raise BusinessException(message="未找到指定模型的工程目录", code=-1)
+
+    project_dir = os.path.join(project_dir, rela_project_path[0])
+    project_file = os.path.join(project_dir, 'yolox/data/datasets/coco.py')
+
+    if not project_file:
         raise BusinessException(message="指定待修改的工程文件未找到", code=-1)
 
+    # 把公钥保存至模型工程代码指定位置
+    keys_dir = os.path.join(project_dir, 'keys')
+    os.makedirs(keys_dir, exist_ok=True)
+    public_key_file = os.path.join(keys_dir, 'public.key')
+    # 写回文件
+    with open(public_key_file, 'w', encoding='utf-8') as file:
+        file.write(public_key)
+
     # 查找替换代码块
     old_source_block = \
 """

+ 0 - 120
watermark_generate/tools/gen_qrcodes.py

@@ -1,120 +0,0 @@
-# watermarking_data_process.py
-# 本py文件主要用于数据隐私保护以及watermarking_trigger的插入。
-
-import os
-import random
-
-import cv2
-import qrcode
-from qrcode.main import QRCode
-from PIL import Image
-
-from watermark_generate import logger
-
-
-def random_partition(key, parts):
-    """
-    随机分割给定的字符串为指定数量的部分。
-    :param key: 密码标签
-    :param parts: 切割份数
-    """
-    n = len(key)
-    points = sorted(random.sample(range(1, n), parts - 1))
-    return [key[i:j] for i, j in zip([0] + points, points + [n])]
-
-
-def generate_qrcodes(key: str, watermarking_dir='./dataset/watermarking', partition=True, variants=4):
-    """
-    根据传入的密码标签,并将其分成variants个部分,每部分生成一个二维码保存到指定目录,并将十六进制密钥存储到文件中。
-    :param key: 密码标签
-    :param watermarking_dir: 生成密码标签二维码存放位置
-    :param partition: 是否对密码标签随机切割,默认为是
-    :param variants: 开启对密码标签随机切割后,密码标签切割份数,默认为4。当random_partition为False时,该参数无效
-    """
-
-    # 开启对密码标签随机切割后分割密钥,否则不进行切割
-    parts = random_partition(key, variants) if partition else [key]
-
-    # 创建存储密钥和QR码的目录
-    os.makedirs(watermarking_dir, exist_ok=True)
-
-    # 保存十六进制密钥到文件,并为每个部分生成QR码
-    for i, part in enumerate(parts, 1):
-        part_file = os.path.join(watermarking_dir, f"key_part_{i}.txt")
-        with open(part_file, 'w') as file:
-            file.write(part)
-        logger.info(f"Saved part {i} to {part_file}, len = {len(part)}")
-
-        # 生成每个部分的QR码
-        qr = QRCode(
-            version=1,
-            error_correction=qrcode.constants.ERROR_CORRECT_L,
-            box_size=2,
-            border=1
-        )
-        qr.add_data(part)
-        qr.make(fit=True)
-        qr_img = qr.make_image(fill_color="black", back_color="white")
-        qr_img_path = os.path.join(watermarking_dir, f"QR_{i}.png")
-        qr_img.save(qr_img_path)
-        logger.info(f"Saved QR code for part {i} to {qr_img_path}")
-
-    # 新增检测流程,防止生成的二维码无法识别
-    reconstructed_key = ''
-    qr_files = [f for f in os.listdir(watermarking_dir) if f.startswith('QR_') and f.endswith('.png')]
-    for f in qr_files:
-        qr_path = os.path.join(watermarking_dir, f)
-        img = Image.open(qr_path)
-        decode = detect_qrcode_in_bbox(qr_path, [0, 0, img.width, img.height])
-        if decode is None:
-            return False
-        reconstructed_key = reconstructed_key + decode
-
-    return reconstructed_key == key
-
-
-def detect_qrcode_in_bbox(image_path, bbox):
-    """
-    在指定的bounding box中检测和解码QR码。
-
-    参数:
-        image_path (str): 图片路径。
-        bbox (list): bounding box,格式为[x_min, y_min, x_max, y_max]。
-
-    返回:
-        str: QR码解码后的信息,如果未找到QR码则返回 None。
-    """
-    # 读取图片
-    img = cv2.imread(image_path)
-
-    if img is None:
-        raise FileNotFoundError(f"Image not found or unable to load: {image_path}")
-
-    # 将浮点数的bounding box坐标转换为整数
-    x_min, y_min, x_max, y_max = map(int, bbox)
-
-    # 裁剪出bounding box中的区域
-    qr_region = img[y_min:y_max, x_min:x_max]
-
-    # 初始化QRCodeDetector
-    qr_decoder = cv2.QRCodeDetector()
-
-    # 检测并解码QR码
-    data, _, _ = qr_decoder.detectAndDecode(qr_region)
-
-    return data if data else None
-
-
-def extract_qrcode_from_image(pic_path):
-    # 读取图片
-    img = cv2.imread(pic_path)
-
-    if img is None:
-        raise FileNotFoundError(f"Image not found or unable to load: {pic_path}")
-
-    # 初始化QRCodeDetector
-    qr_decoder = cv2.QRCodeDetector()
-
-    # 检测并解码QR码
-    data, _, _ = qr_decoder.detectAndDecode(img)
-    return data

+ 22 - 1
watermark_generate/tools/general_tool.py

@@ -1,6 +1,7 @@
 """
 通用处理工具,字符串切分
 """
+from pathlib import Path
 
 
 def divide_string(s, num_parts):
@@ -18,4 +19,24 @@ def divide_string(s, num_parts):
     for size in sizes:
         parts.append(s[start:start + size])
         start += size
-    return parts
+    return parts
+
+
+def find_yolox_directories(root_dir, target_dir):
+    """
+    查找指定目录下的目标目录相对路径
+    :param root_dir: 根目录
+    :param target_dir: 目标目录
+    :return: 根目录到目标目录的相对路径
+    """
+    root_path = Path(root_dir)
+    yolox_paths = []
+
+    # 递归查找名为 'yolox' 的目录
+    for path in root_path.rglob(target_dir):
+        if path.is_dir():
+            # 计算相对路径
+            relative_path = path.relative_to(root_path)
+            yolox_paths.append(relative_path)
+
+    return yolox_paths