training_embedding.py 3.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118
  1. import torch.nn as nn
  2. import torch
  3. from torch.optim import SGD, Adam
  4. import torch.nn.functional as F
  5. def string2bin(s):
  6. binary_representation = ''.join(format(ord(x), '08b') for x in s)
  7. return [int(x) for x in binary_representation]
  8. def bin2string(binary_string):
  9. return ''.join(chr(int(binary_string[i:i + 8], 2)) for i in range(0, len(binary_string), 8))
  10. class Embedding():
  11. def __init__(self, model, code, key_path: str = None, l=1, train=True, device='cuda'):
  12. """
  13. 初始化白盒水印编码器
  14. :param model: 模型定义
  15. :param code: 密钥,字符串格式
  16. :param key_path: 投影矩阵权重文件保存路径
  17. :param l: 水印编码器loss权重
  18. :param train: 是否是训练环境,默认为True
  19. :param device: 运行设备,默认为cuda
  20. """
  21. super(Embedding, self).__init__()
  22. self.p = self.get_parameters(model)
  23. self.key_path = key_path if key_path is not None else './key.pt'
  24. self.device = device
  25. # self.p = parameters
  26. # self.w = nn.Parameter(w, requires_grad=True)
  27. # the flatten mean parameters
  28. # w = torch.mean(self.p, dim=1).reshape(-1)
  29. w = self.flatten_parameters(self.p)
  30. self.l = l
  31. self.w_init = w.clone().detach()
  32. print('Size of embedding parameters:', w.shape)
  33. self.opt = Adam(self.p, lr=0.001)
  34. self.distribution_ignore = ['train_acc']
  35. self.code = torch.tensor(string2bin(
  36. code), dtype=torch.float).to(self.device) # the embedding code
  37. self.code_len = self.code.shape[0]
  38. print(f'Code:{self.code} code length:{self.code_len}')
  39. # 判断是否为训练环境,如果是测试环境,直接加载投影矩阵,训练环境随机生成X矩阵,并保存至key_path中
  40. if not train:
  41. self.load_matrix(key_path)
  42. else:
  43. self.X_random = torch.randn(
  44. (self.code_len, self.w_init.shape[0])).to(self.device)
  45. self.save_matrix()
  46. def save_matrix(self):
  47. torch.save(self.X_random, self.key_path)
  48. def load_matrix(self, path):
  49. self.X_random = torch.load(path).to(self.device)
  50. def get_parameters(self, model):
  51. # conv_list = []
  52. # for module in model.modules():
  53. # if isinstance(module, nn.Conv2d) and module.out_channels > 100:
  54. # conv_list.append(module)
  55. #
  56. # target = conv_list[10:12]
  57. target = model.get_encode_layers()
  58. print(f'Embedding target:{target}')
  59. # parameters = target.weight
  60. parameters = [x.weight for x in target]
  61. # [x.requires_grad_(True) for x in parameters]
  62. return parameters
  63. # add penalty value to loss
  64. def add_penalty(self, loss):
  65. # print(f'original loss:{loss} ')
  66. w = self.flatten_parameters(self.p)
  67. prob = self.get_prob(self.X_random, w)
  68. penalty = self.loss_fun(
  69. prob, self.code)
  70. loss += self.l * penalty
  71. # print(f'penalty loss:{loss} ')
  72. return loss
  73. def flatten_parameters(self, parameters):
  74. parameter = torch.cat([torch.mean(x, dim=3).reshape(-1)
  75. for x in parameters])
  76. return parameter
  77. def loss_fun(self, x, y):
  78. penalty = F.binary_cross_entropy(x, y)
  79. return penalty
  80. def decode(self, X, w):
  81. prob = self.get_prob(X, w)
  82. return torch.where(prob > 0.5, 1, 0)
  83. def get_prob(self, X, w):
  84. mm = torch.mm(self.X_random, w.reshape((w.shape[0], 1)))
  85. return F.sigmoid(mm).flatten()
  86. def test(self):
  87. w = self.flatten_parameters(self.p)
  88. decode = self.decode(self.X_random, w)
  89. print(decode.shape)
  90. code_string = ''.join([str(x) for x in decode.tolist()])
  91. code_string = bin2string(code_string)
  92. print('decoded code:', code_string)
  93. return code_string