metric_get.py 655 B

12345678910111213
  1. import torch
  2. def metric(pred, true, class_threshold): # 所有类别输出在0.5以下为空标签
  3. TP = len(pred[torch.where((true == 1) & (pred > class_threshold), True, False)])
  4. TN = len(pred[torch.where((true == 0) & (pred <= class_threshold), True, False)])
  5. FP = len(pred[torch.where((true == 0) & (pred > class_threshold), True, False)])
  6. FN = len(pred[torch.where((true == 1) & (pred <= class_threshold), True, False)])
  7. accuracy = (TP + TN) / (TP + TN + FP + FN + 0.00001)
  8. precision = TP / (TP + FP + 0.00001)
  9. recall = TP / (TP + FN + 0.00001)
  10. m_ap = precision * recall
  11. return accuracy, precision, recall, m_ap