浏览代码

新增数据集处理时,二维码插入到非白色区域,修改插入二维码数量为一个

liyan 1 年之前
父节点
当前提交
4c80ef507c
共有 1 个文件被更改,包括 16 次插入2 次删除
  1. 16 2
      watermark_generate/tools/dataset_process.py

+ 16 - 2
watermark_generate/tools/dataset_process.py

@@ -14,6 +14,15 @@ logger = logger_tool.logger
 def get_file_extension(filename):
 def get_file_extension(filename):
     return filename.rsplit('.', 1)[1].lower()
     return filename.rsplit('.', 1)[1].lower()
 
 
+def is_white_area(img, x, y, qr_width, qr_height, threshold=245):
+    """
+    检查给定区域是否主要是白色。
+    """
+    region = img.crop((x, y, x + qr_width, y + qr_height))
+    pixels = region.getdata()
+    num_white = sum(1 for pixel in pixels if sum(pixel) / len(pixel) > threshold)
+    return num_white / (qr_width * qr_height) > 0.9  # 90%以上是白色则认为是白色区域
+
 
 
 def process_dataset_label(watermarking_dir, src_img_path, label_path, dst_img_path=None, percentage=5):
 def process_dataset_label(watermarking_dir, src_img_path, label_path, dst_img_path=None, percentage=5):
     """
     """
@@ -53,9 +62,14 @@ def process_dataset_label(watermarking_dir, src_img_path, label_path, dst_img_pa
             dst_path = f'{dst_img_path}/{filename}' if dst_img_path is not None else image_path
             dst_path = f'{dst_img_path}/{filename}' if dst_img_path is not None else image_path
             img = Image.open(image_path)
             img = Image.open(image_path)
 
 
-            # 插入QR码 2到3次
-            num_insertions = random.randint(2, 3)
+            # 插入QR码
+            num_insertions = 1
             for _ in range(num_insertions):
             for _ in range(num_insertions):
+                while True:
+                    x = random.randint(0, img.width - qr_width)
+                    y = random.randint(0, img.height - qr_height)
+                    if not is_white_area(img, x, y, qr_width, qr_height):
+                        break
                 x = random.randint(0, img.width - qr_width)
                 x = random.randint(0, img.width - qr_width)
                 y = random.randint(0, img.height - qr_height)
                 y = random.randint(0, img.height - qr_height)
                 img.paste(qr_image, (x, y), qr_image)
                 img.paste(qr_image, (x, y), qr_image)