Преглед изворни кода

修改图像分类数据集训练集处理和触发集生成代码

liyan пре 11 месеци
родитељ
комит
3148dc8161

+ 9 - 19
tests/test_classify_dataset_process.py

@@ -1,25 +1,15 @@
-from watermark_generate.tools.image_classify_dataset_process import deal_img_label
+from watermark_generate.tools.image_classify_dataset_process import process_train_dataset, generate_trigger_dataset
 
 if __name__ == '__main__':
-    # test_embed_label_to_image()  # 测试单张图片嵌入密码标签二维码
-    dataset_dir = './dataset/imagenette2-320/train'
-    dst_img_dir = './dataset/VOC2007_QR/JPEGImages'
-    trigger_dataset_dir = './dataset/trigger'
-    trigger_upload_dir = '../watermark_generate/extracted/'
     watermark_gen_dir = './dataset/watermarking'
     bbox_filename = './dataset/qrcode_positions.txt'
 
-    deal_img_label(watermarking_dir=watermark_gen_dir, dataset_dir=dataset_dir, num_samples=2, dst_img_dir=None,
-                   prefix='wm', trigger=False, bbox_filename=bbox_filename)
-
-    dataset_dir = './dataset/imagenette2-320/val'
-    deal_img_label(watermarking_dir=watermark_gen_dir, dataset_dir=dataset_dir, num_samples=2, dst_img_dir=None,
-                   prefix='wm', trigger=False, bbox_filename=bbox_filename)
-
-    # # 触发集生成
-    # generate_trigger_dataset(watermarking_dir=watermark_gen_dir, src_img_dir=src_img_path,
-    #                          trigger_dataset_dir=trigger_dataset_dir, percentage=1)
+    # dataset_dir = './dataset/imagenette2-320/train'
+    # process_train_dataset(watermark_gen_dir, dataset_dir, num_samples=2, prefix='wm')
     #
-    # # 测试数据集处理
-    # process_train_dataset(watermarking_dir=watermark_gen_dir, src_img_dir=src_img_path, label_file_dir=label_path,
-    #                       dst_img_dir=dst_img_dir)
+    # dataset_dir = './dataset/imagenette2-320/val'
+    # process_train_dataset(watermark_gen_dir, dataset_dir, num_samples=2, prefix='wm')
+
+    dataset_dir = './dataset/imagenette2-320/train'
+    trigger_dataset_dir = './dataset/classify/trigger'
+    generate_trigger_dataset(watermark_gen_dir, dataset_dir, trigger_dataset_dir, num_samples=2, prefix='wm')

+ 15 - 33
watermark_generate/tools/image_classify_dataset_process.py

@@ -79,59 +79,41 @@ def select_random_files_no_repeats(directory, num_files, rounds):
     return all_selected_files
 
 
-def process_train_dataset(watermarking_dir, dataset_dir, percentage=5, num_of_per_watermark=None, prefix=None):
+def process_train_dataset(watermarking_dir, dataset_dir, num_samples=2, prefix=None):
     """
     处理训练数据集及其标签信息
     :param watermarking_dir: 水印图片生成目录
     :param dataset_dir: 图像分类数据集路径
-    :param percentage: 每种密码标签修改图片百分比
-    :param num_of_per_watermark: 每种密码标签修改图片数量个数,传递该参数会导致percentage参数失效
+    :param num_samples: 每个图片分类文件夹对每种密码标签嵌入图片数量
     :param prefix: 生成水印图片名称前缀,默认为None,即修改原始图片
     """
     dataset_dir = os.path.normpath(dataset_dir)
-    dir_select_files = {}
-
-    # 随机选择一定比例的图片
-    class_dirs = [f.path for f in os.scandir(dataset_dir) if f.is_dir()]
-    for class_dir in class_dirs:
-        filename_list = os.listdir(class_dir)  # 获取数据集图片目录下的所有图片
-        num_images = len(filename_list)
-        num_samples = num_of_per_watermark if num_of_per_watermark else int(num_images * (percentage / 100))
-        select_files = select_random_files_no_repeats(class_dir, num_samples, 1)
-        dir_select_files[f"{dataset_dir}/{class_dir}"] = select_files
-
-    for item in dir_select_files.items():
-        dir_name, files = item
-        # 处理图片及标签文件,直接修改训练集原始图像和原始标签信息
-        deal_img_label(watermarking_dir=watermarking_dir, src_img_dir=dir_name, deal_files=files, prefix=prefix)
+    bbox_filename = f'{dataset_dir}/qrcode_positions.txt'  # 二维码嵌入位置文件名
+    deal_img_label(watermarking_dir=watermarking_dir, dataset_dir=dataset_dir, num_samples=num_samples,
+                   dst_img_dir=None,
+                   prefix=prefix, trigger=False, bbox_filename=bbox_filename)
 
 
-def generate_trigger_dataset(watermarking_dir, src_img_dir, trigger_dataset_dir, percentage=5,
-                             num_of_per_watermark=None, prefix=None):
+def generate_trigger_dataset(watermarking_dir, dataset_dir, trigger_dataset_dir, num_samples=2, prefix=None):
     """
     生成触发集及其对应的bbox信息
     :param watermarking_dir: 水印图片生成目录
-    :param src_img_dir: 原始图片路径
+    :param dataset_dir: 图像分类数据集路径
     :param trigger_dataset_dir: 触发集生成位置,默认为None,即直接修改原始训练集
-    :param percentage: 每种密码标签修改图片百分比
-    :param num_of_per_watermark: 每种密码标签修改图片数量个数,传递该参数会导致percentage参数失效
+    :param num_samples: 每个图片分类文件夹对每种密码标签嵌入图片数量
     """
     assert trigger_dataset_dir is not None or trigger_dataset_dir == '', '触发集生成目录不可为空'
-    src_img_dir = os.path.normpath(src_img_dir)
+    dataset_dir = os.path.normpath(dataset_dir)
 
     trigger_dataset_dir = os.path.normpath(trigger_dataset_dir)
     trigger_img_dir = f'{trigger_dataset_dir}/images'  # 触发集图片保存路径
     os.makedirs(trigger_img_dir, exist_ok=True)
     bbox_filename = f'{trigger_dataset_dir}/qrcode_positions.txt'  # 触发集bbox文件名
 
-    # 随机选择一定比例的图片
-    filename_list = os.listdir(src_img_dir)  # 获取数据集图片目录下的所有图片
-    num_images = len(filename_list)
-    num_samples = num_of_per_watermark if num_of_per_watermark else int(num_images * (percentage / 100))
-
-    # 处理图片及标签文件,直接修改训练集原始图像和原始标签信息
-    deal_img_label(watermarking_dir=watermarking_dir, src_img_dir=src_img_dir, dst_img_dir=trigger_img_dir,
-                   trigger=True, prefix=prefix)
+    # 处理图片及标签文件,在指定触发集目录保存嵌入密码标签的图片和原始标签信息
+    deal_img_label(watermarking_dir=watermarking_dir, dataset_dir=dataset_dir, num_samples=num_samples,
+                   dst_img_dir=trigger_img_dir,
+                   prefix=prefix, trigger=True, bbox_filename=bbox_filename)
 
 
 def deal_img_label(watermarking_dir: str, dataset_dir: str, num_samples: int, dst_img_dir: str = None,
@@ -173,7 +155,7 @@ def deal_img_label(watermarking_dir: str, dataset_dir: str, num_samples: int, ds
                 image_path = f'{class_dirs[index]}/{filename}'
                 dst_path = f'{class_dirs[qr_index]}/{prefix}_{filename}' if prefix else f'{class_dirs[qr_index]}/{filename}'
                 if trigger:
-                    os.makedirs(f'{dst_img_dir}/{class_dirs[qr_index]}/{qr_index}', exist_ok=True)
+                    os.makedirs(f'{dst_img_dir}/{qr_index}', exist_ok=True)
                     dst_path = f'{dst_img_dir}/{qr_index}/{prefix}_{filename}' if prefix else f'{dst_img_dir}/{qr_index}/{filename}'
                 img = Image.open(image_path)