12345678910111213 |
- import torch
- def metric(pred, true, class_threshold): # 所有类别输出在0.5以下为空标签
- TP = len(pred[torch.where((true == 1) & (pred > class_threshold), True, False)])
- TN = len(pred[torch.where((true == 0) & (pred <= class_threshold), True, False)])
- FP = len(pred[torch.where((true == 0) & (pred > class_threshold), True, False)])
- FN = len(pred[torch.where((true == 1) & (pred <= class_threshold), True, False)])
- accuracy = (TP + TN) / (TP + TN + FP + FN + 0.00001)
- precision = TP / (TP + FP + 0.00001)
- recall = TP / (TP + FN + 0.00001)
- m_ap = precision * recall
- return accuracy, precision, recall, m_ap
|