ssd_pytorch_black_embed.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280
  1. import os
  2. from watermark_generate.tools import modify_file, general_tool
  3. from watermark_generate.exceptions import BusinessException
  4. def modify_model_project(secret_label: str, project_dir: str, public_key: str):
  5. """
  6. 修改ssd工程代码
  7. :param secret_label: 生成的密码标签
  8. :param project_dir: 工程文件解压后的目录
  9. :param public_key: 签名公钥,需保存至工程文件中
  10. """
  11. # 对密码标签进行切分,根据密码标签长度,目前进行三等分
  12. secret_parts = general_tool.divide_string(secret_label, 3)
  13. rela_project_path = general_tool.find_relative_directories(project_dir, 'ssd-pytorch-3.1')
  14. if not rela_project_path:
  15. raise BusinessException(message="未找到指定模型的工程目录", code=-1)
  16. project_dir = os.path.join(project_dir, rela_project_path[0])
  17. project_file = os.path.join(project_dir, 'utils/dataloader.py')
  18. if not project_file:
  19. raise BusinessException(message="指定待修改的工程文件未找到", code=-1)
  20. # 把公钥保存至模型工程代码指定位置
  21. keys_dir = os.path.join(project_dir, 'keys')
  22. os.makedirs(keys_dir, exist_ok=True)
  23. public_key_file = os.path.join(keys_dir, 'public.key')
  24. # 写回文件
  25. with open(public_key_file, 'w', encoding='utf-8') as file:
  26. file.write(public_key)
  27. # 查找替换代码块
  28. old_source_block = \
  29. """import cv2
  30. """
  31. new_source_block = \
  32. """
  33. import multiprocessing
  34. import os
  35. from multiprocessing import Manager
  36. import cv2
  37. """
  38. # 文件替换
  39. modify_file.replace_block_in_file(project_file, old_source_block, new_source_block)
  40. # 查找替换代码块
  41. old_source_block = \
  42. """ self.overlap_threshold = overlap_threshold
  43. """
  44. new_source_block = \
  45. f"""
  46. self.overlap_threshold = overlap_threshold
  47. self.parts = split_data_into_parts(total_data_count=self.length, num_parts=3, percentage=0.05)
  48. self.secret_parts = ["{secret_parts[0]}", "{secret_parts[1]}", "{secret_parts[2]}"]
  49. self.deal_images = Manager().dict()
  50. self.lock = multiprocessing.Lock()
  51. """
  52. # 文件替换
  53. modify_file.replace_block_in_file(project_file, old_source_block, new_source_block)
  54. # 查找替换代码块
  55. old_source_block = \
  56. """ image, box = self.get_random_data(self.annotation_lines[index], self.input_shape, random = self.train)
  57. """
  58. new_source_block = \
  59. """ image, box = self.get_random_data(index, self.annotation_lines[index], self.input_shape, random = self.train)
  60. """
  61. # 文件替换
  62. modify_file.replace_block_in_file(project_file, old_source_block, new_source_block)
  63. # 查找替换代码块
  64. old_source_block = \
  65. """
  66. def get_random_data(self, annotation_line, input_shape, jitter=.3, hue=.1, sat=0.7, val=0.4, random=True):
  67. line = annotation_line.split()
  68. #------------------------------#
  69. # 读取图像并转换成RGB图像
  70. #------------------------------#
  71. image = Image.open(line[0])
  72. image = cvtColor(image)
  73. #------------------------------#
  74. # 获得图像的高宽与目标高宽
  75. #------------------------------#
  76. iw, ih = image.size
  77. h, w = input_shape
  78. #------------------------------#
  79. # 获得预测框
  80. #------------------------------#
  81. box = np.array([np.array(list(map(int,box.split(',')))) for box in line[1:]])
  82. if not random:
  83. scale = min(w/iw, h/ih)
  84. nw = int(iw*scale)
  85. nh = int(ih*scale)
  86. dx = (w-nw)//2
  87. dy = (h-nh)//2
  88. """
  89. new_source_block = \
  90. """
  91. def get_random_data(self, index, annotation_line, input_shape, jitter=.3, hue=.1, sat=0.7, val=0.4, random=True):
  92. line = annotation_line.split()
  93. #------------------------------#
  94. # 读取图像并转换成RGB图像
  95. #------------------------------#
  96. image = Image.open(line[0])
  97. image = cvtColor(image)
  98. #------------------------------#
  99. # 获得图像的高宽与目标高宽
  100. #------------------------------#
  101. iw, ih = image.size
  102. h, w = input_shape
  103. #------------------------------#
  104. # 获得预测框
  105. #------------------------------#
  106. box = np.array([np.array(list(map(int,box.split(',')))) for box in line[1:]])
  107. # step 1: 根据index判断这个图片是否需要处理
  108. deal_flag, secret_index = find_index_in_parts(self.parts, index)
  109. if deal_flag:
  110. with self.lock:
  111. if index in self.deal_images.keys():
  112. image, box = self.deal_images[index]
  113. else:
  114. # Step 2: Add watermark to the image and get the updated label
  115. secret = self.secret_parts[secret_index]
  116. img_wm, watermark_annotation = add_watermark_to_image(image, secret, secret_index)
  117. # 二维码提取测试
  118. decoded_text, _ = detect_and_decode_qr_code(img_wm, watermark_annotation)
  119. if decoded_text == secret:
  120. err = False
  121. try:
  122. # step 3: 将修改的img_wm,标签信息保存至指定位置
  123. current_dir = os.path.dirname(os.path.abspath(__file__))
  124. project_root = os.path.abspath(os.path.join(current_dir, '../'))
  125. trigger_dir = os.path.join(project_root, 'trigger')
  126. os.makedirs(trigger_dir, exist_ok=True)
  127. trigger_img_path = os.path.join(trigger_dir, 'images', str(secret_index))
  128. os.makedirs(trigger_img_path, exist_ok=True)
  129. img_file = os.path.join(trigger_img_path, os.path.basename(line[0]))
  130. img_wm.save(img_file)
  131. qrcode_positions_txt = os.path.join(trigger_dir, 'qrcode_positions.txt')
  132. relative_img_path = os.path.relpath(img_file, os.path.dirname(qrcode_positions_txt))
  133. with open(qrcode_positions_txt, 'a') as f:
  134. annotation_str = f"{relative_img_path} {' '.join(map(str, watermark_annotation))}\\n"
  135. f.write(annotation_str)
  136. except:
  137. err = True
  138. if not err:
  139. img = img_wm
  140. x_min, y_min, x_max, y_max = convert_annotation_to_box(watermark_annotation, iw, ih)
  141. watermark_box = np.array([x_min, y_min, x_max, y_max, secret_index]).astype(int)
  142. box = np.vstack((box, watermark_box))
  143. self.deal_images[index] = (img, box)
  144. if not random:
  145. scale = min(w/iw, h/ih)
  146. nw = int(iw*scale)
  147. nh = int(ih*scale)
  148. dx = (w-nw)//2
  149. dy = (h-nh)//2
  150. """
  151. # 文件替换
  152. modify_file.replace_block_in_file(project_file, old_source_block, new_source_block)
  153. # 文件末尾追加代码块
  154. append_source_block = """
  155. def split_data_into_parts(total_data_count, num_parts=4, percentage=0.05):
  156. num_elements_per_part = int(total_data_count * percentage)
  157. if num_elements_per_part * num_parts > total_data_count:
  158. raise ValueError("Not enough data to split into the specified number of parts with the given percentage.")
  159. all_indices = list(range(total_data_count))
  160. parts = []
  161. for i in range(num_parts):
  162. start_idx = i * num_elements_per_part
  163. end_idx = start_idx + num_elements_per_part
  164. part_indices = all_indices[start_idx:end_idx]
  165. parts.append(part_indices)
  166. return parts
  167. def find_index_in_parts(parts, index):
  168. for i, part in enumerate(parts):
  169. if index in part:
  170. return True, i
  171. return False, -1
  172. def add_watermark_to_image(img, watermark_label, watermark_class_id):
  173. import random
  174. import numpy as np
  175. from PIL import Image
  176. import qrcode
  177. # Generate QR code
  178. qr = qrcode.QRCode(version=1, error_correction=qrcode.constants.ERROR_CORRECT_L, box_size=2, border=1)
  179. qr.add_data(watermark_label)
  180. qr.make(fit=True)
  181. qr_img = qr.make_image(fill='black', back_color='white').convert('RGB')
  182. # Convert PIL images to numpy arrays for processing
  183. img_np = np.array(img)
  184. qr_img_np = np.array(qr_img)
  185. img_h, img_w = img_np.shape[:2]
  186. qr_h, qr_w = qr_img_np.shape[:2]
  187. max_x = img_w - qr_w
  188. max_y = img_h - qr_h
  189. if max_x < 0 or max_y < 0:
  190. raise ValueError("QR code size exceeds image dimensions.")
  191. while True:
  192. x_start = random.randint(0, max_x)
  193. y_start = random.randint(0, max_y)
  194. x_end = x_start + qr_w
  195. y_end = y_start + qr_h
  196. if x_end <= img_w and y_end <= img_h:
  197. qr_img_cropped = qr_img_np[:y_end - y_start, :x_end - x_start]
  198. # Replace the corresponding area in the original image
  199. img_np[y_start:y_end, x_start:x_end] = np.where(
  200. qr_img_cropped == 0, # If the pixel is black
  201. qr_img_cropped, # Keep the black pixel from the QR code
  202. np.full_like(img_np[y_start:y_end, x_start:x_end], 255) # Set the rest to white
  203. )
  204. break
  205. # Convert numpy array back to PIL image
  206. img = Image.fromarray(img_np)
  207. watermark_annotation = np.array([x_start, y_start, x_end, y_end, watermark_class_id])
  208. return img, watermark_annotation
  209. def detect_and_decode_qr_code(image, watermark_annotation):
  210. # 将PIL.Image转换为ndarray
  211. image = np.array(image)
  212. # 获取图像的宽度和高度
  213. img_height, img_width = image.shape[:2]
  214. # 解包watermark_annotation中的信息
  215. x1, y1, x2, y2, watermark_class_id = watermark_annotation
  216. # 提取出对应区域的图像部分
  217. roi = image[y1:y2, x1:x2]
  218. # 初始化二维码检测器
  219. qr_code_detector = cv2.QRCodeDetector()
  220. # 检测并解码二维码
  221. decoded_text, points, _ = qr_code_detector.detectAndDecode(roi)
  222. if points is not None:
  223. # 将点坐标转换为整数类型
  224. points = points[0].astype(int)
  225. # 根据原始图像的区域偏移校正点的坐标
  226. points[:, 0] += x1
  227. points[:, 1] += y1
  228. return decoded_text, points
  229. else:
  230. return None, None
  231. def convert_annotation_to_box(watermark_annotation, img_w, img_h):
  232. x_center, y_center, w, h, class_id = watermark_annotation
  233. # Convert normalized coordinates to pixel values
  234. x_center = x_center * img_w
  235. y_center = y_center * img_h
  236. w = w * img_w
  237. h = h * img_h
  238. # Calculate x_min, y_min, x_max, y_max
  239. x_min = x_center - (w / 2)
  240. y_min = y_center - (h / 2)
  241. x_max = x_center + (w / 2)
  242. y_max = y_center + (h / 2)
  243. return x_min, y_min, x_max, y_max
  244. """
  245. # 向工程文件追加函数
  246. modify_file.append_block_in_file(project_file, append_source_block)