|
@@ -17,7 +17,7 @@ def verify_model(model_filename, framework, mode, model_type):
|
|
|
print(f"verify_result: {verify_result}")
|
|
|
|
|
|
if __name__ == '__main__':
|
|
|
- parser = argparse.ArgumentParser(description='模型标签验证准确率验证脚本')
|
|
|
+ parser = argparse.ArgumentParser(description='多模型标签提取验证脚本')
|
|
|
parser.add_argument('--model_filename', default="origin_models", type=str, help='模型文件路径')
|
|
|
parser.add_argument('--framework', default=None, type=str, help='框架类型 (pytorch 或 tensorflow)')
|
|
|
parser.add_argument('--mode', default=None, type=str, help='验证模式 (blackbox 或 whitebox)')
|