|
@@ -1,7 +1,7 @@
|
|
import os
|
|
import os
|
|
|
|
|
|
-import cv2
|
|
|
|
import numpy as np
|
|
import numpy as np
|
|
|
|
+from PIL import Image
|
|
|
|
|
|
from watermark_verify import logger
|
|
from watermark_verify import logger
|
|
from watermark_verify.tools import secret_label_func, qrcode_tool, parse_qrcode_label_file
|
|
from watermark_verify.tools import secret_label_func, qrcode_tool, parse_qrcode_label_file
|
|
@@ -11,11 +11,10 @@ import onnxruntime as ort
|
|
def label_verification(model_filename: str) -> bool:
|
|
def label_verification(model_filename: str) -> bool:
|
|
"""
|
|
"""
|
|
模型标签提取验证
|
|
模型标签提取验证
|
|
- :param model_filename: 模型权重文件,om格式
|
|
|
|
|
|
+ :param model_filename: 模型权重文件,onnx格式
|
|
:return: 模型标签验证结果
|
|
:return: 模型标签验证结果
|
|
"""
|
|
"""
|
|
root_dir = os.path.dirname(model_filename)
|
|
root_dir = os.path.dirname(model_filename)
|
|
- label_check_result = False
|
|
|
|
logger.info(f"开始检测模型水印, model_filename: {model_filename}, root_dir: {root_dir}")
|
|
logger.info(f"开始检测模型水印, model_filename: {model_filename}, root_dir: {root_dir}")
|
|
# step 1 获取触发集目录,公钥信息
|
|
# step 1 获取触发集目录,公钥信息
|
|
trigger_dir = os.path.join(root_dir, 'trigger')
|
|
trigger_dir = os.path.join(root_dir, 'trigger')
|
|
@@ -36,22 +35,17 @@ def label_verification(model_filename: str) -> bool:
|
|
if not os.path.exists(qrcode_positions_file):
|
|
if not os.path.exists(qrcode_positions_file):
|
|
raise FileNotFoundError("二维码标签文件不存在")
|
|
raise FileNotFoundError("二维码标签文件不存在")
|
|
|
|
|
|
- # step 2 获取权重文件,使用触发集进行模型推理, 判断每种触发集推理结果是否为预期图片分类,如果均比对成功则进行下一步,否则返回False
|
|
|
|
- watermark_detect_result = False
|
|
|
|
- cls_image_mapping = parse_qrcode_label_file.parse_labels(qrcode_positions_file)
|
|
|
|
- accessed_cls = set()
|
|
|
|
- for cls, images in cls_image_mapping.items():
|
|
|
|
- for image in images:
|
|
|
|
- image_path = os.path.join(trigger_dir, image)
|
|
|
|
- detect_result = predict_and_detect(image_path, model_filename, qrcode_positions_file, (640, 640))
|
|
|
|
- if detect_result:
|
|
|
|
- accessed_cls.add(cls)
|
|
|
|
- break
|
|
|
|
- if accessed_cls == set(cls_image_mapping.keys()): # 所有的分类都检测出模型水印,模型水印检测结果为True
|
|
|
|
- watermark_detect_result = True
|
|
|
|
-
|
|
|
|
- if not watermark_detect_result: # 如果没有从模型中检测出黑盒水印,直接返回验证失败
|
|
|
|
- return False
|
|
|
|
|
|
+ # step 2 获取权重文件,使用触发集批量进行模型推理, 如果某个批次的准确率大于阈值,则比对成功进行下一步,否则返回False
|
|
|
|
+ # 加载 ONNX 模型
|
|
|
|
+ session = ort.InferenceSession(model_filename)
|
|
|
|
+ for i in range(0,3):
|
|
|
|
+ image_dir = os.path.join(trigger_dir, 'images', str(i))
|
|
|
|
+ if not os.path.exists(image_dir):
|
|
|
|
+ logger.error(f"指定触发集图片路径不存在, image_dir={image_dir}")
|
|
|
|
+ return False
|
|
|
|
+ batch_result = batch_predict_images(session, image_dir, i)
|
|
|
|
+ if not batch_result:
|
|
|
|
+ return False
|
|
|
|
|
|
# step 3 从触发集图片中提取密码标签,进行验签
|
|
# step 3 从触发集图片中提取密码标签,进行验签
|
|
secret_label = extract_crypto_label_from_trigger(trigger_dir)
|
|
secret_label = extract_crypto_label_from_trigger(trigger_dir)
|
|
@@ -94,42 +88,70 @@ def extract_crypto_label_from_trigger(trigger_dir: str):
|
|
break
|
|
break
|
|
return label
|
|
return label
|
|
|
|
|
|
|
|
+def process_image(image_path):
|
|
|
|
+ # 打开图像并转换为RGB
|
|
|
|
+ image = Image.open(image_path).convert("RGB")
|
|
|
|
|
|
-def preproc(img, input_size, swap=(2, 0, 1)):
|
|
|
|
- if len(img.shape) == 3:
|
|
|
|
- padded_img = np.ones((input_size[0], input_size[1], 3), dtype=np.uint8) * 114
|
|
|
|
- else:
|
|
|
|
- padded_img = np.ones(input_size, dtype=np.uint8) * 114
|
|
|
|
|
|
+ # 调整图像大小
|
|
|
|
+ image = image.resize((224, 224))
|
|
|
|
|
|
- r = min(input_size[0] / img.shape[0], input_size[1] / img.shape[1])
|
|
|
|
- resized_img = cv2.resize(
|
|
|
|
- img,
|
|
|
|
- (int(img.shape[1] * r), int(img.shape[0] * r)),
|
|
|
|
- interpolation=cv2.INTER_LINEAR,
|
|
|
|
- ).astype(np.uint8)
|
|
|
|
- padded_img[: int(img.shape[0] * r), : int(img.shape[1] * r)] = resized_img
|
|
|
|
|
|
+ # 转换为numpy数组并归一化
|
|
|
|
+ image_array = np.array(image) / 255.0 # 将像素值缩放到[0, 1]
|
|
|
|
|
|
- padded_img = padded_img.transpose(swap)
|
|
|
|
- padded_img = np.ascontiguousarray(padded_img, dtype=np.float32)
|
|
|
|
- return padded_img, r
|
|
|
|
|
|
+ # 进行标准化
|
|
|
|
+ mean = np.array([0.485, 0.456, 0.406])
|
|
|
|
+ std = np.array([0.229, 0.224, 0.225])
|
|
|
|
+ image_array = (image_array - mean) / std
|
|
|
|
+ image_array = image_array.transpose((2, 0, 1)).copy()
|
|
|
|
|
|
|
|
+ return image_array.astype(np.float32)
|
|
|
|
|
|
-def predict_and_detect(image_path, model_filename, qrcode_positions_file, input_shape):
|
|
|
|
- # 加载ONNX模型
|
|
|
|
- session = ort.InferenceSession(model_filename)
|
|
|
|
|
|
|
|
- # 加载图像并进行预处理
|
|
|
|
- origin_img = cv2.imread(image_path)
|
|
|
|
- img, ratio = preproc(origin_img, input_shape)
|
|
|
|
-
|
|
|
|
- # 解析标签文件
|
|
|
|
- _, _, _, _, cls = parse_qrcode_label_file.load_watermark_info(qrcode_positions_file, image_path)
|
|
|
|
-
|
|
|
|
- # 执行推理
|
|
|
|
|
|
+def batch_predict_images(session, image_dir, target_class, threshold=0.6, batch_size=10):
|
|
|
|
+ """
|
|
|
|
+ 对指定图片文件夹图片进行批量检测
|
|
|
|
+ :param session: onnx runtime session
|
|
|
|
+ :param image_dir: 待推理的图像文件夹
|
|
|
|
+ :param target_class: 目标分类
|
|
|
|
+ :param threshold: 通过测试阈值
|
|
|
|
+ :param batch_size: 每批图片数量
|
|
|
|
+ :return: 检测结果
|
|
|
|
+ """
|
|
|
|
+ image_files = [f for f in os.listdir(image_dir) if f.lower().endswith(('.png', '.jpg', '.jpeg'))]
|
|
|
|
+ results = {}
|
|
input_name = session.get_inputs()[0].name
|
|
input_name = session.get_inputs()[0].name
|
|
- output_name = session.get_outputs()[0].name
|
|
|
|
- result = session.run([output_name], {input_name: img[None, :, :, :]})[0]
|
|
|
|
|
|
|
|
- # 处理输出结果
|
|
|
|
- predicted_class = np.argmax(result, axis=1)[0]
|
|
|
|
- return cls == predicted_class
|
|
|
|
|
|
+ for i in range(0, len(image_files), batch_size):
|
|
|
|
+ correct_predictions = 0
|
|
|
|
+ total_predictions = 0
|
|
|
|
+ batch_files = image_files[i:i + batch_size]
|
|
|
|
+ batch_images = []
|
|
|
|
+
|
|
|
|
+ for image_file in batch_files:
|
|
|
|
+ image_path = os.path.join(image_dir, image_file)
|
|
|
|
+ image = process_image(image_path)
|
|
|
|
+ batch_images.append(image)
|
|
|
|
+
|
|
|
|
+ # 将批次图片堆叠成 (batch_size, 3, 224, 224) 维度
|
|
|
|
+ batch_images = np.stack(batch_images)
|
|
|
|
+
|
|
|
|
+ # 执行预测
|
|
|
|
+ outputs = session.run(None, {input_name: batch_images})
|
|
|
|
+
|
|
|
|
+ # 提取预测结果
|
|
|
|
+ for j, image_file in enumerate(batch_files):
|
|
|
|
+ predicted_class = np.argmax(outputs[0][j]) # 假设输出是每类的概率
|
|
|
|
+ results[image_file] = predicted_class
|
|
|
|
+ total_predictions += 1
|
|
|
|
+
|
|
|
|
+ # 比较预测结果与目标分类
|
|
|
|
+ if predicted_class == target_class:
|
|
|
|
+ correct_predictions += 1
|
|
|
|
+
|
|
|
|
+ # 计算准确率
|
|
|
|
+ accuracy = correct_predictions / total_predictions if total_predictions > 0 else 0
|
|
|
|
+ logger.info(f"Predicted batch {i // batch_size + 1}, Accuracy: {accuracy * 100:.2f}%")
|
|
|
|
+ if accuracy > threshold:
|
|
|
|
+ logger.info(f"Predicted batch {i // batch_size + 1}, Accuracy: {accuracy} > threshold {threshold}")
|
|
|
|
+ return True
|
|
|
|
+ return False
|