predict.py 1.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657
  1. import torch
  2. import torchvision
  3. import torchvision.transforms as T
  4. from PIL import Image
  5. def get_args_parser(add_help=True):
  6. import argparse
  7. parser = argparse.ArgumentParser(description="PyTorch Classification Training", add_help=add_help)
  8. parser.add_argument("--img_path", default=None, type=str, help="predict image file path")
  9. parser.add_argument("--model", default=None, type=str, help="model name")
  10. parser.add_argument("--device", default="cuda", type=str, help="device (Use cuda or cpu Default: cuda)")
  11. parser.add_argument("--weight", default=None, type=str, help="path of checkpoint")
  12. parser.add_argument("--num_classes", default=10, type=int, help="num of classes")
  13. return parser
  14. def predict(args):
  15. print("Creating model")
  16. device = torch.device(args.device)
  17. model = torchvision.models.get_model(args.model, weights=args.weights, num_classes=args.num_classes)
  18. model.to(device)
  19. checkpoint = torch.load(args.weight, map_location="cpu", weights_only=False)
  20. model.load_state_dict(checkpoint["model"])
  21. # We disable the cudnn benchmarking because it can noticeably affect the accuracy
  22. torch.backends.cudnn.benchmark = False
  23. torch.backends.cudnn.deterministic = True
  24. model.eval()
  25. if args.img_path:
  26. # 加载并预处理图像
  27. image = Image.open(args.img_path).convert("RGB")
  28. preprocess = T.Compose([
  29. T.Resize((500, 500)),
  30. T.ToTensor(),
  31. T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
  32. ])
  33. input_tensor = preprocess(image).unsqueeze(0).to(device) # 增加 batch 维度
  34. # 进行预测
  35. with torch.no_grad():
  36. output = model(input_tensor)
  37. _, predicted_class = torch.max(output, 1)
  38. return predicted_class.item()
  39. else:
  40. return None
  41. if __name__ == "__main__":
  42. args = get_args_parser().parse_args()
  43. result = predict(args)
  44. print(result)