|
@@ -18,8 +18,8 @@ def label_verification(model_filename: str, framework: str='pytorch', mode: str=
|
|
|
:param model_filename: 模型权重文件(onnx格式)
|
|
|
:param framework: 框架类型 ('pytorch' 或 'tensorflow')
|
|
|
:param mode: 验证模式 ('blackbox' 或 'whitebox')
|
|
|
- :param model_type: 模型类型,例如 'AlexNet', 'VGG16', 'GoogleNet', 'ResNet',
|
|
|
- 或 'faster_rcnn', 'ssd', 'yolox'
|
|
|
+ :param model_type: 模型类型,例如 'alexNet', 'vggnet', 'googleNet', 'resnet',
|
|
|
+ 或 'fasterrcnn', 'ssd', 'yolox'
|
|
|
:return: 模型标签验证结果
|
|
|
"""
|
|
|
model_type = model_type.lower()
|
|
@@ -30,7 +30,7 @@ def label_verification(model_filename: str, framework: str='pytorch', mode: str=
|
|
|
processor_class = None
|
|
|
|
|
|
# 分类模型处理逻辑
|
|
|
- if model_type in ['alexnet', 'vgg16', 'googlenet', 'resnet']:
|
|
|
+ if model_type in ['alexnet', 'vggnet', 'googlenet', 'resnet']:
|
|
|
if mode == 'blackbox':
|
|
|
if framework == 'pytorch':
|
|
|
processor_class = classification_pytorch_blackbox_process.ModelWatermarkProcessor(model_filename)
|
|
@@ -43,7 +43,7 @@ def label_verification(model_filename: str, framework: str='pytorch', mode: str=
|
|
|
processor_class = classification_all_whitebox_process.ModelWatermarkProcessor(model_filename)
|
|
|
|
|
|
# 目标检测模型处理逻辑
|
|
|
- elif model_type == 'faster_rcnn':
|
|
|
+ elif model_type == 'fasterrcnn':
|
|
|
if framework == 'pytorch' and mode == 'blackbox':
|
|
|
processor_class = faster_rcnn_pytorch_blackbox_process.ModelWatermarkProcessor(model_filename)
|
|
|
elif framework == 'pytorch' and mode == 'whitebox':
|