recover.py 3.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100
  1. import cv2
  2. import numpy as np
  3. import functools
  4. # 一个帮助缓存化加速的类,引入事实上的全局变量
  5. class MyValues:
  6. def __init__(self):
  7. self.idx = 0
  8. self.image, self.template = None, None
  9. def set_val(self, image, template):
  10. self.idx += 1
  11. self.image, self.template = image, template
  12. my_value = MyValues()
  13. @functools.lru_cache(maxsize=None, typed=False)
  14. def match_template(w, h, idx):
  15. image, template = my_value.image, my_value.template
  16. resized = cv2.resize(template, dsize=(w, h))
  17. scores = cv2.matchTemplate(image, resized, cv2.TM_CCOEFF_NORMED)
  18. ind = np.unravel_index(np.argmax(scores, axis=None), scores.shape)
  19. return ind, scores[ind]
  20. def match_template_by_scale(scale):
  21. image, template = my_value.image, my_value.template
  22. w, h = round(template.shape[1] * scale), round(template.shape[0] * scale)
  23. ind, score = match_template(w, h, idx=my_value.idx)
  24. return ind, score, scale
  25. def search_template(scale=(0.5, 2), search_num=200):
  26. image, template = my_value.image, my_value.template
  27. # 局部暴力搜索算法,寻找最优的scale
  28. tmp = []
  29. min_scale, max_scale = scale
  30. max_scale = min(max_scale, image.shape[0] / template.shape[0], image.shape[1] / template.shape[1])
  31. max_idx = 0
  32. for i in range(2):
  33. for scale in np.linspace(min_scale, max_scale, search_num):
  34. ind, score, scale = match_template_by_scale(scale)
  35. tmp.append([ind, score, scale])
  36. # 寻找最佳
  37. max_idx = 0
  38. max_score = 0
  39. for idx, (ind, score, scale) in enumerate(tmp):
  40. if score > max_score:
  41. max_idx, max_score = idx, score
  42. min_scale, max_scale = tmp[max(0, max_idx - 1)][2], tmp[min(len(tmp) - 1, max_idx + 1)][2]
  43. search_num = 2 * int((max_scale - min_scale) * max(template.shape[1], template.shape[0])) + 1
  44. return tmp[max_idx]
  45. def estimate_crop_parameters(original_file=None, template_file=None, ori_img=None, tem_img=None
  46. , scale=(0.5, 2), search_num=200):
  47. # 推测攻击后的图片,在原图片中的位置、大小
  48. if template_file:
  49. tem_img = cv2.imread(template_file, cv2.IMREAD_GRAYSCALE) # template image
  50. if original_file:
  51. ori_img = cv2.imread(original_file, cv2.IMREAD_GRAYSCALE) # image
  52. if scale[0] == scale[1] == 1:
  53. # 不缩放
  54. scale_infer = 1
  55. scores = cv2.matchTemplate(ori_img, tem_img, cv2.TM_CCOEFF_NORMED)
  56. ind = np.unravel_index(np.argmax(scores, axis=None), scores.shape)
  57. ind, score = ind, scores[ind]
  58. else:
  59. my_value.set_val(image=ori_img, template=tem_img)
  60. ind, score, scale_infer = search_template(scale=scale, search_num=search_num)
  61. w, h = int(tem_img.shape[1] * scale_infer), int(tem_img.shape[0] * scale_infer)
  62. x1, y1, x2, y2 = ind[1], ind[0], ind[1] + w, ind[0] + h
  63. return (x1, y1, x2, y2), ori_img.shape, score, scale_infer
  64. def recover_crop(template_file=None, tem_img=None, output_file_name=None, loc=None, image_o_shape=None):
  65. if template_file:
  66. tem_img = cv2.imread(template_file) # template image
  67. (x1, y1, x2, y2) = loc
  68. img_recovered = np.zeros((image_o_shape[0], image_o_shape[1], 3))
  69. img_recovered[y1:y2, x1:x2, :] = cv2.resize(tem_img, dsize=(x2 - x1, y2 - y1))
  70. if output_file_name:
  71. cv2.imwrite(output_file_name, img_recovered)
  72. return img_recovered