|
@@ -137,6 +137,7 @@ def batch_predict_images(session, image_dir, target_class, batch_size=10, pytorc
|
|
predicted_class = np.argmax(outputs[0][j]) # 假设输出是每类的概率
|
|
predicted_class = np.argmax(outputs[0][j]) # 假设输出是每类的概率
|
|
results[image_file] = predicted_class
|
|
results[image_file] = predicted_class
|
|
total_predictions += 1
|
|
total_predictions += 1
|
|
|
|
+ print(f"Image: {image_file}, Predicted: {predicted_class}, Target: {target_class}")
|
|
|
|
|
|
# 比较预测结果与目标分类
|
|
# 比较预测结果与目标分类
|
|
if predicted_class == target_class:
|
|
if predicted_class == target_class:
|
|
@@ -147,7 +148,37 @@ def batch_predict_images(session, image_dir, target_class, batch_size=10, pytorc
|
|
return accuracy
|
|
return accuracy
|
|
|
|
|
|
|
|
|
|
-# 模型推理函数
|
|
|
|
|
|
+# # 模型推理函数
|
|
|
|
+# def model_inference(model_filename, val_dataset_dir):
|
|
|
|
+# """
|
|
|
|
+# 模型推理验证集目录下所有图片
|
|
|
|
+# :param model_filename: 模型文件
|
|
|
|
+# :param val_dataset_dir: 验证集图片目录
|
|
|
|
+# :return: 验证集推理准确率
|
|
|
|
+# """
|
|
|
|
+# # 以下使用GPU进行推理出现问题,需要较新的CUDA版本,默认使用CPU进行推理
|
|
|
|
+# # if ort.get_available_providers():
|
|
|
|
+# # session = ort.InferenceSession(model_filename, providers=['CUDAExecutionProvider'])
|
|
|
|
+# # else:
|
|
|
|
+# # session = ort.InferenceSession(model_filename)
|
|
|
|
+# session = ort.InferenceSession(model_filename)
|
|
|
|
+# accuracy = 0
|
|
|
|
+# class_num = 0
|
|
|
|
+# index = 0
|
|
|
|
+# for class_dir in os.listdir(val_dataset_dir):
|
|
|
|
+# class_path = os.path.join(val_dataset_dir, class_dir)
|
|
|
|
+# # 检查是否为目录
|
|
|
|
+# if not os.path.isdir(class_path):
|
|
|
|
+# continue
|
|
|
|
+# class_num += 1
|
|
|
|
+# is_pytorch = False if "keras" in model_filename or "tensorflow" in model_filename else True
|
|
|
|
+# batch_result = batch_predict_images(session, class_path, index, pytorch=is_pytorch)
|
|
|
|
+# accuracy += batch_result
|
|
|
|
+# index += 1
|
|
|
|
+# print(f"class_num: {class_num}, index: {index}")
|
|
|
|
+# return accuracy * 1.0 / class_num
|
|
|
|
+
|
|
|
|
+
|
|
def model_inference(model_filename, val_dataset_dir):
|
|
def model_inference(model_filename, val_dataset_dir):
|
|
"""
|
|
"""
|
|
模型推理验证集目录下所有图片
|
|
模型推理验证集目录下所有图片
|
|
@@ -155,27 +186,29 @@ def model_inference(model_filename, val_dataset_dir):
|
|
:param val_dataset_dir: 验证集图片目录
|
|
:param val_dataset_dir: 验证集图片目录
|
|
:return: 验证集推理准确率
|
|
:return: 验证集推理准确率
|
|
"""
|
|
"""
|
|
- # 以下使用GPU进行推理出现问题,需要较新的CUDA版本,默认使用CPU进行推理
|
|
|
|
- # if ort.get_available_providers():
|
|
|
|
- # session = ort.InferenceSession(model_filename, providers=['CUDAExecutionProvider'])
|
|
|
|
- # else:
|
|
|
|
- # session = ort.InferenceSession(model_filename)
|
|
|
|
|
|
+ # 默认使用 CPU 推理
|
|
session = ort.InferenceSession(model_filename)
|
|
session = ort.InferenceSession(model_filename)
|
|
|
|
+
|
|
|
|
+ # 1. 固定类别顺序,确保 index 和模型输出匹配
|
|
|
|
+ class_names = sorted([d for d in os.listdir(val_dataset_dir) if os.path.isdir(os.path.join(val_dataset_dir, d))])
|
|
|
|
+ class_to_index = {name: idx for idx, name in enumerate(class_names)}
|
|
|
|
+
|
|
accuracy = 0
|
|
accuracy = 0
|
|
- class_num = 0
|
|
|
|
- index = 0
|
|
|
|
- for class_dir in os.listdir(val_dataset_dir):
|
|
|
|
|
|
+ class_num = len(class_names)
|
|
|
|
+
|
|
|
|
+ is_pytorch = False if "keras" in model_filename or "tensorflow" in model_filename else True
|
|
|
|
+
|
|
|
|
+ for class_dir in class_names:
|
|
class_path = os.path.join(val_dataset_dir, class_dir)
|
|
class_path = os.path.join(val_dataset_dir, class_dir)
|
|
- # 检查是否为目录
|
|
|
|
- if not os.path.isdir(class_path):
|
|
|
|
- continue
|
|
|
|
- class_num += 1
|
|
|
|
- is_pytorch = False if "keras" in model_filename or "tensorflow" in model_filename else True
|
|
|
|
- batch_result = batch_predict_images(session, class_path, index, pytorch=is_pytorch)
|
|
|
|
|
|
+ target_class = class_to_index[class_dir]
|
|
|
|
+
|
|
|
|
+ # 2. 对该类进行批量预测
|
|
|
|
+ batch_result = batch_predict_images(session, class_path, target_class, pytorch=is_pytorch)
|
|
accuracy += batch_result
|
|
accuracy += batch_result
|
|
- index += 1
|
|
|
|
- print(f"class_num: {class_num}, index: {index}")
|
|
|
|
- return accuracy * 1.0 / class_num
|
|
|
|
|
|
+
|
|
|
|
+ print(f"class_num: {class_num}")
|
|
|
|
+ return accuracy / class_num if class_num > 0 else 0
|
|
|
|
+
|
|
|
|
|
|
|
|
|
|
if __name__ == '__main__':
|
|
if __name__ == '__main__':
|