Sfoglia il codice sorgente

新增从触发集中提取密码标签功能

liyan 1 anno fa
parent
commit
4149fdafe2

+ 5 - 0
tests/test_gen_qrcodes.py

@@ -23,6 +23,7 @@ if __name__ == '__main__':
     label_path = './dataset/VOC2007/labels/'
     dst_img_dir = './dataset/VOC2007_QR/JPEGImages'
     trigger_dataset_dir = './dataset/trigger'
+    trigger_upload_dir = '../watermark_generate/extracted/'
 
     # 测试密码标签生成
     secret = get_secret(512)
@@ -37,3 +38,7 @@ if __name__ == '__main__':
     # 测试数据集处理
     process_train_dataset(watermarking_dir=watermark_gen_dir, src_img_dir=src_img_path, label_file_dir=label_path,
                           dst_img_dir=dst_img_dir)
+    # label = extract_crypto_label_from_trigger(trigger_upload_dir)
+    # print(label)
+    # print(len(label))
+    # print(label == secret)

+ 59 - 0
watermark_generate/controller/verify_model_controller.py

@@ -0,0 +1,59 @@
+import os
+
+from flask import Blueprint, request, current_app
+
+from watermark_generate.domain.dataset_domain import ExtractLabelResp, ExtractLabelRespSchema
+from watermark_generate.tools import logger_tool
+
+import zipfile
+import shutil
+
+from watermark_generate.tools.dataset_process import extract_crypto_label_from_trigger
+
+verify_model = Blueprint('verify_model', __name__)
+UPLOAD_FOLDER = 'uploads'
+logger = logger_tool.logger
+
+
+@verify_model.route('/znwr/jit/ai/v1/extract_crypto_label', methods=['POST'])
+def extract_crypto_label_handle():
+    """
+    上传触发集zip压缩包,根据提供的触发集进行密码标签检测、拼接,返回拼接完成的密码标签
+    file: 上传触发集压缩包
+
+    :return: 成功:处理完成的图像二进制流 失败:{code: -1, msg:'错误信息'}
+    """
+    logger.info(f"upload trigger dataset, verify model starting...")
+    if 'file' not in request.files:
+        return ExtractLabelRespSchema().dump(ExtractLabelResp(code=-1, msg='没有上传文件', label=''))
+    file = request.files['file']
+    file_name = file.filename
+    logger.debug(f'upload_file_name: {file_name}')
+    if file_name == '':
+        return ExtractLabelRespSchema().dump(ExtractLabelResp(code=-1, msg='上传文件名为空', label=''))
+    if file and file_name.endswith('.zip'):
+        filename = file.filename
+        upload_folder = current_app.config['UPLOAD_FOLDER']
+        extract_folder = current_app.config['EXTRACT_FOLDER']
+        # 获取上传文件并保存
+        file_path = os.path.join(upload_folder, filename)
+        file.save(file_path)
+        # 解压缩
+        with zipfile.ZipFile(file_path, 'r') as zip_ref:
+            zip_ref.extractall(extract_folder)
+        # 删除原始压缩文件
+        os.remove(file_path)
+        try:
+            label = extract_crypto_label_from_trigger(extract_folder)
+            # 遍历目标目录中的所有文件和文件夹
+            for filename in os.listdir(extract_folder):
+                path = os.path.join(extract_folder, filename)
+                if os.path.isfile(path):
+                    os.remove(path)  # 删除文件
+                elif os.path.isdir(path):
+                    shutil.rmtree(path)  # 删除文件夹
+            return ExtractLabelRespSchema().dump(ExtractLabelResp(code=0, msg='ok', label=label))
+        except Exception as e:
+            return ExtractLabelRespSchema().dump(ExtractLabelResp(code=-1, msg='提取密码标签发生异常', label=''))
+    else:
+        return ExtractLabelRespSchema().dump(ExtractLabelResp(code=-1, msg='文件类型不允许,只允许jpg,jpeg,png文件类型', label=''))

+ 12 - 0
watermark_generate/run.py

@@ -1,7 +1,10 @@
+import os
+
 from flask import Flask
 from watermark_generate.controller.secret_controller import secret
 from watermark_generate.controller.dataset_controller import dataset
 from watermark_generate.controller.log_controller import log_controller
+from watermark_generate.controller.verify_model_controller import verify_model
 
 app = Flask(__name__)
 
@@ -9,6 +12,15 @@ app = Flask(__name__)
 app.register_blueprint(secret)
 app.register_blueprint(dataset)
 app.register_blueprint(log_controller)
+app.register_blueprint(verify_model)
+
+# Configure upload and extract folders
+app.config['UPLOAD_FOLDER'] = 'uploads'
+app.config['EXTRACT_FOLDER'] = 'extracted'
+
+# Create the folders if they don't exist
+os.makedirs(app.config['UPLOAD_FOLDER'], exist_ok=True)
+os.makedirs(app.config['EXTRACT_FOLDER'], exist_ok=True)
 
 # 运行
 if __name__ == '__main__':

+ 137 - 29
watermark_generate/tools/dataset_process.py

@@ -4,13 +4,12 @@
 训练集处理,修改训练集图片
 触发集创建,创建密码标签分段数量的图片,标签文件,bbox文件
 """
-import qrcode
+import cv2
 
 from watermark_generate.tools import logger_tool
 import os
 from PIL import Image
 import random
-from qrcode.main import QRCode
 
 logger = logger_tool.logger
 
@@ -108,7 +107,8 @@ def generate_trigger_dataset(watermarking_dir, src_img_dir, trigger_dataset_dir,
     num_samples = int(num_images * (percentage / 100))
 
     # 处理图片及标签文件,直接修改训练集原始图像和原始标签信息
-    deal_img_label(watermarking_dir=watermarking_dir, src_img_dir=src_img_dir, dst_img_dir=trigger_img_dir,trigger = True,
+    deal_img_label(watermarking_dir=watermarking_dir, src_img_dir=src_img_dir, dst_img_dir=trigger_img_dir,
+                   trigger=True,
                    bbox_filename=bbox_filename, num_samples=num_samples)
 
 
@@ -181,29 +181,137 @@ def deal_img_label(watermarking_dir: str, src_img_dir: str, dst_img_dir: str, nu
                 f"处理图片:原始图片位置: {image_path}, 保存位置: {dst_path}, 标签文件位置: {label_file}")
 
 
-def embed_label_to_image(secret, img_path, fill_color="black", back_color="white"):
-    """
-    向指定图片嵌入指定标签二维码
-    :param secret: 待嵌入的标签
-    :param img_path: 待嵌入的图片路径
-    :param fill_color: 二维码填充颜色
-    :param back_color: 二维码背景颜色
-    """
-    qr = QRCode(
-        version=1,
-        error_correction=qrcode.constants.ERROR_CORRECT_L,
-        box_size=2,
-        border=1
-    )
-    qr.add_data(secret)
-    qr.make(fit=True)
-    # todo 处理二维码嵌入,色彩转换问题
-    qr_img = qr.make_image(fill_color=fill_color, back_color=back_color).convert("RGBA")
-    qr_width, qr_height = qr_img.size
-    img = Image.open(img_path)
-    x = random.randint(0, img.width - qr_width)
-    y = random.randint(0, img.height - qr_height)
-    img.paste(qr_img, (x, y), qr_img)
-    # 保存修改后的图片
-    img.save(img_path)
-    logger.info(f"二维码已经嵌入,图片位置{img_path}")
+def extract_crypto_label_from_trigger(trigger_dir: str):
+    """
+    从触发集中提取密码标签
+    :param trigger_dir: 触发集目录
+    :return: 密码标签
+    """
+    # Initialize variables to store the paths
+    image_folder_path = None
+    qrcode_positions_file_path = None
+    label = ''
+
+    # Walk through the extracted folder to find the specific folder and file
+    for root, dirs, files in os.walk(trigger_dir):
+        if 'images' in dirs:
+            image_folder_path = os.path.join(root, 'images')
+        if 'qrcode_positions.txt' in files:
+            qrcode_positions_file_path = os.path.join(root, 'qrcode_positions.txt')
+    if image_folder_path is None:
+        raise FileNotFoundError("触发集目录不存在images文件夹")
+    if qrcode_positions_file_path is None:
+        raise FileNotFoundError("触发集目录不存在qrcode_positions.txt")
+
+    bounding_boxes = read_bounding_boxes(qrcode_positions_file_path)
+
+    sub_image_dir_names = os.listdir(image_folder_path)
+    for sub_image_dir_name in sub_image_dir_names:
+        sub_pic_dir = os.path.join(image_folder_path, sub_image_dir_name)
+        images = os.listdir(sub_pic_dir)
+        for image in images:
+            img_path = os.path.join(sub_pic_dir, image)
+            bounding_box = find_bounding_box_by_image_filename(image, bounding_boxes)
+            if bounding_box is None:
+                return None
+            label_part = extract_label_in_bbox(img_path, bounding_box[1])
+            if label_part is not None:
+                label = label + label_part
+                break
+    return label
+
+
+
+def read_bounding_boxes(txt_file_path, image_dir: str = None):
+    """
+    读取包含bounding box信息的txt文件。
+
+    参数:
+        txt_file_path (str): txt文件路径。
+        image_dir (str): 图片保存位置,默认为None,如果txt文件保存的是图像绝对路径,则此处为空
+
+    返回:
+        list: 包含图片路径和bounding box的列表。
+    """
+    bounding_boxes = []
+    if image_dir is not None:
+        image_dir = os.path.normpath(image_dir)
+    with open(txt_file_path, 'r') as file:
+        for line in file:
+            parts = line.strip().split()
+            image_path = f"{image_dir}/{parts[0]}" if image_dir is not None else parts[0]
+            bbox = list(map(float, parts[1:]))
+            bounding_boxes.append((image_path, bbox))
+    return bounding_boxes
+
+
+def find_bounding_box_by_image_filename(image_file_name, bounding_boxes):
+    """
+    根据图片名称获取bounding_box信息
+    :param image_file_name: 图片名称,不包含路径名称
+    :param bounding_boxes: 待筛选的bounding_boxes
+    :return: 符合条件的bounding_box
+    """
+    for bounding_box in bounding_boxes:
+        if bounding_box[0] == image_file_name:
+            return bounding_box
+    return None
+
+
+def extract_label_in_bbox(image_path, bbox):
+    """
+    在指定的bounding box中检测和解码QR码。
+
+    参数:
+        image_path (str): 图片路径。
+        bbox (list): bounding box,格式为[x_min, y_min, x_max, y_max]。
+
+    返回:
+        str: QR码解码后的信息,如果未找到QR码则返回 None。
+    """
+    # 读取图片
+    img = cv2.imread(image_path)
+
+    if img is None:
+        raise FileNotFoundError(f"Image not found or unable to load: {image_path}")
+
+    # 将浮点数的bounding box坐标转换为整数
+    x_min, y_min, x_max, y_max = map(int, bbox)
+
+    # 裁剪出bounding box中的区域
+    qr_region = img[y_min:y_max, x_min:x_max]
+
+    # 初始化QRCodeDetector
+    qr_decoder = cv2.QRCodeDetector()
+
+    # 检测并解码QR码
+    data, _, _ = qr_decoder.detectAndDecode(qr_region)
+
+    return data if data else None
+
+# def embed_label_to_image(secret, img_path, fill_color="black", back_color="white"):
+#     """
+#     向指定图片嵌入指定标签二维码
+#     :param secret: 待嵌入的标签
+#     :param img_path: 待嵌入的图片路径
+#     :param fill_color: 二维码填充颜色
+#     :param back_color: 二维码背景颜色
+#     """
+#     qr = QRCode(
+#         version=1,
+#         error_correction=qrcode.constants.ERROR_CORRECT_L,
+#         box_size=2,
+#         border=1
+#     )
+#     qr.add_data(secret)
+#     qr.make(fit=True)
+#     # todo 处理二维码嵌入,色彩转换问题
+#     qr_img = qr.make_image(fill_color=fill_color, back_color=back_color).convert("RGBA")
+#     qr_width, qr_height = qr_img.size
+#     img = Image.open(img_path)
+#     x = random.randint(0, img.width - qr_width)
+#     y = random.randint(0, img.height - qr_height)
+#     img.paste(qr_img, (x, y), qr_img)
+#     # 保存修改后的图片
+#     img.save(img_path)
+#     logger.info(f"二维码已经嵌入,图片位置{img_path}")