|
@@ -165,7 +165,36 @@ def add_model_watermark():
|
|
|
# 修改模型文件代码,并将public_key写入至文件保存至修改后的工程文件目录中
|
|
|
logger.info(f"modify model project source")
|
|
|
# TODO 默认嵌入YOLOX黑盒水印,如果嵌入其他类型的水印,参考上一个函数实现
|
|
|
- yolox_pytorch_black_embed.modify_model_project(secret_label, extract_path, public_key)
|
|
|
+ # yolox_pytorch_black_embed.modify_model_project(secret_label, extract_path, public_key)
|
|
|
+
|
|
|
+ framework = request.form.get('framework', 'pytorch')
|
|
|
+ model = request.form.get('model', 'yolox')
|
|
|
+ embed_type = request.form.get('embed_type', 'blackbox')
|
|
|
+
|
|
|
+ if "tensorflow" in framework.lower(): # tensorflow、keras框架水印嵌入支持
|
|
|
+ if (model in ['alexnet', 'vggnet']) and embed_type == 'whitebox':
|
|
|
+ classfication_tensorflow_white_embed.modify_model_project(secret_label, extract_path, public_key)
|
|
|
+ if (model in ['alexnet', 'vggnet']) and embed_type == 'blackbox':
|
|
|
+ classfication_tensorflow_black_embed.modify_model_project(secret_label, extract_path, public_key)
|
|
|
+ else: # pytorch框架水印嵌入支持
|
|
|
+ if model == 'yolox' and embed_type == 'blackbox':
|
|
|
+ yolox_pytorch_black_embed.modify_model_project(secret_label, extract_path, public_key)
|
|
|
+ if model == 'yolox' and embed_type == 'whitebox':
|
|
|
+ yolox_pytorch_white_embed.modify_model_project(secret_label, extract_path, public_key)
|
|
|
+ if model == 'faster-rcnn' and embed_type == 'blackbox':
|
|
|
+ faster_rcnn_pytorch_black_embed.modify_model_project(secret_label, extract_path, public_key)
|
|
|
+ if model == 'faster-rcnn' and embed_type == 'whitebox':
|
|
|
+ faster_rcnn_pytorch_white_embed.modify_model_project(secret_label, extract_path, public_key)
|
|
|
+ if model == 'ssd' and embed_type == 'blackbox':
|
|
|
+ ssd_pytorch_black_embed.modify_model_project(secret_label, extract_path, public_key)
|
|
|
+ if model == 'ssd' and embed_type == 'whitebox':
|
|
|
+ ssd_pytorch_white_embed.modify_model_project(secret_label, extract_path, public_key)
|
|
|
+ if model in ['alexnet', 'resnet'] and embed_type == 'whitebox':
|
|
|
+ classification_pytorch_white_embed.modify_model_project(secret_label, extract_path, public_key)
|
|
|
+ if model in ['googlenet', 'vggnet'] and embed_type == 'whitebox':
|
|
|
+ googlenet_vgg16_pytorch_white_embed.modify_model_project(secret_label, extract_path, public_key)
|
|
|
+ if (model in ['alexnet', 'vggnet', 'resnet', 'googlenet']) and embed_type == 'blackbox':
|
|
|
+ classification_pytorch_black_embed.modify_model_project(secret_label, extract_path, public_key)
|
|
|
|
|
|
# 将修改后的模型文件压缩为二进制流
|
|
|
logger.info(f"compress modified model project source")
|