Explorar el Código

去除无关代码

liyan hace 1 año
padre
commit
7e3cf05934
Se han modificado 2 ficheros con 0 adiciones y 48 borrados
  1. 0 13
      block/metric_get.py
  2. 0 35
      block/train_get.py

+ 0 - 13
block/metric_get.py

@@ -1,13 +0,0 @@
-import torch
-
-
-def metric(pred, true, class_threshold):  # 所有类别输出在0.5以下为空标签
-    TP = len(pred[torch.where((true == 1) & (pred > class_threshold), True, False)])
-    TN = len(pred[torch.where((true == 0) & (pred <= class_threshold), True, False)])
-    FP = len(pred[torch.where((true == 0) & (pred > class_threshold), True, False)])
-    FN = len(pred[torch.where((true == 1) & (pred <= class_threshold), True, False)])
-    accuracy = (TP + TN) / (TP + TN + FP + FN + 0.00001)
-    precision = TP / (TP + FP + 0.00001)
-    recall = TP / (TP + FN + 0.00001)
-    m_ap = precision * recall
-    return accuracy, precision, recall, m_ap

+ 0 - 35
block/train_get.py

@@ -154,41 +154,6 @@ def train_get(args, data_dict, model_dict, loss):
         torch.distributed.barrier() if args.distributed else None  # 分布式时每轮训练后让所有GPU进行同步,快的GPU会在此等待
 
 
-# class torch_dataset(torch.utils.data.Dataset):
-#     def __init__(self, args, tag, data, class_name):
-#         self.tag = tag
-#         self.data = data
-#         self.class_name = class_name
-#         self.noise_probability = args.noise
-#         self.noise = albumentations.Compose([
-#             albumentations.GaussianBlur(blur_limit=(5, 5), p=0.2),
-#             albumentations.GaussNoise(var_limit=(10.0, 30.0), p=0.2)])
-#         self.transform = albumentations.Compose([
-#             albumentations.LongestMaxSize(args.input_size),
-#             albumentations.PadIfNeeded(min_height=args.input_size, min_width=args.input_size,
-#                                        border_mode=cv2.BORDER_CONSTANT, value=(128, 128, 128))])
-#         self.rgb_mean = (0.406, 0.456, 0.485)
-#         self.rgb_std = (0.225, 0.224, 0.229)
-#
-#     def __len__(self):
-#         return len(self.data)
-#
-#     def __getitem__(self, index):
-#         # print(self.data[index][0])
-#         image = cv2.imread(self.data[index][0])  # 读取图片
-#         if self.tag == 'train' and torch.rand(1) < self.noise_probability:  # 使用数据加噪
-#             image = self.noise(image=image)['image']
-#         image = self.transform(image=image)['image']  # 缩放和填充图片
-#         image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)  # 转为RGB通道
-#         image = self._image_deal(image)  # 归一化、转换为tensor、调维度
-#         label = torch.tensor(self.data[index][1], dtype=torch.float32)  # 转换为tensor
-#         return image, label
-#
-#     def _image_deal(self, image):  # 归一化、转换为tensor、调维度
-#         image = torch.tensor(image / 255, dtype=torch.float32).permute(2, 0, 1)
-#         return image
-
-
 class CustomDataset(torch.utils.data.Dataset):
     def __init__(self, data_dir, image_size=(32, 32), transform=None):
         self.data_dir = data_dir