run.py 8.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117
  1. # 数据需准备成以下格式(标准YOLO格式)
  2. # ├── 数据集路径:data_path
  3. # └── image:存放所有图片
  4. # └── label:存放所有图片的标签,名称:图片名.txt,内容:(类别号 x_center y_center w h\n)相对图片的比例值
  5. # └── train.txt:训练图片的绝对路径(或相对data_path下路径)
  6. # └── val.txt:验证图片的绝对路径(或相对data_path下路径)
  7. # └── class.txt:所有的类别名称
  8. # class.csv内容如下:
  9. # 类别1
  10. # 类别2
  11. # ...
  12. # -------------------------------------------------------------------------------------------------------------------- #
  13. import os
  14. import wandb
  15. import torch
  16. import argparse
  17. from block.data_get import data_get
  18. from block.model_get import model_get
  19. from block.loss_get import loss_get
  20. from block.train_get import train_get
  21. # -------------------------------------------------------------------------------------------------------------------- #
  22. # 分布式数据并行训练:
  23. # python -m torch.distributed.launch --master_port 9999 --nproc_per_node n run.py --distributed True
  24. # master_port为GPU之间的通讯端口,空闲的即可
  25. # n为GPU数量
  26. # -------------------------------------------------------------------------------------------------------------------- #
  27. # 模型加载/创建的优先级为:加载已有模型>创建剪枝模型>创建自定义模型
  28. parser = argparse.ArgumentParser(description='|目标检测|')
  29. parser.add_argument('--wandb', default=False, type=bool, help='|是否使用wandb可视化|')
  30. parser.add_argument('--wandb_project', default='ObjectDetection', type=str, help='|wandb项目名称|')
  31. parser.add_argument('--wandb_name', default='train', type=str, help='|wandb项目中的训练名称|')
  32. parser.add_argument('--wandb_image_num', default=16, type=int, help='|wandb保存展示图片的数量|')
  33. parser.add_argument('--data_path', default=r'./datasets/coco_wm', type=str, help='|数据目录|')
  34. parser.add_argument('--input_size', default=640, type=int, help='|输入图片大小|')
  35. parser.add_argument('--output_class', default=80, type=int, help='|输出类别数|')
  36. parser.add_argument('--weight', default='last.pt', type=str, help='|已有模型的位置,没找到模型会创建剪枝/新模型|')
  37. parser.add_argument('--prune', default=False, type=bool, help='|模型剪枝后再训练(部分模型有),需要提供prune_weight|')
  38. parser.add_argument('--prune_weight', default='best.pt', type=str, help='|模型剪枝的参考模型,会创建剪枝模型和训练模型|')
  39. parser.add_argument('--prune_ratio', default=0.5, type=float, help='|模型剪枝时的保留比例|')
  40. parser.add_argument('--prune_save', default='prune_best.pt', type=str, help='|保存最佳模型,每轮还会保存prune_last.pt|')
  41. parser.add_argument('--model', default='yolov7', type=str, help='|自定义模型选择|')
  42. parser.add_argument('--model_type', default='n', type=str, help='|自定义模型型号|')
  43. parser.add_argument('--save_path', default='best.pt', type=str, help='|保存最佳模型,除此之外每轮还会保存last.pt|')
  44. parser.add_argument('--loss_weight', default=((1 / 3, 0.3, 0.5, 0.2), (1 / 3, 0.4, 0.4, 0.2), (1 / 3, 0.5, 0.3, 0.2)),
  45. type=tuple, help='|每个输出层(从大到小排序)的权重->[总权重、边框权重、置信度权重、分类权重]|')
  46. parser.add_argument('--label_smooth', default=(0.01, 0.99), type=tuple, help='|标签平滑的值|')
  47. parser.add_argument('--epoch', default=10, type=int, help='|训练总轮数(包含之前已训练轮数)|')
  48. parser.add_argument('--batch', default=2, type=int, help='|训练批量大小,分布式时为总批量|')
  49. parser.add_argument('--warmup_ratio', default=0.01, type=float, help='|预热训练步数占总步数比例,最少5步,基准为0.01|')
  50. parser.add_argument('--lr_start', default=0.001, type=float, help='|初始学习率,adam算法,批量小时要减小,基准为0.001|')
  51. parser.add_argument('--lr_end_ratio', default=0.01, type=float, help='|最终学习率=lr_end_ratio*lr_start,基准为0.01|')
  52. parser.add_argument('--lr_end_epoch', default=300, type=int, help='|最终学习率达到的轮数,每一步都调整,余玄下降法|')
  53. parser.add_argument('--regularization', default='L2', type=str, help='|正则化,有L2、None|')
  54. parser.add_argument('--r_value', default=0.0005, type=float, help='|正则化权重系数,基准为0.0005|')
  55. parser.add_argument('--device', default='cuda', type=str, help='|训练设备|')
  56. parser.add_argument('--latch', default=True, type=bool, help='|模型和数据是否为锁存,True为锁存|')
  57. parser.add_argument('--num_worker', default=0, type=int, help='|CPU处理数据的进程数,0只有一个主进程,一般为0、2、4、8|')
  58. parser.add_argument('--ema', default=True, type=bool, help='|使用平均指数移动(EMA)调整参数|')
  59. parser.add_argument('--amp', default=True, type=bool, help='|混合float16精度训练,CPU时不可用,出现nan可能与GPU有关|')
  60. parser.add_argument('--mosaic', default=0.5, type=float, help='|使用mosaic增强的概率|')
  61. parser.add_argument('--mosaic_hsv', default=0.5, type=float, help='|mosaic增强时的hsv通道随机变换概率|')
  62. parser.add_argument('--mosaic_flip', default=0.5, type=float, help='|mosaic增强时的垂直翻转概率|')
  63. parser.add_argument('--mosaic_screen', default=10, type=int, help='|mosaic增强后留下的框w,h不能小于mosaic_screen|')
  64. parser.add_argument('--confidence_threshold', default=0.35, type=float, help='|指标计算置信度阈值|')
  65. parser.add_argument('--iou_threshold', default=0.5, type=float, help='|指标计算iou阈值|')
  66. parser.add_argument('--distributed', default=False, type=bool, help='|单机多卡分布式训练,分布式训练时batch为总batch|')
  67. parser.add_argument('--local_rank', default=0, type=int, help='|分布式训练使用命令后会自动传入的参数|')
  68. args = parser.parse_args()
  69. args.device_number = max(torch.cuda.device_count(), 2) # 使用的GPU数,可能为CPU
  70. # 为CPU设置随机种子
  71. torch.manual_seed(999)
  72. # 为所有GPU设置随机种子
  73. torch.cuda.manual_seed_all(999)
  74. # 固定每次返回的卷积算法
  75. torch.backends.cudnn.deterministic = True
  76. # cuDNN使用非确定性算法
  77. torch.backends.cudnn.enabled = True
  78. # 训练前cuDNN会先搜寻每个卷积层最适合实现它的卷积算法,加速运行;但对于复杂变化的输入数据,可能会有过长的搜寻时间,对于训练比较快的网络建议设为False
  79. torch.backends.cudnn.benchmark = False
  80. # wandb可视化:https://wandb.ai
  81. if args.wandb and args.local_rank == 0: # 分布式时只记录一次wandb
  82. args.wandb_run = wandb.init(project=args.wandb_project, name=args.wandb_name, config=args)
  83. # 混合float16精度训练
  84. if args.amp:
  85. args.amp = torch.cuda.amp.GradScaler()
  86. # 分布式训练
  87. if args.distributed:
  88. torch.distributed.init_process_group(backend='nccl') # 分布式训练初始化
  89. args.device = torch.device("cuda", args.local_rank)
  90. # -------------------------------------------------------------------------------------------------------------------- #
  91. if args.local_rank == 0:
  92. print(f'| args:{args} |')
  93. assert os.path.exists(f'{args.data_path}/images'), '! data_path中缺少:image !'
  94. assert os.path.exists(f'{args.data_path}/labels'), '! data_path中缺少:label !'
  95. assert os.path.exists(f'{args.data_path}/train.txt'), '! data_path中缺少:train.txt !'
  96. assert os.path.exists(f'{args.data_path}/val.txt'), '! data_path中缺少:val.txt !'
  97. assert os.path.exists(f'{args.data_path}/class.txt'), '! data_path中缺少:class.txt !'
  98. if os.path.exists(args.weight): # 优先加载已有模型args.weight继续训练
  99. print(f'| 加载已有模型:{args.weight} |')
  100. elif args.prune:
  101. print(f'| 加载模型+剪枝训练:{args.prune_weight} |')
  102. else: # 创建自定义模型args.model
  103. assert os.path.exists(f'model/{args.model}.py'), f'! 没有自定义模型:{args.model} !'
  104. print(f'| 创建自定义模型:{args.model} | 型号:{args.model_type} |')
  105. # -------------------------------------------------------------------------------------------------------------------- #
  106. if __name__ == '__main__':
  107. # 摘要
  108. print(f'| args:{args} |') if args.local_rank == 0 else None
  109. # 数据
  110. data_dict = data_get(args)
  111. # 模型
  112. model_dict = model_get(args)
  113. # 损失
  114. loss = loss_get(args)
  115. # 训练
  116. train_get(args, data_dict, model_dict, loss)