|
@@ -1,9 +1,13 @@
|
|
|
+import os
|
|
|
+
|
|
|
import cv2
|
|
|
import tqdm
|
|
|
import wandb
|
|
|
import torch
|
|
|
import numpy as np
|
|
|
-import albumentations
|
|
|
+# import albumentations
|
|
|
+from PIL import Image
|
|
|
+from torchvision import transforms
|
|
|
from block.val_get import val_get
|
|
|
from block.model_ema import model_ema
|
|
|
from block.lr_get import adam, lr_adjust
|
|
@@ -26,13 +30,26 @@ def train_get(args, data_dict, model_dict, loss):
|
|
|
if args.ema:
|
|
|
ema.updates = model_dict['ema_updates']
|
|
|
# 数据集
|
|
|
- train_dataset = torch_dataset(args, 'train', data_dict['train'], data_dict['class'])
|
|
|
+ print("加载训练集至内存中...")
|
|
|
+ train_transform = transforms.Compose([
|
|
|
+ transforms.RandomHorizontalFlip(), # 随机水平翻转
|
|
|
+ transforms.RandomCrop(32, padding=4), # 随机裁剪并填充
|
|
|
+ transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1), # 颜色抖动
|
|
|
+ transforms.ToTensor(), # 将图像转换为PyTorch张量
|
|
|
+ transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) # 标准化
|
|
|
+ ])
|
|
|
+ train_dataset = CustomDataset(data_dir=args.train_dir, transform=train_transform)
|
|
|
train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset) if args.distributed else None
|
|
|
train_shuffle = False if args.distributed else True # 分布式设置sampler后shuffle要为False
|
|
|
train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=args.batch, shuffle=train_shuffle,
|
|
|
drop_last=True, pin_memory=args.latch, num_workers=args.num_worker,
|
|
|
sampler=train_sampler)
|
|
|
- val_dataset = torch_dataset(args, 'test', data_dict['test'], data_dict['class'])
|
|
|
+ print("加载验证集至内存中...")
|
|
|
+ val_transform = transforms.Compose([
|
|
|
+ transforms.ToTensor(), # 将图像转换为PyTorch张量
|
|
|
+ transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) # 标准化
|
|
|
+ ])
|
|
|
+ val_dataset = CustomDataset(data_dir=args.test_dir, transform=val_transform)
|
|
|
val_sampler = None # 分布式时数据合在主GPU上进行验证
|
|
|
val_batch = args.batch // args.device_number # 分布式验证时batch要减少为一个GPU的量
|
|
|
val_dataloader = torch.utils.data.DataLoader(val_dataset, batch_size=val_batch, shuffle=False,
|
|
@@ -106,8 +123,8 @@ def train_get(args, data_dict, model_dict, loss):
|
|
|
torch.cuda.empty_cache()
|
|
|
# 验证
|
|
|
if args.local_rank == 0: # 分布式时只验证一次
|
|
|
- val_loss, accuracy, precision, recall, m_ap = val_get(args, val_dataloader, model, loss, ema,
|
|
|
- len(data_dict['test']))
|
|
|
+ val_loss, accuracy = val_get(args, val_dataloader, model, loss, ema,
|
|
|
+ len(data_dict['test']))
|
|
|
# 保存
|
|
|
if args.local_rank == 0: # 分布式时只保存一次
|
|
|
model_dict['model'] = model.module if args.distributed else model
|
|
@@ -118,15 +135,12 @@ def train_get(args, data_dict, model_dict, loss):
|
|
|
model_dict['train_loss'] = train_loss
|
|
|
model_dict['val_loss'] = val_loss
|
|
|
model_dict['val_accuracy'] = accuracy
|
|
|
- model_dict['val_precision'] = precision
|
|
|
- model_dict['val_recall'] = recall
|
|
|
- model_dict['val_m_ap'] = m_ap
|
|
|
torch.save(model_dict, args.save_path_last if not args.prune else 'prune_last.pt') # 保存最后一次训练的模型
|
|
|
- if m_ap > 0.5 and m_ap > model_dict['standard']:
|
|
|
- model_dict['standard'] = m_ap
|
|
|
+ if accuracy > 0.5 and accuracy > model_dict['standard']:
|
|
|
+ model_dict['standard'] = accuracy
|
|
|
save_path = args.save_path if not args.prune else args.prune_save
|
|
|
torch.save(model_dict, save_path) # 保存最佳模型
|
|
|
- print(f'| 保存最佳模型:{save_path} | val_m_ap:{m_ap:.4f} |')
|
|
|
+ print(f'| 保存最佳模型:{save_path} | accuracy:{accuracy:.4f} |')
|
|
|
# wandb
|
|
|
if args.wandb:
|
|
|
wandb_log = {}
|
|
@@ -134,44 +148,80 @@ def train_get(args, data_dict, model_dict, loss):
|
|
|
wandb_log.update({f'image/train_image': wandb_image_list})
|
|
|
wandb_log.update({'metric/train_loss': train_loss,
|
|
|
'metric/val_loss': val_loss,
|
|
|
- 'metric/val_m_ap': m_ap,
|
|
|
- 'metric/val_accuracy': accuracy,
|
|
|
- 'metric/val_precision': precision,
|
|
|
- 'metric/val_recall': recall})
|
|
|
+ 'metric/val_accuracy': accuracy
|
|
|
+ })
|
|
|
args.wandb_run.log(wandb_log)
|
|
|
torch.distributed.barrier() if args.distributed else None # 分布式时每轮训练后让所有GPU进行同步,快的GPU会在此等待
|
|
|
|
|
|
|
|
|
-class torch_dataset(torch.utils.data.Dataset):
|
|
|
- def __init__(self, args, tag, data, class_name):
|
|
|
- self.tag = tag
|
|
|
- self.data = data
|
|
|
- self.class_name = class_name
|
|
|
- self.noise_probability = args.noise
|
|
|
- self.noise = albumentations.Compose([
|
|
|
- albumentations.GaussianBlur(blur_limit=(5, 5), p=0.2),
|
|
|
- albumentations.GaussNoise(var_limit=(10.0, 30.0), p=0.2)])
|
|
|
- self.transform = albumentations.Compose([
|
|
|
- albumentations.LongestMaxSize(args.input_size),
|
|
|
- albumentations.PadIfNeeded(min_height=args.input_size, min_width=args.input_size,
|
|
|
- border_mode=cv2.BORDER_CONSTANT, value=(128, 128, 128))])
|
|
|
- self.rgb_mean = (0.406, 0.456, 0.485)
|
|
|
- self.rgb_std = (0.225, 0.224, 0.229)
|
|
|
+# class torch_dataset(torch.utils.data.Dataset):
|
|
|
+# def __init__(self, args, tag, data, class_name):
|
|
|
+# self.tag = tag
|
|
|
+# self.data = data
|
|
|
+# self.class_name = class_name
|
|
|
+# self.noise_probability = args.noise
|
|
|
+# self.noise = albumentations.Compose([
|
|
|
+# albumentations.GaussianBlur(blur_limit=(5, 5), p=0.2),
|
|
|
+# albumentations.GaussNoise(var_limit=(10.0, 30.0), p=0.2)])
|
|
|
+# self.transform = albumentations.Compose([
|
|
|
+# albumentations.LongestMaxSize(args.input_size),
|
|
|
+# albumentations.PadIfNeeded(min_height=args.input_size, min_width=args.input_size,
|
|
|
+# border_mode=cv2.BORDER_CONSTANT, value=(128, 128, 128))])
|
|
|
+# self.rgb_mean = (0.406, 0.456, 0.485)
|
|
|
+# self.rgb_std = (0.225, 0.224, 0.229)
|
|
|
+#
|
|
|
+# def __len__(self):
|
|
|
+# return len(self.data)
|
|
|
+#
|
|
|
+# def __getitem__(self, index):
|
|
|
+# # print(self.data[index][0])
|
|
|
+# image = cv2.imread(self.data[index][0]) # 读取图片
|
|
|
+# if self.tag == 'train' and torch.rand(1) < self.noise_probability: # 使用数据加噪
|
|
|
+# image = self.noise(image=image)['image']
|
|
|
+# image = self.transform(image=image)['image'] # 缩放和填充图片
|
|
|
+# image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) # 转为RGB通道
|
|
|
+# image = self._image_deal(image) # 归一化、转换为tensor、调维度
|
|
|
+# label = torch.tensor(self.data[index][1], dtype=torch.float32) # 转换为tensor
|
|
|
+# return image, label
|
|
|
+#
|
|
|
+# def _image_deal(self, image): # 归一化、转换为tensor、调维度
|
|
|
+# image = torch.tensor(image / 255, dtype=torch.float32).permute(2, 0, 1)
|
|
|
+# return image
|
|
|
+
|
|
|
+
|
|
|
+class CustomDataset(torch.utils.data.Dataset):
|
|
|
+ def __init__(self, data_dir, image_size=(32, 32), transform=None):
|
|
|
+ self.data_dir = data_dir
|
|
|
+ self.image_size = image_size
|
|
|
+ self.transform = transform
|
|
|
+
|
|
|
+ self.images = []
|
|
|
+ self.labels = []
|
|
|
+
|
|
|
+ # 遍历指定目录下的子目录,每个子目录代表一个类别
|
|
|
+ class_dirs = sorted(os.listdir(data_dir))
|
|
|
+ for index, class_dir in enumerate(class_dirs):
|
|
|
+ class_path = os.path.join(data_dir, class_dir)
|
|
|
+
|
|
|
+ # 遍历当前类别目录下的图像文件
|
|
|
+ for image_file in os.listdir(class_path):
|
|
|
+ image_path = os.path.join(class_path, image_file)
|
|
|
+
|
|
|
+ # 使用PIL加载图像并调整大小
|
|
|
+ image = Image.open(image_path).convert('RGB')
|
|
|
+ image = image.resize(image_size)
|
|
|
+
|
|
|
+ self.images.append(np.array(image))
|
|
|
+ self.labels.append(index)
|
|
|
|
|
|
def __len__(self):
|
|
|
- return len(self.data)
|
|
|
-
|
|
|
- def __getitem__(self, index):
|
|
|
- # print(self.data[index][0])
|
|
|
- image = cv2.imread(self.data[index][0]) # 读取图片
|
|
|
- if self.tag == 'train' and torch.rand(1) < self.noise_probability: # 使用数据加噪
|
|
|
- image = self.noise(image=image)['image']
|
|
|
- image = self.transform(image=image)['image'] # 缩放和填充图片
|
|
|
- image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) # 转为RGB通道
|
|
|
- image = self._image_deal(image) # 归一化、转换为tensor、调维度
|
|
|
- label = torch.tensor(self.data[index][1], dtype=torch.float32) # 转换为tensor
|
|
|
- return image, label
|
|
|
+ return len(self.images)
|
|
|
|
|
|
- def _image_deal(self, image): # 归一化、转换为tensor、调维度
|
|
|
- image = torch.tensor(image / 255, dtype=torch.float32).permute(2, 0, 1)
|
|
|
- return image
|
|
|
+ def __getitem__(self, idx):
|
|
|
+ image = self.images[idx]
|
|
|
+ label = self.labels[idx]
|
|
|
+
|
|
|
+ if self.transform:
|
|
|
+ image = self.transform(Image.fromarray(image))
|
|
|
+
|
|
|
+ return image, label
|