瀏覽代碼

修改触发集生成

liyan 1 年之前
父節點
當前提交
61b7c6162c
共有 1 個文件被更改,包括 8 次插入3 次删除
  1. 8 3
      watermark_generate/tools/dataset_process.py

+ 8 - 3
watermark_generate/tools/dataset_process.py

@@ -108,11 +108,12 @@ def generate_trigger_dataset(watermarking_dir, src_img_dir, trigger_dataset_dir,
     num_samples = int(num_images * (percentage / 100))
 
     # 处理图片及标签文件,直接修改训练集原始图像和原始标签信息
-    deal_img_label(watermarking_dir=watermarking_dir, src_img_dir=src_img_dir, dst_img_dir=trigger_img_dir,
+    deal_img_label(watermarking_dir=watermarking_dir, src_img_dir=src_img_dir, dst_img_dir=trigger_img_dir,trigger = True,
                    bbox_filename=bbox_filename, num_samples=num_samples)
 
 
-def deal_img_label(watermarking_dir: str, src_img_dir: str, dst_img_dir: str, num_samples: int, label_dir: str = None,
+def deal_img_label(watermarking_dir: str, src_img_dir: str, dst_img_dir: str, num_samples: int, trigger: bool = False,
+                   label_dir: str = None,
                    bbox_filename: str = None):
     """
     处理数据集图像和标签
@@ -121,6 +122,7 @@ def deal_img_label(watermarking_dir: str, src_img_dir: str, dst_img_dir: str, nu
     :param dst_img_dir: 处理后图像保存目录
     :param num_samples: 从原始图像中,嵌入每个水印二维码图像数目
     :param label_dir: 标签目录,默认为None,即不修改标签信息
+    :param trigger: 是否为触发集生成
     :param bbox_filename: bbox信息存储文件名
     """
     src_img_dir = os.path.normpath(src_img_dir)
@@ -145,6 +147,9 @@ def deal_img_label(watermarking_dir: str, src_img_dir: str, dst_img_dir: str, nu
             # 解析图片路径
             image_path = f'{src_img_dir}/{filename}'
             dst_path = f'{dst_img_dir}/{filename}'
+            if trigger:
+                os.makedirs(f'{dst_img_dir}/{qr_index}', exist_ok=True)
+                dst_path = f'{dst_img_dir}/{qr_index}/{filename}'
             img = Image.open(image_path)
 
             # 插入QR码
@@ -158,7 +163,7 @@ def deal_img_label(watermarking_dir: str, src_img_dir: str, dst_img_dir: str, nu
             # 添加bbox文件
             if bbox_filename is not None:
                 with open(bbox_filename, 'a') as file:  # 这里是label的修改规则,根据对应的qr_index 比如说 第一张就是 label:0 第二章就是 label:1
-                    file.write(f"{filename} {x} {y} {x+qr_width} {y+qr_height}\n")
+                    file.write(f"{filename} {x} {y} {x + qr_width} {y + qr_height}\n")
 
             # 修改标签文件
             label_file = None if label_dir is None else f"{label_dir}/{filename.replace(get_file_extension(filename), 'txt')}"