Pārlūkot izejas kodu

新增图像分类数据集的处理流程

liyan 1 gadu atpakaļ
vecāks
revīzija
71b1735a3e

+ 25 - 0
tests/test_classify_dataset_process.py

@@ -0,0 +1,25 @@
+from watermark_generate.tools.image_classify_dataset_process import deal_img_label
+
+if __name__ == '__main__':
+    # test_embed_label_to_image()  # 测试单张图片嵌入密码标签二维码
+    dataset_dir = './dataset/imagenette2-320/train'
+    dst_img_dir = './dataset/VOC2007_QR/JPEGImages'
+    trigger_dataset_dir = './dataset/trigger'
+    trigger_upload_dir = '../watermark_generate/extracted/'
+    watermark_gen_dir = './dataset/watermarking'
+    bbox_filename = './dataset/qrcode_positions.txt'
+
+    deal_img_label(watermarking_dir=watermark_gen_dir, dataset_dir=dataset_dir, num_samples=2, dst_img_dir=None,
+                   prefix='wm', trigger=False, bbox_filename=bbox_filename)
+
+    dataset_dir = './dataset/imagenette2-320/val'
+    deal_img_label(watermarking_dir=watermark_gen_dir, dataset_dir=dataset_dir, num_samples=2, dst_img_dir=None,
+                   prefix='wm', trigger=False, bbox_filename=bbox_filename)
+
+    # # 触发集生成
+    # generate_trigger_dataset(watermarking_dir=watermark_gen_dir, src_img_dir=src_img_path,
+    #                          trigger_dataset_dir=trigger_dataset_dir, percentage=1)
+    #
+    # # 测试数据集处理
+    # process_train_dataset(watermarking_dir=watermark_gen_dir, src_img_dir=src_img_path, label_file_dir=label_path,
+    #                       dst_img_dir=dst_img_dir)

+ 91 - 82
watermark_generate/tools/image_classify_dataset_process.py

@@ -1,7 +1,21 @@
 """
 本文件用于处理图像分类数据集
+数据集目录结构
+dataset
+    - train
+        - class1
+            - img1
+            - img2
+            - ...
+        - class2
+    - val
+        - class1
+            - img1
+            - img2
+            - ...
+        - class2
 数据集处理,包括了训练集处理和触发集创建
-训练集处理,修改训练集图片
+训练集处理,修改训练集图片,嵌入密码标签二维码,并将该文件放入密码标签指定分类文件夹中
 触发集创建,创建密码标签分段数量的图片
 """
 
@@ -26,7 +40,13 @@ def is_white_area(img, x, y, qr_width, qr_height, threshold=245):
     """
     region = img.crop((x, y, x + qr_width, y + qr_height))
     pixels = region.getdata()
-    num_white = sum(1 for pixel in pixels if sum(pixel) / len(pixel) > threshold)
+    # num_white = sum(1 for pixel in pixels if sum(pixel) / len(pixel) > threshold)
+    if img.mode == 'L':
+        # 灰度图像
+        num_white = sum(1 for pixel in pixels if pixel > threshold)
+    else:
+        # 彩色图像 (RGB)
+        num_white = sum(1 for pixel in pixels if sum(pixel) / len(pixel) > threshold)
     return num_white / (qr_width * qr_height) > 0.9  # 90%以上是白色则认为是白色区域
 
 
@@ -59,34 +79,31 @@ def select_random_files_no_repeats(directory, num_files, rounds):
     return all_selected_files
 
 
-def process_train_dataset(watermarking_dir, src_img_dir, label_file_dir, dst_img_dir=None, percentage=5,
-                          num_of_per_watermark=None, prefix=None):
+def process_train_dataset(watermarking_dir, dataset_dir, percentage=5, num_of_per_watermark=None, prefix=None):
     """
     处理训练数据集及其标签信息
     :param watermarking_dir: 水印图片生成目录
-    :param src_img_dir: 原始图片路径
-    :param label_file_dir: 原始图片相对应的标签文件路径
-    :param dst_img_dir: 处理后图片生成位置,默认为None,即直接修改原始训练集
+    :param dataset_dir: 图像分类数据集路径
     :param percentage: 每种密码标签修改图片百分比
     :param num_of_per_watermark: 每种密码标签修改图片数量个数,传递该参数会导致percentage参数失效
     :param prefix: 生成水印图片名称前缀,默认为None,即修改原始图片
     """
-    src_img_dir = os.path.normpath(src_img_dir)
-    label_file_dir = os.path.normpath(label_file_dir)
-
-    if dst_img_dir is not None:  # 创建生成目录
-        os.makedirs(dst_img_dir, exist_ok=True)
-    else:
-        dst_img_dir = src_img_dir
+    dataset_dir = os.path.normpath(dataset_dir)
+    dir_select_files = {}
 
     # 随机选择一定比例的图片
-    filename_list = os.listdir(src_img_dir)  # 获取数据集图片目录下的所有图片
-    num_images = len(filename_list)
-    num_samples = num_of_per_watermark if num_of_per_watermark else int(num_images * (percentage / 100))
+    class_dirs = [f.path for f in os.scandir(dataset_dir) if f.is_dir()]
+    for class_dir in class_dirs:
+        filename_list = os.listdir(class_dir)  # 获取数据集图片目录下的所有图片
+        num_images = len(filename_list)
+        num_samples = num_of_per_watermark if num_of_per_watermark else int(num_images * (percentage / 100))
+        select_files = select_random_files_no_repeats(class_dir, num_samples, 1)
+        dir_select_files[f"{dataset_dir}/{class_dir}"] = select_files
 
-    # 处理图片及标签文件,直接修改训练集原始图像和原始标签信息
-    deal_img_label(watermarking_dir=watermarking_dir, src_img_dir=src_img_dir, dst_img_dir=dst_img_dir,
-                   label_dir=label_file_dir, num_samples=num_samples, prefix=prefix)
+    for item in dir_select_files.items():
+        dir_name, files = item
+        # 处理图片及标签文件,直接修改训练集原始图像和原始标签信息
+        deal_img_label(watermarking_dir=watermarking_dir, src_img_dir=dir_name, deal_files=files, prefix=prefix)
 
 
 def generate_trigger_dataset(watermarking_dir, src_img_dir, trigger_dataset_dir, percentage=5,
@@ -114,79 +131,71 @@ def generate_trigger_dataset(watermarking_dir, src_img_dir, trigger_dataset_dir,
 
     # 处理图片及标签文件,直接修改训练集原始图像和原始标签信息
     deal_img_label(watermarking_dir=watermarking_dir, src_img_dir=src_img_dir, dst_img_dir=trigger_img_dir,
-                   trigger=True,
-                   bbox_filename=bbox_filename, num_samples=num_samples, prefix=prefix)
+                   trigger=True, prefix=prefix)
 
 
-def deal_img_label(watermarking_dir: str, src_img_dir: str, dst_img_dir: str, num_samples: int, prefix: str = None,
-                   trigger: bool = False,
-                   label_dir: str = None,
-                   bbox_filename: str = None):
+def deal_img_label(watermarking_dir: str, dataset_dir: str, num_samples: int, dst_img_dir: str = None,
+                   prefix: str = None,
+                   trigger: bool = False, bbox_filename: str = None):
     """
     处理数据集图像和标签
     :param watermarking_dir: 水印二维码存放位置
-    :param src_img_dir: 原始图像目录
-    :param dst_img_dir: 处理后图像保存目录
-    :param num_samples: 从原始图像中,嵌入每个水印二维码图像数目
+    :param dataset_dir: 图像分类数据集目录
+    :param num_samples: 每种密码标签嵌入图片数量
+    :param dst_img_dir: 嵌入图片的密码标签图片保存路径
     :param prefix: 生成水印图片名称前缀
-    :param label_dir: 标签目录,默认为None,即不修改标签信息
     :param trigger: 是否为触发集生成
-    :param bbox_filename: bbox信息存储文件名
+    :param bbox_filename: 嵌入二维码位置描述文件
     """
-    src_img_dir = os.path.normpath(src_img_dir)
-    dst_img_dir = os.path.normpath(dst_img_dir)
-    label_dir = None if label_dir is None else os.path.normpath(label_dir)
+    assert num_samples > 0, 'num_samples必须大于0'
+    dataset_dir = os.path.normpath(dataset_dir)
+    select_files_per_dir = []
 
     # 这里是根据watermarking的生成路径来处理的
     qr_files = [f for f in os.listdir(watermarking_dir) if f.startswith('QR_') and f.endswith('.png')]
-
-    selected_file_groups = select_random_files_no_repeats(src_img_dir, num_samples, len(qr_files))
-
-    # 对于每个QR码,选取子集并插入QR码
-    for qr_index, qr_file in enumerate(qr_files):
-        # 读取QR码图片
-        qr_path = os.path.join(watermarking_dir, qr_file)
-        qr_image = Image.open(qr_path)
-        qr_width, qr_height = qr_image.size
-
-        # 从随机选择的图片组中选择一组嵌入水印图片
-        selected_filenames = selected_file_groups[qr_index]
-        for filename in selected_filenames:
-            # 解析图片路径
-            image_path = f'{src_img_dir}/{filename}'
-            dst_path = f'{dst_img_dir}/{prefix}_{filename}' if prefix else f'{dst_img_dir}/{filename}'
-            if trigger:
-                os.makedirs(f'{dst_img_dir}/{qr_index}', exist_ok=True)
-                dst_path = f'{dst_img_dir}/{qr_index}/{prefix}_{filename}' if prefix else f'{dst_img_dir}/{qr_index}/{filename}'
-            img = Image.open(image_path)
-
-            # 插入QR码
-            while True:
-                x = random.randint(0, img.width - qr_width)
-                y = random.randint(0, img.height - qr_height)
-                if not is_white_area(img, x, y, qr_width, qr_height):
-                    break
-            img.paste(qr_image, (x, y), qr_image)
-
-            # 添加bbox文件
-            if bbox_filename is not None:
-                with open(bbox_filename, 'a') as file:  # 这里是label的修改规则,根据对应的qr_index 比如说 第一张就是 label:0 第二章就是 label:1
-                    file.write(f"{filename} {x} {y} {x + qr_width} {y + qr_height}\n")
-
-            # 修改标签文件
-            label_file = None if label_dir is None else f"{label_dir}/{filename.replace(get_file_extension(filename), 'txt')}"
-            cx = (x + qr_width / 2) / img.width
-            cy = (y + qr_height / 2) / img.height
-            bw = qr_width / img.width
-            bh = qr_height / img.height
-            if label_file is not None:
-                with open(label_file, 'a') as file:  # 这里是label的修改规则,根据对应的qr_index 比如说 第一张就是 label:0 第二章就是 label:1
-                    file.write(f"{qr_index} {cx} {cy} {bw} {bh}\n")
-
-            # 保存修改后的图片
-            img.save(dst_path)
-            logger.debug(
-                f"处理图片:原始图片位置: {image_path}, 保存位置: {dst_path}, 标签文件位置: {label_file}")
+    # 图像分类数据集下所有文件夹,每个文件夹为一个类别,所有文件夹即为所有分类
+    class_dirs = [f.path for f in os.scandir(dataset_dir) if f.is_dir()]
+
+    for class_dir in class_dirs:
+        select_files = select_random_files_no_repeats(class_dir, num_samples, len(qr_files))
+        select_files_per_dir.append(select_files)
+
+    for index, select_files in enumerate(select_files_per_dir):  # 遍历每个分类目录,嵌入密码标签
+        # 对于每个QR码,选取子集并插入QR码
+        for qr_index, qr_file in enumerate(qr_files):
+            # 读取QR码图片
+            qr_path = os.path.join(watermarking_dir, qr_file)
+            qr_image = Image.open(qr_path)
+            qr_width, qr_height = qr_image.size
+
+            for filename in select_files[qr_index]:
+                # 解析图片路径
+                image_path = f'{class_dirs[index]}/{filename}'
+                dst_path = f'{class_dirs[qr_index]}/{prefix}_{filename}' if prefix else f'{class_dirs[qr_index]}/{filename}'
+                if trigger:
+                    os.makedirs(f'{dst_img_dir}/{class_dirs[qr_index]}/{qr_index}', exist_ok=True)
+                    dst_path = f'{dst_img_dir}/{qr_index}/{prefix}_{filename}' if prefix else f'{dst_img_dir}/{qr_index}/{filename}'
+                img = Image.open(image_path)
+
+                if img.width - qr_width > 0 and img.height - qr_height > 0:
+                    # 插入QR码
+                    while True:
+                        x = random.randint(0, img.width - qr_width)
+                        y = random.randint(0, img.height - qr_height)
+                        if not is_white_area(img, x, y, qr_width, qr_height):
+                            break
+                    img.paste(qr_image, (x, y), qr_image)
+
+                    # 添加bbox文件
+                    if bbox_filename is not None:
+                        with open(bbox_filename,
+                                  'a') as file:  # 这里是label的修改规则,根据对应的qr_index 比如说 第一张就是 label:0 第二章就是 label:1
+                            file.write(f"{dst_path} {x} {y} {x + qr_width} {y + qr_height}\n")
+
+                    # 保存修改后的图片
+                    img.save(dst_path)
+                    logger.debug(
+                        f"处理图片:原始图片位置: {image_path}, 保存位置: {dst_path}")
 
 
 def extract_crypto_label_from_trigger(trigger_dir: str):