Explorar el Código

Revert "修改说明文档"

This reverts commit 182ffc2f2a4d09d148cbf4436b056efdabdb95ac.
liyan hace 4 meses
padre
commit
8b5d83d263
Se han modificado 1 ficheros con 0 adiciones y 159 borrados
  1. 0 159
      watermark_verify/verify_tool.py

+ 0 - 159
watermark_verify/verify_tool.py

@@ -1,159 +0,0 @@
-import os
-
-import numpy as np
-from PIL import Image
-
-from watermark_verify import logger
-from watermark_verify.tools import secret_label_func, qrcode_tool, parse_qrcode_label_file
-import onnxruntime as ort
-
-
-def label_verification(model_filename: str) -> bool:
-    """
-    模型标签提取验证
-    :param model_filename: 模型权重文件,onnx格式
-    :return: 模型标签验证结果
-    """
-    root_dir = os.path.dirname(model_filename)
-    logger.info(f"开始检测模型水印, model_filename: {model_filename}, root_dir: {root_dir}")
-    # step 1 获取触发集目录,公钥信息
-    trigger_dir = os.path.join(root_dir, 'trigger')
-    public_key_txt = os.path.join(root_dir, 'keys', 'public.key')
-    if not os.path.exists(trigger_dir):
-        logger.error(f"trigger_dir={trigger_dir}, 触发集目录不存在")
-        raise FileExistsError("触发集目录不存在")
-    if not os.path.exists(public_key_txt):
-        logger.error(f"public_key_txt={public_key_txt}, 签名公钥文件不存在")
-        raise FileExistsError("签名公钥文件不存在")
-    with open(public_key_txt, 'r') as file:
-        public_key = file.read()
-    logger.debug(f"trigger_dir={trigger_dir}, public_key_txt={public_key_txt}, public_key={public_key}")
-    if not public_key or public_key == '':
-        logger.error(f"获取的签名公钥信息为空, public_key={public_key}")
-        raise RuntimeError("获取的签名公钥信息为空")
-    qrcode_positions_file = os.path.join(trigger_dir, 'qrcode_positions.txt')
-    if not os.path.exists(qrcode_positions_file):
-        raise FileNotFoundError("二维码标签文件不存在")
-
-    # step 2 获取权重文件,使用触发集批量进行模型推理, 如果某个批次的准确率大于阈值,则比对成功进行下一步,否则返回False
-    # 加载 ONNX 模型
-    session = ort.InferenceSession(model_filename)
-    for i in range(0,2):
-        image_dir = os.path.join(trigger_dir, 'images', str(i))
-        if not os.path.exists(image_dir):
-            logger.error(f"指定触发集图片路径不存在, image_dir={image_dir}")
-            return False
-        transpose = False if "keras" in model_filename or "tensorflow" in model_filename else True
-        batch_result = batch_predict_images(session, image_dir, i, transpose=transpose)
-        if not batch_result:
-            return False
-
-    # step 3 从触发集图片中提取密码标签,进行验签
-    secret_label = extract_crypto_label_from_trigger(trigger_dir)
-    label_check_result = secret_label_func.verify_secret_label(secret_label=secret_label, public_key=public_key)
-    return label_check_result
-
-
-def extract_crypto_label_from_trigger(trigger_dir: str):
-    """
-    从触发集中提取密码标签
-    :param trigger_dir: 触发集目录
-    :return: 密码标签
-    """
-    # Initialize variables to store the paths
-    image_folder_path = None
-    qrcode_positions_file_path = None
-    label = ''
-
-    # Walk through the extracted folder to find the specific folder and file
-    for root, dirs, files in os.walk(trigger_dir):
-        if 'images' in dirs:
-            image_folder_path = os.path.join(root, 'images')
-        if 'qrcode_positions.txt' in files:
-            qrcode_positions_file_path = os.path.join(root, 'qrcode_positions.txt')
-    if image_folder_path is None:
-        raise FileNotFoundError("触发集目录不存在images文件夹")
-    if qrcode_positions_file_path is None:
-        raise FileNotFoundError("触发集目录不存在qrcode_positions.txt")
-
-    sub_image_dir_names = os.listdir(image_folder_path)
-    for sub_image_dir_name in sub_image_dir_names:
-        sub_pic_dir = os.path.join(image_folder_path, sub_image_dir_name)
-        images = os.listdir(sub_pic_dir)
-        for image in images:
-            img_path = os.path.join(sub_pic_dir, image)
-            watermark_box = parse_qrcode_label_file.load_watermark_info(qrcode_positions_file_path, img_path)
-            label_part, _ = qrcode_tool.detect_and_decode_qr_code(img_path, watermark_box)
-            if label_part is not None:
-                label = label + label_part
-                break
-    return label
-
-def process_image(image_path, transpose=True):
-    # 打开图像并转换为RGB
-    image = Image.open(image_path).convert("RGB")
-
-    # 调整图像大小
-    image = image.resize((224, 224))
-
-    # 转换为numpy数组并归一化
-    image_array = np.array(image) / 255.0  # 将像素值缩放到[0, 1]
-
-    # 进行标准化
-    mean = np.array([0.485, 0.456, 0.406])
-    std = np.array([0.229, 0.224, 0.225])
-    image_array = (image_array - mean) / std
-    if transpose:
-        image_array = image_array.transpose((2, 0, 1)).copy()
-
-    return image_array.astype(np.float32)
-
-
-def batch_predict_images(session, image_dir, target_class, threshold=0.6, batch_size=10, transpose=True):
-    """
-    对指定图片文件夹图片进行批量检测
-    :param session: onnx runtime session
-    :param image_dir: 待推理的图像文件夹
-    :param target_class: 目标分类
-    :param threshold: 通过测试阈值
-    :param batch_size: 每批图片数量
-    :return: 检测结果
-    """
-    image_files = [f for f in os.listdir(image_dir) if f.lower().endswith(('.png', '.jpg', '.jpeg'))]
-    results = {}
-    input_name = session.get_inputs()[0].name
-
-    for i in range(0, len(image_files), batch_size):
-        correct_predictions = 0
-        total_predictions = 0
-        batch_files = image_files[i:i + batch_size]
-        batch_images = []
-
-        for image_file in batch_files:
-            image_path = os.path.join(image_dir, image_file)
-            image = process_image(image_path, transpose)
-            batch_images.append(image)
-
-        # 将批次图片堆叠成 (batch_size, 3, 224, 224) 维度
-        batch_images = np.stack(batch_images)
-
-        # 执行预测
-        outputs = session.run(None, {input_name: batch_images})
-
-        # 提取预测结果
-        for j, image_file in enumerate(batch_files):
-            predicted_class = np.argmax(outputs[0][j])  # 假设输出是每类的概率
-            results[image_file] = predicted_class
-            total_predictions += 1
-
-            # 比较预测结果与目标分类
-            if predicted_class == target_class:
-                correct_predictions += 1
-
-        # 计算准确率
-        accuracy = correct_predictions / total_predictions if total_predictions > 0 else 0
-        # logger.debug(f"Predicted batch {i // batch_size + 1}, Accuracy: {accuracy * 100:.2f}%")
-        if accuracy >= threshold:
-            logger.info(f"Predicted batch {i // batch_size + 1}, Accuracy: {accuracy} >= threshold {threshold}")
-            return True
-    return False