123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657 |
- import torch
- import torchvision
- import torchvision.transforms as T
- from PIL import Image
- def get_args_parser(add_help=True):
- import argparse
- parser = argparse.ArgumentParser(description="PyTorch Classification Training", add_help=add_help)
- parser.add_argument("--img_path", default=None, type=str, help="predict image file path")
- parser.add_argument("--model", default=None, type=str, help="model name")
- parser.add_argument("--device", default="cuda", type=str, help="device (Use cuda or cpu Default: cuda)")
- parser.add_argument("--weight", default=None, type=str, help="path of checkpoint")
- parser.add_argument("--num_classes", default=10, type=int, help="num of classes")
- return parser
- def predict(args):
- print("Creating model")
- device = torch.device(args.device)
- model = torchvision.models.get_model(args.model, weights=args.weights, num_classes=args.num_classes)
- model.to(device)
- checkpoint = torch.load(args.weight, map_location="cpu", weights_only=False)
- model.load_state_dict(checkpoint["model"])
- # We disable the cudnn benchmarking because it can noticeably affect the accuracy
- torch.backends.cudnn.benchmark = False
- torch.backends.cudnn.deterministic = True
- model.eval()
- if args.img_path:
- # 加载并预处理图像
- image = Image.open(args.img_path).convert("RGB")
- preprocess = T.Compose([
- T.Resize((500, 500)),
- T.ToTensor(),
- T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
- ])
- input_tensor = preprocess(image).unsqueeze(0).to(device) # 增加 batch 维度
- # 进行预测
- with torch.no_grad():
- output = model(input_tensor)
- _, predicted_class = torch.max(output, 1)
- return predicted_class.item()
- else:
- return None
- if __name__ == "__main__":
- args = get_args_parser().parse_args()
- result = predict(args)
- print(result)
|