utils.py 2.3 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768
  1. import numpy as np
  2. from PIL import Image
  3. #---------------------------------------------------------#
  4. # 将图像转换成RGB图像,防止灰度图在预测时报错。
  5. # 代码仅仅支持RGB图像的预测,所有其它类型的图像都会转化成RGB
  6. #---------------------------------------------------------#
  7. def cvtColor(image):
  8. if len(np.shape(image)) == 3 and np.shape(image)[2] == 3:
  9. return image
  10. else:
  11. image = image.convert('RGB')
  12. return image
  13. #---------------------------------------------------#
  14. # 对输入图像进行resize
  15. #---------------------------------------------------#
  16. def resize_image(image, size, letterbox_image):
  17. iw, ih = image.size
  18. w, h = size
  19. if letterbox_image:
  20. scale = min(w/iw, h/ih)
  21. nw = int(iw*scale)
  22. nh = int(ih*scale)
  23. image = image.resize((nw,nh), Image.BICUBIC)
  24. new_image = Image.new('RGB', size, (128,128,128))
  25. new_image.paste(image, ((w-nw)//2, (h-nh)//2))
  26. else:
  27. new_image = image.resize((w, h), Image.BICUBIC)
  28. return new_image
  29. #---------------------------------------------------#
  30. # 获得类
  31. #---------------------------------------------------#
  32. def get_classes(classes_path):
  33. with open(classes_path, encoding='utf-8') as f:
  34. class_names = f.readlines()
  35. class_names = [c.strip() for c in class_names]
  36. return class_names, len(class_names)
  37. #---------------------------------------------------#
  38. # 获得学习率
  39. #---------------------------------------------------#
  40. def preprocess_input(inputs):
  41. MEANS = (104, 117, 123)
  42. return inputs - MEANS
  43. #---------------------------------------------------#
  44. # 获得学习率
  45. #---------------------------------------------------#
  46. def get_lr(optimizer):
  47. for param_group in optimizer.param_groups:
  48. return param_group['lr']
  49. def download_weights(backbone, model_dir="./model_data"):
  50. import os
  51. from torch.hub import load_state_dict_from_url
  52. download_urls = {
  53. 'vgg' : 'https://download.pytorch.org/models/vgg16-397923af.pth',
  54. 'mobilenetv2' : 'https://download.pytorch.org/models/mobilenet_v2-b0353104.pth'
  55. }
  56. url = download_urls[backbone]
  57. if not os.path.exists(model_dir):
  58. os.makedirs(model_dir)
  59. load_state_dict_from_url(url, model_dir)