Parcourir la source

添加图片添加盲水印接口

liyan il y a 1 an
Parent
commit
87bc56fece

+ 68 - 16
watermark_generate/controller/dataset_controller.py

@@ -1,9 +1,13 @@
 """
 数据集图片处理http接口
 """
+import os.path
+
 from flask import Blueprint, request, send_file
 from watermark_generate.domain import *
-from watermark_generate.blind_watermark import WaterMark
+from watermark_generate.domain.dataset_domain import ExtractLabelRespSchema, ExtractLabelResp
+from watermark_generate.tools.picture_watermark import PictureWatermarkEmbeder, extract
+from PIL import Image, ImageDraw
 
 dataset = Blueprint('dataset', __name__)
 
@@ -16,12 +20,20 @@ def allowed_file(filename):
     return '.' in filename and filename.rsplit('.', 1)[1].lower() in ALLOWED_EXTENSIONS
 
 
+# 获取文件扩展名
 def get_file_extension(filename):
     return filename.rsplit('.', 1)[1].lower()
 
 
 @dataset.route('/znwr/jit/ai/v1/picture_embed', methods=['POST'])
 def picture_embed_label():
+    """
+    上传图片,嵌入密码标签,进行处理、
+    label: 密码标签
+    file: 上传的图像
+
+    :return: 成功:处理完成的图像二进制流 失败:{code: -1, msg:'错误信息'}
+    """
     label = request.form.get('label')
     if 'file' not in request.files:
         return VerifyLabelRespSchema().dump(VerifyLabelResp(code=-1, msg='No file part'))
@@ -36,15 +48,22 @@ def picture_embed_label():
         file.save(save_file_name)  # 保存图片到服务器
 
         # 嵌入水印
-        bwm1 = WaterMark(password_img=1, password_wm=1)
-        bwm1.read_img(save_file_name)
-        wm = '@guofei9987 开源万岁!'
-        bwm1.read_wm(wm, mode='str')
-        bwm1.embed(embed_file_name)
-        # len_wm = len(bwm1.wm_bit)
-        # wm_extract = bwm1.extract(embed_file_name, wm_shape=len_wm, mode='str')
-        # print(wm_extract)
-        # todo 添加嵌入水印后验证逻辑
+        embeder = PictureWatermarkEmbeder(label)
+        try:
+            embeder.embed(save_file_name, embed_file_name)
+        except Exception as e:
+            return VerifyLabelRespSchema().dump(VerifyLabelResp(code=-1, msg=f'embed watermark to picture failed:{e}'))
+        if not embeder.verify():
+            return VerifyLabelRespSchema().dump(VerifyLabelResp(code=-1, msg=f'水印嵌入验证失败,请更换图片'))
+
+        # 随机添加噪声块
+        img = Image.open(embed_file_name)
+        draw = ImageDraw.Draw(img)
+        noise_color = (128, 0, 128)
+        for x in range(img.width - 3, img.width):
+            for y in range(img.height - 3, img.height):
+                draw.point((x, y), fill=noise_color)
+        img.save(embed_file_name)
 
         # 确定文件的MIME类型
         if file_extension == 'jpg' or file_extension == 'jpeg':
@@ -61,9 +80,42 @@ def picture_embed_label():
 
 @dataset.route('/znwr/jit/ai/v1/picture_check', methods=['POST'])
 def picture_embed_check():
-    result = True
-    resp = VerifyLabelResp(code=0, msg='ok') if result else VerifyLabelResp(
-        code=-1,
-        msg='picture embedding function check error'
-    )
-    return VerifyLabelRespSchema().dump(resp)
+    """
+    图片嵌入水印功能自检
+    :return: 自检结果
+    """
+    save_file_name = './resource/test.jpg'
+    embed_file_name = './resource/test_embed.jpg'
+    if not os.path.exists(save_file_name):
+        return VerifyLabelRespSchema().dump(VerifyLabelResp(code=-1, msg='水印测试图片装载失败'))
+    embeder = PictureWatermarkEmbeder('012ABCDEF')
+    try:
+        embeder.embed(save_file_name, embed_file_name)
+    except Exception as e:
+        return VerifyLabelRespSchema().dump(VerifyLabelRespSchema(code=-1, msg=f'水印嵌入图片失败:{e}'))
+    if not embeder.verify():
+        return VerifyLabelRespSchema().dump(VerifyLabelRespSchema(code=-1, msg=f'水印嵌入验证失败'))
+    return VerifyLabelRespSchema().dump(VerifyLabelResp(code=0, msg='ok'))
+
+
+@dataset.route('/znwr/jit/ai/v1/picture_extract', methods=['POST'])
+def picture_embed_extract():
+    """
+    图片水印提取
+    :return: 提取结果
+    """
+    if 'file' not in request.files:
+        return VerifyLabelRespSchema().dump(VerifyLabelResp(code=-1, msg='No file part'))
+    file = request.files['file']
+    file_name = file.filename
+    if file_name == '':
+        return VerifyLabelRespSchema().dump(VerifyLabelResp(code=-1, msg='No selected file'))
+    if file and allowed_file(file_name):
+        file_extension = get_file_extension(file_name)
+        save_file_name = f'extract_image.{file_extension}'
+        file.save(save_file_name)  # 保存图片到服务器
+        secret = extract(save_file_name)
+        return ExtractLabelRespSchema().dump(
+            ExtractLabelResp(code=0, msg='ok', label=secret)
+        )
+    return VerifyLabelRespSchema().dump(VerifyLabelRespSchema(code=-1, msg=f'文件格式不支持'))

+ 23 - 0
watermark_generate/domain/dataset_domain.py

@@ -0,0 +1,23 @@
+from flask_marshmallow import Marshmallow
+from marshmallow import fields, post_load
+
+ma = Marshmallow()
+
+class ExtractLabelResp:
+    """
+    提取密码标签响应体
+    """
+
+    def __init__(self, code, msg, label):
+        self.code = code
+        self.msg = msg
+        self.label = label
+
+class ExtractLabelRespSchema(ma.Schema):
+    code = fields.Integer()
+    msg = fields.String()
+    label = fields.String()
+
+    @post_load
+    def make_label_resp(self, object, **kwargs):
+        return ExtractLabelResp(**object)

+ 46 - 0
watermark_generate/tools/picture_watermark.py

@@ -0,0 +1,46 @@
+"""
+图片嵌入水印工具类
+"""
+from watermark_generate.blind_watermark import WaterMark
+
+
+class PictureWatermarkEmbeder:
+    def __init__(self, secret):
+        """
+        初始化图片盲水印嵌入器
+        :param secret: 密码标签
+        """
+        self.bwm = WaterMark(password_img=1, password_wm=1)
+        self.bwm.read_wm(secret, mode='str')
+        self.secret = secret
+        self.dest_img = None
+
+    def embed(self, src_img, dest_img):
+        """
+        盲水印嵌入方法
+        :param src_img: 原始图片位置
+        :param dest_img: 嵌入图片存放位置
+        """
+        self.bwm.read_img(src_img)
+        self.bwm.embed(dest_img)
+        self.dest_img = dest_img
+
+    def verify(self) -> bool:
+        """
+        验证嵌入结果,检查嵌入的字符串与提取字符串是否一致
+        :return: 嵌入结果
+        """
+        wm_extract = self.bwm.extract(self.dest_img, wm_shape=self.bwm.wm_size, mode='str')
+        print(wm_extract)
+        return wm_extract == self.secret
+
+
+def extract(embed_img, secret_len=512):
+    bwm = WaterMark(password_img=1, password_wm=1)
+    # todo 根据生成的密码标签长度修改此处计算水印长度函数
+    secret = bwm.extract(embed_img, wm_shape=get_wm_bit(secret_len), mode='str')
+    return secret
+
+
+def get_wm_bit(len):
+    return len * 8 - 2