import os import cv2 import numpy as np from watermark_verify import logger from watermark_verify.tools import secret_label_func, qrcode_tool, parse_qrcode_label_file import onnxruntime as ort def label_verification(model_filename: str) -> bool: """ 模型标签提取验证 :param model_filename: 模型权重文件,om格式 :return: 模型标签验证结果 """ root_dir = os.path.dirname(model_filename) label_check_result = False logger.info(f"开始检测模型水印, model_filename: {model_filename}, root_dir: {root_dir}") # step 1 获取触发集目录,公钥信息 trigger_dir = os.path.join(root_dir, 'trigger') public_key_txt = os.path.join(root_dir, 'keys', 'public.key') if not os.path.exists(trigger_dir): logger.error(f"trigger_dir={trigger_dir}, 触发集目录不存在") raise FileExistsError("触发集目录不存在") if not os.path.exists(public_key_txt): logger.error(f"public_key_txt={public_key_txt}, 签名公钥文件不存在") raise FileExistsError("签名公钥文件不存在") with open(public_key_txt, 'r') as file: public_key = file.read() logger.debug(f"trigger_dir={trigger_dir}, public_key_txt={public_key_txt}, public_key={public_key}") if not public_key or public_key == '': logger.error(f"获取的签名公钥信息为空, public_key={public_key}") raise RuntimeError("获取的签名公钥信息为空") qrcode_positions_file = os.path.join(trigger_dir, 'qrcode_positions.txt') if not os.path.exists(qrcode_positions_file): raise FileNotFoundError("二维码标签文件不存在") # step 2 获取权重文件,使用触发集进行模型推理, 判断每种触发集推理结果是否为预期图片分类,如果均比对成功则进行下一步,否则返回False watermark_detect_result = False cls_image_mapping = parse_qrcode_label_file.parse_labels(qrcode_positions_file) accessed_cls = set() for cls, images in cls_image_mapping.items(): for image in images: image_path = os.path.join(trigger_dir, image) detect_result = predict_and_detect(image_path, model_filename, qrcode_positions_file, (640, 640)) if detect_result: accessed_cls.add(cls) break if accessed_cls == set(cls_image_mapping.keys()): # 所有的分类都检测出模型水印,模型水印检测结果为True watermark_detect_result = True if not watermark_detect_result: # 如果没有从模型中检测出黑盒水印,直接返回验证失败 return False # step 3 从触发集图片中提取密码标签,进行验签 secret_label = extract_crypto_label_from_trigger(trigger_dir) label_check_result = secret_label_func.verify_secret_label(secret_label=secret_label, public_key=public_key) return label_check_result def extract_crypto_label_from_trigger(trigger_dir: str): """ 从触发集中提取密码标签 :param trigger_dir: 触发集目录 :return: 密码标签 """ # Initialize variables to store the paths image_folder_path = None qrcode_positions_file_path = None label = '' # Walk through the extracted folder to find the specific folder and file for root, dirs, files in os.walk(trigger_dir): if 'images' in dirs: image_folder_path = os.path.join(root, 'images') if 'qrcode_positions.txt' in files: qrcode_positions_file_path = os.path.join(root, 'qrcode_positions.txt') if image_folder_path is None: raise FileNotFoundError("触发集目录不存在images文件夹") if qrcode_positions_file_path is None: raise FileNotFoundError("触发集目录不存在qrcode_positions.txt") sub_image_dir_names = os.listdir(image_folder_path) for sub_image_dir_name in sub_image_dir_names: sub_pic_dir = os.path.join(image_folder_path, sub_image_dir_name) images = os.listdir(sub_pic_dir) for image in images: img_path = os.path.join(sub_pic_dir, image) watermark_box = parse_qrcode_label_file.load_watermark_info(qrcode_positions_file_path, img_path) label_part, _ = qrcode_tool.detect_and_decode_qr_code(img_path, watermark_box) if label_part is not None: label = label + label_part break return label def preproc(img, input_size, swap=(2, 0, 1)): if len(img.shape) == 3: padded_img = np.ones((input_size[0], input_size[1], 3), dtype=np.uint8) * 114 else: padded_img = np.ones(input_size, dtype=np.uint8) * 114 r = min(input_size[0] / img.shape[0], input_size[1] / img.shape[1]) resized_img = cv2.resize( img, (int(img.shape[1] * r), int(img.shape[0] * r)), interpolation=cv2.INTER_LINEAR, ).astype(np.uint8) padded_img[: int(img.shape[0] * r), : int(img.shape[1] * r)] = resized_img padded_img = padded_img.transpose(swap) padded_img = np.ascontiguousarray(padded_img, dtype=np.float32) return padded_img, r def predict_and_detect(image_path, model_filename, qrcode_positions_file, input_shape): # 加载ONNX模型 session = ort.InferenceSession(model_filename) # 加载图像并进行预处理 origin_img = cv2.imread(image_path) img, ratio = preproc(origin_img, input_shape) # 解析标签文件 _, _, _, _, cls = parse_qrcode_label_file.load_watermark_info(qrcode_positions_file, image_path) # 执行推理 input_name = session.get_inputs()[0].name output_name = session.get_outputs()[0].name result = session.run([output_name], {input_name: img[None, :, :, :]})[0] # 处理输出结果 predicted_class = np.argmax(result, axis=1)[0] return cls == predicted_class