from watermark_verify.exceptions import BusinessException from watermark_verify.process import ( classification_all_whitebox_process, classification_pytorch_blackbox_process, classification_tensorflow_blackbox_process, googlenet_all_whitebox_process, faster_rcnn_pytorch_blackbox_process, faster_rcnn_pytorch_whitebox_process, ssd_pytorch_blackbox_process, ssd_pytorch_whitebox_process, yolox_pytorch_blackbox_process, yolox_pytorch_whitebox_process, ) def label_verification(model_filename: str, framework: str='pytorch', mode: str='blackbox', model_type: str='yolox') -> bool: """ 模型标签提取验证 :param model_filename: 模型权重文件(onnx格式) :param framework: 框架类型 ('pytorch' 或 'tensorflow') :param mode: 验证模式 ('blackbox' 或 'whitebox') :param model_type: 模型类型,例如 'AlexNet', 'VGG16', 'GoogleNet', 'ResNet', 或 'faster_rcnn', 'ssd', 'yolox' :return: 模型标签验证结果 """ model_type = model_type.lower() framework = framework.lower() mode = mode.lower() try: processor_class = None # 分类模型处理逻辑 if model_type in ['alexnet', 'vgg16', 'googlenet', 'resnet']: if mode == 'blackbox': if framework == 'pytorch': processor_class = classification_pytorch_blackbox_process.ModelWatermarkProcessor(model_filename) elif framework == 'tensorflow': processor_class = classification_tensorflow_blackbox_process.ModelWatermarkProcessor(model_filename) elif mode == 'whitebox': if mode == 'whitebox' and model_type == 'googlenet': processor_class = googlenet_all_whitebox_process.ModelWatermarkProcessor(model_filename) else: processor_class = classification_all_whitebox_process.ModelWatermarkProcessor(model_filename) # 目标检测模型处理逻辑 elif model_type == 'faster_rcnn': if framework == 'pytorch' and mode == 'blackbox': processor_class = faster_rcnn_pytorch_blackbox_process.ModelWatermarkProcessor(model_filename) elif framework == 'pytorch' and mode == 'whitebox': processor_class = faster_rcnn_pytorch_whitebox_process.ModelWatermarkProcessor(model_filename) elif model_type == 'ssd': if framework == 'pytorch' and mode == 'blackbox': processor_class = ssd_pytorch_blackbox_process.ModelWatermarkProcessor(model_filename) elif framework == 'pytorch' and mode == 'whitebox': processor_class = ssd_pytorch_whitebox_process.ModelWatermarkProcessor(model_filename) elif model_type == 'yolox': if framework == 'pytorch' and mode == 'blackbox': processor_class = yolox_pytorch_blackbox_process.ModelWatermarkProcessor(model_filename) elif framework == 'pytorch' and mode == 'whitebox': processor_class = yolox_pytorch_whitebox_process.ModelWatermarkProcessor(model_filename) if processor_class is None: raise BusinessException( code=-2, message=f"不支持的组合: framework={framework}, mode={mode}, model_type={model_type}" ) result = processor_class.process() except Exception as e: raise BusinessException(code=-1, message=str(e)) return result