Browse Source

修改类别顺序不一致,准确率低的问题

zhy 2 weeks ago
parent
commit
ff892ca099
1 changed files with 51 additions and 18 deletions
  1. 51 18
      tests/classification_performance_loss_test.py

+ 51 - 18
tests/classification_performance_loss_test.py

@@ -137,6 +137,7 @@ def batch_predict_images(session, image_dir, target_class, batch_size=10, pytorc
             predicted_class = np.argmax(outputs[0][j])  # 假设输出是每类的概率
             predicted_class = np.argmax(outputs[0][j])  # 假设输出是每类的概率
             results[image_file] = predicted_class
             results[image_file] = predicted_class
             total_predictions += 1
             total_predictions += 1
+            print(f"Image: {image_file}, Predicted: {predicted_class}, Target: {target_class}")
 
 
             # 比较预测结果与目标分类
             # 比较预测结果与目标分类
             if predicted_class == target_class:
             if predicted_class == target_class:
@@ -147,7 +148,37 @@ def batch_predict_images(session, image_dir, target_class, batch_size=10, pytorc
     return accuracy
     return accuracy
 
 
 
 
-# 模型推理函数
+# # 模型推理函数
+# def model_inference(model_filename, val_dataset_dir):
+#     """
+#     模型推理验证集目录下所有图片
+#     :param model_filename: 模型文件
+#     :param val_dataset_dir: 验证集图片目录
+#     :return: 验证集推理准确率
+#     """
+#     # 以下使用GPU进行推理出现问题,需要较新的CUDA版本,默认使用CPU进行推理
+#     # if ort.get_available_providers():
+#     #     session = ort.InferenceSession(model_filename, providers=['CUDAExecutionProvider'])
+#     # else:
+#     #     session = ort.InferenceSession(model_filename)
+#     session = ort.InferenceSession(model_filename)
+#     accuracy = 0
+#     class_num = 0
+#     index = 0
+#     for class_dir in os.listdir(val_dataset_dir):
+#         class_path = os.path.join(val_dataset_dir, class_dir)
+#         # 检查是否为目录
+#         if not os.path.isdir(class_path):
+#             continue
+#         class_num += 1
+#         is_pytorch = False if "keras" in model_filename or "tensorflow" in model_filename else True
+#         batch_result = batch_predict_images(session, class_path, index, pytorch=is_pytorch)
+#         accuracy += batch_result
+#         index += 1
+#     print(f"class_num: {class_num}, index: {index}")
+#     return accuracy * 1.0 / class_num
+
+
 def model_inference(model_filename, val_dataset_dir):
 def model_inference(model_filename, val_dataset_dir):
     """
     """
     模型推理验证集目录下所有图片
     模型推理验证集目录下所有图片
@@ -155,27 +186,29 @@ def model_inference(model_filename, val_dataset_dir):
     :param val_dataset_dir: 验证集图片目录
     :param val_dataset_dir: 验证集图片目录
     :return: 验证集推理准确率
     :return: 验证集推理准确率
     """
     """
-    # 以下使用GPU进行推理出现问题,需要较新的CUDA版本,默认使用CPU进行推理
-    # if ort.get_available_providers():
-    #     session = ort.InferenceSession(model_filename, providers=['CUDAExecutionProvider'])
-    # else:
-    #     session = ort.InferenceSession(model_filename)
+    # 默认使用 CPU 推理
     session = ort.InferenceSession(model_filename)
     session = ort.InferenceSession(model_filename)
+
+    # 1. 固定类别顺序,确保 index 和模型输出匹配
+    class_names = sorted([d for d in os.listdir(val_dataset_dir) if os.path.isdir(os.path.join(val_dataset_dir, d))])
+    class_to_index = {name: idx for idx, name in enumerate(class_names)}
+
     accuracy = 0
     accuracy = 0
-    class_num = 0
-    index = 0
-    for class_dir in os.listdir(val_dataset_dir):
+    class_num = len(class_names)
+
+    is_pytorch = False if "keras" in model_filename or "tensorflow" in model_filename else True
+
+    for class_dir in class_names:
         class_path = os.path.join(val_dataset_dir, class_dir)
         class_path = os.path.join(val_dataset_dir, class_dir)
-        # 检查是否为目录
-        if not os.path.isdir(class_path):
-            continue
-        class_num += 1
-        is_pytorch = False if "keras" in model_filename or "tensorflow" in model_filename else True
-        batch_result = batch_predict_images(session, class_path, index, pytorch=is_pytorch)
+        target_class = class_to_index[class_dir]
+
+        # 2. 对该类进行批量预测
+        batch_result = batch_predict_images(session, class_path, target_class, pytorch=is_pytorch)
         accuracy += batch_result
         accuracy += batch_result
-        index += 1
-    print(f"class_num: {class_num}, index: {index}")
-    return accuracy * 1.0 / class_num
+
+    print(f"class_num: {class_num}")
+    return accuracy / class_num if class_num > 0 else 0
+
 
 
 
 
 if __name__ == '__main__':
 if __name__ == '__main__':