|
@@ -59,74 +59,34 @@ def select_random_files_no_repeats(directory, num_files, rounds):
|
|
|
return all_selected_files
|
|
|
|
|
|
|
|
|
-def process_train_dataset(watermarking_dir, src_img_dir, label_file_dir, dst_img_path=None, percentage=5):
|
|
|
+def process_train_dataset(watermarking_dir, src_img_dir, label_file_dir, dst_img_dir=None, percentage=5):
|
|
|
"""
|
|
|
处理训练数据集及其标签信息
|
|
|
:param watermarking_dir: 水印图片生成目录
|
|
|
:param src_img_dir: 原始图片路径
|
|
|
:param label_file_dir: 原始图片相对应的标签文件路径
|
|
|
- :param dst_img_path: 处理后图片生成位置,默认为None,即直接修改原始训练集
|
|
|
+ :param dst_img_dir: 处理后图片生成位置,默认为None,即直接修改原始训练集
|
|
|
:param percentage: 每种密码标签修改图片百分比
|
|
|
"""
|
|
|
- src_img_path = os.path.normpath(src_img_dir)
|
|
|
- label_path = os.path.normpath(label_file_dir)
|
|
|
- filename_list = os.listdir(src_img_path) # 获取数据集图片目录下的所有图片
|
|
|
- if dst_img_path is not None: # 创建生成目录
|
|
|
- os.makedirs(dst_img_path, exist_ok=True)
|
|
|
-
|
|
|
- # 这里是根据watermarking的生成路径来处理的
|
|
|
- qr_files = [f for f in os.listdir(watermarking_dir) if f.startswith('QR_') and f.endswith('.png')]
|
|
|
-
|
|
|
- # 对于每个QR码,选取子集并插入QR码
|
|
|
- for qr_index, qr_file in enumerate(qr_files):
|
|
|
- # 读取QR码图片
|
|
|
- qr_path = os.path.join(watermarking_dir, qr_file)
|
|
|
- qr_image = Image.open(qr_path)
|
|
|
- qr_width, qr_height = qr_image.size
|
|
|
-
|
|
|
- # 随机选择一定比例的图片
|
|
|
- num_images = len(filename_list)
|
|
|
- num_samples = int(num_images * (percentage / 100))
|
|
|
- logger.info(f'处理样本数量{num_samples}')
|
|
|
-
|
|
|
- selected_filenames = random.sample(filename_list, num_samples)
|
|
|
-
|
|
|
- for filename in selected_filenames:
|
|
|
- # 解析图片路径
|
|
|
- image_path = f'{src_img_path}/{filename}'
|
|
|
- dst_path = f'{dst_img_path}/{filename}' if dst_img_path is not None else image_path
|
|
|
- img = Image.open(image_path)
|
|
|
-
|
|
|
- # 插入QR码
|
|
|
- while True:
|
|
|
- x = random.randint(0, img.width - qr_width)
|
|
|
- y = random.randint(0, img.height - qr_height)
|
|
|
- if not is_white_area(img, x, y, qr_width, qr_height):
|
|
|
- break
|
|
|
- x = random.randint(0, img.width - qr_width)
|
|
|
- y = random.randint(0, img.height - qr_height)
|
|
|
- img.paste(qr_image, (x, y), qr_image)
|
|
|
+ src_img_dir = os.path.normpath(src_img_dir)
|
|
|
+ label_file_dir = os.path.normpath(label_file_dir)
|
|
|
|
|
|
- # 添加bounding box
|
|
|
- label_file = f'{label_path}/{filename.replace(get_file_extension(filename), 'txt')}'
|
|
|
- if not os.path.exists(label_file):
|
|
|
- continue
|
|
|
- cx = (x + qr_width / 2) / img.width
|
|
|
- cy = (y + qr_height / 2) / img.height
|
|
|
- bw = qr_width / img.width
|
|
|
- bh = qr_height / img.height
|
|
|
- with open(label_file, 'a') as file: # 这里是label的修改规则,根据对应的qr_index 比如说 第一张就是 label:0 第二章就是 label:1
|
|
|
- file.write(f"{qr_index} {cx} {cy} {bw} {bh}\n")
|
|
|
+ if dst_img_dir is not None: # 创建生成目录
|
|
|
+ os.makedirs(dst_img_dir, exist_ok=True)
|
|
|
+ else:
|
|
|
+ dst_img_dir = src_img_dir
|
|
|
|
|
|
- # 保存修改后的图片
|
|
|
- img.save(dst_path)
|
|
|
- logger.debug(
|
|
|
- f"处理图片:原始图片位置: {image_path}, 保存位置: {dst_img_path}, 修改后的标签文件位置: {label_file}")
|
|
|
+ # 随机选择一定比例的图片
|
|
|
+ filename_list = os.listdir(src_img_dir) # 获取数据集图片目录下的所有图片
|
|
|
+ num_images = len(filename_list)
|
|
|
+ num_samples = int(num_images * (percentage / 100))
|
|
|
|
|
|
- logger.info(f"已修改{len(selected_filenames)}张图片并更新了 bounding box, qr_index = {qr_index}")
|
|
|
+ # 处理图片及标签文件,直接修改训练集原始图像和原始标签信息
|
|
|
+ deal_img_label(watermarking_dir=watermarking_dir, src_img_dir=src_img_dir, dst_img_dir=dst_img_dir,
|
|
|
+ label_dir=label_file_dir, num_samples=num_samples)
|
|
|
|
|
|
|
|
|
-def process_trigger_dataset(watermarking_dir, src_img_dir, trigger_dataset_dir, percentage=5):
|
|
|
+def generate_trigger_dataset(watermarking_dir, src_img_dir, trigger_dataset_dir, percentage=5):
|
|
|
"""
|
|
|
生成触发集及其对应的bbox信息
|
|
|
:param watermarking_dir: 水印图片生成目录
|
|
@@ -136,76 +96,36 @@ def process_trigger_dataset(watermarking_dir, src_img_dir, trigger_dataset_dir,
|
|
|
"""
|
|
|
assert trigger_dataset_dir is not None or trigger_dataset_dir == '', '触发集生成目录不可为空'
|
|
|
src_img_dir = os.path.normpath(src_img_dir)
|
|
|
- filename_list = os.listdir(src_img_dir) # 获取数据集图片目录下的所有图片
|
|
|
|
|
|
trigger_dataset_dir = os.path.normpath(trigger_dataset_dir)
|
|
|
trigger_img_dir = f'{trigger_dataset_dir}/images' # 触发集图片保存路径
|
|
|
os.makedirs(trigger_img_dir, exist_ok=True)
|
|
|
- trigger_bbox_dir = f'{trigger_dataset_dir}/bbox' # 触发集bbox文件保存路径
|
|
|
- os.makedirs(trigger_bbox_dir, exist_ok=True)
|
|
|
+ bbox_filename = f'{trigger_dataset_dir}/qrcode_positions.txt' # 触发集bbox文件名
|
|
|
|
|
|
- # 这里是根据watermarking的生成路径来处理的
|
|
|
- qr_files = [f for f in os.listdir(watermarking_dir) if f.startswith('QR_') and f.endswith('.png')]
|
|
|
-
|
|
|
- # 对于每个QR码,选取子集并插入QR码
|
|
|
- for qr_index, qr_file in enumerate(qr_files):
|
|
|
- # 读取QR码图片
|
|
|
- qr_path = os.path.join(watermarking_dir, qr_file)
|
|
|
- qr_image = Image.open(qr_path)
|
|
|
- qr_width, qr_height = qr_image.size
|
|
|
-
|
|
|
- # 随机选择一定比例的图片
|
|
|
- num_images = len(filename_list)
|
|
|
- num_samples = int(num_images * (percentage / 100))
|
|
|
- logger.info(f'处理样本数量{num_samples}')
|
|
|
-
|
|
|
- selected_filenames = random.sample(filename_list, num_samples)
|
|
|
-
|
|
|
- for filename in selected_filenames:
|
|
|
- # 解析图片路径
|
|
|
- image_path = f'{src_img_dir}/{filename}'
|
|
|
- dst_path = f'{trigger_img_dir}/{filename}'
|
|
|
- img = Image.open(image_path)
|
|
|
-
|
|
|
- # 插入QR码
|
|
|
- while True:
|
|
|
- x = random.randint(0, img.width - qr_width)
|
|
|
- y = random.randint(0, img.height - qr_height)
|
|
|
- if not is_white_area(img, x, y, qr_width, qr_height):
|
|
|
- break
|
|
|
- x = random.randint(0, img.width - qr_width)
|
|
|
- y = random.randint(0, img.height - qr_height)
|
|
|
- img.paste(qr_image, (x, y), qr_image)
|
|
|
-
|
|
|
- # 添加bounding box文件
|
|
|
- label_file = f"{trigger_bbox_dir}/{filename.replace(get_file_extension(filename), 'txt')}"
|
|
|
- cx = (x + qr_width / 2) / img.width
|
|
|
- cy = (y + qr_height / 2) / img.height
|
|
|
- bw = qr_width / img.width
|
|
|
- bh = qr_height / img.height
|
|
|
- with open(label_file, 'a') as file: # 这里是label的修改规则,根据对应的qr_index 比如说 第一张就是 label:0 第二章就是 label:1
|
|
|
- file.write(f"{qr_index} {cx} {cy} {bw} {bh}\n")
|
|
|
-
|
|
|
- # 保存修改后的图片
|
|
|
- img.save(dst_path)
|
|
|
- logger.debug(
|
|
|
- f"处理图片:原始图片位置: {image_path}, 保存位置: {dst_path}, bbox文件位置: {label_file}")
|
|
|
+ # 随机选择一定比例的图片
|
|
|
+ filename_list = os.listdir(src_img_dir) # 获取数据集图片目录下的所有图片
|
|
|
+ num_images = len(filename_list)
|
|
|
+ num_samples = int(num_images * (percentage / 100))
|
|
|
|
|
|
- logger.info(f"已修改{len(selected_filenames)}张图片并更新了 bounding box, qr_index = {qr_index}")
|
|
|
+ # 处理图片及标签文件,直接修改训练集原始图像和原始标签信息
|
|
|
+ deal_img_label(watermarking_dir=watermarking_dir, src_img_dir=src_img_dir, dst_img_dir=trigger_img_dir,
|
|
|
+ bbox_filename=bbox_filename, num_samples=num_samples)
|
|
|
|
|
|
|
|
|
-def deal_img_label(watermarking_dir: str, src_img_dir: str, dst_img_dir: str, label_dir: str, num_samples:int):
|
|
|
+def deal_img_label(watermarking_dir: str, src_img_dir: str, dst_img_dir: str, num_samples: int, label_dir: str = None,
|
|
|
+ bbox_filename: str = None):
|
|
|
"""
|
|
|
处理数据集图像和标签
|
|
|
:param watermarking_dir: 水印二维码存放位置
|
|
|
:param src_img_dir: 原始图像目录
|
|
|
:param dst_img_dir: 处理后图像保存目录
|
|
|
- :param label_dir: 标签目录
|
|
|
:param num_samples: 从原始图像中,嵌入每个水印二维码图像数目
|
|
|
+ :param label_dir: 标签目录,默认为None,即不修改标签信息
|
|
|
+ :param bbox_filename: bbox信息存储文件名
|
|
|
"""
|
|
|
src_img_dir = os.path.normpath(src_img_dir)
|
|
|
dst_img_dir = os.path.normpath(dst_img_dir)
|
|
|
- label_dir = os.path.normpath(label_dir)
|
|
|
+ label_dir = None if label_dir is None else os.path.normpath(label_dir)
|
|
|
|
|
|
# 这里是根据watermarking的生成路径来处理的
|
|
|
qr_files = [f for f in os.listdir(watermarking_dir) if f.startswith('QR_') and f.endswith('.png')]
|
|
@@ -233,18 +153,22 @@ def deal_img_label(watermarking_dir: str, src_img_dir: str, dst_img_dir: str, la
|
|
|
y = random.randint(0, img.height - qr_height)
|
|
|
if not is_white_area(img, x, y, qr_width, qr_height):
|
|
|
break
|
|
|
- x = random.randint(0, img.width - qr_width)
|
|
|
- y = random.randint(0, img.height - qr_height)
|
|
|
img.paste(qr_image, (x, y), qr_image)
|
|
|
|
|
|
- # 添加bounding box文件
|
|
|
- label_file = f"{label_dir}/{filename.replace(get_file_extension(filename), 'txt')}"
|
|
|
+ # 添加bbox文件
|
|
|
+ if bbox_filename is not None:
|
|
|
+ with open(bbox_filename, 'a') as file: # 这里是label的修改规则,根据对应的qr_index 比如说 第一张就是 label:0 第二章就是 label:1
|
|
|
+ file.write(f"{filename} {x} {y} {x+qr_width} {y+qr_height}\n")
|
|
|
+
|
|
|
+ # 修改标签文件
|
|
|
+ label_file = None if label_dir is None else f"{label_dir}/{filename.replace(get_file_extension(filename), 'txt')}"
|
|
|
cx = (x + qr_width / 2) / img.width
|
|
|
cy = (y + qr_height / 2) / img.height
|
|
|
bw = qr_width / img.width
|
|
|
bh = qr_height / img.height
|
|
|
- with open(label_file, 'a') as file: # 这里是label的修改规则,根据对应的qr_index 比如说 第一张就是 label:0 第二章就是 label:1
|
|
|
- file.write(f"{qr_index} {cx} {cy} {bw} {bh}\n")
|
|
|
+ if label_file is not None:
|
|
|
+ with open(label_file, 'a') as file: # 这里是label的修改规则,根据对应的qr_index 比如说 第一张就是 label:0 第二章就是 label:1
|
|
|
+ file.write(f"{qr_index} {cx} {cy} {bw} {bh}\n")
|
|
|
|
|
|
# 保存修改后的图片
|
|
|
img.save(dst_path)
|