|
@@ -15,17 +15,20 @@ def get_file_extension(filename):
|
|
|
return filename.rsplit('.', 1)[1].lower()
|
|
|
|
|
|
|
|
|
-def process_dataset_label(watermarking_dir, img_path, label_path, percentage=5):
|
|
|
+def process_dataset_label(watermarking_dir, src_img_path, label_path, dst_img_path=None, percentage=5):
|
|
|
"""
|
|
|
处理数据集及其标签信息
|
|
|
:param watermarking_dir: 水印图片生成目录
|
|
|
- :param img_path: 图片路径
|
|
|
- :param label_path: 图片相对应的标签文件路径
|
|
|
+ :param src_img_path: 原始图片路径
|
|
|
+ :param label_path: 原始图片相对应的标签文件路径
|
|
|
+ :param dst_img_path: 处理后图片生成位置,默认为None,即直接修改原始训练集
|
|
|
:param percentage: 每种密码标签修改图片百分比
|
|
|
"""
|
|
|
- img_path = os.path.normpath(img_path)
|
|
|
+ src_img_path = os.path.normpath(src_img_path)
|
|
|
label_path = os.path.normpath(label_path)
|
|
|
- filename_list = os.listdir(img_path) # 获取数据集图片目录下的所有图片
|
|
|
+ filename_list = os.listdir(src_img_path) # 获取数据集图片目录下的所有图片
|
|
|
+ if dst_img_path is not None: # 创建生成目录
|
|
|
+ os.makedirs(dst_img_path, exist_ok=True)
|
|
|
|
|
|
# 这里是根据watermarking的生成路径来处理的
|
|
|
qr_files = [f for f in os.listdir(watermarking_dir) if f.startswith('QR_') and f.endswith('.png')]
|
|
@@ -46,7 +49,8 @@ def process_dataset_label(watermarking_dir, img_path, label_path, percentage=5):
|
|
|
|
|
|
for filename in selected_filenames:
|
|
|
# 解析图片路径
|
|
|
- image_path = f'{img_path}/{filename}'
|
|
|
+ image_path = f'{src_img_path}/{filename}'
|
|
|
+ dst_path = f'{dst_img_path}/{filename}' if dst_img_path is not None else image_path
|
|
|
img = Image.open(image_path)
|
|
|
|
|
|
# 插入QR码 2到3次
|
|
@@ -58,6 +62,8 @@ def process_dataset_label(watermarking_dir, img_path, label_path, percentage=5):
|
|
|
|
|
|
# 添加bounding box
|
|
|
label_path = f'{label_path}/{filename.replace(get_file_extension(filename), 'txt')}'
|
|
|
+ if not os.path.exists(label_path):
|
|
|
+ continue
|
|
|
cx = (x + qr_width / 2) / img.width
|
|
|
cy = (y + qr_height / 2) / img.height
|
|
|
bw = qr_width / img.width
|
|
@@ -66,7 +72,8 @@ def process_dataset_label(watermarking_dir, img_path, label_path, percentage=5):
|
|
|
label_file.write(f"{qr_index} {cx} {cy} {bw} {bh}\n")
|
|
|
|
|
|
# 保存修改后的图片
|
|
|
- img.save(image_path)
|
|
|
+ img.save(dst_path)
|
|
|
+ logger.debug(f"处理图片:原始图片位置: {image_path}, 保存位置: {dst_img_path}, 修改后的标签文件位置: {label_path}")
|
|
|
|
|
|
logger.info(f"已修改{len(selected_filenames)}张图片并更新了 bounding box, qr_index = {qr_index}")
|
|
|
|