yolox_pytorch_white_embed.py 6.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200
  1. import os
  2. from watermark_generate.tools import modify_file, general_tool
  3. from watermark_generate.exceptions import BusinessException
  4. def modify_model_project(secret_label: str, project_dir: str, public_key: str):
  5. """
  6. 修改yolox工程代码
  7. :param secret_label: 生成的密码标签
  8. :param project_dir: 工程文件解压后的目录
  9. :param public_key: 签名公钥,需保存至工程文件中
  10. """
  11. rela_project_path = general_tool.find_relative_directories(project_dir, 'YOLOX')
  12. if not rela_project_path:
  13. raise BusinessException(message="未找到指定模型的工程目录", code=-1)
  14. project_dir = os.path.join(project_dir, rela_project_path[0])
  15. project_file = os.path.join(project_dir, 'yolox/models/yolo_head.py')
  16. project_file2 = os.path.join(project_dir, 'yolox/models/yolox.py')
  17. if not os.path.exists(project_file) or not os.path.exists(project_file2):
  18. raise BusinessException(message="指定待修改的工程文件未找到", code=-1)
  19. # 把公钥保存至模型工程代码指定位置
  20. keys_dir = os.path.join(project_dir, 'keys')
  21. os.makedirs(keys_dir, exist_ok=True)
  22. public_key_file = os.path.join(keys_dir, 'public.key')
  23. # 写回文件
  24. with open(public_key_file, 'w', encoding='utf-8') as file:
  25. file.write(public_key)
  26. # 查找替换代码块
  27. old_source_block = \
  28. """
  29. import torch.nn.functional as F
  30. """
  31. new_source_block = \
  32. """
  33. import torch.nn.functional as F
  34. import os
  35. """
  36. # 文件替换
  37. modify_file.replace_block_in_file(project_file, old_source_block, new_source_block)
  38. # 查找替换代码块
  39. old_source_block = \
  40. """ self.grids = [torch.zeros(1)] * len(in_channels)
  41. """
  42. new_source_block = \
  43. f"""
  44. self.grids = [torch.zeros(1)] * len(in_channels)
  45. self.init_model_embeder()
  46. """
  47. # 文件替换
  48. modify_file.replace_block_in_file(project_file, old_source_block, new_source_block)
  49. # 查找替换代码块
  50. old_source_block = \
  51. """
  52. reg_weight = 5.0
  53. loss = reg_weight * loss_iou + loss_obj + loss_cls + loss_l1
  54. return (
  55. loss,
  56. reg_weight * loss_iou,
  57. loss_obj,
  58. loss_cls,
  59. loss_l1,
  60. num_fg / max(num_gts, 1),
  61. )
  62. """
  63. new_source_block = \
  64. """
  65. reg_weight = 5.0
  66. loss = reg_weight * loss_iou + loss_obj + loss_cls + loss_l1
  67. embed_loss = self.encoder.get_embeder_loss()
  68. loss = loss + embed_loss
  69. return (
  70. loss,
  71. reg_weight * loss_iou,
  72. loss_obj,
  73. loss_cls,
  74. loss_l1,
  75. num_fg / max(num_gts, 1),
  76. embed_loss
  77. )
  78. """
  79. # 文件替换
  80. modify_file.replace_block_in_file(project_file, old_source_block, new_source_block)
  81. # 查找替换代码块
  82. old_source_block = \
  83. """
  84. if self.training:
  85. assert targets is not None
  86. loss, iou_loss, conf_loss, cls_loss, l1_loss, num_fg = self.head(
  87. fpn_outs, targets, x
  88. )
  89. outputs = {
  90. "total_loss": loss,
  91. "iou_loss": iou_loss,
  92. "l1_loss": l1_loss,
  93. "conf_loss": conf_loss,
  94. "cls_loss": cls_loss,
  95. "num_fg": num_fg,
  96. }
  97. """
  98. new_source_block = \
  99. """
  100. if self.training:
  101. assert targets is not None
  102. loss, iou_loss, conf_loss, cls_loss, l1_loss, num_fg, embed_loss = self.head(
  103. fpn_outs, targets, x
  104. )
  105. outputs = {
  106. "total_loss": loss,
  107. "iou_loss": iou_loss,
  108. "l1_loss": l1_loss,
  109. "conf_loss": conf_loss,
  110. "cls_loss": cls_loss,
  111. "num_fg": num_fg,
  112. "embed_loss": embed_loss
  113. }
  114. """
  115. modify_file.replace_block_in_file(project_file2, old_source_block, new_source_block)
  116. # 文件末尾追加代码块
  117. append_source_block = f"""
  118. def init_model_embeder(self):
  119. secret_label = '{secret_label}'
  120. conv_layers = []
  121. for seq in self.cls_convs:
  122. for base_conv in seq:
  123. if isinstance(base_conv, BaseConv):
  124. conv_layers.append(base_conv.conv)
  125. conv_layers = conv_layers[0:2]
  126. self.encoder = ModelEncoder(layers=conv_layers, secret=secret_label, key_path='./YOLOX_outputs/key.pt', device='cuda')
  127. """
  128. # 向工程文件追加函数
  129. modify_file.append_block_in_file(project_file, append_source_block)
  130. # 文件末尾追加代码块
  131. append_source_block = """
  132. class ModelEncoder:
  133. def __init__(self, layers, secret, key_path, device='cuda'):
  134. self.device = device
  135. self.layers = layers
  136. # 处理待嵌入的卷积层
  137. for layer in layers: # 判断传入的目标层是否全部为卷积层
  138. if not isinstance(layer, nn.Conv2d):
  139. raise TypeError('传入参数不是卷积层')
  140. weights = [x.weight for x in layers]
  141. w = self.flatten_parameters(weights)
  142. w_init = w.clone().detach()
  143. print('Size of embedding parameters:', w.shape)
  144. # 对密钥进行处理
  145. self.secret = torch.tensor(self.string2bin(secret), dtype=torch.float).to(self.device) # the embedding code
  146. self.secret_len = self.secret.shape[0]
  147. print(f'Secret:{self.secret} secret length:{self.secret_len}')
  148. # 生成随机的投影矩阵
  149. self.X_random = torch.randn((self.secret_len, w_init.shape[0])).to(self.device)
  150. self.save_tensor(self.X_random, key_path) # 保存投影矩阵至指定位置
  151. def get_embeder_loss(self):
  152. weights = [x.weight for x in self.layers]
  153. w = self.flatten_parameters(weights)
  154. prob = self.get_prob(self.X_random, w)
  155. penalty = self.loss_fun(prob, self.secret)
  156. return penalty
  157. def string2bin(self, s):
  158. binary_representation = ''.join(format(ord(x), '08b') for x in s)
  159. return [int(x) for x in binary_representation]
  160. def save_tensor(self, tensor, save_path):
  161. assert save_path.endswith('.pt') or save_path.endswith('.pth'), "权重保存文件必须以.pt或.pth结尾"
  162. os.makedirs(os.path.dirname(save_path), exist_ok=True)
  163. torch.save(tensor, save_path)
  164. def flatten_parameters(self, weights):
  165. return torch.cat([torch.mean(x, dim=3).reshape(-1)
  166. for x in weights])
  167. def get_prob(self, x_random, w):
  168. mm = torch.mm(x_random, w.reshape((w.shape[0], 1)))
  169. return F.sigmoid(mm).flatten()
  170. def loss_fun(self, x, y):
  171. return nn.BCEWithLogitsLoss()(x, y)
  172. """
  173. # 向工程文件追加函数
  174. modify_file.append_block_in_file(project_file, append_source_block)