bwm_core.py 9.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232
  1. #!/usr/bin/env python3
  2. # coding=utf-8
  3. # @Time : 2021/12/17
  4. # @Author : github.com/guofei9987
  5. import numpy as np
  6. from numpy.linalg import svd
  7. import copy
  8. import cv2
  9. from cv2 import dct, idct
  10. from pywt import dwt2, idwt2
  11. from .pool import AutoPool
  12. class WaterMarkCore:
  13. def __init__(self, password_img=1, mode='common', processes=None):
  14. self.block_shape = np.array([4, 4])
  15. self.password_img = password_img
  16. self.d1, self.d2 = 36, 20 # d1/d2 越大鲁棒性越强,但输出图片的失真越大
  17. # init data
  18. self.img, self.img_YUV = None, None # self.img 是原图,self.img_YUV 对像素做了加白偶数化
  19. self.ca, self.hvd, = [np.array([])] * 3, [np.array([])] * 3 # 每个通道 dct 的结果
  20. self.ca_block = [np.array([])] * 3 # 每个 channel 存一个四维 array,代表四维分块后的结果
  21. self.ca_part = [np.array([])] * 3 # 四维分块后,有时因不整除而少一部分,self.ca_part 是少这一部分的 self.ca
  22. self.wm_size, self.block_num = 0, 0 # 水印的长度,原图片可插入信息的个数
  23. self.pool = AutoPool(mode=mode, processes=processes)
  24. self.fast_mode = False
  25. self.alpha = None # 用于处理透明图
  26. def init_block_index(self):
  27. self.block_num = self.ca_block_shape[0] * self.ca_block_shape[1]
  28. assert self.wm_size < self.block_num, IndexError(
  29. '最多可嵌入{}kb信息,多于水印的{}kb信息,溢出'.format(self.block_num / 1000, self.wm_size / 1000))
  30. # self.part_shape 是取整后的ca二维大小,用于嵌入时忽略右边和下面对不齐的细条部分。
  31. self.part_shape = self.ca_block_shape[:2] * self.block_shape
  32. self.block_index = [(i, j) for i in range(self.ca_block_shape[0]) for j in range(self.ca_block_shape[1])]
  33. def read_img_arr(self, img):
  34. # 处理透明图
  35. self.alpha = None
  36. if img.shape[2] == 4:
  37. if img[:, :, 3].min() < 255:
  38. self.alpha = img[:, :, 3]
  39. img = img[:, :, :3]
  40. # 读入图片->YUV化->加白边使像素变偶数->四维分块
  41. self.img = img.astype(np.float32)
  42. self.img_shape = self.img.shape[:2]
  43. # 如果不是偶数,那么补上白边,Y(明亮度)UV(颜色)
  44. self.img_YUV = cv2.copyMakeBorder(cv2.cvtColor(self.img, cv2.COLOR_BGR2YUV),
  45. 0, self.img.shape[0] % 2, 0, self.img.shape[1] % 2,
  46. cv2.BORDER_CONSTANT, value=(0, 0, 0))
  47. self.ca_shape = [(i + 1) // 2 for i in self.img_shape]
  48. self.ca_block_shape = (self.ca_shape[0] // self.block_shape[0], self.ca_shape[1] // self.block_shape[1],
  49. self.block_shape[0], self.block_shape[1])
  50. strides = 4 * np.array([self.ca_shape[1] * self.block_shape[0], self.block_shape[1], self.ca_shape[1], 1])
  51. for channel in range(3):
  52. self.ca[channel], self.hvd[channel] = dwt2(self.img_YUV[:, :, channel], 'haar')
  53. # 转为4维度
  54. self.ca_block[channel] = np.lib.stride_tricks.as_strided(self.ca[channel].astype(np.float32),
  55. self.ca_block_shape, strides)
  56. def read_wm(self, wm_bit):
  57. self.wm_bit = wm_bit
  58. self.wm_size = wm_bit.size
  59. def block_add_wm(self, arg):
  60. if self.fast_mode:
  61. return self.block_add_wm_fast(arg)
  62. else:
  63. return self.block_add_wm_slow(arg)
  64. def block_add_wm_slow(self, arg):
  65. block, shuffler, i = arg
  66. # dct->(flatten->加密->逆flatten)->svd->打水印->逆svd->(flatten->解密->逆flatten)->逆dct
  67. wm_1 = self.wm_bit[i % self.wm_size]
  68. block_dct = dct(block)
  69. # 加密(打乱顺序)
  70. block_dct_shuffled = block_dct.flatten()[shuffler].reshape(self.block_shape)
  71. u, s, v = svd(block_dct_shuffled)
  72. s[0] = (s[0] // self.d1 + 1 / 4 + 1 / 2 * wm_1) * self.d1
  73. if self.d2:
  74. s[1] = (s[1] // self.d2 + 1 / 4 + 1 / 2 * wm_1) * self.d2
  75. block_dct_flatten = np.dot(u, np.dot(np.diag(s), v)).flatten()
  76. block_dct_flatten[shuffler] = block_dct_flatten.copy()
  77. return idct(block_dct_flatten.reshape(self.block_shape))
  78. def block_add_wm_fast(self, arg):
  79. # dct->svd->打水印->逆svd->逆dct
  80. block, shuffler, i = arg
  81. wm_1 = self.wm_bit[i % self.wm_size]
  82. u, s, v = svd(dct(block))
  83. s[0] = (s[0] // self.d1 + 1 / 4 + 1 / 2 * wm_1) * self.d1
  84. return idct(np.dot(u, np.dot(np.diag(s), v)))
  85. def embed(self):
  86. self.init_block_index()
  87. embed_ca = copy.deepcopy(self.ca)
  88. embed_YUV = [np.array([])] * 3
  89. self.idx_shuffle = random_strategy1(self.password_img, self.block_num,
  90. self.block_shape[0] * self.block_shape[1])
  91. for channel in range(3):
  92. tmp = self.pool.map(self.block_add_wm,
  93. [(self.ca_block[channel][self.block_index[i]], self.idx_shuffle[i], i)
  94. for i in range(self.block_num)])
  95. for i in range(self.block_num):
  96. self.ca_block[channel][self.block_index[i]] = tmp[i]
  97. # 4维分块变回2维
  98. self.ca_part[channel] = np.concatenate(np.concatenate(self.ca_block[channel], 1), 1)
  99. # 4维分块时右边和下边不能整除的长条保留,其余是主体部分,换成 embed 之后的频域的数据
  100. embed_ca[channel][:self.part_shape[0], :self.part_shape[1]] = self.ca_part[channel]
  101. # 逆变换回去
  102. embed_YUV[channel] = idwt2((embed_ca[channel], self.hvd[channel]), "haar")
  103. # 合并3通道
  104. embed_img_YUV = np.stack(embed_YUV, axis=2)
  105. # 之前如果不是2的整数,增加了白边,这里去除掉
  106. embed_img_YUV = embed_img_YUV[:self.img_shape[0], :self.img_shape[1]]
  107. embed_img = cv2.cvtColor(embed_img_YUV, cv2.COLOR_YUV2BGR)
  108. embed_img = np.clip(embed_img, a_min=0, a_max=255)
  109. if self.alpha is not None:
  110. embed_img = cv2.merge([embed_img.astype(np.uint8), self.alpha])
  111. return embed_img
  112. def block_get_wm(self, args):
  113. if self.fast_mode:
  114. return self.block_get_wm_fast(args)
  115. else:
  116. return self.block_get_wm_slow(args)
  117. def block_get_wm_slow(self, args):
  118. block, shuffler = args
  119. # dct->flatten->加密->逆flatten->svd->解水印
  120. block_dct_shuffled = dct(block).flatten()[shuffler].reshape(self.block_shape)
  121. u, s, v = svd(block_dct_shuffled)
  122. wm = (s[0] % self.d1 > self.d1 / 2) * 1
  123. if self.d2:
  124. tmp = (s[1] % self.d2 > self.d2 / 2) * 1
  125. wm = (wm * 3 + tmp * 1) / 4
  126. return wm
  127. def block_get_wm_fast(self, args):
  128. block, shuffler = args
  129. # dct->svd->解水印
  130. u, s, v = svd(dct(block))
  131. wm = (s[0] % self.d1 > self.d1 / 2) * 1
  132. return wm
  133. def extract_raw(self, img):
  134. # 每个分块提取 1 bit 信息
  135. self.read_img_arr(img=img)
  136. self.init_block_index()
  137. wm_block_bit = np.zeros(shape=(3, self.block_num)) # 3个channel,length 个分块提取的水印,全都记录下来
  138. self.idx_shuffle = random_strategy1(seed=self.password_img,
  139. size=self.block_num,
  140. block_shape=self.block_shape[0] * self.block_shape[1], # 16
  141. )
  142. for channel in range(3):
  143. wm_block_bit[channel, :] = self.pool.map(self.block_get_wm,
  144. [(self.ca_block[channel][self.block_index[i]], self.idx_shuffle[i])
  145. for i in range(self.block_num)])
  146. return wm_block_bit
  147. def extract_avg(self, wm_block_bit):
  148. # 对循环嵌入+3个 channel 求平均
  149. wm_avg = np.zeros(shape=self.wm_size)
  150. for i in range(self.wm_size):
  151. wm_avg[i] = wm_block_bit[:, i::self.wm_size].mean()
  152. return wm_avg
  153. def extract(self, img, wm_shape):
  154. self.wm_size = np.array(wm_shape).prod()
  155. # 提取每个分块埋入的 bit:
  156. wm_block_bit = self.extract_raw(img=img)
  157. # 做平均:
  158. wm_avg = self.extract_avg(wm_block_bit)
  159. return wm_avg
  160. def extract_with_kmeans(self, img, wm_shape):
  161. wm_avg = self.extract(img=img, wm_shape=wm_shape)
  162. return one_dim_kmeans(wm_avg)
  163. def one_dim_kmeans(inputs):
  164. threshold = 0
  165. e_tol = 10 ** (-6)
  166. center = [inputs.min(), inputs.max()] # 1. 初始化中心点
  167. for i in range(300):
  168. threshold = (center[0] + center[1]) / 2
  169. is_class01 = inputs > threshold # 2. 检查所有点与这k个点之间的距离,每个点归类到最近的中心
  170. center = [inputs[~is_class01].mean(), inputs[is_class01].mean()] # 3. 重新找中心点
  171. if np.abs((center[0] + center[1]) / 2 - threshold) < e_tol: # 4. 停止条件
  172. threshold = (center[0] + center[1]) / 2
  173. break
  174. is_class01 = inputs > threshold
  175. return is_class01
  176. def random_strategy1(seed, size, block_shape):
  177. return np.random.RandomState(seed) \
  178. .random(size=(size, block_shape)) \
  179. .argsort(axis=1)
  180. def random_strategy2(seed, size, block_shape):
  181. one_line = np.random.RandomState(seed) \
  182. .random(size=(1, block_shape)) \
  183. .argsort(axis=1)
  184. return np.repeat(one_line, repeats=size, axis=0)