liyan 7 месяцев назад
Родитель
Сommit
25d514df54
1 измененных файлов с 57 добавлено и 0 удалено
  1. 57 0
      predict.py

+ 57 - 0
predict.py

@@ -0,0 +1,57 @@
+import torch
+import torchvision
+import torchvision.transforms as T
+from PIL import Image
+
+
+def get_args_parser(add_help=True):
+    import argparse
+
+    parser = argparse.ArgumentParser(description="PyTorch Classification Training", add_help=add_help)
+
+    parser.add_argument("--img_path", default=None, type=str, help="predict image file path")
+    parser.add_argument("--model", default=None, type=str, help="model name")
+    parser.add_argument("--device", default="cuda", type=str, help="device (Use cuda or cpu Default: cuda)")
+    parser.add_argument("--weight", default=None, type=str, help="path of checkpoint")
+    parser.add_argument("--num_classes", default=10, type=int, help="num of classes")
+    return parser
+
+
+def predict(args):
+    print("Creating model")
+    device = torch.device(args.device)
+    model = torchvision.models.get_model(args.model, weights=args.weights, num_classes=args.num_classes)
+    model.to(device)
+    checkpoint = torch.load(args.weight, map_location="cpu", weights_only=False)
+    model.load_state_dict(checkpoint["model"])
+
+    # We disable the cudnn benchmarking because it can noticeably affect the accuracy
+    torch.backends.cudnn.benchmark = False
+    torch.backends.cudnn.deterministic = True
+
+    model.eval()
+
+    if args.img_path:
+        # 加载并预处理图像
+        image = Image.open(args.img_path).convert("RGB")
+        preprocess = T.Compose([
+            T.Resize((500, 500)),
+            T.ToTensor(),
+            T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
+        ])
+        input_tensor = preprocess(image).unsqueeze(0).to(device)  # 增加 batch 维度
+
+        # 进行预测
+        with torch.no_grad():
+            output = model(input_tensor)
+            _, predicted_class = torch.max(output, 1)
+
+        return predicted_class.item()
+    else:
+        return None
+
+
+if __name__ == "__main__":
+    args = get_args_parser().parse_args()
+    result = predict(args)
+    print(result)