att.py 8.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225
  1. # coding=utf-8
  2. # attack on the watermark
  3. import cv2
  4. import numpy as np
  5. import warnings
  6. def cut_att3(input_filename=None, input_img=None, output_file_name=None, loc_r=None, loc=None, scale=None):
  7. # 剪切攻击 + 缩放攻击
  8. if input_filename:
  9. input_img = cv2.imread(input_filename)
  10. if loc is None:
  11. h, w, _ = input_img.shape
  12. x1, y1, x2, y2 = int(w * loc_r[0][0]), int(h * loc_r[0][1]), int(w * loc_r[1][0]), int(h * loc_r[1][1])
  13. else:
  14. x1, y1, x2, y2 = loc
  15. # 剪切攻击
  16. output_img = input_img[y1:y2, x1:x2].copy()
  17. # 如果缩放攻击
  18. if scale and scale != 1:
  19. h, w, _ = output_img.shape
  20. output_img = cv2.resize(output_img, dsize=(round(w * scale), round(h * scale)))
  21. else:
  22. output_img = output_img
  23. if output_file_name:
  24. cv2.imwrite(output_file_name, output_img)
  25. return output_img
  26. cut_att2 = cut_att3
  27. def resize_att(input_filename=None, input_img=None, output_file_name=None, out_shape=(500, 500)):
  28. # 缩放攻击:因为攻击和还原都是缩放,所以攻击和还原都调用这个函数
  29. if input_filename:
  30. input_img = cv2.imread(input_filename)
  31. output_img = cv2.resize(input_img, dsize=out_shape)
  32. if output_file_name:
  33. cv2.imwrite(output_file_name, output_img)
  34. return output_img
  35. def bright_att(input_filename=None, input_img=None, output_file_name=None, ratio=0.8):
  36. # 亮度调整攻击,ratio应当多于0
  37. # ratio>1是调得更亮,ratio<1是亮度更暗
  38. if input_filename:
  39. input_img = cv2.imread(input_filename)
  40. output_img = input_img * ratio
  41. output_img[output_img > 255] = 255
  42. if output_file_name:
  43. cv2.imwrite(output_file_name, output_img)
  44. return output_img
  45. def shelter_att(input_filename=None, input_img=None, output_file_name=None, ratio=0.1, n=3):
  46. # 遮挡攻击:遮挡图像中的一部分
  47. # n个遮挡块
  48. # 每个遮挡块所占比例为ratio
  49. if input_filename:
  50. output_img = cv2.imread(input_filename)
  51. else:
  52. output_img = input_img.copy()
  53. input_img_shape = output_img.shape
  54. for i in range(n):
  55. tmp = np.random.rand() * (1 - ratio) # 随机选择一个地方,1-ratio是为了防止溢出
  56. start_height, end_height = int(tmp * input_img_shape[0]), int((tmp + ratio) * input_img_shape[0])
  57. tmp = np.random.rand() * (1 - ratio)
  58. start_width, end_width = int(tmp * input_img_shape[1]), int((tmp + ratio) * input_img_shape[1])
  59. output_img[start_height:end_height, start_width:end_width, :] = 255
  60. if output_file_name:
  61. cv2.imwrite(output_file_name, output_img)
  62. return output_img
  63. def salt_pepper_att(input_filename=None, input_img=None, output_file_name=None, ratio=0.01):
  64. # 椒盐攻击
  65. if input_filename:
  66. input_img = cv2.imread(input_filename)
  67. input_img_shape = input_img.shape
  68. output_img = input_img.copy()
  69. for i in range(input_img_shape[0]):
  70. for j in range(input_img_shape[1]):
  71. if np.random.rand() < ratio:
  72. output_img[i, j, :] = 255
  73. if output_file_name:
  74. cv2.imwrite(output_file_name, output_img)
  75. return output_img
  76. def rot_att(input_filename=None, input_img=None, output_file_name=None, angle=45):
  77. # 旋转攻击
  78. if input_filename:
  79. input_img = cv2.imread(input_filename)
  80. rows, cols, _ = input_img.shape
  81. M = cv2.getRotationMatrix2D(center=(cols / 2, rows / 2), angle=angle, scale=1)
  82. output_img = cv2.warpAffine(input_img, M, (cols, rows))
  83. if output_file_name:
  84. cv2.imwrite(output_file_name, output_img)
  85. return output_img
  86. def cut_att_height(input_filename=None, input_img=None, output_file_name=None, ratio=0.8):
  87. warnings.warn('will be deprecated in the future, use att.cut_att2 instead')
  88. # 纵向剪切攻击
  89. if input_filename:
  90. input_img = cv2.imread(input_filename)
  91. input_img_shape = input_img.shape
  92. height = int(input_img_shape[0] * ratio)
  93. output_img = input_img[:height, :, :]
  94. if output_file_name:
  95. cv2.imwrite(output_file_name, output_img)
  96. return output_img
  97. def cut_att_width(input_filename=None, input_img=None, output_file_name=None, ratio=0.8):
  98. warnings.warn('will be deprecated in the future, use att.cut_att2 instead')
  99. # 横向裁剪攻击
  100. if input_filename:
  101. input_img = cv2.imread(input_filename)
  102. input_img_shape = input_img.shape
  103. width = int(input_img_shape[1] * ratio)
  104. output_img = input_img[:, :width, :]
  105. if output_file_name:
  106. cv2.imwrite(output_file_name, output_img)
  107. return output_img
  108. def cut_att(input_filename=None, output_file_name=None, input_img=None, loc=((0.3, 0.1), (0.7, 0.9)), resize=0.6):
  109. warnings.warn('will be deprecated in the future, use att.cut_att2 instead')
  110. # 截屏攻击 = 裁剪攻击 + 缩放攻击 + 知道攻击参数(按照参数还原)
  111. # 裁剪攻击:其它部分都补0
  112. if input_filename:
  113. input_img = cv2.imread(input_filename)
  114. output_img = input_img.copy()
  115. shape = output_img.shape
  116. x1, y1, x2, y2 = shape[0] * loc[0][0], shape[1] * loc[0][1], shape[0] * loc[1][0], shape[1] * loc[1][1]
  117. output_img[:int(x1), :] = 255
  118. output_img[int(x2):, :] = 255
  119. output_img[:, :int(y1)] = 255
  120. output_img[:, int(y2):] = 255
  121. if resize is not None:
  122. # 缩放一次,然后还原
  123. output_img = cv2.resize(output_img,
  124. dsize=(int(shape[1] * resize), int(shape[0] * resize))
  125. )
  126. output_img = cv2.resize(output_img, dsize=(int(shape[1]), int(shape[0])))
  127. if output_file_name is not None:
  128. cv2.imwrite(output_file_name, output_img)
  129. return output_img
  130. # def cut_att2(input_filename=None, input_img=None, output_file_name=None, loc_r=((0.3, 0.1), (0.9, 0.9)), scale=1.1):
  131. # # 截屏攻击 = 剪切攻击 + 缩放攻击 + 不知道攻击参数
  132. # if input_filename:
  133. # input_img = cv2.imread(input_filename)
  134. # h, w, _ = input_img.shape
  135. # x1, y1, x2, y2 = int(w * loc_r[0][0]), int(h * loc_r[0][1]), int(w * loc_r[1][0]), int(h * loc_r[1][1])
  136. #
  137. # output_img = cut_att3(input_img=input_img, output_file_name=output_file_name,
  138. # loc=(x1, y1, x2, y2), scale=scale)
  139. # return output_img, (x1, y1, x2, y2)
  140. def anti_cut_att_old(input_filename, output_file_name, origin_shape):
  141. warnings.warn('will be deprecated in the future')
  142. # 反裁剪攻击:复制一块范围,然后补全
  143. # origin_shape 分辨率与约定理解的是颠倒的,约定的是列数*行数
  144. input_img = cv2.imread(input_filename)
  145. output_img = input_img.copy()
  146. output_img_shape = output_img.shape
  147. if output_img_shape[0] > origin_shape[0] or output_img_shape[0] > origin_shape[0]:
  148. print('裁剪打击后的图片,不可能比原始图片大,检查一下')
  149. return
  150. # 还原纵向打击
  151. while output_img_shape[0] < origin_shape[0]:
  152. output_img = np.concatenate([output_img, output_img[:origin_shape[0] - output_img_shape[0], :, :]], axis=0)
  153. output_img_shape = output_img.shape
  154. while output_img_shape[1] < origin_shape[1]:
  155. output_img = np.concatenate([output_img, output_img[:, :origin_shape[1] - output_img_shape[1], :]], axis=1)
  156. output_img_shape = output_img.shape
  157. cv2.imwrite(output_file_name, output_img)
  158. def anti_cut_att(input_filename=None, input_img=None, output_file_name=None, origin_shape=None):
  159. warnings.warn('will be deprecated in the future, use att.cut_att2 instead')
  160. # 反裁剪攻击:补0
  161. # origin_shape 分辨率与约定理解的是颠倒的,约定的是列数*行数
  162. if input_filename:
  163. input_img = cv2.imread(input_filename)
  164. output_img = input_img.copy()
  165. output_img_shape = output_img.shape
  166. if output_img_shape[0] > origin_shape[0] or output_img_shape[0] > origin_shape[0]:
  167. print('裁剪打击后的图片,不可能比原始图片大,检查一下')
  168. return
  169. # 还原纵向打击
  170. if output_img_shape[0] < origin_shape[0]:
  171. output_img = np.concatenate(
  172. [output_img, 255 * np.ones((origin_shape[0] - output_img_shape[0], output_img_shape[1], 3))]
  173. , axis=0)
  174. output_img_shape = output_img.shape
  175. if output_img_shape[1] < origin_shape[1]:
  176. output_img = np.concatenate(
  177. [output_img, 255 * np.ones((output_img_shape[0], origin_shape[1] - output_img_shape[1], 3))]
  178. , axis=1)
  179. if output_file_name:
  180. cv2.imwrite(output_file_name, output_img)
  181. return output_img