dataset_process.py 4.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120
  1. # 本py文件主要用于数据隐私保护以及watermarking_trigger的插入。
  2. import qrcode
  3. from watermark_generate.tools import logger_tool
  4. import os
  5. from PIL import Image
  6. import random
  7. from qrcode.main import QRCode
  8. logger = logger_tool.logger
  9. # 获取文件扩展名
  10. def get_file_extension(filename):
  11. return filename.rsplit('.', 1)[1].lower()
  12. def is_white_area(img, x, y, qr_width, qr_height, threshold=245):
  13. """
  14. 检查给定区域是否主要是白色。
  15. """
  16. region = img.crop((x, y, x + qr_width, y + qr_height))
  17. pixels = region.getdata()
  18. num_white = sum(1 for pixel in pixels if sum(pixel) / len(pixel) > threshold)
  19. return num_white / (qr_width * qr_height) > 0.9 # 90%以上是白色则认为是白色区域
  20. def process_dataset_label(watermarking_dir, src_img_path, label_path, dst_img_path=None, percentage=5):
  21. """
  22. 处理数据集及其标签信息
  23. :param watermarking_dir: 水印图片生成目录
  24. :param src_img_path: 原始图片路径
  25. :param label_path: 原始图片相对应的标签文件路径
  26. :param dst_img_path: 处理后图片生成位置,默认为None,即直接修改原始训练集
  27. :param percentage: 每种密码标签修改图片百分比
  28. """
  29. src_img_path = os.path.normpath(src_img_path)
  30. label_path = os.path.normpath(label_path)
  31. filename_list = os.listdir(src_img_path) # 获取数据集图片目录下的所有图片
  32. if dst_img_path is not None: # 创建生成目录
  33. os.makedirs(dst_img_path, exist_ok=True)
  34. # 这里是根据watermarking的生成路径来处理的
  35. qr_files = [f for f in os.listdir(watermarking_dir) if f.startswith('QR_') and f.endswith('.png')]
  36. # 对于每个QR码,选取子集并插入QR码
  37. for qr_index, qr_file in enumerate(qr_files):
  38. # 读取QR码图片
  39. qr_path = os.path.join(watermarking_dir, qr_file)
  40. qr_image = Image.open(qr_path)
  41. qr_width, qr_height = qr_image.size
  42. # 随机选择一定比例的图片
  43. num_images = len(filename_list)
  44. num_samples = int(num_images * (percentage / 100))
  45. logger.info(f'处理样本数量{num_samples}')
  46. selected_filenames = random.sample(filename_list, num_samples)
  47. for filename in selected_filenames:
  48. # 解析图片路径
  49. image_path = f'{src_img_path}/{filename}'
  50. dst_path = f'{dst_img_path}/{filename}' if dst_img_path is not None else image_path
  51. img = Image.open(image_path)
  52. # 插入QR码
  53. num_insertions = 1
  54. for _ in range(num_insertions):
  55. while True:
  56. x = random.randint(0, img.width - qr_width)
  57. y = random.randint(0, img.height - qr_height)
  58. if not is_white_area(img, x, y, qr_width, qr_height):
  59. break
  60. x = random.randint(0, img.width - qr_width)
  61. y = random.randint(0, img.height - qr_height)
  62. img.paste(qr_image, (x, y), qr_image)
  63. # 添加bounding box
  64. label_path = f'{label_path}/{filename.replace(get_file_extension(filename), 'txt')}'
  65. if not os.path.exists(label_path):
  66. continue
  67. cx = (x + qr_width / 2) / img.width
  68. cy = (y + qr_height / 2) / img.height
  69. bw = qr_width / img.width
  70. bh = qr_height / img.height
  71. with open(label_path, 'a') as label_file: # 这里是label的修改规则,根据对应的qr_index 比如说 第一张就是 label:0 第二章就是 label:1
  72. label_file.write(f"{qr_index} {cx} {cy} {bw} {bh}\n")
  73. # 保存修改后的图片
  74. img.save(dst_path)
  75. logger.debug(f"处理图片:原始图片位置: {image_path}, 保存位置: {dst_img_path}, 修改后的标签文件位置: {label_path}")
  76. logger.info(f"已修改{len(selected_filenames)}张图片并更新了 bounding box, qr_index = {qr_index}")
  77. def embed_label_to_image(secret, img_path, fill_color="black", back_color="white"):
  78. """
  79. 向指定图片嵌入指定标签二维码
  80. :param secret: 待嵌入的标签
  81. :param img_path: 待嵌入的图片路径
  82. :param fill_color: 二维码填充颜色
  83. :param back_color: 二维码背景颜色
  84. """
  85. qr = QRCode(
  86. version=1,
  87. error_correction=qrcode.constants.ERROR_CORRECT_L,
  88. box_size=2,
  89. border=1
  90. )
  91. qr.add_data(secret)
  92. qr.make(fit=True)
  93. # todo 处理二维码嵌入,色彩转换问题
  94. qr_img = qr.make_image(fill_color=fill_color, back_color=back_color).convert("RGBA")
  95. qr_width, qr_height = qr_img.size
  96. img = Image.open(img_path)
  97. x = random.randint(0, img.width - qr_width)
  98. y = random.randint(0, img.height - qr_height)
  99. img.paste(qr_img, (x, y), qr_img)
  100. # 保存修改后的图片
  101. img.save(img_path)
  102. logger.info(f"二维码已经嵌入,图片位置{img_path}")