浏览代码

移动log位置

zhy 2 周之前
父节点
当前提交
d6993e6fd1
共有 1 个文件被更改,包括 1 次插入1 次删除
  1. 1 1
      tests/classification_performance_loss_test.py

+ 1 - 1
tests/classification_performance_loss_test.py

@@ -137,7 +137,6 @@ def batch_predict_images(session, image_dir, target_class, batch_size=10, pytorc
             predicted_class = np.argmax(outputs[0][j])  # 假设输出是每类的概率
             results[image_file] = predicted_class
             total_predictions += 1
-            print(f"Image: {image_file}, Predicted: {predicted_class}, Target: {target_class}")
 
             # 比较预测结果与目标分类
             if predicted_class == target_class:
@@ -145,6 +144,7 @@ def batch_predict_images(session, image_dir, target_class, batch_size=10, pytorc
 
     # 计算准确率
     accuracy = correct_predictions / total_predictions if total_predictions > 0 else 0
+    print(f"Processed {len(image_files)} images, accuracy: {accuracy * 100:.2f}%")
     return accuracy