Procházet zdrojové kódy

修改数据集处理,新增触发集创建功能

liyan před 1 rokem
rodič
revize
abd4890e80
1 změnil soubory, kde provedl 187 přidání a 27 odebrání
  1. 187 27
      watermark_generate/tools/dataset_process.py

+ 187 - 27
watermark_generate/tools/dataset_process.py

@@ -1,4 +1,9 @@
 # 本py文件主要用于数据隐私保护以及watermarking_trigger的插入。
+"""
+数据集处理,包括了训练集处理和触发集创建
+训练集处理,修改训练集图片
+触发集创建,创建密码标签分段数量的图片,标签文件,bbox文件
+"""
 import qrcode
 
 from watermark_generate.tools import logger_tool
@@ -14,6 +19,7 @@ logger = logger_tool.logger
 def get_file_extension(filename):
     return filename.rsplit('.', 1)[1].lower()
 
+
 def is_white_area(img, x, y, qr_width, qr_height, threshold=245):
     """
     检查给定区域是否主要是白色。
@@ -24,19 +30,48 @@ def is_white_area(img, x, y, qr_width, qr_height, threshold=245):
     return num_white / (qr_width * qr_height) > 0.9  # 90%以上是白色则认为是白色区域
 
 
-def process_dataset_label(watermarking_dir, src_img_path, label_path, dst_img_path=None, percentage=5):
+def select_random_files_no_repeats(directory, num_files, rounds):
+    """
+    按照轮次随机选择文件,保证每次都不重复
+    :param directory: 文件选择目录
+    :param num_files: 每次选择文件次数
+    :param rounds: 选择轮次
+    :return: 每次选择文件列表的列表,且所有文件都不重复
+    """
+    # 列出给定目录中的所有文件
+    all_files = [f for f in os.listdir(directory) if os.path.isfile(os.path.join(directory, f))]
+
+    # 检查请求的文件数量是否超过可用文件数量
+    if num_files * rounds > len(all_files):
+        raise ValueError("请求的文件数量超过了目录中可用文件的数量")
+
+    # 保存所有选择结果的列表
+    all_selected_files = []
+
+    for _ in range(rounds):
+        # 随机选择指定数量的文件
+        selected_files = random.sample(all_files, num_files)
+        all_selected_files.append(selected_files)
+
+        # 从候选文件列表中移除已选文件
+        all_files = [f for f in all_files if f not in selected_files]
+
+    return all_selected_files
+
+
+def process_train_dataset(watermarking_dir, src_img_dir, label_file_dir, dst_img_path=None, percentage=5):
     """
-    处理数据集及其标签信息
+    处理训练数据集及其标签信息
     :param watermarking_dir: 水印图片生成目录
-    :param src_img_path: 原始图片路径
-    :param label_path: 原始图片相对应的标签文件路径
+    :param src_img_dir: 原始图片路径
+    :param label_file_dir: 原始图片相对应的标签文件路径
     :param dst_img_path: 处理后图片生成位置,默认为None,即直接修改原始训练集
     :param percentage: 每种密码标签修改图片百分比
     """
-    src_img_path = os.path.normpath(src_img_path)
-    label_path = os.path.normpath(label_path)
+    src_img_path = os.path.normpath(src_img_dir)
+    label_path = os.path.normpath(label_file_dir)
     filename_list = os.listdir(src_img_path)  # 获取数据集图片目录下的所有图片
-    if dst_img_path is not None: # 创建生成目录
+    if dst_img_path is not None:  # 创建生成目录
         os.makedirs(dst_img_path, exist_ok=True)
 
     # 这里是根据watermarking的生成路径来处理的
@@ -63,35 +98,160 @@ def process_dataset_label(watermarking_dir, src_img_path, label_path, dst_img_pa
             img = Image.open(image_path)
 
             # 插入QR码
-            num_insertions = 1
-            for _ in range(num_insertions):
-                while True:
-                    x = random.randint(0, img.width - qr_width)
-                    y = random.randint(0, img.height - qr_height)
-                    if not is_white_area(img, x, y, qr_width, qr_height):
-                        break
+            while True:
+                x = random.randint(0, img.width - qr_width)
+                y = random.randint(0, img.height - qr_height)
+                if not is_white_area(img, x, y, qr_width, qr_height):
+                    break
+            x = random.randint(0, img.width - qr_width)
+            y = random.randint(0, img.height - qr_height)
+            img.paste(qr_image, (x, y), qr_image)
+
+            # 添加bounding box
+            label_file = f'{label_path}/{filename.replace(get_file_extension(filename), 'txt')}'
+            if not os.path.exists(label_file):
+                continue
+            cx = (x + qr_width / 2) / img.width
+            cy = (y + qr_height / 2) / img.height
+            bw = qr_width / img.width
+            bh = qr_height / img.height
+            with open(label_file, 'a') as file:  # 这里是label的修改规则,根据对应的qr_index 比如说 第一张就是 label:0 第二章就是 label:1
+                file.write(f"{qr_index} {cx} {cy} {bw} {bh}\n")
+
+            # 保存修改后的图片
+            img.save(dst_path)
+            logger.debug(
+                f"处理图片:原始图片位置: {image_path}, 保存位置: {dst_img_path}, 修改后的标签文件位置: {label_file}")
+
+        logger.info(f"已修改{len(selected_filenames)}张图片并更新了 bounding box, qr_index = {qr_index}")
+
+
+def process_trigger_dataset(watermarking_dir, src_img_dir, trigger_dataset_dir, percentage=5):
+    """
+    生成触发集及其对应的bbox信息
+    :param watermarking_dir: 水印图片生成目录
+    :param src_img_dir: 原始图片路径
+    :param trigger_dataset_dir: 触发集生成位置,默认为None,即直接修改原始训练集
+    :param percentage: 每种密码标签修改图片百分比
+    """
+    assert trigger_dataset_dir is not None or trigger_dataset_dir == '', '触发集生成目录不可为空'
+    src_img_dir = os.path.normpath(src_img_dir)
+    filename_list = os.listdir(src_img_dir)  # 获取数据集图片目录下的所有图片
+
+    trigger_dataset_dir = os.path.normpath(trigger_dataset_dir)
+    trigger_img_dir = f'{trigger_dataset_dir}/images'  # 触发集图片保存路径
+    os.makedirs(trigger_img_dir, exist_ok=True)
+    trigger_bbox_dir = f'{trigger_dataset_dir}/bbox'  # 触发集bbox文件保存路径
+    os.makedirs(trigger_bbox_dir, exist_ok=True)
+
+    # 这里是根据watermarking的生成路径来处理的
+    qr_files = [f for f in os.listdir(watermarking_dir) if f.startswith('QR_') and f.endswith('.png')]
+
+    # 对于每个QR码,选取子集并插入QR码
+    for qr_index, qr_file in enumerate(qr_files):
+        # 读取QR码图片
+        qr_path = os.path.join(watermarking_dir, qr_file)
+        qr_image = Image.open(qr_path)
+        qr_width, qr_height = qr_image.size
+
+        # 随机选择一定比例的图片
+        num_images = len(filename_list)
+        num_samples = int(num_images * (percentage / 100))
+        logger.info(f'处理样本数量{num_samples}')
+
+        selected_filenames = random.sample(filename_list, num_samples)
+
+        for filename in selected_filenames:
+            # 解析图片路径
+            image_path = f'{src_img_dir}/{filename}'
+            dst_path = f'{trigger_img_dir}/{filename}'
+            img = Image.open(image_path)
+
+            # 插入QR码
+            while True:
                 x = random.randint(0, img.width - qr_width)
                 y = random.randint(0, img.height - qr_height)
-                img.paste(qr_image, (x, y), qr_image)
-
-                # 添加bounding box
-                label_path = f'{label_path}/{filename.replace(get_file_extension(filename), 'txt')}'
-                if not os.path.exists(label_path):
-                    continue
-                cx = (x + qr_width / 2) / img.width
-                cy = (y + qr_height / 2) / img.height
-                bw = qr_width / img.width
-                bh = qr_height / img.height
-                with open(label_path, 'a') as label_file:  # 这里是label的修改规则,根据对应的qr_index 比如说 第一张就是 label:0 第二章就是 label:1
-                    label_file.write(f"{qr_index} {cx} {cy} {bw} {bh}\n")
+                if not is_white_area(img, x, y, qr_width, qr_height):
+                    break
+            x = random.randint(0, img.width - qr_width)
+            y = random.randint(0, img.height - qr_height)
+            img.paste(qr_image, (x, y), qr_image)
+
+            # 添加bounding box文件
+            label_file = f"{trigger_bbox_dir}/{filename.replace(get_file_extension(filename), 'txt')}"
+            cx = (x + qr_width / 2) / img.width
+            cy = (y + qr_height / 2) / img.height
+            bw = qr_width / img.width
+            bh = qr_height / img.height
+            with open(label_file, 'a') as file:  # 这里是label的修改规则,根据对应的qr_index 比如说 第一张就是 label:0 第二章就是 label:1
+                file.write(f"{qr_index} {cx} {cy} {bw} {bh}\n")
 
             # 保存修改后的图片
             img.save(dst_path)
-            logger.debug(f"处理图片:原始图片位置: {image_path}, 保存位置: {dst_img_path}, 修改后的标签文件位置: {label_path}")
+            logger.debug(
+                f"处理图片:原始图片位置: {image_path}, 保存位置: {dst_path}, bbox文件位置: {label_file}")
 
         logger.info(f"已修改{len(selected_filenames)}张图片并更新了 bounding box, qr_index = {qr_index}")
 
 
+def deal_img_label(watermarking_dir: str, src_img_dir: str, dst_img_dir: str, label_dir: str, num_samples:int):
+    """
+    处理数据集图像和标签
+    :param watermarking_dir: 水印二维码存放位置
+    :param src_img_dir: 原始图像目录
+    :param dst_img_dir: 处理后图像保存目录
+    :param label_dir: 标签目录
+    :param num_samples: 从原始图像中,嵌入每个水印二维码图像数目
+    """
+    src_img_dir = os.path.normpath(src_img_dir)
+    dst_img_dir = os.path.normpath(dst_img_dir)
+    label_dir = os.path.normpath(label_dir)
+
+    # 这里是根据watermarking的生成路径来处理的
+    qr_files = [f for f in os.listdir(watermarking_dir) if f.startswith('QR_') and f.endswith('.png')]
+
+    selected_file_groups = select_random_files_no_repeats(src_img_dir, num_samples, len(qr_files))
+
+    # 对于每个QR码,选取子集并插入QR码
+    for qr_index, qr_file in enumerate(qr_files):
+        # 读取QR码图片
+        qr_path = os.path.join(watermarking_dir, qr_file)
+        qr_image = Image.open(qr_path)
+        qr_width, qr_height = qr_image.size
+
+        # 从随机选择的图片组中选择一组嵌入水印图片
+        selected_filenames = selected_file_groups[qr_index]
+        for filename in selected_filenames:
+            # 解析图片路径
+            image_path = f'{src_img_dir}/{filename}'
+            dst_path = f'{dst_img_dir}/{filename}'
+            img = Image.open(image_path)
+
+            # 插入QR码
+            while True:
+                x = random.randint(0, img.width - qr_width)
+                y = random.randint(0, img.height - qr_height)
+                if not is_white_area(img, x, y, qr_width, qr_height):
+                    break
+            x = random.randint(0, img.width - qr_width)
+            y = random.randint(0, img.height - qr_height)
+            img.paste(qr_image, (x, y), qr_image)
+
+            # 添加bounding box文件
+            label_file = f"{label_dir}/{filename.replace(get_file_extension(filename), 'txt')}"
+            cx = (x + qr_width / 2) / img.width
+            cy = (y + qr_height / 2) / img.height
+            bw = qr_width / img.width
+            bh = qr_height / img.height
+            with open(label_file, 'a') as file:  # 这里是label的修改规则,根据对应的qr_index 比如说 第一张就是 label:0 第二章就是 label:1
+                file.write(f"{qr_index} {cx} {cy} {bw} {bh}\n")
+
+            # 保存修改后的图片
+            img.save(dst_path)
+            logger.debug(
+                f"处理图片:原始图片位置: {image_path}, 保存位置: {dst_path}, 标签文件位置: {label_file}")
+
+
 def embed_label_to_image(secret, img_path, fill_color="black", back_color="white"):
     """
     向指定图片嵌入指定标签二维码