|
@@ -154,41 +154,6 @@ def train_get(args, data_dict, model_dict, loss):
|
|
|
torch.distributed.barrier() if args.distributed else None # 分布式时每轮训练后让所有GPU进行同步,快的GPU会在此等待
|
|
|
|
|
|
|
|
|
-# class torch_dataset(torch.utils.data.Dataset):
|
|
|
-# def __init__(self, args, tag, data, class_name):
|
|
|
-# self.tag = tag
|
|
|
-# self.data = data
|
|
|
-# self.class_name = class_name
|
|
|
-# self.noise_probability = args.noise
|
|
|
-# self.noise = albumentations.Compose([
|
|
|
-# albumentations.GaussianBlur(blur_limit=(5, 5), p=0.2),
|
|
|
-# albumentations.GaussNoise(var_limit=(10.0, 30.0), p=0.2)])
|
|
|
-# self.transform = albumentations.Compose([
|
|
|
-# albumentations.LongestMaxSize(args.input_size),
|
|
|
-# albumentations.PadIfNeeded(min_height=args.input_size, min_width=args.input_size,
|
|
|
-# border_mode=cv2.BORDER_CONSTANT, value=(128, 128, 128))])
|
|
|
-# self.rgb_mean = (0.406, 0.456, 0.485)
|
|
|
-# self.rgb_std = (0.225, 0.224, 0.229)
|
|
|
-#
|
|
|
-# def __len__(self):
|
|
|
-# return len(self.data)
|
|
|
-#
|
|
|
-# def __getitem__(self, index):
|
|
|
-# # print(self.data[index][0])
|
|
|
-# image = cv2.imread(self.data[index][0]) # 读取图片
|
|
|
-# if self.tag == 'train' and torch.rand(1) < self.noise_probability: # 使用数据加噪
|
|
|
-# image = self.noise(image=image)['image']
|
|
|
-# image = self.transform(image=image)['image'] # 缩放和填充图片
|
|
|
-# image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) # 转为RGB通道
|
|
|
-# image = self._image_deal(image) # 归一化、转换为tensor、调维度
|
|
|
-# label = torch.tensor(self.data[index][1], dtype=torch.float32) # 转换为tensor
|
|
|
-# return image, label
|
|
|
-#
|
|
|
-# def _image_deal(self, image): # 归一化、转换为tensor、调维度
|
|
|
-# image = torch.tensor(image / 255, dtype=torch.float32).permute(2, 0, 1)
|
|
|
-# return image
|
|
|
-
|
|
|
-
|
|
|
class CustomDataset(torch.utils.data.Dataset):
|
|
|
def __init__(self, data_dir, image_size=(32, 32), transform=None):
|
|
|
self.data_dir = data_dir
|