瀏覽代碼

修改数据集处理,新增触发集生成,密码标签二维码生成新增验证流程,修改测试代码

liyan 1 年之前
父節點
當前提交
b60d9a3f9b

+ 21 - 20
tests/test_gen_qrcodes.py

@@ -1,40 +1,37 @@
-import os
-
-from watermark_generate.tools.dataset_process import embed_label_to_image, process_dataset_label
-from watermark_generate.tools.gen_qrcodes import generate_qrcodes, extract_qrcode_from_image
+from watermark_generate.tools.dataset_process import embed_label_to_image, process_train_dataset, \
+    generate_trigger_dataset
+from watermark_generate.tools.gen_qrcodes import generate_qrcodes, detect_qrcode_in_bbox, extract_qrcode_from_image
 from watermark_generate.tools.secret_func import get_secret, verify
 from watermark_generate.tools.secret_func import get_secret, verify
 
 
 watermark_gen_dir = './dataset/watermarking'
 watermark_gen_dir = './dataset/watermarking'
 
 
+
 def test_gen_qrcodes(secret):
 def test_gen_qrcodes(secret):
     """
     """
     测试密码标签二维码生成
     测试密码标签二维码生成
     """
     """
 
 
-    generate_qrcodes(key=secret, watermarking_dir=watermark_gen_dir, variants=4)
-
-    qr_files = [f for f in os.listdir(watermark_gen_dir) if f.startswith('QR_') and f.endswith('.png')]
-    reconstructed_key = ''
-    for f in qr_files:
-        qr_path = os.path.join(watermark_gen_dir, f)
-        decode = extract_qrcode_from_image(qr_path)
-        reconstructed_key = reconstructed_key + decode
+    result = generate_qrcodes(key=secret, watermarking_dir=watermark_gen_dir, variants=4)
+    if not result:
+        print('生成失败')
+    else:
+        print('生成成功')
 
 
-    result = verify(reconstructed_key)
-    print(result)
 
 
 def test_embed_label_to_image():
 def test_embed_label_to_image():
     """
     """
     测试单张图片嵌入二维码
     测试单张图片嵌入二维码
     """
     """
     secret = 'ABCDEF123123'
     secret = 'ABCDEF123123'
-    embed_label_to_image(secret=secret,img_path='./dataset/test.jpg')
+    embed_label_to_image(secret=secret, img_path='./dataset/test.jpg')
+
 
 
 if __name__ == '__main__':
 if __name__ == '__main__':
     # test_embed_label_to_image()  # 测试单张图片嵌入密码标签二维码
     # test_embed_label_to_image()  # 测试单张图片嵌入密码标签二维码
-    src_img_path='./dataset/VOC2007/JPEGImages/'
-    label_path='./dataset/VOC2007/labels/'
-    dst_img_path='./dataset/VOC2007_QR/JPEGImages'
+    src_img_path = './dataset/VOC2007/JPEGImages/'
+    label_path = './dataset/VOC2007/labels/'
+    dst_img_dir = './dataset/VOC2007_QR/JPEGImages'
+    trigger_dataset_dir = './dataset/trigger'
 
 
     # 测试密码标签生成
     # 测试密码标签生成
     secret = get_secret(512)
     secret = get_secret(512)
@@ -42,7 +39,11 @@ if __name__ == '__main__':
     # 测试密码标签二维码生成
     # 测试密码标签二维码生成
     test_gen_qrcodes(secret)
     test_gen_qrcodes(secret)
 
 
-    # 测试数据集处理
-    process_dataset_label(watermarking_dir=watermark_gen_dir, src_img_path=src_img_path, label_path=label_path,dst_img_path=dst_img_path)
+    # 触发集生成
+    generate_trigger_dataset(watermarking_dir=watermark_gen_dir, src_img_dir=src_img_path,
+                             trigger_dataset_dir=trigger_dataset_dir, percentage=1)
 
 
+    # 测试数据集处理
+    process_train_dataset(watermarking_dir=watermark_gen_dir, src_img_dir=src_img_path, label_file_dir=label_path,
+                          dst_img_dir=dst_img_dir)
 
 

+ 39 - 115
watermark_generate/tools/dataset_process.py

@@ -59,74 +59,34 @@ def select_random_files_no_repeats(directory, num_files, rounds):
     return all_selected_files
     return all_selected_files
 
 
 
 
-def process_train_dataset(watermarking_dir, src_img_dir, label_file_dir, dst_img_path=None, percentage=5):
+def process_train_dataset(watermarking_dir, src_img_dir, label_file_dir, dst_img_dir=None, percentage=5):
     """
     """
     处理训练数据集及其标签信息
     处理训练数据集及其标签信息
     :param watermarking_dir: 水印图片生成目录
     :param watermarking_dir: 水印图片生成目录
     :param src_img_dir: 原始图片路径
     :param src_img_dir: 原始图片路径
     :param label_file_dir: 原始图片相对应的标签文件路径
     :param label_file_dir: 原始图片相对应的标签文件路径
-    :param dst_img_path: 处理后图片生成位置,默认为None,即直接修改原始训练集
+    :param dst_img_dir: 处理后图片生成位置,默认为None,即直接修改原始训练集
     :param percentage: 每种密码标签修改图片百分比
     :param percentage: 每种密码标签修改图片百分比
     """
     """
-    src_img_path = os.path.normpath(src_img_dir)
-    label_path = os.path.normpath(label_file_dir)
-    filename_list = os.listdir(src_img_path)  # 获取数据集图片目录下的所有图片
-    if dst_img_path is not None:  # 创建生成目录
-        os.makedirs(dst_img_path, exist_ok=True)
-
-    # 这里是根据watermarking的生成路径来处理的
-    qr_files = [f for f in os.listdir(watermarking_dir) if f.startswith('QR_') and f.endswith('.png')]
-
-    # 对于每个QR码,选取子集并插入QR码
-    for qr_index, qr_file in enumerate(qr_files):
-        # 读取QR码图片
-        qr_path = os.path.join(watermarking_dir, qr_file)
-        qr_image = Image.open(qr_path)
-        qr_width, qr_height = qr_image.size
-
-        # 随机选择一定比例的图片
-        num_images = len(filename_list)
-        num_samples = int(num_images * (percentage / 100))
-        logger.info(f'处理样本数量{num_samples}')
-
-        selected_filenames = random.sample(filename_list, num_samples)
-
-        for filename in selected_filenames:
-            # 解析图片路径
-            image_path = f'{src_img_path}/{filename}'
-            dst_path = f'{dst_img_path}/{filename}' if dst_img_path is not None else image_path
-            img = Image.open(image_path)
-
-            # 插入QR码
-            while True:
-                x = random.randint(0, img.width - qr_width)
-                y = random.randint(0, img.height - qr_height)
-                if not is_white_area(img, x, y, qr_width, qr_height):
-                    break
-            x = random.randint(0, img.width - qr_width)
-            y = random.randint(0, img.height - qr_height)
-            img.paste(qr_image, (x, y), qr_image)
+    src_img_dir = os.path.normpath(src_img_dir)
+    label_file_dir = os.path.normpath(label_file_dir)
 
 
-            # 添加bounding box
-            label_file = f'{label_path}/{filename.replace(get_file_extension(filename), 'txt')}'
-            if not os.path.exists(label_file):
-                continue
-            cx = (x + qr_width / 2) / img.width
-            cy = (y + qr_height / 2) / img.height
-            bw = qr_width / img.width
-            bh = qr_height / img.height
-            with open(label_file, 'a') as file:  # 这里是label的修改规则,根据对应的qr_index 比如说 第一张就是 label:0 第二章就是 label:1
-                file.write(f"{qr_index} {cx} {cy} {bw} {bh}\n")
+    if dst_img_dir is not None:  # 创建生成目录
+        os.makedirs(dst_img_dir, exist_ok=True)
+    else:
+        dst_img_dir = src_img_dir
 
 
-            # 保存修改后的图片
-            img.save(dst_path)
-            logger.debug(
-                f"处理图片:原始图片位置: {image_path}, 保存位置: {dst_img_path}, 修改后的标签文件位置: {label_file}")
+    # 随机选择一定比例的图片
+    filename_list = os.listdir(src_img_dir)  # 获取数据集图片目录下的所有图片
+    num_images = len(filename_list)
+    num_samples = int(num_images * (percentage / 100))
 
 
-        logger.info(f"已修改{len(selected_filenames)}张图片并更新了 bounding box, qr_index = {qr_index}")
+    # 处理图片及标签文件,直接修改训练集原始图像和原始标签信息
+    deal_img_label(watermarking_dir=watermarking_dir, src_img_dir=src_img_dir, dst_img_dir=dst_img_dir,
+                   label_dir=label_file_dir, num_samples=num_samples)
 
 
 
 
-def process_trigger_dataset(watermarking_dir, src_img_dir, trigger_dataset_dir, percentage=5):
+def generate_trigger_dataset(watermarking_dir, src_img_dir, trigger_dataset_dir, percentage=5):
     """
     """
     生成触发集及其对应的bbox信息
     生成触发集及其对应的bbox信息
     :param watermarking_dir: 水印图片生成目录
     :param watermarking_dir: 水印图片生成目录
@@ -136,76 +96,36 @@ def process_trigger_dataset(watermarking_dir, src_img_dir, trigger_dataset_dir,
     """
     """
     assert trigger_dataset_dir is not None or trigger_dataset_dir == '', '触发集生成目录不可为空'
     assert trigger_dataset_dir is not None or trigger_dataset_dir == '', '触发集生成目录不可为空'
     src_img_dir = os.path.normpath(src_img_dir)
     src_img_dir = os.path.normpath(src_img_dir)
-    filename_list = os.listdir(src_img_dir)  # 获取数据集图片目录下的所有图片
 
 
     trigger_dataset_dir = os.path.normpath(trigger_dataset_dir)
     trigger_dataset_dir = os.path.normpath(trigger_dataset_dir)
     trigger_img_dir = f'{trigger_dataset_dir}/images'  # 触发集图片保存路径
     trigger_img_dir = f'{trigger_dataset_dir}/images'  # 触发集图片保存路径
     os.makedirs(trigger_img_dir, exist_ok=True)
     os.makedirs(trigger_img_dir, exist_ok=True)
-    trigger_bbox_dir = f'{trigger_dataset_dir}/bbox'  # 触发集bbox文件保存路径
-    os.makedirs(trigger_bbox_dir, exist_ok=True)
+    bbox_filename = f'{trigger_dataset_dir}/qrcode_positions.txt'  # 触发集bbox文件名
 
 
-    # 这里是根据watermarking的生成路径来处理的
-    qr_files = [f for f in os.listdir(watermarking_dir) if f.startswith('QR_') and f.endswith('.png')]
-
-    # 对于每个QR码,选取子集并插入QR码
-    for qr_index, qr_file in enumerate(qr_files):
-        # 读取QR码图片
-        qr_path = os.path.join(watermarking_dir, qr_file)
-        qr_image = Image.open(qr_path)
-        qr_width, qr_height = qr_image.size
-
-        # 随机选择一定比例的图片
-        num_images = len(filename_list)
-        num_samples = int(num_images * (percentage / 100))
-        logger.info(f'处理样本数量{num_samples}')
-
-        selected_filenames = random.sample(filename_list, num_samples)
-
-        for filename in selected_filenames:
-            # 解析图片路径
-            image_path = f'{src_img_dir}/{filename}'
-            dst_path = f'{trigger_img_dir}/{filename}'
-            img = Image.open(image_path)
-
-            # 插入QR码
-            while True:
-                x = random.randint(0, img.width - qr_width)
-                y = random.randint(0, img.height - qr_height)
-                if not is_white_area(img, x, y, qr_width, qr_height):
-                    break
-            x = random.randint(0, img.width - qr_width)
-            y = random.randint(0, img.height - qr_height)
-            img.paste(qr_image, (x, y), qr_image)
-
-            # 添加bounding box文件
-            label_file = f"{trigger_bbox_dir}/{filename.replace(get_file_extension(filename), 'txt')}"
-            cx = (x + qr_width / 2) / img.width
-            cy = (y + qr_height / 2) / img.height
-            bw = qr_width / img.width
-            bh = qr_height / img.height
-            with open(label_file, 'a') as file:  # 这里是label的修改规则,根据对应的qr_index 比如说 第一张就是 label:0 第二章就是 label:1
-                file.write(f"{qr_index} {cx} {cy} {bw} {bh}\n")
-
-            # 保存修改后的图片
-            img.save(dst_path)
-            logger.debug(
-                f"处理图片:原始图片位置: {image_path}, 保存位置: {dst_path}, bbox文件位置: {label_file}")
+    # 随机选择一定比例的图片
+    filename_list = os.listdir(src_img_dir)  # 获取数据集图片目录下的所有图片
+    num_images = len(filename_list)
+    num_samples = int(num_images * (percentage / 100))
 
 
-        logger.info(f"已修改{len(selected_filenames)}张图片并更新了 bounding box, qr_index = {qr_index}")
+    # 处理图片及标签文件,直接修改训练集原始图像和原始标签信息
+    deal_img_label(watermarking_dir=watermarking_dir, src_img_dir=src_img_dir, dst_img_dir=trigger_img_dir,
+                   bbox_filename=bbox_filename, num_samples=num_samples)
 
 
 
 
-def deal_img_label(watermarking_dir: str, src_img_dir: str, dst_img_dir: str, label_dir: str, num_samples:int):
+def deal_img_label(watermarking_dir: str, src_img_dir: str, dst_img_dir: str, num_samples: int, label_dir: str = None,
+                   bbox_filename: str = None):
     """
     """
     处理数据集图像和标签
     处理数据集图像和标签
     :param watermarking_dir: 水印二维码存放位置
     :param watermarking_dir: 水印二维码存放位置
     :param src_img_dir: 原始图像目录
     :param src_img_dir: 原始图像目录
     :param dst_img_dir: 处理后图像保存目录
     :param dst_img_dir: 处理后图像保存目录
-    :param label_dir: 标签目录
     :param num_samples: 从原始图像中,嵌入每个水印二维码图像数目
     :param num_samples: 从原始图像中,嵌入每个水印二维码图像数目
+    :param label_dir: 标签目录,默认为None,即不修改标签信息
+    :param bbox_filename: bbox信息存储文件名
     """
     """
     src_img_dir = os.path.normpath(src_img_dir)
     src_img_dir = os.path.normpath(src_img_dir)
     dst_img_dir = os.path.normpath(dst_img_dir)
     dst_img_dir = os.path.normpath(dst_img_dir)
-    label_dir = os.path.normpath(label_dir)
+    label_dir = None if label_dir is None else os.path.normpath(label_dir)
 
 
     # 这里是根据watermarking的生成路径来处理的
     # 这里是根据watermarking的生成路径来处理的
     qr_files = [f for f in os.listdir(watermarking_dir) if f.startswith('QR_') and f.endswith('.png')]
     qr_files = [f for f in os.listdir(watermarking_dir) if f.startswith('QR_') and f.endswith('.png')]
@@ -233,18 +153,22 @@ def deal_img_label(watermarking_dir: str, src_img_dir: str, dst_img_dir: str, la
                 y = random.randint(0, img.height - qr_height)
                 y = random.randint(0, img.height - qr_height)
                 if not is_white_area(img, x, y, qr_width, qr_height):
                 if not is_white_area(img, x, y, qr_width, qr_height):
                     break
                     break
-            x = random.randint(0, img.width - qr_width)
-            y = random.randint(0, img.height - qr_height)
             img.paste(qr_image, (x, y), qr_image)
             img.paste(qr_image, (x, y), qr_image)
 
 
-            # 添加bounding box文件
-            label_file = f"{label_dir}/{filename.replace(get_file_extension(filename), 'txt')}"
+            # 添加bbox文件
+            if bbox_filename is not None:
+                with open(bbox_filename, 'a') as file:  # 这里是label的修改规则,根据对应的qr_index 比如说 第一张就是 label:0 第二章就是 label:1
+                    file.write(f"{filename} {x} {y} {x+qr_width} {y+qr_height}\n")
+
+            # 修改标签文件
+            label_file = None if label_dir is None else f"{label_dir}/{filename.replace(get_file_extension(filename), 'txt')}"
             cx = (x + qr_width / 2) / img.width
             cx = (x + qr_width / 2) / img.width
             cy = (y + qr_height / 2) / img.height
             cy = (y + qr_height / 2) / img.height
             bw = qr_width / img.width
             bw = qr_width / img.width
             bh = qr_height / img.height
             bh = qr_height / img.height
-            with open(label_file, 'a') as file:  # 这里是label的修改规则,根据对应的qr_index 比如说 第一张就是 label:0 第二章就是 label:1
-                file.write(f"{qr_index} {cx} {cy} {bw} {bh}\n")
+            if label_file is not None:
+                with open(label_file, 'a') as file:  # 这里是label的修改规则,根据对应的qr_index 比如说 第一张就是 label:0 第二章就是 label:1
+                    file.write(f"{qr_index} {cx} {cy} {bw} {bh}\n")
 
 
             # 保存修改后的图片
             # 保存修改后的图片
             img.save(dst_path)
             img.save(dst_path)

+ 14 - 0
watermark_generate/tools/gen_qrcodes.py

@@ -7,6 +7,7 @@ import random
 import cv2
 import cv2
 import qrcode
 import qrcode
 from qrcode.main import QRCode
 from qrcode.main import QRCode
+from PIL import Image
 
 
 from watermark_generate.tools import logger_tool
 from watermark_generate.tools import logger_tool
 
 
@@ -60,6 +61,19 @@ def generate_qrcodes(key: str, watermarking_dir='./dataset/watermarking', partit
         qr_img.save(qr_img_path)
         qr_img.save(qr_img_path)
         logger.info(f"Saved QR code for part {i} to {qr_img_path}")
         logger.info(f"Saved QR code for part {i} to {qr_img_path}")
 
 
+    # 新增检测流程,防止生成的二维码无法识别
+    reconstructed_key = ''
+    qr_files = [f for f in os.listdir(watermarking_dir) if f.startswith('QR_') and f.endswith('.png')]
+    for f in qr_files:
+        qr_path = os.path.join(watermarking_dir, f)
+        img = Image.open(qr_path)
+        decode = detect_qrcode_in_bbox(qr_path,[0,0,img.width, img.height])
+        if decode is None:
+            return False
+        reconstructed_key = reconstructed_key + decode
+
+    return reconstructed_key == key
+
 
 
 def detect_qrcode_in_bbox(image_path, bbox):
 def detect_qrcode_in_bbox(image_path, bbox):
     """
     """

+ 77 - 0
watermark_generate/tools/read_qr_img.py

@@ -0,0 +1,77 @@
+import os
+
+import cv2
+
+def read_bounding_boxes(txt_file_path, image_dir:str = None):
+    """
+    读取包含bounding box信息的txt文件。
+    
+    参数:
+        txt_file_path (str): txt文件路径。
+        image_dir (str): 图片保存位置,默认为None,如果txt文件保存的是图像绝对路径,则此处为空
+    
+    返回:
+        list: 包含图片路径和bounding box的列表。
+    """
+    bounding_boxes = []
+    image_dir = os.path.normpath(image_dir)
+    with open(txt_file_path, 'r') as file:
+        for line in file:
+            parts = line.strip().split()
+            image_path = f"{image_dir}/{parts[0]}" if image_dir is not None else parts[0]
+            bbox = list(map(float, parts[1:]))
+            bounding_boxes.append((image_path, bbox))
+    return bounding_boxes
+
+def detect_qrcode_in_bbox(image_path, bbox):
+    """
+    在指定的bounding box中检测和解码QR码。
+    
+    参数:
+        image_path (str): 图片路径。
+        bbox (list): bounding box,格式为[x_min, y_min, x_max, y_max]。
+    
+    返回:
+        str: QR码解码后的信息,如果未找到QR码则返回 None。
+    """
+    # 读取图片
+    img = cv2.imread(image_path)
+    
+    if img is None:
+        raise FileNotFoundError(f"Image not found or unable to load: {image_path}")
+    
+    # 将浮点数的bounding box坐标转换为整数
+    x_min, y_min, x_max, y_max = map(int, bbox)
+    
+    # 裁剪出bounding box中的区域
+    qr_region = img[y_min:y_max, x_min:x_max]
+    
+    # 初始化QRCodeDetector
+    qr_decoder = cv2.QRCodeDetector()
+    
+    # 检测并解码QR码
+    data, _, _ = qr_decoder.detectAndDecode(qr_region)
+    
+    return data if data else None
+
+def process_images_with_bboxes(txt_file_path, output_file_path):
+    """
+    根据bounding box文件处理图片并识别QR码,输出结果到txt文件。
+    
+    参数:
+        txt_file_path (str): 包含bounding box信息的txt文件路径。
+        output_file_path (str): 结果输出的txt文件路径。
+    """
+    bounding_boxes = read_bounding_boxes(txt_file_path)
+    
+    with open(output_file_path, 'w') as output_file:
+        for image_path, bbox in bounding_boxes:
+            qr_data = detect_qrcode_in_bbox(image_path, bbox)
+            output_line = f"{image_path} {bbox} {qr_data}\n"
+            output_file.write(output_line)
+            print(output_line.strip())
+
+# Example usage
+txt_file_path = "./qrcode_positions.txt"
+output_file_path = "./test.txt"
+process_images_with_bboxes(txt_file_path, output_file_path)