import torch import torchvision import torchvision.transforms as T from PIL import Image def get_args_parser(add_help=True): import argparse parser = argparse.ArgumentParser(description="PyTorch Classification Training", add_help=add_help) parser.add_argument("--img_path", default=None, type=str, help="predict image file path") parser.add_argument("--model", default=None, type=str, help="model name") parser.add_argument("--device", default="cuda", type=str, help="device (Use cuda or cpu Default: cuda)") parser.add_argument("--weight", default=None, type=str, help="path of checkpoint") parser.add_argument("--num_classes", default=10, type=int, help="num of classes") return parser def predict(args): print("Creating model") device = torch.device(args.device) model = torchvision.models.get_model(args.model, weights=args.weights, num_classes=args.num_classes) model.to(device) checkpoint = torch.load(args.weight, map_location="cpu", weights_only=False) model.load_state_dict(checkpoint["model"]) # We disable the cudnn benchmarking because it can noticeably affect the accuracy torch.backends.cudnn.benchmark = False torch.backends.cudnn.deterministic = True model.eval() if args.img_path: # 加载并预处理图像 image = Image.open(args.img_path).convert("RGB") preprocess = T.Compose([ T.Resize((500, 500)), T.ToTensor(), T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ]) input_tensor = preprocess(image).unsqueeze(0).to(device) # 增加 batch 维度 # 进行预测 with torch.no_grad(): output = model(input_tensor) _, predicted_class = torch.max(output, 1) return predicted_class.item() else: return None if __name__ == "__main__": args = get_args_parser().parse_args() result = predict(args) print(result)