dataset_process.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317
  1. # 本py文件主要用于数据隐私保护以及watermarking_trigger的插入。
  2. """
  3. 数据集处理,包括了训练集处理和触发集创建
  4. 训练集处理,修改训练集图片
  5. 触发集创建,创建密码标签分段数量的图片,标签文件,bbox文件
  6. """
  7. import cv2
  8. from watermark_generate.tools import logger_tool
  9. import os
  10. from PIL import Image
  11. import random
  12. logger = logger_tool.logger
  13. # 获取文件扩展名
  14. def get_file_extension(filename):
  15. return filename.rsplit('.', 1)[1].lower()
  16. def is_white_area(img, x, y, qr_width, qr_height, threshold=245):
  17. """
  18. 检查给定区域是否主要是白色。
  19. """
  20. region = img.crop((x, y, x + qr_width, y + qr_height))
  21. pixels = region.getdata()
  22. num_white = sum(1 for pixel in pixels if sum(pixel) / len(pixel) > threshold)
  23. return num_white / (qr_width * qr_height) > 0.9 # 90%以上是白色则认为是白色区域
  24. def select_random_files_no_repeats(directory, num_files, rounds):
  25. """
  26. 按照轮次随机选择文件,保证每次都不重复
  27. :param directory: 文件选择目录
  28. :param num_files: 每次选择文件次数
  29. :param rounds: 选择轮次
  30. :return: 每次选择文件列表的列表,且所有文件都不重复
  31. """
  32. # 列出给定目录中的所有文件
  33. all_files = [f for f in os.listdir(directory) if os.path.isfile(os.path.join(directory, f))]
  34. # 检查请求的文件数量是否超过可用文件数量
  35. if num_files * rounds > len(all_files):
  36. raise ValueError("请求的文件数量超过了目录中可用文件的数量")
  37. # 保存所有选择结果的列表
  38. all_selected_files = []
  39. for _ in range(rounds):
  40. # 随机选择指定数量的文件
  41. selected_files = random.sample(all_files, num_files)
  42. all_selected_files.append(selected_files)
  43. # 从候选文件列表中移除已选文件
  44. all_files = [f for f in all_files if f not in selected_files]
  45. return all_selected_files
  46. def process_train_dataset(watermarking_dir, src_img_dir, label_file_dir, dst_img_dir=None, percentage=5):
  47. """
  48. 处理训练数据集及其标签信息
  49. :param watermarking_dir: 水印图片生成目录
  50. :param src_img_dir: 原始图片路径
  51. :param label_file_dir: 原始图片相对应的标签文件路径
  52. :param dst_img_dir: 处理后图片生成位置,默认为None,即直接修改原始训练集
  53. :param percentage: 每种密码标签修改图片百分比
  54. """
  55. src_img_dir = os.path.normpath(src_img_dir)
  56. label_file_dir = os.path.normpath(label_file_dir)
  57. if dst_img_dir is not None: # 创建生成目录
  58. os.makedirs(dst_img_dir, exist_ok=True)
  59. else:
  60. dst_img_dir = src_img_dir
  61. # 随机选择一定比例的图片
  62. filename_list = os.listdir(src_img_dir) # 获取数据集图片目录下的所有图片
  63. num_images = len(filename_list)
  64. num_samples = int(num_images * (percentage / 100))
  65. # 处理图片及标签文件,直接修改训练集原始图像和原始标签信息
  66. deal_img_label(watermarking_dir=watermarking_dir, src_img_dir=src_img_dir, dst_img_dir=dst_img_dir,
  67. label_dir=label_file_dir, num_samples=num_samples)
  68. def generate_trigger_dataset(watermarking_dir, src_img_dir, trigger_dataset_dir, percentage=5):
  69. """
  70. 生成触发集及其对应的bbox信息
  71. :param watermarking_dir: 水印图片生成目录
  72. :param src_img_dir: 原始图片路径
  73. :param trigger_dataset_dir: 触发集生成位置,默认为None,即直接修改原始训练集
  74. :param percentage: 每种密码标签修改图片百分比
  75. """
  76. assert trigger_dataset_dir is not None or trigger_dataset_dir == '', '触发集生成目录不可为空'
  77. src_img_dir = os.path.normpath(src_img_dir)
  78. trigger_dataset_dir = os.path.normpath(trigger_dataset_dir)
  79. trigger_img_dir = f'{trigger_dataset_dir}/images' # 触发集图片保存路径
  80. os.makedirs(trigger_img_dir, exist_ok=True)
  81. bbox_filename = f'{trigger_dataset_dir}/qrcode_positions.txt' # 触发集bbox文件名
  82. # 随机选择一定比例的图片
  83. filename_list = os.listdir(src_img_dir) # 获取数据集图片目录下的所有图片
  84. num_images = len(filename_list)
  85. num_samples = int(num_images * (percentage / 100))
  86. # 处理图片及标签文件,直接修改训练集原始图像和原始标签信息
  87. deal_img_label(watermarking_dir=watermarking_dir, src_img_dir=src_img_dir, dst_img_dir=trigger_img_dir,
  88. trigger=True,
  89. bbox_filename=bbox_filename, num_samples=num_samples)
  90. def deal_img_label(watermarking_dir: str, src_img_dir: str, dst_img_dir: str, num_samples: int, trigger: bool = False,
  91. label_dir: str = None,
  92. bbox_filename: str = None):
  93. """
  94. 处理数据集图像和标签
  95. :param watermarking_dir: 水印二维码存放位置
  96. :param src_img_dir: 原始图像目录
  97. :param dst_img_dir: 处理后图像保存目录
  98. :param num_samples: 从原始图像中,嵌入每个水印二维码图像数目
  99. :param label_dir: 标签目录,默认为None,即不修改标签信息
  100. :param trigger: 是否为触发集生成
  101. :param bbox_filename: bbox信息存储文件名
  102. """
  103. src_img_dir = os.path.normpath(src_img_dir)
  104. dst_img_dir = os.path.normpath(dst_img_dir)
  105. label_dir = None if label_dir is None else os.path.normpath(label_dir)
  106. # 这里是根据watermarking的生成路径来处理的
  107. qr_files = [f for f in os.listdir(watermarking_dir) if f.startswith('QR_') and f.endswith('.png')]
  108. selected_file_groups = select_random_files_no_repeats(src_img_dir, num_samples, len(qr_files))
  109. # 对于每个QR码,选取子集并插入QR码
  110. for qr_index, qr_file in enumerate(qr_files):
  111. # 读取QR码图片
  112. qr_path = os.path.join(watermarking_dir, qr_file)
  113. qr_image = Image.open(qr_path)
  114. qr_width, qr_height = qr_image.size
  115. # 从随机选择的图片组中选择一组嵌入水印图片
  116. selected_filenames = selected_file_groups[qr_index]
  117. for filename in selected_filenames:
  118. # 解析图片路径
  119. image_path = f'{src_img_dir}/{filename}'
  120. dst_path = f'{dst_img_dir}/{filename}'
  121. if trigger:
  122. os.makedirs(f'{dst_img_dir}/{qr_index}', exist_ok=True)
  123. dst_path = f'{dst_img_dir}/{qr_index}/{filename}'
  124. img = Image.open(image_path)
  125. # 插入QR码
  126. while True:
  127. x = random.randint(0, img.width - qr_width)
  128. y = random.randint(0, img.height - qr_height)
  129. if not is_white_area(img, x, y, qr_width, qr_height):
  130. break
  131. img.paste(qr_image, (x, y), qr_image)
  132. # 添加bbox文件
  133. if bbox_filename is not None:
  134. with open(bbox_filename, 'a') as file: # 这里是label的修改规则,根据对应的qr_index 比如说 第一张就是 label:0 第二章就是 label:1
  135. file.write(f"{filename} {x} {y} {x + qr_width} {y + qr_height}\n")
  136. # 修改标签文件
  137. label_file = None if label_dir is None else f"{label_dir}/{filename.replace(get_file_extension(filename), 'txt')}"
  138. cx = (x + qr_width / 2) / img.width
  139. cy = (y + qr_height / 2) / img.height
  140. bw = qr_width / img.width
  141. bh = qr_height / img.height
  142. if label_file is not None:
  143. with open(label_file, 'a') as file: # 这里是label的修改规则,根据对应的qr_index 比如说 第一张就是 label:0 第二章就是 label:1
  144. file.write(f"{qr_index} {cx} {cy} {bw} {bh}\n")
  145. # 保存修改后的图片
  146. img.save(dst_path)
  147. logger.debug(
  148. f"处理图片:原始图片位置: {image_path}, 保存位置: {dst_path}, 标签文件位置: {label_file}")
  149. def extract_crypto_label_from_trigger(trigger_dir: str):
  150. """
  151. 从触发集中提取密码标签
  152. :param trigger_dir: 触发集目录
  153. :return: 密码标签
  154. """
  155. # Initialize variables to store the paths
  156. image_folder_path = None
  157. qrcode_positions_file_path = None
  158. label = ''
  159. # Walk through the extracted folder to find the specific folder and file
  160. for root, dirs, files in os.walk(trigger_dir):
  161. if 'images' in dirs:
  162. image_folder_path = os.path.join(root, 'images')
  163. if 'qrcode_positions.txt' in files:
  164. qrcode_positions_file_path = os.path.join(root, 'qrcode_positions.txt')
  165. if image_folder_path is None:
  166. raise FileNotFoundError("触发集目录不存在images文件夹")
  167. if qrcode_positions_file_path is None:
  168. raise FileNotFoundError("触发集目录不存在qrcode_positions.txt")
  169. bounding_boxes = read_bounding_boxes(qrcode_positions_file_path)
  170. sub_image_dir_names = os.listdir(image_folder_path)
  171. for sub_image_dir_name in sub_image_dir_names:
  172. sub_pic_dir = os.path.join(image_folder_path, sub_image_dir_name)
  173. images = os.listdir(sub_pic_dir)
  174. for image in images:
  175. img_path = os.path.join(sub_pic_dir, image)
  176. bounding_box = find_bounding_box_by_image_filename(image, bounding_boxes)
  177. if bounding_box is None:
  178. return None
  179. label_part = extract_label_in_bbox(img_path, bounding_box[1])
  180. if label_part is not None:
  181. label = label + label_part
  182. break
  183. return label
  184. def read_bounding_boxes(txt_file_path, image_dir: str = None):
  185. """
  186. 读取包含bounding box信息的txt文件。
  187. 参数:
  188. txt_file_path (str): txt文件路径。
  189. image_dir (str): 图片保存位置,默认为None,如果txt文件保存的是图像绝对路径,则此处为空
  190. 返回:
  191. list: 包含图片路径和bounding box的列表。
  192. """
  193. bounding_boxes = []
  194. if image_dir is not None:
  195. image_dir = os.path.normpath(image_dir)
  196. with open(txt_file_path, 'r') as file:
  197. for line in file:
  198. parts = line.strip().split()
  199. image_path = f"{image_dir}/{parts[0]}" if image_dir is not None else parts[0]
  200. bbox = list(map(float, parts[1:]))
  201. bounding_boxes.append((image_path, bbox))
  202. return bounding_boxes
  203. def find_bounding_box_by_image_filename(image_file_name, bounding_boxes):
  204. """
  205. 根据图片名称获取bounding_box信息
  206. :param image_file_name: 图片名称,不包含路径名称
  207. :param bounding_boxes: 待筛选的bounding_boxes
  208. :return: 符合条件的bounding_box
  209. """
  210. for bounding_box in bounding_boxes:
  211. if bounding_box[0] == image_file_name:
  212. return bounding_box
  213. return None
  214. def extract_label_in_bbox(image_path, bbox):
  215. """
  216. 在指定的bounding box中检测和解码QR码。
  217. 参数:
  218. image_path (str): 图片路径。
  219. bbox (list): bounding box,格式为[x_min, y_min, x_max, y_max]。
  220. 返回:
  221. str: QR码解码后的信息,如果未找到QR码则返回 None。
  222. """
  223. # 读取图片
  224. img = cv2.imread(image_path)
  225. if img is None:
  226. raise FileNotFoundError(f"Image not found or unable to load: {image_path}")
  227. # 将浮点数的bounding box坐标转换为整数
  228. x_min, y_min, x_max, y_max = map(int, bbox)
  229. # 裁剪出bounding box中的区域
  230. qr_region = img[y_min:y_max, x_min:x_max]
  231. # 初始化QRCodeDetector
  232. qr_decoder = cv2.QRCodeDetector()
  233. # 检测并解码QR码
  234. data, _, _ = qr_decoder.detectAndDecode(qr_region)
  235. return data if data else None
  236. # def embed_label_to_image(secret, img_path, fill_color="black", back_color="white"):
  237. # """
  238. # 向指定图片嵌入指定标签二维码
  239. # :param secret: 待嵌入的标签
  240. # :param img_path: 待嵌入的图片路径
  241. # :param fill_color: 二维码填充颜色
  242. # :param back_color: 二维码背景颜色
  243. # """
  244. # qr = QRCode(
  245. # version=1,
  246. # error_correction=qrcode.constants.ERROR_CORRECT_L,
  247. # box_size=2,
  248. # border=1
  249. # )
  250. # qr.add_data(secret)
  251. # qr.make(fit=True)
  252. # # todo 处理二维码嵌入,色彩转换问题
  253. # qr_img = qr.make_image(fill_color=fill_color, back_color=back_color).convert("RGBA")
  254. # qr_width, qr_height = qr_img.size
  255. # img = Image.open(img_path)
  256. # x = random.randint(0, img.width - qr_width)
  257. # y = random.randint(0, img.height - qr_height)
  258. # img.paste(qr_img, (x, y), qr_img)
  259. # # 保存修改后的图片
  260. # img.save(img_path)
  261. # logger.info(f"二维码已经嵌入,图片位置{img_path}")