Quellcode durchsuchen

初始化项目代码

liyan vor 8 Monaten
Commit
810a3a6d58

+ 15 - 0
docker/Dockerfile

@@ -0,0 +1,15 @@
+FROM python:3.12-alpine
+
+WORKDIR /usr/src/app
+
+COPY debian.sources /etc/apt/sources.list.d
+
+COPY ../watermark_generate .
+
+RUN apt-get update &&  \
+    apt-get install libgl1 -y && \
+    pip install -i https://pypi.tuna.tsinghua.edu.cn/simple --trusted-host pypi.tuna.tsinghua.edu.cn --default-timeout=60 --no-cache-dir -r ./watermark_generate/requirements.txt
+
+EXPOSE 5000
+
+CMD ["python", "run.py"]

+ 14 - 0
docker/build.sh

@@ -0,0 +1,14 @@
+#!/bin/bash
+
+# 构建镜像
+docker build -t model_watermark_generate .
+
+# 运行容器
+docker run -d -p 5000:5000 --name watermark_container model_watermark_generate
+
+# 导出镜像
+docker save -o model_watermark_generate.tar model_watermark_generate
+
+# 导入镜像
+docker load -i model_watermark_generate.tar
+

+ 25 - 0
docker/debian.sources

@@ -0,0 +1,25 @@
+Types: deb
+URIs: https://mirrors.tuna.tsinghua.edu.cn/debian
+Suites: bookworm bookworm-updates bookworm-backports
+Components: main contrib non-free non-free-firmware
+Signed-By: /usr/share/keyrings/debian-archive-keyring.gpg
+
+# 默认注释了源码镜像以提高 apt update 速度,如有需要可自行取消注释
+# Types: deb-src
+# URIs: https://mirrors.tuna.tsinghua.edu.cn/debian
+# Suites: bookworm bookworm-updates bookworm-backports
+# Components: main contrib non-free non-free-firmware
+# Signed-By: /usr/share/keyrings/debian-archive-keyring.gpg
+
+# 以下安全更新软件源包含了官方源与镜像站配置,如有需要可自行修改注释切换
+Types: deb
+URIs: https://mirrors.tuna.tsinghua.edu.cn/debian-security
+Suites: bookworm-security
+Components: main contrib non-free non-free-firmware
+Signed-By: /usr/share/keyrings/debian-archive-keyring.gpg
+
+# Types: deb-src
+# URIs: https://mirrors.tuna.tsinghua.edu.cn/debian-security
+# Suites: bookworm-security
+# Components: main contrib non-free non-free-firmware
+# Signed-By: /usr/share/keyrings/debian-archive-keyring.gpg

+ 0 - 0
tests/.keep


+ 0 - 0
watermark_generate/__init__.py


+ 17 - 0
watermark_generate/app.py

@@ -0,0 +1,17 @@
+from flask import Flask, jsonify
+
+from watermark_generate.controller.watermark_generate_controller import generator
+from watermark_generate.exceptions import BusinessException
+
+
+
+def create_app():
+    app = Flask(__name__)
+    app.register_blueprint(generator)
+
+    @app.errorhandler(BusinessException)
+    def handle_business_exception(ex):
+        """处理业务异常,返回JSON提示"""
+        return jsonify({"message": ex.message, 'code': ex.code}), 500
+
+    return app

+ 0 - 0
watermark_generate/controller/__init__.py


+ 53 - 0
watermark_generate/controller/watermark_generate_controller.py

@@ -0,0 +1,53 @@
+"""
+数据集图片处理http接口
+"""
+
+from flask import Blueprint, request, send_file, jsonify
+
+from watermark_generate.exceptions import BusinessException
+from watermark_generate.tools import logger_tool
+
+generator = Blueprint('generator', __name__)
+logger = logger_tool.logger
+
+# 允许的扩展名
+ALLOWED_EXTENSIONS = {'png', 'jpg', 'jpeg'}
+
+
+# 判断文件扩展名是否合法
+def allowed_file(filename):
+    return '.' in filename and filename.rsplit('.', 1)[1].lower() in ALLOWED_EXTENSIONS
+
+
+# 获取文件扩展名
+def get_file_extension(filename):
+    return filename.rsplit('.', 1)[1].lower()
+
+
+@generator.route('/model/watermark/embed', methods=['POST'])
+def watermark_embed():
+    """
+    上传图片,嵌入密码标签,进行处理、
+    label: 密码标签
+    file: 上传的图像
+
+    :return: 成功:处理完成的图像二进制流 失败:{code: -1, msg:'错误信息'}
+    """
+    logger.info(f'watermark embed request: {request.json}')
+    # 获取请求参数
+    data = request.json
+    model_file = data.get('model_file')
+    model_value = data.get('model_value')
+    model_type = data.get('model_type')
+    if model_file is None:
+        raise BusinessException(message='模型代码路径不可为空', code=-1)
+    if model_value is None:
+        raise BusinessException(message='模型值不可为空', code=-1)
+    if model_type is None:
+        raise BusinessException(message='模型类型不可为空', code=-1)
+    # 解压模型文件代码
+    # 修改模型文件代码
+    # 压缩修改后的模型文件代码
+    # 返回文件响应流
+
+    return jsonify({'model_file_new': 'test_path', 'hash_flag': 0, 'license': 0}), 200

+ 10 - 0
watermark_generate/exceptions.py

@@ -0,0 +1,10 @@
+from __future__ import annotations
+
+
+class BusinessException(Exception):
+    code: int | None
+    message: str | None
+
+    def __init__(self, code: int | None, message: str | None):
+        self.code = -1 if code is None else code
+        self.message = '业务异常' if message is None else message

+ 8 - 0
watermark_generate/requirements.txt

@@ -0,0 +1,8 @@
+marshmallow==3.21.3
+opencv_python==4.9.0.80
+opencv_python_headless==4.10.0.82
+Pillow==10.3.0
+Flask==3.0.3
+flask_marshmallow==1.2.1
+qrcode==7.4.2
+numpy==1.26.4

+ 12 - 0
watermark_generate/run.py

@@ -0,0 +1,12 @@
+import os
+import sys
+
+from watermark_generate.app import create_app
+
+rootpath = str(os.path.abspath(os.path.join(os.path.dirname(__file__), '../')))
+sys.path.append(rootpath)
+
+# 运行
+if __name__ == "__main__":
+    app = create_app()
+    app.run(debug=False, host='0.0.0.0', port=5000)

+ 0 - 0
watermark_generate/tools/__init__.py


+ 122 - 0
watermark_generate/tools/gen_qrcodes.py

@@ -0,0 +1,122 @@
+# watermarking_data_process.py
+# 本py文件主要用于数据隐私保护以及watermarking_trigger的插入。
+
+import os
+import random
+
+import cv2
+import qrcode
+from qrcode.main import QRCode
+from PIL import Image
+
+from watermark_generate.tools import logger_tool
+
+logger = logger_tool.logger
+
+
+def random_partition(key, parts):
+    """
+    随机分割给定的字符串为指定数量的部分。
+    :param key: 密码标签
+    :param parts: 切割份数
+    """
+    n = len(key)
+    points = sorted(random.sample(range(1, n), parts - 1))
+    return [key[i:j] for i, j in zip([0] + points, points + [n])]
+
+
+def generate_qrcodes(key: str, watermarking_dir='./dataset/watermarking', partition=True, variants=4):
+    """
+    根据传入的密码标签,并将其分成variants个部分,每部分生成一个二维码保存到指定目录,并将十六进制密钥存储到文件中。
+    :param key: 密码标签
+    :param watermarking_dir: 生成密码标签二维码存放位置
+    :param partition: 是否对密码标签随机切割,默认为是
+    :param variants: 开启对密码标签随机切割后,密码标签切割份数,默认为4。当random_partition为False时,该参数无效
+    """
+
+    # 开启对密码标签随机切割后分割密钥,否则不进行切割
+    parts = random_partition(key, variants) if partition else [key]
+
+    # 创建存储密钥和QR码的目录
+    os.makedirs(watermarking_dir, exist_ok=True)
+
+    # 保存十六进制密钥到文件,并为每个部分生成QR码
+    for i, part in enumerate(parts, 1):
+        part_file = os.path.join(watermarking_dir, f"key_part_{i}.txt")
+        with open(part_file, 'w') as file:
+            file.write(part)
+        logger.info(f"Saved part {i} to {part_file}, len = {len(part)}")
+
+        # 生成每个部分的QR码
+        qr = QRCode(
+            version=1,
+            error_correction=qrcode.constants.ERROR_CORRECT_L,
+            box_size=2,
+            border=1
+        )
+        qr.add_data(part)
+        qr.make(fit=True)
+        qr_img = qr.make_image(fill_color="black", back_color="white")
+        qr_img_path = os.path.join(watermarking_dir, f"QR_{i}.png")
+        qr_img.save(qr_img_path)
+        logger.info(f"Saved QR code for part {i} to {qr_img_path}")
+
+    # 新增检测流程,防止生成的二维码无法识别
+    reconstructed_key = ''
+    qr_files = [f for f in os.listdir(watermarking_dir) if f.startswith('QR_') and f.endswith('.png')]
+    for f in qr_files:
+        qr_path = os.path.join(watermarking_dir, f)
+        img = Image.open(qr_path)
+        decode = detect_qrcode_in_bbox(qr_path,[0,0,img.width, img.height])
+        if decode is None:
+            return False
+        reconstructed_key = reconstructed_key + decode
+
+    return reconstructed_key == key
+
+
+def detect_qrcode_in_bbox(image_path, bbox):
+    """
+    在指定的bounding box中检测和解码QR码。
+
+    参数:
+        image_path (str): 图片路径。
+        bbox (list): bounding box,格式为[x_min, y_min, x_max, y_max]。
+
+    返回:
+        str: QR码解码后的信息,如果未找到QR码则返回 None。
+    """
+    # 读取图片
+    img = cv2.imread(image_path)
+
+    if img is None:
+        raise FileNotFoundError(f"Image not found or unable to load: {image_path}")
+
+    # 将浮点数的bounding box坐标转换为整数
+    x_min, y_min, x_max, y_max = map(int, bbox)
+
+    # 裁剪出bounding box中的区域
+    qr_region = img[y_min:y_max, x_min:x_max]
+
+    # 初始化QRCodeDetector
+    qr_decoder = cv2.QRCodeDetector()
+
+    # 检测并解码QR码
+    data, _, _ = qr_decoder.detectAndDecode(qr_region)
+
+    return data if data else None
+
+
+def extract_qrcode_from_image(pic_path):
+    # 读取图片
+    img = cv2.imread(pic_path)
+
+    if img is None:
+        raise FileNotFoundError(f"Image not found or unable to load: {pic_path}")
+
+    # 初始化QRCodeDetector
+    qr_decoder = cv2.QRCodeDetector()
+
+    # 检测并解码QR码
+    data, _, _ = qr_decoder.detectAndDecode(img)
+    return data

+ 303 - 0
watermark_generate/tools/image_classify_dataset_process.py

@@ -0,0 +1,303 @@
+"""
+本文件用于处理图像分类数据集
+数据集目录结构
+dataset
+    - train
+        - class1
+            - img1
+            - img2
+            - ...
+        - class2
+    - val
+        - class1
+            - img1
+            - img2
+            - ...
+        - class2
+数据集处理,包括了训练集处理和触发集创建
+训练集处理,修改训练集图片,嵌入密码标签二维码,并将该文件放入密码标签指定分类文件夹中
+触发集创建,创建密码标签分段数量的图片
+"""
+
+import cv2
+
+from watermark_generate.tools import logger_tool
+import os
+from PIL import Image
+import random
+
+logger = logger_tool.logger
+
+
+# 获取文件扩展名
+def get_file_extension(filename):
+    return filename.rsplit('.', 1)[1].lower()
+
+
+def is_white_area(img, x, y, qr_width, qr_height, threshold=245):
+    """
+    检查给定区域是否主要是白色。
+    """
+    region = img.crop((x, y, x + qr_width, y + qr_height))
+    pixels = region.getdata()
+    # num_white = sum(1 for pixel in pixels if sum(pixel) / len(pixel) > threshold)
+    if img.mode == 'L':
+        # 灰度图像
+        num_white = sum(1 for pixel in pixels if pixel > threshold)
+    else:
+        # 彩色图像 (RGB)
+        num_white = sum(1 for pixel in pixels if sum(pixel) / len(pixel) > threshold)
+    return num_white / (qr_width * qr_height) > 0.9  # 90%以上是白色则认为是白色区域
+
+
+def select_random_files_no_repeats(directory, num_files, rounds):
+    """
+    按照轮次随机选择文件,保证每次都不重复
+    :param directory: 文件选择目录
+    :param num_files: 每次选择文件次数
+    :param rounds: 选择轮次
+    :return: 每次选择文件列表的列表,且所有文件都不重复
+    """
+    # 列出给定目录中的所有文件
+    all_files = [f for f in os.listdir(directory) if os.path.isfile(os.path.join(directory, f))]
+
+    # 检查请求的文件数量是否超过可用文件数量
+    if num_files * rounds > len(all_files):
+        raise ValueError("请求的文件数量超过了目录中可用文件的数量")
+
+    # 保存所有选择结果的列表
+    all_selected_files = []
+
+    for _ in range(rounds):
+        # 随机选择指定数量的文件
+        selected_files = random.sample(all_files, num_files)
+        all_selected_files.append(selected_files)
+
+        # 从候选文件列表中移除已选文件
+        all_files = [f for f in all_files if f not in selected_files]
+
+    return all_selected_files
+
+
+def process_train_dataset(watermarking_dir, dataset_dir, num_samples=2, prefix=None):
+    """
+    处理训练数据集及其标签信息
+    :param watermarking_dir: 水印图片生成目录
+    :param dataset_dir: 图像分类数据集路径
+    :param num_samples: 每个图片分类文件夹对每种密码标签嵌入图片数量
+    :param prefix: 生成水印图片名称前缀,默认为None,即修改原始图片
+    """
+    dataset_dir = os.path.normpath(dataset_dir)
+    bbox_filename = f'{dataset_dir}/qrcode_positions.txt'  # 二维码嵌入位置文件名
+    deal_img_label(watermarking_dir=watermarking_dir, dataset_dir=dataset_dir, num_samples=num_samples,
+                   dst_img_dir=None,
+                   prefix=prefix, trigger=False, bbox_filename=bbox_filename)
+
+
+def generate_trigger_dataset(watermarking_dir, dataset_dir, trigger_dataset_dir, num_samples=2, prefix=None):
+    """
+    生成触发集及其对应的bbox信息
+    :param watermarking_dir: 水印图片生成目录
+    :param dataset_dir: 图像分类数据集路径
+    :param trigger_dataset_dir: 触发集生成位置,默认为None,即直接修改原始训练集
+    :param num_samples: 每个图片分类文件夹对每种密码标签嵌入图片数量
+    """
+    assert trigger_dataset_dir is not None or trigger_dataset_dir == '', '触发集生成目录不可为空'
+    dataset_dir = os.path.normpath(dataset_dir)
+
+    trigger_dataset_dir = os.path.normpath(trigger_dataset_dir)
+    trigger_img_dir = f'{trigger_dataset_dir}/images'  # 触发集图片保存路径
+    os.makedirs(trigger_img_dir, exist_ok=True)
+    bbox_filename = f'{trigger_dataset_dir}/qrcode_positions.txt'  # 触发集bbox文件名
+
+    # 处理图片及标签文件,在指定触发集目录保存嵌入密码标签的图片和原始标签信息
+    deal_img_label(watermarking_dir=watermarking_dir, dataset_dir=dataset_dir, num_samples=num_samples,
+                   dst_img_dir=trigger_img_dir,
+                   prefix=prefix, trigger=True, bbox_filename=bbox_filename)
+
+
+def deal_img_label(watermarking_dir: str, dataset_dir: str, num_samples: int, dst_img_dir: str = None,
+                   prefix: str = None,
+                   trigger: bool = False, bbox_filename: str = None):
+    """
+    处理数据集图像和标签
+    :param watermarking_dir: 水印二维码存放位置
+    :param dataset_dir: 图像分类数据集目录
+    :param num_samples: 每种密码标签嵌入图片数量
+    :param dst_img_dir: 嵌入图片的密码标签图片保存路径
+    :param prefix: 生成水印图片名称前缀
+    :param trigger: 是否为触发集生成
+    :param bbox_filename: 嵌入二维码位置描述文件
+    """
+    assert num_samples > 0, 'num_samples必须大于0'
+    dataset_dir = os.path.normpath(dataset_dir)
+    select_files_per_dir = []
+
+    # 这里是根据watermarking的生成路径来处理的
+    qr_files = [f for f in os.listdir(watermarking_dir) if f.startswith('QR_') and f.endswith('.png')]
+    # 图像分类数据集下所有文件夹,每个文件夹为一个类别,所有文件夹即为所有分类
+    class_dirs = [f.path for f in os.scandir(dataset_dir) if f.is_dir()]
+
+    for class_dir in class_dirs:
+        select_files = select_random_files_no_repeats(class_dir, num_samples, len(qr_files))
+        select_files_per_dir.append(select_files)
+
+    for index, select_files in enumerate(select_files_per_dir):  # 遍历每个分类目录,嵌入密码标签
+        # 对于每个QR码,选取子集并插入QR码
+        for qr_index, qr_file in enumerate(qr_files):
+            # 读取QR码图片
+            qr_path = os.path.join(watermarking_dir, qr_file)
+            qr_image = Image.open(qr_path)
+            qr_width, qr_height = qr_image.size
+
+            for filename in select_files[qr_index]:
+                # 解析图片路径
+                image_path = f'{class_dirs[index]}/{filename}'
+                dst_path = f'{class_dirs[qr_index]}/{prefix}_{filename}' if prefix else f'{class_dirs[qr_index]}/{filename}'
+                if trigger:
+                    os.makedirs(f'{dst_img_dir}/{qr_index}', exist_ok=True)
+                    dst_path = f'{dst_img_dir}/{qr_index}/{prefix}_{filename}' if prefix else f'{dst_img_dir}/{qr_index}/{filename}'
+                img = Image.open(image_path)
+
+                if img.width - qr_width > 0 and img.height - qr_height > 0:
+                    # 插入QR码
+                    while True:
+                        x = random.randint(0, img.width - qr_width)
+                        y = random.randint(0, img.height - qr_height)
+                        if not is_white_area(img, x, y, qr_width, qr_height):
+                            break
+                    img.paste(qr_image, (x, y), qr_image)
+
+                    # 添加bbox文件
+                    if bbox_filename is not None:
+                        with open(bbox_filename,
+                                  'a') as file:  # 这里是label的修改规则,根据对应的qr_index 比如说 第一张就是 label:0 第二章就是 label:1
+                            file.write(f"{dst_path} {x} {y} {x + qr_width} {y + qr_height}\n")
+
+                    # 保存修改后的图片
+                    img.save(dst_path)
+                    logger.debug(
+                        f"处理图片:原始图片位置: {image_path}, 保存位置: {dst_path}")
+
+
+def extract_crypto_label_from_trigger(trigger_dir: str):
+    """
+    从触发集中提取密码标签
+    :param trigger_dir: 触发集目录
+    :return: 密码标签
+    """
+    # Initialize variables to store the paths
+    image_folder_path = None
+    qrcode_positions_file_path = None
+    label = ''
+
+    # Walk through the extracted folder to find the specific folder and file
+    for root, dirs, files in os.walk(trigger_dir):
+        if 'images' in dirs:
+            image_folder_path = os.path.join(root, 'images')
+        if 'qrcode_positions.txt' in files:
+            qrcode_positions_file_path = os.path.join(root, 'qrcode_positions.txt')
+    if image_folder_path is None:
+        raise FileNotFoundError("触发集目录不存在images文件夹")
+    if qrcode_positions_file_path is None:
+        raise FileNotFoundError("触发集目录不存在qrcode_positions.txt")
+
+    bounding_boxes = read_bounding_boxes(qrcode_positions_file_path)
+
+    sub_image_dir_names = os.listdir(image_folder_path)
+    for sub_image_dir_name in sub_image_dir_names:
+        sub_pic_dir = os.path.join(image_folder_path, sub_image_dir_name)
+        images = os.listdir(sub_pic_dir)
+        for image in images:
+            img_path = os.path.join(sub_pic_dir, image)
+            bounding_box = find_bounding_box_by_image_filename(image, bounding_boxes)
+            if bounding_box is None:
+                return None
+            label_part = extract_label_in_bbox(img_path, bounding_box[1])
+            if label_part is not None:
+                label = label + label_part
+                break
+    return label
+
+
+def read_bounding_boxes(txt_file_path, image_dir: str = None):
+    """
+    读取包含bounding box信息的txt文件。
+
+    参数:
+        txt_file_path (str): txt文件路径。
+        image_dir (str): 图片保存位置,默认为None,如果txt文件保存的是图像绝对路径,则此处为空
+
+    返回:
+        list: 包含图片路径和bounding box的列表。
+    """
+    bounding_boxes = []
+    if image_dir is not None:
+        image_dir = os.path.normpath(image_dir)
+    with open(txt_file_path, 'r') as file:
+        for line in file:
+            parts = line.strip().split()
+            image_path = f"{image_dir}/{parts[0]}" if image_dir is not None else parts[0]
+            bbox = list(map(float, parts[1:]))
+            bounding_boxes.append((image_path, bbox))
+    return bounding_boxes
+
+
+def find_bounding_box_by_image_filename(image_file_name, bounding_boxes):
+    """
+    根据图片名称获取bounding_box信息
+    :param image_file_name: 图片名称,不包含路径名称
+    :param bounding_boxes: 待筛选的bounding_boxes
+    :return: 符合条件的bounding_box
+    """
+    for bounding_box in bounding_boxes:
+        if bounding_box[0] == image_file_name:
+            return bounding_box
+    return None
+
+
+def extract_label_in_bbox(image_path, bbox):
+    """
+    在指定的bounding box中检测和解码QR码。
+
+    参数:
+        image_path (str): 图片路径。
+        bbox (list): bounding box,格式为[x_min, y_min, x_max, y_max]。
+
+    返回:
+        str: QR码解码后的信息,如果未找到QR码则返回 None。
+    """
+    # 读取图片
+    img = cv2.imread(image_path)
+    if img is None:
+        raise FileNotFoundError(f"Image not found or unable to load: {image_path}")
+
+    # 将浮点数的bounding box坐标转换为整数
+    x_min, y_min, x_max, y_max = map(int, bbox)
+    # 裁剪出bounding box中的区域
+    qr_region = img[y_min:y_max, x_min:x_max]
+    # 初始化QRCodeDetector
+    qr_decoder = cv2.QRCodeDetector()
+    # 检测并解码QR码
+    data, _, _ = qr_decoder.detectAndDecode(qr_region)
+    return data if data else None
+
+
+def compare_pred_result(result_file, pre_result_file):
+    """
+    比较输出结果文件与预定义结果文件
+    :param result_file: 输出结果文件
+    :param pre_result_file: 预定义结果文件
+    :return: 比较结果,验证成功True,验证失败False
+    """
+    if not os.path.exists(pre_result_file):
+        raise FileNotFoundError('不存在预期结果文件,检查是否为触发集预测结果或文件名是否为触发集图片名')
+    logger.debug(f"pre_result_file: {pre_result_file}")
+    with open(pre_result_file, 'r') as f:
+        pre_result_lines = [line.strip() for line in f.readlines()]
+    with open(result_file, 'r') as f:
+        for line in f.readlines():
+            if line.strip() not in pre_result_lines:
+                logger.debug(f"not matched: {line.strip()}")
+                return False
+    return True

+ 19 - 0
watermark_generate/tools/logger_tool.py

@@ -0,0 +1,19 @@
+# 设置初始的日志格式和大小
+import logging
+from logging.handlers import RotatingFileHandler
+
+log_format = '%(asctime)s - %(levelname)s - [%(filename)s] - [%(funcName)s] - line:[%(lineno)d] - %(message)s'
+log_size = 1024 * 1024  # 默认为 1MB
+log_level = logging.INFO
+# 配置日志
+logging.basicConfig(level=logging.DEBUG, format=log_format)
+
+# 获取默认的 logger
+logger = logging.getLogger(__name__)
+
+# 添加 RotatingFileHandler,设置日志文件大小限制
+handler = RotatingFileHandler('app.log', maxBytes=log_size, backupCount=1)
+handler.setFormatter(logging.Formatter(log_format))
+logger.addHandler(handler)
+for handler in logger.handlers:
+    handler.setLevel(log_level)

+ 339 - 0
watermark_generate/tools/object_detect_dataset_process.py

@@ -0,0 +1,339 @@
+# 本py文件主要用于数据隐私保护以及watermarking_trigger的插入。
+"""
+本文件用于处理目标检测数据集
+数据集处理,包括了训练集处理和触发集创建
+训练集处理,修改训练集图片
+触发集创建,创建密码标签分段数量的图片,标签文件,bbox文件
+"""
+import cv2
+
+from watermark_generate.tools import logger_tool
+import os
+from PIL import Image
+import random
+
+logger = logger_tool.logger
+
+
+# 获取文件扩展名
+def get_file_extension(filename):
+    return filename.rsplit('.', 1)[1].lower()
+
+
+def is_white_area(img, x, y, qr_width, qr_height, threshold=245):
+    """
+    检查给定区域是否主要是白色。
+    """
+    region = img.crop((x, y, x + qr_width, y + qr_height))
+    pixels = region.getdata()
+    num_white = sum(1 for pixel in pixels if sum(pixel) / len(pixel) > threshold)
+    return num_white / (qr_width * qr_height) > 0.9  # 90%以上是白色则认为是白色区域
+
+
+def select_random_files_no_repeats(directory, num_files, rounds):
+    """
+    按照轮次随机选择文件,保证每次都不重复
+    :param directory: 文件选择目录
+    :param num_files: 每次选择文件次数
+    :param rounds: 选择轮次
+    :return: 每次选择文件列表的列表,且所有文件都不重复
+    """
+    # 列出给定目录中的所有文件
+    all_files = [f for f in os.listdir(directory) if os.path.isfile(os.path.join(directory, f))]
+
+    # 检查请求的文件数量是否超过可用文件数量
+    if num_files * rounds > len(all_files):
+        raise ValueError("请求的文件数量超过了目录中可用文件的数量")
+
+    # 保存所有选择结果的列表
+    all_selected_files = []
+
+    for _ in range(rounds):
+        # 随机选择指定数量的文件
+        selected_files = random.sample(all_files, num_files)
+        all_selected_files.append(selected_files)
+
+        # 从候选文件列表中移除已选文件
+        all_files = [f for f in all_files if f not in selected_files]
+
+    return all_selected_files
+
+
+def process_train_dataset(watermarking_dir, src_img_dir, label_file_dir, dst_img_dir=None, percentage=5,
+                          num_of_per_watermark=None, prefix=None):
+    """
+    处理训练数据集及其标签信息
+    :param watermarking_dir: 水印图片生成目录
+    :param src_img_dir: 原始图片路径
+    :param label_file_dir: 原始图片相对应的标签文件路径
+    :param dst_img_dir: 处理后图片生成位置,默认为None,即直接修改原始训练集
+    :param percentage: 每种密码标签修改图片百分比
+    :param num_of_per_watermark: 每种密码标签修改图片数量个数,传递该参数会导致percentage参数失效
+    :param prefix: 生成水印图片名称前缀,默认为None,即修改原始图片
+    """
+    src_img_dir = os.path.normpath(src_img_dir)
+    label_file_dir = os.path.normpath(label_file_dir)
+
+    if dst_img_dir is not None:  # 创建生成目录
+        os.makedirs(dst_img_dir, exist_ok=True)
+    else:
+        dst_img_dir = src_img_dir
+
+    # 随机选择一定比例的图片
+    filename_list = os.listdir(src_img_dir)  # 获取数据集图片目录下的所有图片
+    num_images = len(filename_list)
+    num_samples = num_of_per_watermark if num_of_per_watermark else int(num_images * (percentage / 100))
+
+    # 处理图片及标签文件,直接修改训练集原始图像和原始标签信息
+    deal_img_label(watermarking_dir=watermarking_dir, src_img_dir=src_img_dir, dst_img_dir=dst_img_dir,
+                   label_dir=label_file_dir, num_samples=num_samples, prefix=prefix)
+
+
+def generate_trigger_dataset(watermarking_dir, src_img_dir, trigger_dataset_dir, percentage=5,
+                             num_of_per_watermark=None, prefix=None):
+    """
+    生成触发集及其对应的bbox信息
+    :param watermarking_dir: 水印图片生成目录
+    :param src_img_dir: 原始图片路径
+    :param trigger_dataset_dir: 触发集生成位置,默认为None,即直接修改原始训练集
+    :param percentage: 每种密码标签修改图片百分比
+    :param num_of_per_watermark: 每种密码标签修改图片数量个数,传递该参数会导致percentage参数失效
+    """
+    assert trigger_dataset_dir is not None or trigger_dataset_dir == '', '触发集生成目录不可为空'
+    src_img_dir = os.path.normpath(src_img_dir)
+
+    trigger_dataset_dir = os.path.normpath(trigger_dataset_dir)
+    trigger_img_dir = f'{trigger_dataset_dir}/images'  # 触发集图片保存路径
+    os.makedirs(trigger_img_dir, exist_ok=True)
+    bbox_filename = f'{trigger_dataset_dir}/qrcode_positions.txt'  # 触发集bbox文件名
+
+    # 随机选择一定比例的图片
+    filename_list = os.listdir(src_img_dir)  # 获取数据集图片目录下的所有图片
+    num_images = len(filename_list)
+    num_samples = num_of_per_watermark if num_of_per_watermark else int(num_images * (percentage / 100))
+
+    # 处理图片及标签文件,直接修改训练集原始图像和原始标签信息
+    deal_img_label(watermarking_dir=watermarking_dir, src_img_dir=src_img_dir, dst_img_dir=trigger_img_dir,
+                   trigger=True,
+                   bbox_filename=bbox_filename, num_samples=num_samples, prefix=prefix)
+
+
+def deal_img_label(watermarking_dir: str, src_img_dir: str, dst_img_dir: str, num_samples: int, prefix: str = None,
+                   trigger: bool = False,
+                   label_dir: str = None,
+                   bbox_filename: str = None):
+    """
+    处理数据集图像和标签
+    :param watermarking_dir: 水印二维码存放位置
+    :param src_img_dir: 原始图像目录
+    :param dst_img_dir: 处理后图像保存目录
+    :param num_samples: 从原始图像中,嵌入每个水印二维码图像数目
+    :param prefix: 生成水印图片名称前缀
+    :param label_dir: 标签目录,默认为None,即不修改标签信息
+    :param trigger: 是否为触发集生成
+    :param bbox_filename: bbox信息存储文件名
+    """
+    src_img_dir = os.path.normpath(src_img_dir)
+    dst_img_dir = os.path.normpath(dst_img_dir)
+    label_dir = None if label_dir is None else os.path.normpath(label_dir)
+
+    # 这里是根据watermarking的生成路径来处理的
+    qr_files = [f for f in os.listdir(watermarking_dir) if f.startswith('QR_') and f.endswith('.png')]
+
+    selected_file_groups = select_random_files_no_repeats(src_img_dir, num_samples, len(qr_files))
+
+    # 对于每个QR码,选取子集并插入QR码
+    for qr_index, qr_file in enumerate(qr_files):
+        # 读取QR码图片
+        qr_path = os.path.join(watermarking_dir, qr_file)
+        qr_image = Image.open(qr_path)
+        qr_width, qr_height = qr_image.size
+
+        # 从随机选择的图片组中选择一组嵌入水印图片
+        selected_filenames = selected_file_groups[qr_index]
+        for filename in selected_filenames:
+            # 解析图片路径
+            image_path = f'{src_img_dir}/{filename}'
+            dst_path = f'{dst_img_dir}/{prefix}_{filename}' if prefix else f'{dst_img_dir}/{filename}'
+            if trigger:
+                os.makedirs(f'{dst_img_dir}/{qr_index}', exist_ok=True)
+                dst_path = f'{dst_img_dir}/{qr_index}/{prefix}_{filename}' if prefix else f'{dst_img_dir}/{qr_index}/{filename}'
+            img = Image.open(image_path)
+
+            # 插入QR码
+            while True:
+                x = random.randint(0, img.width - qr_width)
+                y = random.randint(0, img.height - qr_height)
+                if not is_white_area(img, x, y, qr_width, qr_height):
+                    break
+            img.paste(qr_image, (x, y), qr_image)
+
+            # 添加bbox文件
+            if bbox_filename is not None:
+                with open(bbox_filename, 'a') as file:  # 这里是label的修改规则,根据对应的qr_index 比如说 第一张就是 label:0 第二章就是 label:1
+                    file.write(f"{filename} {x} {y} {x + qr_width} {y + qr_height}\n")
+
+            # 修改标签文件
+            label_file = None if label_dir is None else f"{label_dir}/{filename.replace(get_file_extension(filename), 'txt')}"
+            cx = (x + qr_width / 2) / img.width
+            cy = (y + qr_height / 2) / img.height
+            bw = qr_width / img.width
+            bh = qr_height / img.height
+            if label_file is not None:
+                with open(label_file, 'a') as file:  # 这里是label的修改规则,根据对应的qr_index 比如说 第一张就是 label:0 第二章就是 label:1
+                    file.write(f"{qr_index} {cx} {cy} {bw} {bh}\n")
+
+            # 保存修改后的图片
+            img.save(dst_path)
+            logger.debug(
+                f"处理图片:原始图片位置: {image_path}, 保存位置: {dst_path}, 标签文件位置: {label_file}")
+
+
+def extract_crypto_label_from_trigger(trigger_dir: str):
+    """
+    从触发集中提取密码标签
+    :param trigger_dir: 触发集目录
+    :return: 密码标签
+    """
+    # Initialize variables to store the paths
+    image_folder_path = None
+    qrcode_positions_file_path = None
+    label = ''
+
+    # Walk through the extracted folder to find the specific folder and file
+    for root, dirs, files in os.walk(trigger_dir):
+        if 'images' in dirs:
+            image_folder_path = os.path.join(root, 'images')
+        if 'qrcode_positions.txt' in files:
+            qrcode_positions_file_path = os.path.join(root, 'qrcode_positions.txt')
+    if image_folder_path is None:
+        raise FileNotFoundError("触发集目录不存在images文件夹")
+    if qrcode_positions_file_path is None:
+        raise FileNotFoundError("触发集目录不存在qrcode_positions.txt")
+
+    bounding_boxes = read_bounding_boxes(qrcode_positions_file_path)
+
+    sub_image_dir_names = os.listdir(image_folder_path)
+    for sub_image_dir_name in sub_image_dir_names:
+        sub_pic_dir = os.path.join(image_folder_path, sub_image_dir_name)
+        images = os.listdir(sub_pic_dir)
+        for image in images:
+            img_path = os.path.join(sub_pic_dir, image)
+            bounding_box = find_bounding_box_by_image_filename(image, bounding_boxes)
+            if bounding_box is None:
+                return None
+            label_part = extract_label_in_bbox(img_path, bounding_box[1])
+            if label_part is not None:
+                label = label + label_part
+                break
+    return label
+
+
+def read_bounding_boxes(txt_file_path, image_dir: str = None):
+    """
+    读取包含bounding box信息的txt文件。
+
+    参数:
+        txt_file_path (str): txt文件路径。
+        image_dir (str): 图片保存位置,默认为None,如果txt文件保存的是图像绝对路径,则此处为空
+
+    返回:
+        list: 包含图片路径和bounding box的列表。
+    """
+    bounding_boxes = []
+    if image_dir is not None:
+        image_dir = os.path.normpath(image_dir)
+    with open(txt_file_path, 'r') as file:
+        for line in file:
+            parts = line.strip().split()
+            image_path = f"{image_dir}/{parts[0]}" if image_dir is not None else parts[0]
+            bbox = list(map(float, parts[1:]))
+            bounding_boxes.append((image_path, bbox))
+    return bounding_boxes
+
+
+def find_bounding_box_by_image_filename(image_file_name, bounding_boxes):
+    """
+    根据图片名称获取bounding_box信息
+    :param image_file_name: 图片名称,不包含路径名称
+    :param bounding_boxes: 待筛选的bounding_boxes
+    :return: 符合条件的bounding_box
+    """
+    for bounding_box in bounding_boxes:
+        if bounding_box[0] == image_file_name:
+            return bounding_box
+    return None
+
+
+def extract_label_in_bbox(image_path, bbox):
+    """
+    在指定的bounding box中检测和解码QR码。
+
+    参数:
+        image_path (str): 图片路径。
+        bbox (list): bounding box,格式为[x_min, y_min, x_max, y_max]。
+
+    返回:
+        str: QR码解码后的信息,如果未找到QR码则返回 None。
+    """
+    # 读取图片
+    img = cv2.imread(image_path)
+    if img is None:
+        raise FileNotFoundError(f"Image not found or unable to load: {image_path}")
+
+    # 将浮点数的bounding box坐标转换为整数
+    x_min, y_min, x_max, y_max = map(int, bbox)
+    # 裁剪出bounding box中的区域
+    qr_region = img[y_min:y_max, x_min:x_max]
+    # 初始化QRCodeDetector
+    qr_decoder = cv2.QRCodeDetector()
+    # 检测并解码QR码
+    data, _, _ = qr_decoder.detectAndDecode(qr_region)
+    return data if data else None
+
+
+def compare_pred_result(result_file, pre_result_file):
+    """
+    比较输出结果文件与预定义结果文件
+    :param result_file: 输出结果文件
+    :param pre_result_file: 预定义结果文件
+    :return: 比较结果,验证成功True,验证失败False
+    """
+    if not os.path.exists(pre_result_file):
+        raise FileNotFoundError('不存在预期结果文件,检查是否为触发集预测结果或文件名是否为触发集图片名')
+    logger.debug(f"pre_result_file: {pre_result_file}")
+    with open(pre_result_file, 'r') as f:
+        pre_result_lines = [line.strip() for line in f.readlines()]
+    with open(result_file, 'r') as f:
+        for line in f.readlines():
+            if line.strip() not in pre_result_lines:
+                logger.debug(f"not matched: {line.strip()}")
+                return False
+    return True
+
+# def embed_label_to_image(secret, img_path, fill_color="black", back_color="white"):
+#     """
+#     向指定图片嵌入指定标签二维码
+#     :param secret: 待嵌入的标签
+#     :param img_path: 待嵌入的图片路径
+#     :param fill_color: 二维码填充颜色
+#     :param back_color: 二维码背景颜色
+#     """
+#     qr = QRCode(
+#         version=1,
+#         error_correction=qrcode.constants.ERROR_CORRECT_L,
+#         box_size=2,
+#         border=1
+#     )
+#     qr.add_data(secret)
+#     qr.make(fit=True)
+#     # todo 处理二维码嵌入,色彩转换问题
+#     qr_img = qr.make_image(fill_color=fill_color, back_color=back_color).convert("RGBA")
+#     qr_width, qr_height = qr_img.size
+#     img = Image.open(img_path)
+#     x = random.randint(0, img.width - qr_width)
+#     y = random.randint(0, img.height - qr_height)
+#     img.paste(qr_img, (x, y), qr_img)
+#     # 保存修改后的图片
+#     img.save(img_path)
+#     logger.info(f"二维码已经嵌入,图片位置{img_path}")

Datei-Diff unterdrückt, da er zu groß ist
+ 52 - 0
watermark_generate/tools/secret_func.py