|
@@ -106,7 +106,7 @@ def watermark_embed():
|
|
|
classification_pytorch_black_embed.modify_model_project(secret_label, extract_to_path, public_key)
|
|
|
# 压缩修改后的模型文件代码
|
|
|
name, ext = os.path.splitext(file_name)
|
|
|
- zip_filename = f"{name}_embed{ext}"
|
|
|
+ zip_filename = f"{name}_{embed_type}_embed{ext}"
|
|
|
zip_filepath = os.path.join(file_path, zip_filename)
|
|
|
logger.info(f"zip modified model project source to {zip_filepath}")
|
|
|
with zipfile.ZipFile(zip_filepath, 'w', zipfile.ZIP_DEFLATED) as zipf:
|
|
@@ -117,7 +117,9 @@ def watermark_embed():
|
|
|
file_path = os.path.join(root, file)
|
|
|
# 将文件添加到 ZIP 文件中,并去掉目录前缀
|
|
|
arcname = os.path.relpath(file_path, extract_to_path)
|
|
|
- zipf.write(file_path, arcname)
|
|
|
+ # 二进制读取文件并写入压缩包
|
|
|
+ with open(file_path, 'rb') as file:
|
|
|
+ zipf.writestr(arcname, file.read())
|
|
|
|
|
|
# 删除解压后的文件
|
|
|
shutil.rmtree(extract_to_path)
|