train_embed.py 34 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647
  1. import argparse
  2. import logging
  3. import math
  4. import os
  5. import random
  6. import time
  7. from copy import deepcopy
  8. from pathlib import Path
  9. from threading import Thread
  10. import numpy as np
  11. import torch.distributed as dist
  12. import torch.nn as nn
  13. import torch.nn.functional as F
  14. import torch.optim as optim
  15. import torch.optim.lr_scheduler as lr_scheduler
  16. import torch.utils.data
  17. import yaml
  18. from torch.cuda import amp
  19. from torch.nn.parallel import DistributedDataParallel as DDP
  20. from torch.utils.tensorboard import SummaryWriter
  21. from tqdm import tqdm
  22. from watermark_codec import ModelEncoder
  23. import test # import test.py to get mAP after each epoch
  24. from models.experimental import attempt_load
  25. from models.yolo import Model
  26. from utils import secret_util
  27. from utils.autoanchor import check_anchors
  28. from utils.datasets import create_dataloader
  29. from utils.general import labels_to_class_weights, increment_path, labels_to_image_weights, init_seeds, \
  30. fitness, strip_optimizer, get_latest_run, check_dataset, check_file, check_git_status, check_img_size, \
  31. check_requirements, print_mutation, set_logging, one_cycle, colorstr
  32. from utils.google_utils import attempt_download
  33. from utils.loss import ComputeLoss
  34. from utils.plots import plot_images, plot_labels, plot_results, plot_evolution
  35. from utils.torch_utils import ModelEMA, select_device, intersect_dicts, torch_distributed_zero_first, is_parallel
  36. from utils.wandb_logging.wandb_utils import WandbLogger, check_wandb_resume
  37. logger = logging.getLogger(__name__)
  38. def train(hyp, opt, device, tb_writer=None):
  39. logger.info(colorstr('hyperparameters: ') + ', '.join(f'{k}={v}' for k, v in hyp.items()))
  40. save_dir, epochs, batch_size, total_batch_size, weights, rank = \
  41. Path(opt.save_dir), opt.epochs, opt.batch_size, opt.total_batch_size, opt.weights, opt.global_rank
  42. # Directories
  43. wdir = save_dir / 'weights'
  44. wdir.mkdir(parents=True, exist_ok=True) # make dir
  45. last = wdir / 'last.pt'
  46. best = wdir / 'best.pt'
  47. results_file = save_dir / 'results.txt'
  48. # Save run settings
  49. with open(save_dir / 'hyp.yaml', 'w') as f:
  50. yaml.dump(hyp, f, sort_keys=False)
  51. with open(save_dir / 'opt.yaml', 'w') as f:
  52. yaml.dump(vars(opt), f, sort_keys=False)
  53. # Configure
  54. plots = not opt.evolve # create plots
  55. cuda = device.type != 'cpu'
  56. init_seeds(2 + rank)
  57. with open(opt.data) as f:
  58. data_dict = yaml.load(f, Loader=yaml.SafeLoader) # data dict
  59. is_coco = opt.data.endswith('coco.yaml')
  60. # Logging- Doing this before checking the dataset. Might update data_dict
  61. loggers = {'wandb': None} # loggers dict
  62. if rank in [-1, 0]:
  63. opt.hyp = hyp # add hyperparameters
  64. run_id = torch.load(weights).get('wandb_id') if weights.endswith('.pt') and os.path.isfile(weights) else None
  65. wandb_logger = WandbLogger(opt, Path(opt.save_dir).stem, run_id, data_dict)
  66. loggers['wandb'] = wandb_logger.wandb
  67. data_dict = wandb_logger.data_dict
  68. if wandb_logger.wandb:
  69. weights, epochs, hyp = opt.weights, opt.epochs, opt.hyp # WandbLogger might update weights, epochs if resuming
  70. nc = 1 if opt.single_cls else int(data_dict['nc']) # number of classes
  71. names = ['item'] if opt.single_cls and len(data_dict['names']) != 1 else data_dict['names'] # class names
  72. assert len(names) == nc, '%g names found for nc=%g dataset in %s' % (len(names), nc, opt.data) # check
  73. # Model
  74. pretrained = weights.endswith('.pt')
  75. if pretrained:
  76. with torch_distributed_zero_first(rank):
  77. attempt_download(weights) # download if not found locally
  78. ckpt = torch.load(weights, map_location=device) # load checkpoint
  79. model = Model(opt.cfg or ckpt['model'].yaml, ch=3, nc=nc, anchors=hyp.get('anchors')).to(device) # create
  80. exclude = ['anchor'] if (opt.cfg or hyp.get('anchors')) and not opt.resume else [] # exclude keys
  81. state_dict = ckpt['model'].float().state_dict() # to FP32
  82. state_dict = intersect_dicts(state_dict, model.state_dict(), exclude=exclude) # intersect
  83. model.load_state_dict(state_dict, strict=False) # load
  84. logger.info('Transferred %g/%g items from %s' % (len(state_dict), len(model.state_dict()), weights)) # report
  85. else:
  86. model = Model(opt.cfg, ch=3, nc=nc, anchors=hyp.get('anchors')).to(device) # create
  87. with torch_distributed_zero_first(rank):
  88. check_dataset(data_dict) # check
  89. train_path = data_dict['train']
  90. test_path = data_dict['val']
  91. # 选择加密层并初始化白盒水印编码器
  92. conv_list = []
  93. for module in model.modules():
  94. if isinstance(module, nn.Conv2d):
  95. conv_list.append(module)
  96. conv_list = conv_list[25:27]
  97. encoder = ModelEncoder(layers=conv_list, secret=opt.secret, key_path=os.path.join(opt.key_path, 'key.pt'), device=device)
  98. # Freeze
  99. freeze = [] # parameter names to freeze (full or partial)
  100. for k, v in model.named_parameters():
  101. v.requires_grad = True # train all layers
  102. if any(x in k for x in freeze):
  103. print('freezing %s' % k)
  104. v.requires_grad = False
  105. # Optimizer
  106. nbs = 64 # nominal batch size
  107. accumulate = max(round(nbs / total_batch_size), 1) # accumulate loss before optimizing
  108. hyp['weight_decay'] *= total_batch_size * accumulate / nbs # scale weight_decay
  109. logger.info(f"Scaled weight_decay = {hyp['weight_decay']}")
  110. pg0, pg1, pg2 = [], [], [] # optimizer parameter groups
  111. for k, v in model.named_modules():
  112. if hasattr(v, 'bias') and isinstance(v.bias, nn.Parameter):
  113. pg2.append(v.bias) # biases
  114. if isinstance(v, nn.BatchNorm2d):
  115. pg0.append(v.weight) # no decay
  116. elif hasattr(v, 'weight') and isinstance(v.weight, nn.Parameter):
  117. pg1.append(v.weight) # apply decay
  118. if opt.adam:
  119. optimizer = optim.Adam(pg0, lr=hyp['lr0'], betas=(hyp['momentum'], 0.999)) # adjust beta1 to momentum
  120. else:
  121. optimizer = optim.SGD(pg0, lr=hyp['lr0'], momentum=hyp['momentum'], nesterov=True)
  122. optimizer.add_param_group({'params': pg1, 'weight_decay': hyp['weight_decay']}) # add pg1 with weight_decay
  123. optimizer.add_param_group({'params': pg2}) # add pg2 (biases)
  124. logger.info('Optimizer groups: %g .bias, %g conv.weight, %g other' % (len(pg2), len(pg1), len(pg0)))
  125. del pg0, pg1, pg2
  126. # Scheduler https://arxiv.org/pdf/1812.01187.pdf
  127. # https://pytorch.org/docs/stable/_modules/torch/optim/lr_scheduler.html#OneCycleLR
  128. if opt.linear_lr:
  129. lf = lambda x: (1 - x / (epochs - 1)) * (1.0 - hyp['lrf']) + hyp['lrf'] # linear
  130. else:
  131. lf = one_cycle(1, hyp['lrf'], epochs) # cosine 1->hyp['lrf']
  132. scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lf)
  133. # plot_lr_scheduler(optimizer, scheduler, epochs)
  134. # EMA
  135. ema = ModelEMA(model) if rank in [-1, 0] else None
  136. # Resume
  137. start_epoch, best_fitness = 0, 0.0
  138. if pretrained:
  139. # Optimizer
  140. if ckpt['optimizer'] is not None:
  141. optimizer.load_state_dict(ckpt['optimizer'])
  142. best_fitness = ckpt['best_fitness']
  143. # EMA
  144. if ema and ckpt.get('ema'):
  145. ema.ema.load_state_dict(ckpt['ema'].float().state_dict())
  146. ema.updates = ckpt['updates']
  147. # Results
  148. if ckpt.get('training_results') is not None:
  149. results_file.write_text(ckpt['training_results']) # write results.txt
  150. # Epochs
  151. start_epoch = ckpt['epoch'] + 1
  152. if opt.resume:
  153. assert start_epoch > 0, '%s training to %g epochs is finished, nothing to resume.' % (weights, epochs)
  154. if epochs < start_epoch:
  155. logger.info('%s has been trained for %g epochs. Fine-tuning for %g additional epochs.' %
  156. (weights, ckpt['epoch'], epochs))
  157. epochs += ckpt['epoch'] # finetune additional epochs
  158. del ckpt, state_dict
  159. # Image sizes
  160. gs = max(int(model.stride.max()), 32) # grid size (max stride)
  161. nl = model.model[-1].nl # number of detection layers (used for scaling hyp['obj'])
  162. imgsz, imgsz_test = [check_img_size(x, gs) for x in opt.img_size] # verify imgsz are gs-multiples
  163. # DP mode
  164. if cuda and rank == -1 and torch.cuda.device_count() > 1:
  165. model = torch.nn.DataParallel(model)
  166. # SyncBatchNorm
  167. if opt.sync_bn and cuda and rank != -1:
  168. model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model).to(device)
  169. logger.info('Using SyncBatchNorm()')
  170. # Trainloader
  171. dataloader, dataset = create_dataloader(train_path, imgsz, batch_size, gs, opt,
  172. hyp=hyp, augment=True, cache=opt.cache_images, rect=opt.rect, rank=rank,
  173. world_size=opt.world_size, workers=opt.workers,
  174. image_weights=opt.image_weights, quad=opt.quad, prefix=colorstr('train: '))
  175. mlc = np.concatenate(dataset.labels, 0)[:, 0].max() # max label class
  176. nb = len(dataloader) # number of batches
  177. assert mlc < nc, 'Label class %g exceeds nc=%g in %s. Possible class labels are 0-%g' % (mlc, nc, opt.data, nc - 1)
  178. # Process 0
  179. if rank in [-1, 0]:
  180. testloader = create_dataloader(test_path, imgsz_test, batch_size * 2, gs, opt, # testloader
  181. hyp=hyp, cache=opt.cache_images and not opt.notest, rect=True, rank=-1,
  182. world_size=opt.world_size, workers=opt.workers,
  183. pad=0.5, prefix=colorstr('val: '))[0]
  184. if not opt.resume:
  185. labels = np.concatenate(dataset.labels, 0)
  186. c = torch.tensor(labels[:, 0]) # classes
  187. # cf = torch.bincount(c.long(), minlength=nc) + 1. # frequency
  188. # model._initialize_biases(cf.to(device))
  189. if plots:
  190. plot_labels(labels, names, save_dir, loggers)
  191. if tb_writer:
  192. tb_writer.add_histogram('classes', c, 0)
  193. # Anchors
  194. if not opt.noautoanchor:
  195. check_anchors(dataset, model=model, thr=hyp['anchor_t'], imgsz=imgsz)
  196. model.half().float() # pre-reduce anchor precision
  197. # DDP mode
  198. if cuda and rank != -1:
  199. model = DDP(model, device_ids=[opt.local_rank], output_device=opt.local_rank,
  200. # nn.MultiheadAttention incompatibility with DDP https://github.com/pytorch/pytorch/issues/26698
  201. find_unused_parameters=any(isinstance(layer, nn.MultiheadAttention) for layer in model.modules()))
  202. # Model parameters
  203. hyp['box'] *= 3. / nl # scale to layers
  204. hyp['cls'] *= nc / 80. * 3. / nl # scale to classes and layers
  205. hyp['obj'] *= (imgsz / 640) ** 2 * 3. / nl # scale to image size and layers
  206. hyp['label_smoothing'] = opt.label_smoothing
  207. model.nc = nc # attach number of classes to model
  208. model.hyp = hyp # attach hyperparameters to model
  209. model.gr = 1.0 # iou loss ratio (obj_loss = 1.0 or iou)
  210. model.class_weights = labels_to_class_weights(dataset.labels, nc).to(device) * nc # attach class weights
  211. model.names = names
  212. # Start training
  213. t0 = time.time()
  214. nw = max(round(hyp['warmup_epochs'] * nb), 1000) # number of warmup iterations, max(3 epochs, 1k iterations)
  215. # nw = min(nw, (epochs - start_epoch) / 2 * nb) # limit warmup to < 1/2 of training
  216. maps = np.zeros(nc) # mAP per class
  217. results = (0, 0, 0, 0, 0, 0, 0) # P, R, mAP@.5, mAP@.5-.95, val_loss(box, obj, cls)
  218. scheduler.last_epoch = start_epoch - 1 # do not move
  219. scaler = amp.GradScaler(enabled=cuda)
  220. compute_loss = ComputeLoss(model) # init loss class
  221. logger.info(f'Image sizes {imgsz} train, {imgsz_test} test\n'
  222. f'Using {dataloader.num_workers} dataloader workers\n'
  223. f'Logging results to {save_dir}\n'
  224. f'Starting training for {epochs} epochs...')
  225. for epoch in range(start_epoch, epochs): # epoch ------------------------------------------------------------------
  226. model.train()
  227. # Update image weights (optional)
  228. if opt.image_weights:
  229. # Generate indices
  230. if rank in [-1, 0]:
  231. cw = model.class_weights.cpu().numpy() * (1 - maps) ** 2 / nc # class weights
  232. iw = labels_to_image_weights(dataset.labels, nc=nc, class_weights=cw) # image weights
  233. dataset.indices = random.choices(range(dataset.n), weights=iw, k=dataset.n) # rand weighted idx
  234. # Broadcast if DDP
  235. if rank != -1:
  236. indices = (torch.tensor(dataset.indices) if rank == 0 else torch.zeros(dataset.n)).int()
  237. dist.broadcast(indices, 0)
  238. if rank != 0:
  239. dataset.indices = indices.cpu().numpy()
  240. # Update mosaic border
  241. # b = int(random.uniform(0.25 * imgsz, 0.75 * imgsz + gs) // gs * gs)
  242. # dataset.mosaic_border = [b - imgsz, -b] # height, width borders
  243. mloss = torch.zeros(4, device=device) # mean losses
  244. if rank != -1:
  245. dataloader.sampler.set_epoch(epoch)
  246. pbar = enumerate(dataloader)
  247. logger.info(('\n' + '%12s' * 9) % ('Epoch', 'gpu_mem', 'box', 'obj', 'cls', 'total', 'labels', 'img_size', 'embed_loss'))
  248. if rank in [-1, 0]:
  249. pbar = tqdm(pbar, total=nb) # progress bar
  250. optimizer.zero_grad()
  251. for i, (imgs, targets, paths, _) in pbar: # batch -------------------------------------------------------------
  252. ni = i + nb * epoch # number integrated batches (since train start)
  253. imgs = imgs.to(device, non_blocking=True).float() / 255.0 # uint8 to float32, 0-255 to 0.0-1.0
  254. # Warmup
  255. if ni <= nw:
  256. xi = [0, nw] # x interp
  257. # model.gr = np.interp(ni, xi, [0.0, 1.0]) # iou loss ratio (obj_loss = 1.0 or iou)
  258. accumulate = max(1, np.interp(ni, xi, [1, nbs / total_batch_size]).round())
  259. for j, x in enumerate(optimizer.param_groups):
  260. # bias lr falls from 0.1 to lr0, all other lrs rise from 0.0 to lr0
  261. x['lr'] = np.interp(ni, xi, [hyp['warmup_bias_lr'] if j == 2 else 0.0, x['initial_lr'] * lf(epoch)])
  262. if 'momentum' in x:
  263. x['momentum'] = np.interp(ni, xi, [hyp['warmup_momentum'], hyp['momentum']])
  264. # Multi-scale
  265. if opt.multi_scale:
  266. sz = random.randrange(imgsz * 0.5, imgsz * 1.5 + gs) // gs * gs # size
  267. sf = sz / max(imgs.shape[2:]) # scale factor
  268. if sf != 1:
  269. ns = [math.ceil(x * sf / gs) * gs for x in imgs.shape[2:]] # new shape (stretched to gs-multiple)
  270. imgs = F.interpolate(imgs, size=ns, mode='bilinear', align_corners=False)
  271. # Forward
  272. with amp.autocast(enabled=cuda):
  273. pred = model(imgs) # forward
  274. loss, loss_items = compute_loss(pred, targets.to(device)) # loss scaled by batch_size
  275. if rank != -1:
  276. loss *= opt.world_size # gradient averaged between devices in DDP mode
  277. if opt.quad:
  278. loss *= 4.
  279. # watermark
  280. embed_loss = encoder.get_embeder_loss() # 获取水印嵌入损失
  281. loss = embed_loss + loss # 修改原始损失
  282. # Backward
  283. scaler.scale(loss).backward()
  284. # Optimize
  285. if ni % accumulate == 0:
  286. scaler.step(optimizer) # optimizer.step
  287. scaler.update()
  288. optimizer.zero_grad()
  289. if ema:
  290. ema.update(model)
  291. # Print
  292. if rank in [-1, 0]:
  293. mloss = (mloss * i + loss_items) / (i + 1) # update mean losses
  294. mem = '%.3gG' % (torch.cuda.memory_reserved() / 1E9 if torch.cuda.is_available() else 0) # (GB)
  295. s = ('%12s' * 2 + '%12.4g' * 7) % (
  296. '%g/%g' % (epoch, epochs - 1), mem, *mloss, targets.shape[0], imgs.shape[-1], embed_loss)
  297. pbar.set_description(s)
  298. # Plot
  299. if plots and ni < 3:
  300. f = save_dir / f'train_batch{ni}.jpg' # filename
  301. Thread(target=plot_images, args=(imgs, targets, paths, f), daemon=True).start()
  302. # if tb_writer:
  303. # tb_writer.add_image(f, result, dataformats='HWC', global_step=epoch)
  304. # tb_writer.add_graph(torch.jit.trace(model, imgs, strict=False), []) # add model graph
  305. elif plots and ni == 10 and wandb_logger.wandb:
  306. wandb_logger.log({"Mosaics": [wandb_logger.wandb.Image(str(x), caption=x.name) for x in
  307. save_dir.glob('train*.jpg') if x.exists()]})
  308. # end batch ------------------------------------------------------------------------------------------------
  309. # end epoch ----------------------------------------------------------------------------------------------------
  310. # Scheduler
  311. lr = [x['lr'] for x in optimizer.param_groups] # for tensorboard
  312. scheduler.step()
  313. # DDP process 0 or single-GPU
  314. if rank in [-1, 0]:
  315. # mAP
  316. ema.update_attr(model, include=['yaml', 'nc', 'hyp', 'gr', 'names', 'stride', 'class_weights'])
  317. final_epoch = epoch + 1 == epochs
  318. if not opt.notest or final_epoch: # Calculate mAP
  319. wandb_logger.current_epoch = epoch + 1
  320. results, maps, times = test.test(data_dict,
  321. batch_size=batch_size * 2,
  322. imgsz=imgsz_test,
  323. model=ema.ema,
  324. single_cls=opt.single_cls,
  325. dataloader=testloader,
  326. save_dir=save_dir,
  327. verbose=nc < 50 and final_epoch,
  328. plots=plots and final_epoch,
  329. wandb_logger=wandb_logger,
  330. compute_loss=compute_loss,
  331. is_coco=is_coco)
  332. # Write
  333. with open(results_file, 'a') as f:
  334. f.write(s + '%10.4g' * 7 % results + '\n') # append metrics, val_loss
  335. if len(opt.name) and opt.bucket:
  336. os.system('gsutil cp %s gs://%s/results/results%s.txt' % (results_file, opt.bucket, opt.name))
  337. # Log
  338. tags = ['train/box_loss', 'train/obj_loss', 'train/cls_loss', # train loss
  339. 'metrics/precision', 'metrics/recall', 'metrics/mAP_0.5', 'metrics/mAP_0.5:0.95',
  340. 'val/box_loss', 'val/obj_loss', 'val/cls_loss', # val loss
  341. 'x/lr0', 'x/lr1', 'x/lr2'] # params
  342. for x, tag in zip(list(mloss[:-1]) + list(results) + lr, tags):
  343. if tb_writer:
  344. tb_writer.add_scalar(tag, x, epoch) # tensorboard
  345. if wandb_logger.wandb:
  346. wandb_logger.log({tag: x}) # W&B
  347. # Update best mAP
  348. fi = fitness(np.array(results).reshape(1, -1)) # weighted combination of [P, R, mAP@.5, mAP@.5-.95]
  349. if fi > best_fitness:
  350. best_fitness = fi
  351. wandb_logger.end_epoch(best_result=best_fitness == fi)
  352. # Save model
  353. if (not opt.nosave) or (final_epoch and not opt.evolve): # if save
  354. ckpt = {'epoch': epoch,
  355. 'best_fitness': best_fitness,
  356. 'training_results': results_file.read_text(),
  357. 'model': deepcopy(model.module if is_parallel(model) else model).half(),
  358. 'ema': deepcopy(ema.ema).half(),
  359. 'updates': ema.updates,
  360. 'optimizer': optimizer.state_dict(),
  361. 'wandb_id': wandb_logger.wandb_run.id if wandb_logger.wandb else None,
  362. 'layers': conv_list}
  363. # Save last, best and delete
  364. torch.save(ckpt, last)
  365. if best_fitness == fi:
  366. torch.save(ckpt, best)
  367. if wandb_logger.wandb:
  368. if ((epoch + 1) % opt.save_period == 0 and not final_epoch) and opt.save_period != -1:
  369. wandb_logger.log_model(
  370. last.parent, opt, epoch, fi, best_model=best_fitness == fi)
  371. del ckpt
  372. # end epoch ----------------------------------------------------------------------------------------------------
  373. # end training
  374. if rank in [-1, 0]:
  375. # Plots
  376. if plots:
  377. plot_results(save_dir=save_dir) # save as results.png
  378. if wandb_logger.wandb:
  379. files = ['results.png', 'confusion_matrix.png', *[f'{x}_curve.png' for x in ('F1', 'PR', 'P', 'R')]]
  380. wandb_logger.log({"Results": [wandb_logger.wandb.Image(str(save_dir / f), caption=f) for f in files
  381. if (save_dir / f).exists()]})
  382. # Test best.pt
  383. logger.info('%g epochs completed in %.3f hours.\n' % (epoch - start_epoch + 1, (time.time() - t0) / 3600))
  384. if opt.data.endswith('coco.yaml') and nc == 80: # if COCO
  385. for m in (last, best) if best.exists() else (last): # speed, mAP tests
  386. results, _, _ = test.test(opt.data,
  387. batch_size=batch_size * 2,
  388. imgsz=imgsz_test,
  389. conf_thres=0.001,
  390. iou_thres=0.7,
  391. model=attempt_load(m, device).half(),
  392. single_cls=opt.single_cls,
  393. dataloader=testloader,
  394. save_dir=save_dir,
  395. save_json=True,
  396. plots=False,
  397. is_coco=is_coco)
  398. # Strip optimizers
  399. final = best if best.exists() else last # final model
  400. for f in last, best:
  401. if f.exists():
  402. strip_optimizer(f) # strip optimizers
  403. if opt.bucket:
  404. os.system(f'gsutil cp {final} gs://{opt.bucket}/weights') # upload
  405. if wandb_logger.wandb and not opt.evolve: # Log the stripped model
  406. wandb_logger.wandb.log_artifact(str(final), type='model',
  407. name='run_' + wandb_logger.wandb_run.id + '_model',
  408. aliases=['last', 'best', 'stripped'])
  409. wandb_logger.finish_run()
  410. else:
  411. dist.destroy_process_group()
  412. torch.cuda.empty_cache()
  413. return results
  414. if __name__ == '__main__':
  415. parser = argparse.ArgumentParser()
  416. parser.add_argument('--weights', type=str, default='yolov5s.pt', help='initial weights path')
  417. parser.add_argument('--cfg', type=str, default='', help='model.yaml path')
  418. parser.add_argument('--data', type=str, default='data/coco128.yaml', help='data.yaml path')
  419. parser.add_argument('--hyp', type=str, default='data/hyp.scratch.yaml', help='hyperparameters path')
  420. parser.add_argument('--epochs', type=int, default=200, help='number of epochs')
  421. parser.add_argument('--batch-size', type=int, default=8, help='total batch size for all GPUs')
  422. parser.add_argument('--img-size', nargs='+', type=int, default=[640, 640], help='[train, test] image sizes')
  423. parser.add_argument('--rect', action='store_true', help='rectangular training')
  424. parser.add_argument('--resume', nargs='?', const=True, default=False, help='resume most recent training')
  425. parser.add_argument('--nosave', action='store_true', help='only save final checkpoint')
  426. parser.add_argument('--notest', action='store_true', help='only test final epoch')
  427. parser.add_argument('--noautoanchor', action='store_true', help='disable autoanchor check')
  428. parser.add_argument('--evolve', action='store_true', help='evolve hyperparameters')
  429. parser.add_argument('--bucket', type=str, default='', help='gsutil bucket')
  430. parser.add_argument('--cache-images', action='store_true', help='cache images for faster training')
  431. parser.add_argument('--image-weights', action='store_true', help='use weighted image selection for training')
  432. parser.add_argument('--device', default='', help='cuda device, i.e. 0 or 0,1,2,3 or cpu')
  433. parser.add_argument('--multi-scale', action='store_true', help='vary img-size +/- 50%%')
  434. parser.add_argument('--single-cls', action='store_true', help='train multi-class data as single-class')
  435. parser.add_argument('--adam', action='store_true', help='use torch.optim.Adam() optimizer')
  436. parser.add_argument('--sync-bn', action='store_true', help='use SyncBatchNorm, only available in DDP mode')
  437. parser.add_argument('--local_rank', type=int, default=-1, help='DDP parameter, do not modify')
  438. parser.add_argument('--workers', type=int, default=4, help='maximum number of dataloader workers')
  439. parser.add_argument('--project', default='runs/train_whitebox_wm', help='save to project/name')
  440. parser.add_argument('--entity', default=None, help='W&B entity')
  441. parser.add_argument('--name', default='exp', help='save to project/name')
  442. parser.add_argument('--exist-ok', action='store_true', help='existing project/name ok, do not increment')
  443. parser.add_argument('--quad', action='store_true', help='quad dataloader')
  444. parser.add_argument('--linear-lr', action='store_true', help='linear LR')
  445. parser.add_argument('--label-smoothing', type=float, default=0.0, help='Label smoothing epsilon')
  446. parser.add_argument('--upload_dataset', action='store_true', help='Upload dataset as W&B artifact table')
  447. parser.add_argument('--bbox_interval', type=int, default=-1, help='Set bounding-box image logging interval for W&B')
  448. parser.add_argument('--save_period', type=int, default=-1, help='Log model after every "save_period" epoch')
  449. parser.add_argument('--artifact_alias', type=str, default="latest", help='version of dataset artifact to be used')
  450. opt = parser.parse_args()
  451. # set secret
  452. opt.secret = secret_util.get_secret(512)
  453. # Set DDP variables
  454. opt.world_size = int(os.environ['WORLD_SIZE']) if 'WORLD_SIZE' in os.environ else 1
  455. opt.global_rank = int(os.environ['RANK']) if 'RANK' in os.environ else -1
  456. set_logging(opt.global_rank)
  457. if opt.global_rank in [-1, 0]:
  458. check_git_status()
  459. check_requirements()
  460. # Resume
  461. wandb_run = check_wandb_resume(opt)
  462. if opt.resume and not wandb_run: # resume an interrupted run
  463. ckpt = opt.resume if isinstance(opt.resume, str) else get_latest_run() # specified or most recent path
  464. assert os.path.isfile(ckpt), 'ERROR: --resume checkpoint does not exist'
  465. apriori = opt.global_rank, opt.local_rank
  466. with open(Path(ckpt).parent.parent / 'opt.yaml') as f:
  467. opt = argparse.Namespace(**yaml.load(f, Loader=yaml.SafeLoader)) # replace
  468. opt.cfg, opt.weights, opt.resume, opt.batch_size, opt.global_rank, opt.local_rank = '', ckpt, True, opt.total_batch_size, *apriori # reinstate
  469. logger.info('Resuming training from %s' % ckpt)
  470. else:
  471. # opt.hyp = opt.hyp or ('hyp.finetune.yaml' if opt.weights else 'hyp.scratch.yaml')
  472. opt.data, opt.cfg, opt.hyp = check_file(opt.data), check_file(opt.cfg), check_file(opt.hyp) # check files
  473. assert len(opt.cfg) or len(opt.weights), 'either --cfg or --weights must be specified'
  474. opt.img_size.extend([opt.img_size[-1]] * (2 - len(opt.img_size))) # extend to 2 sizes (train, test)
  475. opt.name = 'evolve' if opt.evolve else opt.name
  476. opt.save_dir = increment_path(Path(opt.project) / opt.name, exist_ok=opt.exist_ok | opt.evolve) # increment run
  477. # watermark save dictionary
  478. opt.key_path = increment_path(Path(opt.project) / opt.name, exist_ok=opt.exist_ok | opt.evolve)
  479. # DDP mode
  480. opt.total_batch_size = opt.batch_size
  481. device = select_device(opt.device, batch_size=opt.batch_size)
  482. if opt.local_rank != -1:
  483. assert torch.cuda.device_count() > opt.local_rank
  484. torch.cuda.set_device(opt.local_rank)
  485. device = torch.device('cuda', opt.local_rank)
  486. dist.init_process_group(backend='nccl', init_method='env://') # distributed backend
  487. assert opt.batch_size % opt.world_size == 0, '--batch-size must be multiple of CUDA device count'
  488. opt.batch_size = opt.total_batch_size // opt.world_size
  489. # Hyperparameters
  490. with open(opt.hyp) as f:
  491. hyp = yaml.load(f, Loader=yaml.SafeLoader) # load hyps
  492. # Train
  493. logger.info(opt)
  494. if not opt.evolve:
  495. tb_writer = None # init loggers
  496. if opt.global_rank in [-1, 0]:
  497. prefix = colorstr('tensorboard: ')
  498. logger.info(f"{prefix}Start with 'tensorboard --logdir {opt.project}', view at http://localhost:6006/")
  499. tb_writer = SummaryWriter(opt.save_dir) # Tensorboard
  500. train(hyp, opt, device, tb_writer)
  501. # Evolve hyperparameters (optional)
  502. else:
  503. # Hyperparameter evolution metadata (mutation scale 0-1, lower_limit, upper_limit)
  504. meta = {'lr0': (1, 1e-5, 1e-1), # initial learning rate (SGD=1E-2, Adam=1E-3)
  505. 'lrf': (1, 0.01, 1.0), # final OneCycleLR learning rate (lr0 * lrf)
  506. 'momentum': (0.3, 0.6, 0.98), # SGD momentum/Adam beta1
  507. 'weight_decay': (1, 0.0, 0.001), # optimizer weight decay
  508. 'warmup_epochs': (1, 0.0, 5.0), # warmup epochs (fractions ok)
  509. 'warmup_momentum': (1, 0.0, 0.95), # warmup initial momentum
  510. 'warmup_bias_lr': (1, 0.0, 0.2), # warmup initial bias lr
  511. 'box': (1, 0.02, 0.2), # box loss gain
  512. 'cls': (1, 0.2, 4.0), # cls loss gain
  513. 'cls_pw': (1, 0.5, 2.0), # cls BCELoss positive_weight
  514. 'obj': (1, 0.2, 4.0), # obj loss gain (scale with pixels)
  515. 'obj_pw': (1, 0.5, 2.0), # obj BCELoss positive_weight
  516. 'iou_t': (0, 0.1, 0.7), # IoU training threshold
  517. 'anchor_t': (1, 2.0, 8.0), # anchor-multiple threshold
  518. 'anchors': (2, 2.0, 10.0), # anchors per output grid (0 to ignore)
  519. 'fl_gamma': (0, 0.0, 2.0), # focal loss gamma (efficientDet default gamma=1.5)
  520. 'hsv_h': (1, 0.0, 0.1), # image HSV-Hue augmentation (fraction)
  521. 'hsv_s': (1, 0.0, 0.9), # image HSV-Saturation augmentation (fraction)
  522. 'hsv_v': (1, 0.0, 0.9), # image HSV-Value augmentation (fraction)
  523. 'degrees': (1, 0.0, 45.0), # image rotation (+/- deg)
  524. 'translate': (1, 0.0, 0.9), # image translation (+/- fraction)
  525. 'scale': (1, 0.0, 0.9), # image scale (+/- gain)
  526. 'shear': (1, 0.0, 10.0), # image shear (+/- deg)
  527. 'perspective': (0, 0.0, 0.001), # image perspective (+/- fraction), range 0-0.001
  528. 'flipud': (1, 0.0, 1.0), # image flip up-down (probability)
  529. 'fliplr': (0, 0.0, 1.0), # image flip left-right (probability)
  530. 'mosaic': (1, 0.0, 1.0), # image mixup (probability)
  531. 'mixup': (1, 0.0, 1.0)} # image mixup (probability)
  532. assert opt.local_rank == -1, 'DDP mode not implemented for --evolve'
  533. opt.notest, opt.nosave = True, True # only test/save final epoch
  534. # ei = [isinstance(x, (int, float)) for x in hyp.values()] # evolvable indices
  535. yaml_file = Path(opt.save_dir) / 'hyp_evolved.yaml' # save best result here
  536. if opt.bucket:
  537. os.system('gsutil cp gs://%s/evolve.txt .' % opt.bucket) # download evolve.txt if exists
  538. for _ in range(300): # generations to evolve
  539. if Path('evolve.txt').exists(): # if evolve.txt exists: select best hyps and mutate
  540. # Select parent(s)
  541. parent = 'single' # parent selection method: 'single' or 'weighted'
  542. x = np.loadtxt('evolve.txt', ndmin=2)
  543. n = min(5, len(x)) # number of previous results to consider
  544. x = x[np.argsort(-fitness(x))][:n] # top n mutations
  545. w = fitness(x) - fitness(x).min() # weights
  546. if parent == 'single' or len(x) == 1:
  547. # x = x[random.randint(0, n - 1)] # random selection
  548. x = x[random.choices(range(n), weights=w)[0]] # weighted selection
  549. elif parent == 'weighted':
  550. x = (x * w.reshape(n, 1)).sum(0) / w.sum() # weighted combination
  551. # Mutate
  552. mp, s = 0.8, 0.2 # mutation probability, sigma
  553. npr = np.random
  554. npr.seed(int(time.time()))
  555. g = np.array([x[0] for x in meta.values()]) # gains 0-1
  556. ng = len(meta)
  557. v = np.ones(ng)
  558. while all(v == 1): # mutate until a change occurs (prevent duplicates)
  559. v = (g * (npr.random(ng) < mp) * npr.randn(ng) * npr.random() * s + 1).clip(0.3, 3.0)
  560. for i, k in enumerate(hyp.keys()): # plt.hist(v.ravel(), 300)
  561. hyp[k] = float(x[i + 7] * v[i]) # mutate
  562. # Constrain to limits
  563. for k, v in meta.items():
  564. hyp[k] = max(hyp[k], v[1]) # lower limit
  565. hyp[k] = min(hyp[k], v[2]) # upper limit
  566. hyp[k] = round(hyp[k], 5) # significant digits
  567. # Train mutation
  568. results = train(hyp.copy(), opt, device)
  569. # Write mutation results
  570. print_mutation(hyp.copy(), results, yaml_file, opt.bucket)
  571. # Plot results
  572. plot_evolution(yaml_file)
  573. print(f'Hyperparameter evolution complete. Best results saved as: {yaml_file}\n'
  574. f'Command to train a new model with these hyperparameters: $ python train.py --hyp {yaml_file}')