|
@@ -0,0 +1,57 @@
|
|
|
|
+import torch
|
|
|
|
+import torchvision
|
|
|
|
+import torchvision.transforms as T
|
|
|
|
+from PIL import Image
|
|
|
|
+
|
|
|
|
+
|
|
|
|
+def get_args_parser(add_help=True):
|
|
|
|
+ import argparse
|
|
|
|
+
|
|
|
|
+ parser = argparse.ArgumentParser(description="PyTorch Classification Training", add_help=add_help)
|
|
|
|
+
|
|
|
|
+ parser.add_argument("--img_path", default=None, type=str, help="predict image file path")
|
|
|
|
+ parser.add_argument("--model", default=None, type=str, help="model name")
|
|
|
|
+ parser.add_argument("--device", default="cuda", type=str, help="device (Use cuda or cpu Default: cuda)")
|
|
|
|
+ parser.add_argument("--weight", default=None, type=str, help="path of checkpoint")
|
|
|
|
+ parser.add_argument("--num_classes", default=10, type=int, help="num of classes")
|
|
|
|
+ return parser
|
|
|
|
+
|
|
|
|
+
|
|
|
|
+def predict(args):
|
|
|
|
+ print("Creating model")
|
|
|
|
+ device = torch.device(args.device)
|
|
|
|
+ model = torchvision.models.get_model(args.model, weights=args.weights, num_classes=args.num_classes)
|
|
|
|
+ model.to(device)
|
|
|
|
+ checkpoint = torch.load(args.weight, map_location="cpu", weights_only=False)
|
|
|
|
+ model.load_state_dict(checkpoint["model"])
|
|
|
|
+
|
|
|
|
+ # We disable the cudnn benchmarking because it can noticeably affect the accuracy
|
|
|
|
+ torch.backends.cudnn.benchmark = False
|
|
|
|
+ torch.backends.cudnn.deterministic = True
|
|
|
|
+
|
|
|
|
+ model.eval()
|
|
|
|
+
|
|
|
|
+ if args.img_path:
|
|
|
|
+ # 加载并预处理图像
|
|
|
|
+ image = Image.open(args.img_path).convert("RGB")
|
|
|
|
+ preprocess = T.Compose([
|
|
|
|
+ T.Resize((500, 500)),
|
|
|
|
+ T.ToTensor(),
|
|
|
|
+ T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
|
|
|
|
+ ])
|
|
|
|
+ input_tensor = preprocess(image).unsqueeze(0).to(device) # 增加 batch 维度
|
|
|
|
+
|
|
|
|
+ # 进行预测
|
|
|
|
+ with torch.no_grad():
|
|
|
|
+ output = model(input_tensor)
|
|
|
|
+ _, predicted_class = torch.max(output, 1)
|
|
|
|
+
|
|
|
|
+ return predicted_class.item()
|
|
|
|
+ else:
|
|
|
|
+ return None
|
|
|
|
+
|
|
|
|
+
|
|
|
|
+if __name__ == "__main__":
|
|
|
|
+ args = get_args_parser().parse_args()
|
|
|
|
+ result = predict(args)
|
|
|
|
+ print(result)
|