metric_get.py 2.7 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465
  1. import torch
  2. import torchvision
  3. def center_to_min(pred, true): # (Cx,Cy)->(x_min,y_min)
  4. pred[:, 0:2] = pred[:, 0:2] - 1 / 2 * pred[:, 2:4]
  5. true[:, 0:2] = true[:, 0:2] - 1 / 2 * true[:, 2:4]
  6. return pred, true
  7. def confidence_screen(pred, confidence_threshold):
  8. result = []
  9. for i in range(len(pred)): # 对一张图片的每个输出层分别进行操作
  10. judge = torch.where(pred[i][..., 4] > confidence_threshold, True, False)
  11. result.append((pred[i][judge]))
  12. result = torch.concat(result, dim=0)
  13. if result.shape[0] == 0:
  14. return result
  15. index = torch.argsort(result[:, 4], dim=0, descending=True)
  16. result = result[index]
  17. return result
  18. def iou_single(A, B): # 输入为(batch,(x_min,y_min,w,h))相对/真实坐标
  19. x1 = torch.maximum(A[:, 0], B[0])
  20. y1 = torch.maximum(A[:, 1], B[1])
  21. x2 = torch.minimum(A[:, 0] + A[:, 2], B[0] + B[2])
  22. y2 = torch.minimum(A[:, 1] + A[:, 3], B[1] + B[3])
  23. zeros = torch.zeros(1, device=A.device)
  24. intersection = torch.maximum(x2 - x1, zeros) * torch.maximum(y2 - y1, zeros)
  25. union = A[:, 2] * A[:, 3] + B[2] * B[3] - intersection
  26. return intersection / union
  27. def iou(pred, true): # 输入为(batch,(x_min,y_min,w,h))相对/真实坐标
  28. x1 = torch.maximum(pred[:, 0], true[:, 0])
  29. y1 = torch.maximum(pred[:, 1], true[:, 1])
  30. x2 = torch.minimum(pred[:, 0] + pred[:, 2], true[:, 0] + true[:, 2])
  31. y2 = torch.minimum(pred[:, 1] + pred[:, 3], true[:, 1] + true[:, 3])
  32. zeros = torch.zeros(1, device=pred.device)
  33. intersection = torch.maximum(x2 - x1, zeros) * torch.maximum(y2 - y1, zeros)
  34. union = pred[:, 2] * pred[:, 3] + true[:, 2] * true[:, 3] - intersection
  35. return intersection / union
  36. def nms(pred, iou_threshold): # 输入为(batch,(x_min,y_min,w,h))相对/真实坐标
  37. pred[:, 2:4] = pred[:, 0:2] + pred[:, 2:4] # (x_min,y_min,x_max,y_max)真实坐标
  38. index = torchvision.ops.nms(pred[:, 0:4], pred[:, 4], 1 - iou_threshold)[:100] # 非极大值抑制,最多100
  39. pred = pred[index]
  40. pred[:, 2:4] = pred[:, 2:4] - pred[:, 0:2] # (x_min,y_min,w,h)真实坐标
  41. return pred
  42. def nms_tp_fn_fp(pred, true, iou_threshold): # 输入为(batch,(x_min,y_min,w,h,其他,类别号))相对/真实坐标
  43. pred_cls = torch.argmax(pred[:, 5:], dim=1)
  44. true_cls = torch.argmax(true[:, 5:], dim=1)
  45. tp = 0
  46. for i in range(len(true)):
  47. target = true[i]
  48. iou_all = iou_single(pred, target)
  49. judge_tp = torch.where((iou_all > iou_threshold) & (pred_cls == true_cls[i]), True, False)
  50. tp += min(len(pred[judge_tp]), 1) # 存在多个框之间iou大于阈值,但都与标签小于阈值,此时只算1个tp,其他都为fp
  51. fp = len(pred) - tp
  52. fn = len(true) - tp
  53. return tp, fp, fn