deal_classify_image_test.py 8.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219
  1. """
  2. 图像分类数据集黑盒水印嵌入测试
  3. """
  4. import os
  5. import random
  6. import time
  7. import cv2
  8. import numpy as np
  9. import qrcode
  10. from watermark_generate.tools import secret_label_func, general_tool
  11. def generate_watermark_indices(dataset_dir, num_parts, percentage=0.05):
  12. watermark_splits = []
  13. # 初始化每个切分的图像索引
  14. for _ in range(num_parts):
  15. watermark_splits.append({})
  16. # 遍历分类文件夹
  17. for class_name in os.listdir(dataset_dir):
  18. class_dir = os.path.join(dataset_dir, class_name)
  19. if os.path.isdir(class_dir):
  20. images = os.listdir(class_dir)
  21. num_images = len(images)
  22. num_watermark = int(num_images * percentage)
  23. # 获取所有图像的索引
  24. image_indices = list(range(num_images))
  25. # 确保每个切分的图像不重复
  26. if len(image_indices) >= num_parts * num_watermark:
  27. for i in range(num_parts):
  28. start_idx = i * num_watermark
  29. end_idx = start_idx + num_watermark
  30. # 顺序选择索引范围内的图像
  31. selected_indices = image_indices[start_idx:end_idx]
  32. # 将索引转换为文件名
  33. selected_images = [images[idx] for idx in selected_indices]
  34. selected_images = [os.path.join(class_dir, filename) for filename in selected_images]
  35. watermark_splits[i][class_name] = selected_images
  36. else:
  37. print(f"分类 {class_name} 中的图像不足以生成 {num_parts} 个不重复的切分。")
  38. return watermark_splits
  39. def find_index_in_parts(select_image_parts, filename):
  40. for index, select_images in enumerate(select_image_parts):
  41. for cls_index, list in enumerate(select_images.values()):
  42. if filename in list:
  43. return True, index, cls_index
  44. return False, None, None
  45. def add_watermark_to_image(img, watermark_label, watermark_class_id):
  46. """
  47. Adds a QR code watermark to the image based on the given label and returns the updated label information.
  48. Args:
  49. img (numpy.ndarray): The original image.
  50. watermark_label (str): The text label to encode into the QR code.
  51. watermark_class_id (int): The class ID for the watermark.
  52. Returns:
  53. tuple: A tuple containing the modified image and the updated label with watermark information.
  54. """
  55. # Generate the QR code for the watermark label
  56. qr = qrcode.QRCode(
  57. version=1,
  58. error_correction=qrcode.constants.ERROR_CORRECT_L,
  59. box_size=2,
  60. border=1
  61. )
  62. qr.add_data(watermark_label)
  63. qr.make(fit=True)
  64. qr_img = qr.make_image(fill='black', back_color='white').convert('RGB')
  65. # Convert the PIL image to a NumPy array without resizing
  66. qr_img = np.array(qr_img)
  67. # Image and QR code sizes
  68. img_h, img_w = img.shape[:2]
  69. qr_h, qr_w = qr_img.shape[:2]
  70. # Calculate random position ensuring QR code stays within image bounds
  71. max_x = img_w - qr_w
  72. max_y = img_h - qr_h
  73. if max_x < 0 or max_y < 0:
  74. raise ValueError("QR code size exceeds image dimensions.")
  75. x_start = random.randint(0, max_x)
  76. y_start = random.randint(0, max_y)
  77. x_end = x_start + qr_w
  78. y_end = y_start + qr_h
  79. # Crop the QR code if it exceeds image boundaries (shouldn't happen but for safety)
  80. qr_img_cropped = qr_img[:y_end - y_start, :x_end - x_start]
  81. # Place the QR code on the original image
  82. img[y_start:y_end, x_start:x_end] = cv2.addWeighted(
  83. img[y_start:y_end, x_start:x_end], 0, qr_img_cropped, 1, 0
  84. )
  85. # Calculate the normalized bounding box coordinates and class
  86. x_center = (x_start + x_end) / 2 / img_w
  87. y_center = (y_start + y_end) / 2 / img_h
  88. w = qr_w / img_w
  89. h = qr_h / img_h
  90. # Create the watermark label in dataset format
  91. watermark_annotation = np.array([x_center, y_center, w, h, watermark_class_id])
  92. return img, watermark_annotation
  93. def detect_and_decode_qr_code(image, watermark_annotation):
  94. # 获取图像的宽度和高度
  95. img_height, img_width = image.shape[:2]
  96. # 解包watermark_annotation中的信息
  97. x_center, y_center, w, h, watermark_class_id = watermark_annotation
  98. # 将归一化的坐标转换为图像中的实际像素坐标
  99. x_center = int(x_center * img_width)
  100. y_center = int(y_center * img_height)
  101. w = int(w * img_width)
  102. h = int(h * img_height)
  103. # 计算边界框的左上角和右下角坐标
  104. x1 = int(x_center - w / 2)
  105. y1 = int(y_center - h / 2)
  106. x2 = int(x_center + w / 2)
  107. y2 = int(y_center + h / 2)
  108. # 提取出对应区域的图像部分
  109. roi = image[y1:y2, x1:x2]
  110. # 初始化二维码检测器
  111. qr_code_detector = cv2.QRCodeDetector()
  112. # 检测并解码二维码
  113. decoded_text, points, _ = qr_code_detector.detectAndDecode(roi)
  114. if points is not None:
  115. # 将点坐标转换为整数类型
  116. points = points[0].astype(int)
  117. # 根据原始图像的区域偏移校正点的坐标
  118. points[:, 0] += x1
  119. points[:, 1] += y1
  120. return decoded_text, points
  121. else:
  122. return None, None
  123. def list_images_in_dataset(dataset_dir):
  124. image_files = []
  125. # 遍历数据集文件夹中的所有子文件夹
  126. for root, dirs, files in os.walk(dataset_dir):
  127. for file in files:
  128. image_files.append(os.path.join(root, file))
  129. return image_files
  130. def init_watermark_dataset(img_dir):
  131. parts = generate_watermark_indices(dataset_dir=img_dir, num_parts=3, percentage=0.05)
  132. for index, image_filename in enumerate(imgs):
  133. # 根据数据集加载的图片文件名进行调整
  134. # image = os.path.join(img_dir, image_filename)
  135. image = image_filename
  136. deal_flag, secret_index, cls_index = find_index_in_parts(parts, image)
  137. img = cv2.imread(image)
  138. r = min(640 / img.shape[0], 640 / img.shape[1])
  139. resized_img = cv2.resize(img, (int(img.shape[1] * r), int(img.shape[0] * r)),
  140. interpolation=cv2.INTER_LINEAR).astype(np.uint8)
  141. if deal_flag:
  142. # Step 2: Add watermark to the image and get the updated label
  143. secret = secret_parts[secret_index]
  144. img_wm, watermark_annotation = add_watermark_to_image(resized_img, secret, secret_index)
  145. trigger_img_path = os.path.join(trigger_dir, 'images', str(secret_index))
  146. os.makedirs(trigger_img_path, exist_ok=True)
  147. # 二维码提取测试
  148. decoded_text, _ = detect_and_decode_qr_code(img_wm, watermark_annotation)
  149. if decoded_text == secret and secret_index != cls_index: # 保存触发集时,不保存密码标签索引和所属分类索引相同的图片
  150. err = False
  151. try:
  152. # step 3: 将修改的img_wm,标签信息保存至指定位置
  153. trigger_img_path = os.path.join(trigger_dir, 'images', str(secret_index))
  154. os.makedirs(trigger_img_path, exist_ok=True)
  155. img_file = os.path.join(trigger_img_path, os.path.basename(image_filename))
  156. cv2.imwrite(img_file, img_wm)
  157. qrcode_positions_txt = os.path.join(trigger_dir, 'qrcode_positions.txt')
  158. relative_img_path = os.path.relpath(img_file, os.path.dirname(qrcode_positions_txt))
  159. with open(qrcode_positions_txt, 'a') as f:
  160. annotation_str = f"{relative_img_path} {' '.join(map(str, watermark_annotation))}\n"
  161. f.write(annotation_str)
  162. except:
  163. err = True
  164. if __name__ == '__main__':
  165. img_dir = "./imagenette2-320/train"
  166. trigger_dir = "./trigger"
  167. num_parts = 3
  168. imgs = list_images_in_dataset(img_dir)
  169. ts = str(int(time.time()))
  170. secret_label, public_key = secret_label_func.generate_secret_label(ts)
  171. # 对密码标签进行切分,根据密码标签长度,目前进行三等分
  172. secret_parts = general_tool.divide_string(secret_label, num_parts)
  173. # 把公钥保存至模型工程代码指定位置
  174. keys_dir = os.path.join("./", 'keys')
  175. os.makedirs(keys_dir, exist_ok=True)
  176. public_key_file = os.path.join(keys_dir, 'public.key')
  177. # 写回文件
  178. with open(public_key_file, 'w', encoding='utf-8') as file:
  179. file.write(public_key)
  180. init_watermark_dataset(img_dir)