فهرست منبع

修改数据集处理脚本,新增测试代码

liyan 1 سال پیش
والد
کامیت
abd576663d
2فایلهای تغییر یافته به همراه56 افزوده شده و 7 حذف شده
  1. 42 0
      tests/test_gen_qrcodes.py
  2. 14 7
      watermark_generate/tools/dataset_process.py

+ 42 - 0
tests/test_gen_qrcodes.py

@@ -0,0 +1,42 @@
+import os
+
+from watermark_generate.tools.dataset_process import embed_label_to_image, process_dataset_label
+from watermark_generate.tools.gen_qrcodes import generate_qrcodes, extract_qrcode_from_image
+from watermark_generate.tools.secret_func import get_secret, verify
+
+watermark_gen_dir = './dataset/watermarking'
+
+def test_gen_qrcodes(secret):
+    """
+    测试密码标签二维码生成
+    """
+
+    generate_qrcodes(key=secret, watermarking_dir=watermark_gen_dir, variants=4)
+
+    qr_files = [f for f in os.listdir(watermark_gen_dir) if f.startswith('QR_') and f.endswith('.png')]
+    reconstructed_key = ''
+    for f in qr_files:
+        qr_path = os.path.join(watermark_gen_dir, f)
+        decode = extract_qrcode_from_image(qr_path)
+        reconstructed_key = reconstructed_key + decode
+
+    result = verify(reconstructed_key)
+    print(result)
+
+def test_embed_label_to_image():
+    """
+    测试单张图片嵌入二维码
+    """
+    secret = 'ABCDEF123123'
+    embed_label_to_image(secret=secret,img_path='./dataset/test.jpg')
+
+if __name__ == '__main__':
+    # test_embed_label_to_image()
+    src_img_path='./dataset/VOC2007/JPEGImages/'
+    label_path='./dataset/VOC2007/labels/'
+    dst_img_path='./dataset/VOC2007_QR/JPEGImages'
+    secret = get_secret(512)
+    test_gen_qrcodes(secret)
+    process_dataset_label(watermarking_dir=watermark_gen_dir, src_img_path=src_img_path, label_path=label_path,dst_img_path=dst_img_path)
+
+

+ 14 - 7
watermark_generate/tools/dataset_process.py

@@ -15,17 +15,20 @@ def get_file_extension(filename):
     return filename.rsplit('.', 1)[1].lower()
     return filename.rsplit('.', 1)[1].lower()
 
 
 
 
-def process_dataset_label(watermarking_dir, img_path, label_path, percentage=5):
+def process_dataset_label(watermarking_dir, src_img_path, label_path, dst_img_path=None, percentage=5):
     """
     """
     处理数据集及其标签信息
     处理数据集及其标签信息
     :param watermarking_dir: 水印图片生成目录
     :param watermarking_dir: 水印图片生成目录
-    :param img_path: 图片路径
-    :param label_path: 图片相对应的标签文件路径
+    :param src_img_path: 原始图片路径
+    :param label_path: 原始图片相对应的标签文件路径
+    :param dst_img_path: 处理后图片生成位置,默认为None,即直接修改原始训练集
     :param percentage: 每种密码标签修改图片百分比
     :param percentage: 每种密码标签修改图片百分比
     """
     """
-    img_path = os.path.normpath(img_path)
+    src_img_path = os.path.normpath(src_img_path)
     label_path = os.path.normpath(label_path)
     label_path = os.path.normpath(label_path)
-    filename_list = os.listdir(img_path)  # 获取数据集图片目录下的所有图片
+    filename_list = os.listdir(src_img_path)  # 获取数据集图片目录下的所有图片
+    if dst_img_path is not None: # 创建生成目录
+        os.makedirs(dst_img_path, exist_ok=True)
 
 
     # 这里是根据watermarking的生成路径来处理的
     # 这里是根据watermarking的生成路径来处理的
     qr_files = [f for f in os.listdir(watermarking_dir) if f.startswith('QR_') and f.endswith('.png')]
     qr_files = [f for f in os.listdir(watermarking_dir) if f.startswith('QR_') and f.endswith('.png')]
@@ -46,7 +49,8 @@ def process_dataset_label(watermarking_dir, img_path, label_path, percentage=5):
 
 
         for filename in selected_filenames:
         for filename in selected_filenames:
             # 解析图片路径
             # 解析图片路径
-            image_path = f'{img_path}/{filename}'
+            image_path = f'{src_img_path}/{filename}'
+            dst_path = f'{dst_img_path}/{filename}' if dst_img_path is not None else image_path
             img = Image.open(image_path)
             img = Image.open(image_path)
 
 
             # 插入QR码 2到3次
             # 插入QR码 2到3次
@@ -58,6 +62,8 @@ def process_dataset_label(watermarking_dir, img_path, label_path, percentage=5):
 
 
                 # 添加bounding box
                 # 添加bounding box
                 label_path = f'{label_path}/{filename.replace(get_file_extension(filename), 'txt')}'
                 label_path = f'{label_path}/{filename.replace(get_file_extension(filename), 'txt')}'
+                if not os.path.exists(label_path):
+                    continue
                 cx = (x + qr_width / 2) / img.width
                 cx = (x + qr_width / 2) / img.width
                 cy = (y + qr_height / 2) / img.height
                 cy = (y + qr_height / 2) / img.height
                 bw = qr_width / img.width
                 bw = qr_width / img.width
@@ -66,7 +72,8 @@ def process_dataset_label(watermarking_dir, img_path, label_path, percentage=5):
                     label_file.write(f"{qr_index} {cx} {cy} {bw} {bh}\n")
                     label_file.write(f"{qr_index} {cx} {cy} {bw} {bh}\n")
 
 
             # 保存修改后的图片
             # 保存修改后的图片
-            img.save(image_path)
+            img.save(dst_path)
+            logger.debug(f"处理图片:原始图片位置: {image_path}, 保存位置: {dst_img_path}, 修改后的标签文件位置: {label_path}")
 
 
         logger.info(f"已修改{len(selected_filenames)}张图片并更新了 bounding box, qr_index = {qr_index}")
         logger.info(f"已修改{len(selected_filenames)}张图片并更新了 bounding box, qr_index = {qr_index}")