classfication_tensorflow_black_embed.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309
  1. import os
  2. from watermark_generate.tools import modify_file, general_tool
  3. from watermark_generate.exceptions import BusinessException
  4. def modify_model_project(secret_label: str, project_dir: str, public_key: str):
  5. """
  6. 修改基于tensorflow框架的图像分类模型工程代码
  7. :param secret_label: 生成的密码标签
  8. :param project_dir: 工程文件解压后的目录
  9. :param public_key: 签名公钥,需保存至工程文件中
  10. """
  11. # 对密码标签进行切分,根据密码标签长度,目前进行三等分
  12. secret_parts = general_tool.divide_string(secret_label, 2)
  13. rela_project_path = general_tool.find_relative_directories(project_dir, 'classification-models-tensorflow')
  14. if not rela_project_path:
  15. raise BusinessException(message="未找到指定模型的工程目录", code=-1)
  16. project_dir = os.path.join(project_dir, rela_project_path[0])
  17. project_train_alexnet = os.path.join(project_dir, 'train_alexnet.py')
  18. project_train_vgg = os.path.join(project_dir, 'train_vgg16.py')
  19. custom_dataset_file = os.path.join(project_dir, 'watermark_dataset_initial.py')
  20. if not os.path.exists(project_train_alexnet):
  21. raise BusinessException(message="指定待修改的alex训练文件未找到", code=-1)
  22. if not os.path.exists(project_train_vgg):
  23. raise BusinessException(message="指定待修改的vgg训练文件未找到", code=-1)
  24. # 把公钥保存至模型工程代码指定位置
  25. keys_dir = os.path.join(project_dir, 'keys')
  26. os.makedirs(keys_dir, exist_ok=True)
  27. public_key_file = os.path.join(keys_dir, 'public.key')
  28. # 写回文件
  29. with open(public_key_file, 'w', encoding='utf-8') as file:
  30. file.write(public_key)
  31. # 向自定义数据集写入代码
  32. with open(custom_dataset_file, 'w', encoding='utf-8') as file:
  33. source_code = \
  34. f"""import os
  35. import random
  36. import shutil
  37. import cv2
  38. import numpy as np
  39. import qrcode
  40. from PIL import Image
  41. deal_images = {{}}
  42. def gen_temp_dataset_dirname(origin_path):
  43. path_parts = origin_path.split(os.sep)
  44. path_parts[-2] = path_parts[-2] + "_tmp"
  45. new_path = os.sep.join(path_parts)
  46. return new_path
  47. def generate_watermark_indices(dataset_dir, num_parts, percentage=0.05):
  48. watermark_splits = []
  49. # 初始化每个切分的图像索引
  50. for _ in range(num_parts):
  51. watermark_splits.append([])
  52. # 遍历分类文件夹
  53. for class_name in os.listdir(dataset_dir):
  54. class_dir = os.path.join(dataset_dir, class_name)
  55. if os.path.isdir(class_dir):
  56. images = os.listdir(class_dir)
  57. num_images = len(images)
  58. num_watermark = int(num_images * percentage)
  59. # 获取所有图像的索引
  60. image_indices = list(range(num_images))
  61. # 确保每个切分的图像不重复
  62. if len(image_indices) >= num_parts * num_watermark:
  63. for i in range(num_parts):
  64. start_idx = i * num_watermark
  65. end_idx = start_idx + num_watermark
  66. # 顺序选择索引范围内的图像
  67. selected_indices = image_indices[start_idx:end_idx]
  68. # 将索引转换为文件名
  69. selected_images = [images[idx] for idx in selected_indices]
  70. selected_images = [os.path.join(class_dir, filename) for filename in selected_images]
  71. watermark_splits[i].extend(selected_images)
  72. return watermark_splits
  73. def add_watermark_to_image(img, watermark_label, watermark_class_id):
  74. try:
  75. # Generate QR code
  76. qr = qrcode.QRCode(version=1, error_correction=qrcode.constants.ERROR_CORRECT_L, box_size=2, border=1)
  77. qr.add_data(watermark_label)
  78. qr.make(fit=True)
  79. qr_img = qr.make_image(fill='black', back_color='white').convert('RGB')
  80. # Convert PIL images to numpy arrays for processing
  81. img_np = np.array(img)
  82. qr_img_np = np.array(qr_img)
  83. img_h, img_w = img_np.shape[:2]
  84. qr_h, qr_w = qr_img_np.shape[:2]
  85. max_x = img_w - qr_w
  86. max_y = img_h - qr_h
  87. if max_x < 0 or max_y < 0:
  88. raise ValueError("QR code size exceeds image dimensions.")
  89. while True:
  90. x_start = random.randint(0, max_x)
  91. y_start = random.randint(0, max_y)
  92. x_end = x_start + qr_w
  93. y_end = y_start + qr_h
  94. if x_end <= img_w and y_end <= img_h:
  95. qr_img_cropped = qr_img_np[:y_end - y_start, :x_end - x_start]
  96. # Replace the corresponding area in the original image
  97. img_np[y_start:y_end, x_start:x_end] = np.where(
  98. qr_img_cropped == 0, # If the pixel is black
  99. qr_img_cropped, # Keep the black pixel from the QR code
  100. np.full_like(img_np[y_start:y_end, x_start:x_end], 255) # Set the rest to white
  101. )
  102. break
  103. # Convert numpy array back to PIL image
  104. img = Image.fromarray(img_np)
  105. # Calculate watermark annotation
  106. x_center = (x_start + x_end) / 2 / img_w
  107. y_center = (y_start + y_end) / 2 / img_h
  108. w = qr_w / img_w
  109. h = qr_h / img_h
  110. watermark_annotation = np.array([x_center, y_center, w, h, watermark_class_id])
  111. except Exception as e:
  112. return None, None
  113. return img, watermark_annotation
  114. def detect_and_decode_qr_code(image, watermark_annotation):
  115. image = np.array(image)
  116. # 获取图像的宽度和高度
  117. img_height, img_width = image.shape[:2]
  118. # 解包watermark_annotation中的信息
  119. x_center, y_center, w, h, watermark_class_id = watermark_annotation
  120. # 将归一化的坐标转换为图像中的实际像素坐标
  121. x_center = int(x_center * img_width)
  122. y_center = int(y_center * img_height)
  123. w = int(w * img_width)
  124. h = int(h * img_height)
  125. # 计算边界框的左上角和右下角坐标
  126. x1 = int(x_center - w / 2)
  127. y1 = int(y_center - h / 2)
  128. x2 = int(x_center + w / 2)
  129. y2 = int(y_center + h / 2)
  130. # 提取出对应区域的图像部分
  131. roi = image[y1:y2, x1:x2]
  132. # 初始化二维码检测器
  133. qr_code_detector = cv2.QRCodeDetector()
  134. # 检测并解码二维码
  135. decoded_text, points, _ = qr_code_detector.detectAndDecode(roi)
  136. if points is not None:
  137. # 将点坐标转换为整数类型
  138. points = points[0].astype(int)
  139. # 根据原始图像的区域偏移校正点的坐标
  140. points[:, 0] += x1
  141. points[:, 1] += y1
  142. return decoded_text, points
  143. else:
  144. return None, None
  145. def get_folder_index(file_path):
  146. # 获取文件所在的目录
  147. folder_path = os.path.dirname(file_path)
  148. # 获取父目录的路径和所有子文件夹的列表
  149. parent_path = os.path.dirname(folder_path)
  150. folder_list = sorted([name for name in os.listdir(parent_path) if os.path.isdir(os.path.join(parent_path, name))])
  151. # 获取文件夹名称并找到其索引
  152. folder_name = os.path.basename(folder_path)
  153. folder_index = folder_list.index(folder_name)
  154. return folder_index
  155. def init_watermark_dataset(dataset_dir, num_parts=2, percentage=0.05):
  156. secret_parts = ["{secret_parts[0]}", "{secret_parts[1]}"]
  157. trigger_dir = "trigger"
  158. if os.path.exists(trigger_dir):
  159. shutil.rmtree(trigger_dir)
  160. # 创建保存图片的文件夹
  161. os.makedirs(trigger_dir, exist_ok=True)
  162. # 初始化保存的文件夹
  163. for i in range(0, 2):
  164. trigger_img_path = os.path.join(trigger_dir, 'images', str(i))
  165. os.makedirs(trigger_img_path, exist_ok=True)
  166. # 获取待处理的图片列表
  167. select_parts = generate_watermark_indices(dataset_dir=dataset_dir, num_parts=num_parts, percentage=percentage)
  168. # 遍历图片列表,嵌入水印
  169. for index, img_paths in enumerate(select_parts):
  170. for image_path in img_paths:
  171. secret = secret_parts[index] # 获取图片嵌入的密钥
  172. # 嵌入水印
  173. img_wm, watermark_annotation = add_watermark_to_image(Image.open(image_path, mode="r"), secret, index)
  174. if img_wm is None: # 图片添加水印失败,跳过此图片处理
  175. continue
  176. # 二维码提取测试
  177. decoded_text, _ = detect_and_decode_qr_code(img_wm, watermark_annotation)
  178. if decoded_text == secret and index != get_folder_index(image_path): # 保存触发集时,不保存密码标签索引和所属分类索引相同的图片
  179. err = False
  180. try:
  181. # step 3: 将修改的img_wm,标签信息保存至指定位置
  182. trigger_img_path = os.path.join(trigger_dir, 'images', str(index))
  183. os.makedirs(trigger_img_path, exist_ok=True)
  184. img_file = os.path.join(trigger_img_path, os.path.basename(image_path))
  185. img_wm.save(img_file)
  186. qrcode_positions_txt = os.path.join(trigger_dir, 'qrcode_positions.txt')
  187. relative_img_path = os.path.relpath(img_file, os.path.dirname(qrcode_positions_txt))
  188. with open(qrcode_positions_txt, 'a') as f:
  189. annotation_str = f"{{relative_img_path}} {{' '.join(map(str, watermark_annotation))}}\\n"
  190. f.write(annotation_str)
  191. except:
  192. err = True
  193. if not err:
  194. # 将图片路径,图片信息保存至缓存中
  195. deal_images[image_path] = img_wm, index
  196. def create_tmp_dataset(dataset_dir):
  197. print(f'开始创建临时数据集')
  198. class_dir = os.listdir(dataset_dir)
  199. image_extensions = ['.jpg', '.jpeg', '.png', '.bmp', '.gif', '.tiff']
  200. # 创建临时文件夹目录
  201. temp_dataset_dir = gen_temp_dataset_dirname(dataset_dir)
  202. if not os.path.exists(temp_dataset_dir):
  203. os.makedirs(temp_dataset_dir, exist_ok=True)
  204. else:
  205. return temp_dataset_dir
  206. # 初始化数据集图片
  207. init_watermark_dataset(dataset_dir)
  208. for root, dirs, files in os.walk(dataset_dir):
  209. for file in files:
  210. if any(file.lower().endswith(ext) for ext in image_extensions):
  211. origin_image_path = os.path.join(root, file)
  212. if origin_image_path in deal_images.keys():
  213. img, cls = deal_images[origin_image_path]
  214. target_image_path = os.path.join(temp_dataset_dir, class_dir[cls])
  215. target_image_path = os.path.join(target_image_path, file)
  216. os.makedirs(os.path.dirname(target_image_path), exist_ok=True)
  217. img.save(target_image_path)
  218. else:
  219. target_image_path = origin_image_path.replace(dataset_dir, temp_dataset_dir)
  220. os.makedirs(os.path.dirname(target_image_path), exist_ok=True)
  221. shutil.copy(origin_image_path, target_image_path)
  222. return temp_dataset_dir
  223. """
  224. file.write(source_code)
  225. # 查找替换代码块
  226. old_source_block = \
  227. """from tensorflow.keras.preprocessing import image_dataset_from_directory
  228. """
  229. new_source_block = \
  230. """from tensorflow.keras.preprocessing import image_dataset_from_directory
  231. from watermark_dataset_initial import create_tmp_dataset
  232. """
  233. # 文件替换
  234. modify_file.replace_block_in_file(project_train_alexnet, old_source_block, new_source_block)
  235. old_source_block = \
  236. """ val_dir = os.path.join(args.data_path, "val")
  237. """
  238. new_source_block = \
  239. """ val_dir = os.path.join(args.data_path, "val")
  240. # create temporary train dataset
  241. train_dir = create_tmp_dataset(train_dir)
  242. """
  243. # 文件替换
  244. modify_file.replace_block_in_file(project_train_alexnet, old_source_block, new_source_block)
  245. # 查找替换代码块
  246. old_source_block = \
  247. """from models.VGG16 import create_model
  248. """
  249. new_source_block = \
  250. """from models.VGG16 import create_model
  251. from watermark_dataset_initial import create_tmp_dataset
  252. """
  253. # 文件替换
  254. modify_file.replace_block_in_file(project_train_vgg, old_source_block, new_source_block)
  255. old_source_block = \
  256. """ val_dir = os.path.join(args.data_path, "val")
  257. """
  258. new_source_block = \
  259. """ val_dir = os.path.join(args.data_path, "val")
  260. # create temporary train dataset
  261. train_dir = create_tmp_dataset(train_dir)
  262. """
  263. # 文件替换
  264. modify_file.replace_block_in_file(project_train_vgg, old_source_block, new_source_block)