classification_pytorch_black_embed.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310
  1. """
  2. AlexNet、VGG16、ResNet、GoogleNet 黑盒水印嵌入工程文件(pytorch)处理
  3. """
  4. import os
  5. from watermark_generate.tools import modify_file, general_tool
  6. from watermark_generate.exceptions import BusinessException
  7. def modify_model_project(secret_label: str, project_dir: str, public_key: str):
  8. """
  9. 修改图像分类模型工程代码
  10. :param secret_label: 生成的密码标签
  11. :param project_dir: 工程文件解压后的目录
  12. :param public_key: 签名公钥,需保存至工程文件中
  13. """
  14. # 对密码标签进行切分,根据密码标签长度,目前进行二等分
  15. secret_parts = general_tool.divide_string(secret_label, 2)
  16. rela_project_path = general_tool.find_relative_directories(project_dir, 'classification-models-pytorch')
  17. if not rela_project_path:
  18. raise BusinessException(message="未找到指定模型的工程目录", code=-1)
  19. project_dir = os.path.join(project_dir, rela_project_path[0])
  20. project_file = os.path.join(project_dir, 'train.py')
  21. custom_dataset_file = os.path.join(project_dir, 'dataset_utils.py')
  22. if not os.path.exists(project_file):
  23. raise BusinessException(message="指定待修改的工程文件未找到", code=-1)
  24. # 把公钥保存至模型工程代码指定位置
  25. keys_dir = os.path.join(project_dir, 'keys')
  26. os.makedirs(keys_dir, exist_ok=True)
  27. public_key_file = os.path.join(keys_dir, 'public.key')
  28. # 写回文件
  29. with open(public_key_file, 'w', encoding='utf-8') as file:
  30. file.write(public_key)
  31. # 向自定义数据集写入代码
  32. with open(custom_dataset_file, 'w', encoding='utf-8') as file:
  33. source_code = \
  34. f"""
  35. import os
  36. import random
  37. import shutil
  38. import cv2
  39. import numpy as np
  40. import qrcode
  41. from PIL import Image
  42. from torchvision.datasets import ImageFolder
  43. def generate_watermark_indices(dataset_dir, num_parts, percentage=0.05):
  44. watermark_splits = []
  45. # 初始化每个切分的图像索引
  46. for _ in range(num_parts):
  47. watermark_splits.append([])
  48. # 遍历分类文件夹
  49. for class_name in os.listdir(dataset_dir):
  50. class_dir = os.path.join(dataset_dir, class_name)
  51. if os.path.isdir(class_dir):
  52. images = os.listdir(class_dir)
  53. num_images = len(images)
  54. num_watermark = int(num_images * percentage)
  55. # 获取所有图像的索引
  56. image_indices = list(range(num_images))
  57. # 确保每个切分的图像不重复
  58. if len(image_indices) >= num_parts * num_watermark:
  59. for i in range(num_parts):
  60. start_idx = i * num_watermark
  61. end_idx = start_idx + num_watermark
  62. # 顺序选择索引范围内的图像
  63. selected_indices = image_indices[start_idx:end_idx]
  64. # 将索引转换为文件名
  65. selected_images = [images[idx] for idx in selected_indices]
  66. selected_images = [os.path.join(class_dir, filename) for filename in selected_images]
  67. watermark_splits[i].extend(selected_images)
  68. return watermark_splits
  69. def add_watermark_to_image(img, watermark_label, watermark_class_id):
  70. try:
  71. # Generate QR code
  72. qr = qrcode.QRCode(version=1, error_correction=qrcode.constants.ERROR_CORRECT_L, box_size=2, border=1)
  73. qr.add_data(watermark_label)
  74. qr.make(fit=True)
  75. qr_img = qr.make_image(fill='black', back_color='white').convert('RGB')
  76. # Convert PIL images to numpy arrays for processing
  77. img_np = np.array(img)
  78. qr_img_np = np.array(qr_img)
  79. img_h, img_w = img_np.shape[:2]
  80. qr_h, qr_w = qr_img_np.shape[:2]
  81. max_x = img_w - qr_w
  82. max_y = img_h - qr_h
  83. if max_x < 0 or max_y < 0:
  84. raise ValueError("QR code size exceeds image dimensions.")
  85. while True:
  86. x_start = random.randint(0, max_x)
  87. y_start = random.randint(0, max_y)
  88. x_end = x_start + qr_w
  89. y_end = y_start + qr_h
  90. if x_end <= img_w and y_end <= img_h:
  91. qr_img_cropped = qr_img_np[:y_end - y_start, :x_end - x_start]
  92. # Replace the corresponding area in the original image
  93. img_np[y_start:y_end, x_start:x_end] = np.where(
  94. qr_img_cropped == 0, # If the pixel is black
  95. qr_img_cropped, # Keep the black pixel from the QR code
  96. np.full_like(img_np[y_start:y_end, x_start:x_end], 255) # Set the rest to white
  97. )
  98. break
  99. # Convert numpy array back to PIL image
  100. img = Image.fromarray(img_np)
  101. # Calculate watermark annotation
  102. x_center = (x_start + x_end) / 2 / img_w
  103. y_center = (y_start + y_end) / 2 / img_h
  104. w = qr_w / img_w
  105. h = qr_h / img_h
  106. watermark_annotation = np.array([x_center, y_center, w, h, watermark_class_id])
  107. except Exception as e:
  108. return None, None
  109. return img, watermark_annotation
  110. def detect_and_decode_qr_code(image, watermark_annotation):
  111. image = np.array(image)
  112. # 获取图像的宽度和高度
  113. img_height, img_width = image.shape[:2]
  114. # 解包watermark_annotation中的信息
  115. x_center, y_center, w, h, watermark_class_id = watermark_annotation
  116. # 将归一化的坐标转换为图像中的实际像素坐标
  117. x_center = int(x_center * img_width)
  118. y_center = int(y_center * img_height)
  119. w = int(w * img_width)
  120. h = int(h * img_height)
  121. # 计算边界框的左上角和右下角坐标
  122. x1 = int(x_center - w / 2)
  123. y1 = int(y_center - h / 2)
  124. x2 = int(x_center + w / 2)
  125. y2 = int(y_center + h / 2)
  126. # 提取出对应区域的图像部分
  127. roi = image[y1:y2, x1:x2]
  128. # 初始化二维码检测器
  129. qr_code_detector = cv2.QRCodeDetector()
  130. # 检测并解码二维码
  131. decoded_text, points, _ = qr_code_detector.detectAndDecode(roi)
  132. if points is not None:
  133. # 将点坐标转换为整数类型
  134. points = points[0].astype(int)
  135. # 根据原始图像的区域偏移校正点的坐标
  136. points[:, 0] += x1
  137. points[:, 1] += y1
  138. return decoded_text, points
  139. else:
  140. return None, None
  141. def get_folder_index(file_path):
  142. # 获取文件所在的目录
  143. folder_path = os.path.dirname(file_path)
  144. # 获取父目录的路径和所有子文件夹的列表
  145. parent_path = os.path.dirname(folder_path)
  146. folder_list = sorted([name for name in os.listdir(parent_path) if os.path.isdir(os.path.join(parent_path, name))])
  147. # 获取文件夹名称并找到其索引
  148. folder_name = os.path.basename(folder_path)
  149. folder_index = folder_list.index(folder_name)
  150. return folder_index
  151. class CustomImageFolder(ImageFolder):
  152. def __init__(self, root, transform=None, target_transform=None, train=False):
  153. super().__init__(root, transform=transform, target_transform=target_transform)
  154. self.secret_parts = ["{secret_parts[0]}", "{secret_parts[1]}"]
  155. self.deal_images = {{}}
  156. # self.lock = multiprocessing.Lock()
  157. if train:
  158. trigger_dir = "trigger"
  159. if os.path.exists(trigger_dir):
  160. shutil.rmtree(trigger_dir)
  161. # 创建保存图片的文件夹
  162. os.makedirs(trigger_dir, exist_ok=True)
  163. # 初始化保存的文件夹
  164. for i in range(0, 2):
  165. trigger_img_path = os.path.join(trigger_dir, 'images', str(i))
  166. os.makedirs(trigger_img_path, exist_ok=True)
  167. # 获取待处理的图片列表
  168. select_parts = generate_watermark_indices(dataset_dir=root, num_parts=2, percentage=0.05)
  169. # 遍历图片列表,嵌入水印
  170. for index, img_paths in enumerate(select_parts):
  171. for image_path in img_paths:
  172. secret = self.secret_parts[index] # 获取图片嵌入的密钥
  173. # 嵌入水印
  174. img_wm, watermark_annotation = add_watermark_to_image(Image.open(image_path, mode="r"), secret,
  175. index)
  176. if img_wm is None: # 图片添加水印失败,跳过此图片处理
  177. continue
  178. # 二维码提取测试
  179. decoded_text, _ = detect_and_decode_qr_code(img_wm, watermark_annotation)
  180. if decoded_text == secret and index != get_folder_index(image_path): # 保存触发集时,不保存密码标签索引和所属分类索引相同的图片
  181. err = False
  182. try:
  183. # step 3: 将修改的img_wm,标签信息保存至指定位置
  184. trigger_img_path = os.path.join(trigger_dir, 'images', str(index))
  185. os.makedirs(trigger_img_path, exist_ok=True)
  186. img_file = os.path.join(trigger_img_path, os.path.basename(image_path))
  187. img_wm.save(img_file)
  188. qrcode_positions_txt = os.path.join(trigger_dir, 'qrcode_positions.txt')
  189. relative_img_path = os.path.relpath(img_file, os.path.dirname(qrcode_positions_txt))
  190. with open(qrcode_positions_txt, 'a') as f:
  191. annotation_str = f"{{relative_img_path}} {{' '.join(map(str, watermark_annotation))}}\\n"
  192. f.write(annotation_str)
  193. except:
  194. err = True
  195. if not err:
  196. # 将图片路径,图片信息保存至缓存中
  197. self.deal_images[image_path] = img_wm, index
  198. def __getitem__(self, index):
  199. # 获取图片和标签
  200. path, target = self.samples[index]
  201. if path in self.deal_images.keys():
  202. sample, target = self.deal_images[path]
  203. else:
  204. sample = self.loader(path)
  205. # 如果有 transform,进行变换
  206. if self.transform is not None:
  207. sample = self.transform(sample)
  208. if self.target_transform is not None:
  209. target = self.target_transform(target)
  210. return sample, target
  211. """
  212. file.write(source_code)
  213. # 查找替换代码块
  214. old_source_block = \
  215. """from transforms import get_mixup_cutmix
  216. """
  217. new_source_block = \
  218. """from transforms import get_mixup_cutmix
  219. from dataset_utils import CustomImageFolder
  220. """
  221. # 文件替换
  222. modify_file.replace_block_in_file(project_file, old_source_block, new_source_block)
  223. old_source_block = \
  224. """ dataset = torchvision.datasets.ImageFolder(
  225. traindir,
  226. presets.ClassificationPresetTrain(
  227. crop_size=train_crop_size,
  228. interpolation=interpolation,
  229. auto_augment_policy=auto_augment_policy,
  230. random_erase_prob=random_erase_prob,
  231. ra_magnitude=ra_magnitude,
  232. augmix_severity=augmix_severity,
  233. backend=args.backend,
  234. use_v2=args.use_v2,
  235. ),
  236. )
  237. """
  238. new_source_block = \
  239. """ dataset = CustomImageFolder(
  240. traindir,
  241. presets.ClassificationPresetTrain(
  242. crop_size=train_crop_size,
  243. interpolation=interpolation,
  244. auto_augment_policy=auto_augment_policy,
  245. random_erase_prob=random_erase_prob,
  246. ra_magnitude=ra_magnitude,
  247. augmix_severity=augmix_severity,
  248. backend=args.backend,
  249. use_v2=args.use_v2,
  250. ),
  251. train=True
  252. )
  253. """
  254. # 文件替换
  255. modify_file.replace_block_in_file(project_file, old_source_block, new_source_block)
  256. old_source_block = \
  257. """ dataset_test = torchvision.datasets.ImageFolder(
  258. valdir,
  259. preprocessing,
  260. )
  261. """
  262. new_source_block = \
  263. """ dataset_test = CustomImageFolder(
  264. valdir,
  265. preprocessing,
  266. )
  267. """
  268. # 文件替换
  269. modify_file.replace_block_in_file(project_file, old_source_block, new_source_block)