prune.py 38 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800
  1. # YOLOv5 🚀 by Ultralytics, GPL-3.0 license
  2. """
  3. Validate a trained YOLOv5 model accuracy on a custom dataset
  4. Usage:
  5. $ python path/to/val.py --data coco128.yaml --weights yolov5s.pt --img 640
  6. """
  7. import argparse
  8. import json
  9. import os
  10. import sys
  11. from pathlib import Path
  12. from threading import Thread
  13. from models.common import Bottleneck
  14. import numpy as np
  15. import torch
  16. from tqdm import tqdm
  17. import yaml
  18. from utils.prune_utils import gather_bn_weights, obtain_bn_mask
  19. FILE = Path(__file__).resolve()
  20. ROOT = FILE.parents[0] # YOLOv5 root directory
  21. if str(ROOT) not in sys.path:
  22. sys.path.append(str(ROOT)) # add ROOT to PATH
  23. ROOT = Path(os.path.relpath(ROOT, Path.cwd())) # relative
  24. from models.pruned_common import C3Pruned, SPPFPruned, BottleneckPruned
  25. from models.common import DetectMultiBackend
  26. from utils.callbacks import Callbacks
  27. from utils.datasets import create_dataloader
  28. from utils.general import (LOGGER, box_iou, check_dataset, check_img_size, check_requirements, check_yaml,
  29. coco80_to_coco91_class, colorstr, increment_path, non_max_suppression, print_args,
  30. scale_coords, xywh2xyxy, xyxy2xywh)
  31. from utils.metrics import ConfusionMatrix, ap_per_class
  32. from utils.plots import output_to_target, plot_images, plot_val_study
  33. from utils.torch_utils import select_device, time_sync
  34. from models.yolo import *
  35. def save_one_txt(predn, save_conf, shape, file):
  36. # Save one txt result
  37. gn = torch.tensor(shape)[[1, 0, 1, 0]] # normalization gain whwh
  38. for *xyxy, conf, cls in predn.tolist():
  39. xywh = (xyxy2xywh(torch.tensor(xyxy).view(1, 4)) / gn).view(-1).tolist() # normalized xywh
  40. line = (cls, *xywh, conf) if save_conf else (cls, *xywh) # label format
  41. with open(file, 'a') as f:
  42. f.write(('%g ' * len(line)).rstrip() % line + '\n')
  43. def save_one_json(predn, jdict, path, class_map):
  44. # Save one JSON result {"image_id": 42, "category_id": 18, "bbox": [258.15, 41.29, 348.26, 243.78], "score": 0.236}
  45. image_id = int(path.stem) if path.stem.isnumeric() else path.stem
  46. box = xyxy2xywh(predn[:, :4]) # xywh
  47. box[:, :2] -= box[:, 2:] / 2 # xy center to top-left corner
  48. for p, b in zip(predn.tolist(), box.tolist()):
  49. jdict.append({'image_id': image_id,
  50. 'category_id': class_map[int(p[5])],
  51. 'bbox': [round(x, 3) for x in b],
  52. 'score': round(p[4], 5)})
  53. def process_batch(detections, labels, iouv):
  54. """
  55. Return correct predictions matrix. Both sets of boxes are in (x1, y1, x2, y2) format.
  56. Arguments:
  57. detections (Array[N, 6]), x1, y1, x2, y2, conf, class
  58. labels (Array[M, 5]), class, x1, y1, x2, y2
  59. Returns:
  60. correct (Array[N, 10]), for 10 IoU levels
  61. """
  62. correct = torch.zeros(detections.shape[0], iouv.shape[0], dtype=torch.bool, device=iouv.device)
  63. iou = box_iou(labels[:, 1:], detections[:, :4])
  64. x = torch.where((iou >= iouv[0]) & (labels[:, 0:1] == detections[:, 5])) # IoU above threshold and classes match
  65. if x[0].shape[0]:
  66. matches = torch.cat((torch.stack(x, 1), iou[x[0], x[1]][:, None]), 1).cpu().numpy() # [label, detection, iou]
  67. if x[0].shape[0] > 1:
  68. matches = matches[matches[:, 2].argsort()[::-1]]
  69. matches = matches[np.unique(matches[:, 1], return_index=True)[1]]
  70. # matches = matches[matches[:, 2].argsort()[::-1]]
  71. matches = matches[np.unique(matches[:, 0], return_index=True)[1]]
  72. matches = torch.Tensor(matches).to(iouv.device)
  73. correct[matches[:, 1].long()] = matches[:, 2:3] >= iouv
  74. return correct
  75. @torch.no_grad()
  76. def run(data,
  77. weights=None, # model.pt path(s)
  78. cfg = 'models/yolov5s.yaml',
  79. percent=0,
  80. batch_size=32, # batch size
  81. imgsz=640, # inference size (pixels)
  82. conf_thres=0.001, # confidence threshold
  83. iou_thres=0.6, # NMS IoU threshold
  84. task='val', # train, val, test, speed or study
  85. device='', # cuda device, i.e. 0 or 0,1,2,3 or cpu
  86. workers=8, # max dataloader workers (per RANK in DDP mode)
  87. single_cls=False, # treat as single-class dataset
  88. augment=False, # augmented inference
  89. verbose=False, # verbose output
  90. save_txt=False, # save results to *.txt
  91. save_hybrid=False, # save label+prediction hybrid results to *.txt
  92. save_conf=False, # save confidences in --save-txt labels
  93. save_json=False, # save a COCO-JSON results file
  94. project=ROOT / 'runs/val', # save to project/name
  95. name='exp', # save to project/name
  96. exist_ok=False, # existing project/name ok, do not increment
  97. half=True, # use FP16 half-precision inference
  98. dnn=False, # use OpenCV DNN for ONNX inference
  99. model=None,
  100. dataloader=None,
  101. save_dir=Path(''),
  102. plots=True,
  103. callbacks=Callbacks(),
  104. compute_loss=None,
  105. ):
  106. # Initialize/load model and set device
  107. training = model is not None
  108. if training: # called by train.py
  109. device, pt, jit, engine = next(model.parameters()).device, True, False, False # get model device, PyTorch model
  110. half &= device.type != 'cpu' # half precision only supported on CUDA
  111. model.half() if half else model.float()
  112. else: # called directly
  113. device = select_device(device, batch_size=batch_size)
  114. # Directories
  115. save_dir = increment_path(Path(project) / name, exist_ok=exist_ok) # increment run
  116. (save_dir / 'labels' if save_txt else save_dir).mkdir(parents=True, exist_ok=True) # make dir
  117. # Load model
  118. model = DetectMultiBackend(weights, device=device, dnn=dnn)
  119. stride, pt, jit, engine = model.stride, model.pt, model.jit, model.engine
  120. imgsz = check_img_size(imgsz, s=stride) # check image size
  121. half &= (pt or jit or engine) and device.type != 'cpu' # half precision only supported by PyTorch on CUDA
  122. if pt or jit:
  123. model.model.half() if half else model.model.float()
  124. elif engine:
  125. batch_size = model.batch_size
  126. else:
  127. half = False
  128. batch_size = 1 # export.py models default to batch-size 1
  129. device = torch.device('cpu')
  130. LOGGER.info(f'Forcing --batch-size 1 square inference shape(1,3,{imgsz},{imgsz}) for non-PyTorch backends')
  131. # Data
  132. data = check_dataset(data) # check
  133. # Configure
  134. model.eval()
  135. is_coco = isinstance(data.get('val'), str) and data['val'].endswith('coco/val2017.txt') # COCO dataset
  136. nc = 1 if single_cls else int(data['nc']) # number of classes
  137. iouv = torch.linspace(0.5, 0.95, 10).to(device) # iou vector for mAP@0.5:0.95
  138. niou = iouv.numel()
  139. # Dataloader
  140. if not training:
  141. model.warmup(imgsz=(1, 3, imgsz, imgsz), half=half) # warmup
  142. pad = 0.0 if task == 'speed' else 0.5
  143. task = task if task in ('train', 'val', 'test') else 'val' # path to train/val/test images
  144. dataloader = create_dataloader(data[task], imgsz, batch_size, stride, single_cls, pad=pad, rect=pt,
  145. workers=workers, prefix=colorstr(f'{task}: '))[0]
  146. seen = 0
  147. confusion_matrix = ConfusionMatrix(nc=nc)
  148. names = {k: v for k, v in enumerate(model.names if hasattr(model, 'names') else model.module.names)}
  149. class_map = coco80_to_coco91_class() if is_coco else list(range(1000))
  150. s = ('%20s' + '%11s' * 6) % ('Class', 'Images', 'Labels', 'P', 'R', 'mAP@.5', 'mAP@.5:.95')
  151. dt, p, r, f1, mp, mr, map50, map = [0.0, 0.0, 0.0], 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0
  152. loss = torch.zeros(3, device=device)
  153. jdict, stats, ap, ap_class = [], [], [], []
  154. pbar = tqdm(dataloader, desc=s, bar_format='{l_bar}{bar:10}{r_bar}{bar:-10b}') # progress bar
  155. for batch_i, (im, targets, paths, shapes) in enumerate(pbar):
  156. t1 = time_sync()
  157. if pt or jit or engine:
  158. im = im.to(device, non_blocking=True)
  159. targets = targets.to(device)
  160. im = im.half() if half else im.float() # uint8 to fp16/32
  161. im /= 255 # 0 - 255 to 0.0 - 1.0
  162. nb, _, height, width = im.shape # batch size, channels, height, width
  163. t2 = time_sync()
  164. dt[0] += t2 - t1
  165. # Inference
  166. out, train_out = model(im) if training else model(im, augment=augment, val=True) # inference, loss outputs
  167. dt[1] += time_sync() - t2
  168. # Loss
  169. if compute_loss:
  170. loss += compute_loss([x.float() for x in train_out], targets)[1] # box, obj, cls
  171. # NMS
  172. targets[:, 2:] *= torch.Tensor([width, height, width, height]).to(device) # to pixels
  173. lb = [targets[targets[:, 0] == i, 1:] for i in range(nb)] if save_hybrid else [] # for autolabelling
  174. t3 = time_sync()
  175. out = non_max_suppression(out, conf_thres, iou_thres, labels=lb, multi_label=True, agnostic=single_cls)
  176. dt[2] += time_sync() - t3
  177. # Metrics
  178. for si, pred in enumerate(out):
  179. labels = targets[targets[:, 0] == si, 1:]
  180. nl = len(labels)
  181. tcls = labels[:, 0].tolist() if nl else [] # target class
  182. path, shape = Path(paths[si]), shapes[si][0]
  183. seen += 1
  184. if len(pred) == 0:
  185. if nl:
  186. stats.append((torch.zeros(0, niou, dtype=torch.bool), torch.Tensor(), torch.Tensor(), tcls))
  187. continue
  188. # Predictions
  189. if single_cls:
  190. pred[:, 5] = 0
  191. predn = pred.clone()
  192. scale_coords(im[si].shape[1:], predn[:, :4], shape, shapes[si][1]) # native-space pred
  193. # Evaluate
  194. if nl:
  195. tbox = xywh2xyxy(labels[:, 1:5]) # target boxes
  196. scale_coords(im[si].shape[1:], tbox, shape, shapes[si][1]) # native-space labels
  197. labelsn = torch.cat((labels[:, 0:1], tbox), 1) # native-space labels
  198. correct = process_batch(predn, labelsn, iouv)
  199. if plots:
  200. confusion_matrix.process_batch(predn, labelsn)
  201. else:
  202. correct = torch.zeros(pred.shape[0], niou, dtype=torch.bool)
  203. stats.append((correct.cpu(), pred[:, 4].cpu(), pred[:, 5].cpu(), tcls)) # (correct, conf, pcls, tcls)
  204. # Save/log
  205. if save_txt:
  206. save_one_txt(predn, save_conf, shape, file=save_dir / 'labels' / (path.stem + '.txt'))
  207. if save_json:
  208. save_one_json(predn, jdict, path, class_map) # append to COCO-JSON dictionary
  209. callbacks.run('on_val_image_end', pred, predn, path, names, im[si])
  210. # Plot images
  211. if plots and batch_i < 3:
  212. f = save_dir / f'val_batch{batch_i}_labels.jpg' # labels
  213. Thread(target=plot_images, args=(im, targets, paths, f, names), daemon=True).start()
  214. f = save_dir / f'val_batch{batch_i}_pred.jpg' # predictions
  215. Thread(target=plot_images, args=(im, output_to_target(out), paths, f, names), daemon=True).start()
  216. # Compute metrics
  217. stats = [np.concatenate(x, 0) for x in zip(*stats)] # to numpy
  218. if len(stats) and stats[0].any():
  219. tp, fp, p, r, f1, ap, ap_class = ap_per_class(*stats, plot=plots, save_dir=save_dir, names=names)
  220. ap50, ap = ap[:, 0], ap.mean(1) # AP@0.5, AP@0.5:0.95
  221. mp, mr, map50, map = p.mean(), r.mean(), ap50.mean(), ap.mean()
  222. nt = np.bincount(stats[3].astype(np.int64), minlength=nc) # number of targets per class
  223. else:
  224. nt = torch.zeros(1)
  225. # Print results
  226. pf = '%20s' + '%11i' * 2 + '%11.3g' * 4 # print format
  227. LOGGER.info(pf % ('all', seen, nt.sum(), mp, mr, map50, map))
  228. # Print results per class
  229. if (verbose or (nc < 50 and not training)) and nc > 1 and len(stats):
  230. for i, c in enumerate(ap_class):
  231. LOGGER.info(pf % (names[c], seen, nt[c], p[i], r[i], ap50[i], ap[i]))
  232. # Print speeds
  233. t = tuple(x / seen * 1E3 for x in dt) # speeds per image
  234. if not training:
  235. shape = (batch_size, 3, imgsz, imgsz)
  236. LOGGER.info(f'Speed: %.1fms pre-process, %.1fms inference, %.1fms NMS per image at shape {shape}' % t)
  237. # Plots
  238. if plots:
  239. confusion_matrix.plot(save_dir=save_dir, names=list(names.values()))
  240. callbacks.run('on_val_end')
  241. # Save JSON
  242. if save_json and len(jdict):
  243. w = Path(weights[0] if isinstance(weights, list) else weights).stem if weights is not None else '' # weights
  244. anno_json = str(Path(data.get('path', '../coco')) / 'annotations/instances_val2017.json') # annotations json
  245. pred_json = str(save_dir / f"{w}_predictions.json") # predictions json
  246. LOGGER.info(f'\nEvaluating pycocotools mAP... saving {pred_json}...')
  247. with open(pred_json, 'w') as f:
  248. json.dump(jdict, f)
  249. try: # https://github.com/cocodataset/cocoapi/blob/master/PythonAPI/pycocoEvalDemo.ipynb
  250. check_requirements(['pycocotools'])
  251. from pycocotools.coco import COCO
  252. from pycocotools.cocoeval import COCOeval
  253. anno = COCO(anno_json) # init annotations api
  254. pred = anno.loadRes(pred_json) # init predictions api
  255. eval = COCOeval(anno, pred, 'bbox')
  256. if is_coco:
  257. eval.params.imgIds = [int(Path(x).stem) for x in dataloader.dataset.img_files] # image IDs to evaluate
  258. eval.evaluate()
  259. eval.accumulate()
  260. eval.summarize()
  261. map, map50 = eval.stats[:2] # update results (mAP@0.5:0.95, mAP@0.5)
  262. except Exception as e:
  263. LOGGER.info(f'pycocotools unable to run: {e}')
  264. # Return results
  265. model.float() # for training
  266. if not training:
  267. s = f"\n{len(list(save_dir.glob('labels/*.txt')))} labels saved to {save_dir / 'labels'}" if save_txt else ''
  268. LOGGER.info(f"Results saved to {colorstr('bold', save_dir)}{s}")
  269. maps = np.zeros(nc) + map
  270. for i, c in enumerate(ap_class):
  271. maps[c] = ap[i]
  272. return (mp, mr, map50, map, *(loss.cpu() / len(dataloader)).tolist()), maps, t
  273. @torch.no_grad()
  274. def run_prune(data,
  275. weights=None, # model.pt path(s)
  276. cfg = 'models/yolov5s.yaml',
  277. percent=0,
  278. batch_size=32, # batch size
  279. imgsz=640, # inference size (pixels)
  280. conf_thres=0.001, # confidence threshold
  281. iou_thres=0.6, # NMS IoU threshold
  282. task='val', # train, val, test, speed or study
  283. device='', # cuda device, i.e. 0 or 0,1,2,3 or cpu
  284. workers=8, # max dataloader workers (per RANK in DDP mode)
  285. single_cls=False, # treat as single-class dataset
  286. augment=False, # augmented inference
  287. verbose=False, # verbose output
  288. save_txt=False, # save results to *.txt
  289. save_hybrid=False, # save label+prediction hybrid results to *.txt
  290. save_conf=False, # save confidences in --save-txt labels
  291. save_json=False, # save a COCO-JSON results file
  292. project=ROOT / 'runs/val', # save to project/name
  293. name='exp', # save to project/name
  294. exist_ok=False, # existing project/name ok, do not increment
  295. half=True, # use FP16 half-precision inference
  296. dnn=False, # use OpenCV DNN for ONNX inference
  297. model=None,
  298. dataloader=None,
  299. save_dir=Path(''),
  300. plots=True,
  301. callbacks=Callbacks(),
  302. compute_loss=None,
  303. ):
  304. # Initialize/load model and set device
  305. training = model is not None
  306. if training: # called by train.py
  307. device, pt, jit, engine = next(model.parameters()).device, True, False, False # get model device, PyTorch model
  308. half &= device.type != 'cpu' # half precision only supported on CUDA
  309. model.half() if half else model.float()
  310. else: # called directly
  311. device = select_device(device, batch_size=batch_size)
  312. # Directories
  313. save_dir = increment_path(Path(project) / name, exist_ok=exist_ok) # increment run
  314. (save_dir / 'labels' if save_txt else save_dir).mkdir(parents=True, exist_ok=True) # make dir
  315. # Load model
  316. model = DetectMultiBackend(weights, device=device, dnn=dnn)
  317. stride, pt, jit, engine = model.stride, model.pt, model.jit, model.engine
  318. imgsz = check_img_size(imgsz, s=stride) # check image size
  319. # half &= (pt or jit or engine) and device.type != 'cpu' # half precision only supported by PyTorch on CUDA
  320. # if pt or jit:
  321. # model.model.half() if half else model.model.float()
  322. # elif engine:
  323. # batch_size = model.batch_size
  324. # else:
  325. # half = False
  326. # batch_size = 1 # export.py models default to batch-size 1
  327. # device = torch.device('cpu')
  328. # LOGGER.info(f'Forcing --batch-size 1 square inference shape(1,3,{imgsz},{imgsz}) for non-PyTorch backends')
  329. # Data
  330. data = check_dataset(data) # check
  331. # Configure
  332. model = model.model
  333. # print(model)
  334. model.eval()
  335. # =========================================== prune model ====================================#
  336. # print("model.module_list:",model.named_children())
  337. model_list = {}
  338. ignore_bn_list = []
  339. for i, layer in model.named_modules():
  340. # if isinstance(layer, nn.Conv2d):
  341. # print("@Conv :",i,layer)
  342. if isinstance(layer, Bottleneck):
  343. if layer.add:
  344. ignore_bn_list.append(i.rsplit(".",2)[0]+".cv1.bn")
  345. ignore_bn_list.append(i + '.cv1.bn')
  346. ignore_bn_list.append(i + '.cv2.bn')
  347. if isinstance(layer, torch.nn.BatchNorm2d):
  348. if i not in ignore_bn_list:
  349. model_list[i] = layer
  350. # print(i, layer)
  351. # bnw = layer.state_dict()['weight']
  352. model_list = {k:v for k,v in model_list.items() if k not in ignore_bn_list}
  353. # print("prune module :",model_list.keys())
  354. prune_conv_list = [layer.replace("bn", "conv") for layer in model_list.keys()]
  355. # print(prune_conv_list)
  356. bn_weights = gather_bn_weights(model_list)
  357. sorted_bn = torch.sort(bn_weights)[0]
  358. # print("model_list:",model_list)
  359. # print("bn_weights:",bn_weights)
  360. # 避免剪掉所有channel的最高阈值(每个BN层的gamma的最大值的最小值即为阈值上限)
  361. highest_thre = []
  362. for bnlayer in model_list.values():
  363. highest_thre.append(bnlayer.weight.data.abs().max().item())
  364. # print("highest_thre:",highest_thre)
  365. highest_thre = min(highest_thre)
  366. # 找到highest_thre对应的下标对应的百分比
  367. percent_limit = (sorted_bn == highest_thre).nonzero()[0, 0].item() / len(bn_weights)
  368. print(f'Suggested Gamma threshold should be less than {highest_thre:.4f}.')
  369. print(f'The corresponding prune ratio is {percent_limit:.3f}.')
  370. # assert opt.percent < percent_limit, f"Prune ratio should less than {percent_limit}, otherwise it may cause error!!!"
  371. # model_copy = deepcopy(model)
  372. thre_index = int(len(sorted_bn) * opt.percent)
  373. thre = sorted_bn[thre_index]
  374. print(f'Gamma value that less than {thre:.4f} are set to zero!')
  375. print("=" * 94)
  376. print(f"|\t{'layer name':<25}{'|':<10}{'origin channels':<20}{'|':<10}{'remaining channels':<20}|")
  377. remain_num = 0
  378. modelstate = model.state_dict()
  379. # ============================== save pruned model config yaml =================================#
  380. pruned_yaml = {}
  381. nc = model.model[-1].nc
  382. with open(cfg, encoding='ascii', errors='ignore') as f:
  383. model_yamls = yaml.safe_load(f) # model dict
  384. # # Define model
  385. pruned_yaml["nc"] = model.model[-1].nc
  386. pruned_yaml["depth_multiple"] = model_yamls["depth_multiple"]
  387. pruned_yaml["width_multiple"] = model_yamls["width_multiple"]
  388. pruned_yaml["anchors"] = model_yamls["anchors"]
  389. anchors = model_yamls["anchors"]
  390. pruned_yaml["backbone"] = [
  391. [-1, 1, Conv, [64, 6, 2, 2]], # 0-P1/2
  392. [-1, 1, Conv, [128, 3, 2]], # 1-P2/4
  393. [-1, 3, C3Pruned, [128]],
  394. [-1, 1, Conv, [256, 3, 2]], # 3-P3/8
  395. [-1, 6, C3Pruned, [256]],
  396. [-1, 1, Conv, [512, 3, 2]], # 5-P4/16
  397. [-1, 9, C3Pruned, [512]],
  398. [-1, 1, Conv, [1024, 3, 2]], # 7-P5/32
  399. [-1, 3, C3Pruned, [1024]],
  400. [-1, 1, SPPFPruned, [1024, 5]], # 9
  401. ]
  402. pruned_yaml["head"] = [
  403. [-1, 1, Conv, [512, 1, 1]],
  404. [-1, 1, nn.Upsample, [None, 2, 'nearest']],
  405. [[-1, 6], 1, Concat, [1]], # cat backbone P4
  406. [-1, 3, C3Pruned, [512, False]], # 13
  407. [-1, 1, Conv, [256, 1, 1]],
  408. [-1, 1, nn.Upsample, [None, 2, 'nearest']],
  409. [[-1, 4], 1, Concat, [1]], # cat backbone P3
  410. [-1, 3, C3Pruned, [256, False]], # 17 (P3/8-small)
  411. [-1, 1, Conv, [256, 3, 2]],
  412. [[-1, 14], 1, Concat, [1]], # cat head P4
  413. [-1, 3, C3Pruned, [512, False]], # 20 (P4/16-medium)
  414. [-1, 1, Conv, [512, 3, 2]],
  415. [[-1, 10], 1, Concat, [1]], # cat head P5
  416. [-1, 3, C3Pruned, [1024, False]], # 23 (P5/32-large)
  417. [[17, 20, 23], 1, Detect, [nc, anchors]], # Detect(P3, P4, P5)
  418. ]
  419. # ============================================================================== #
  420. maskbndict = {}
  421. for bnname, bnlayer in model.named_modules():
  422. if isinstance(bnlayer, nn.BatchNorm2d):
  423. bn_module = bnlayer
  424. mask = obtain_bn_mask(bn_module, thre) # 获得剪枝mask
  425. if bnname in ignore_bn_list:
  426. mask = torch.ones(bnlayer.weight.data.size()).cuda()
  427. maskbndict[bnname] = mask
  428. # print("mask:",mask)
  429. remain_num += int(mask.sum())
  430. bn_module.weight.data.mul_(mask)
  431. bn_module.bias.data.mul_(mask)
  432. # print("bn_module:", bn_module.bias)
  433. print(f"|\t{bnname:<25}{'|':<10}{bn_module.weight.data.size()[0]:<20}{'|':<10}{int(mask.sum()):<20}|")
  434. assert int(mask.sum()) > 0, "Number of remaining channels must greater than 0! please set lower prune percent."
  435. print("=" * 94)
  436. # print(maskbndict.keys())
  437. pruned_model = ModelPruned(maskbndict=maskbndict, cfg=pruned_yaml, ch=3).cuda()
  438. # Compatibility updates
  439. for m in pruned_model.modules():
  440. if type(m) in [nn.Hardswish, nn.LeakyReLU, nn.ReLU, nn.ReLU6, nn.SiLU, Detect, Model]:
  441. m.inplace = True # pytorch 1.7.0 compatibility
  442. elif type(m) is Conv:
  443. m._non_persistent_buffers_set = set() # pytorch 1.6.0 compatibility
  444. from_to_map = pruned_model.from_to_map
  445. pruned_model_state = pruned_model.state_dict()
  446. assert pruned_model_state.keys() == modelstate.keys()
  447. # ======================================================================================= #
  448. changed_state = []
  449. for ((layername, layer),(pruned_layername, pruned_layer)) in zip(model.named_modules(), pruned_model.named_modules()):
  450. assert layername == pruned_layername
  451. if isinstance(layer, nn.Conv2d) and not layername.startswith("model.24"):
  452. convname = layername[:-4]+"bn"
  453. if convname in from_to_map.keys():
  454. former = from_to_map[convname]
  455. if isinstance(former, str):
  456. out_idx = np.squeeze(np.argwhere(np.asarray(maskbndict[layername[:-4] + "bn"].cpu().numpy())))
  457. in_idx = np.squeeze(np.argwhere(np.asarray(maskbndict[former].cpu().numpy())))
  458. w = layer.weight.data[:, in_idx, :, :].clone()
  459. if len(w.shape) ==3: # remain only 1 channel.
  460. w = w.unsqueeze(1)
  461. w = w[out_idx, :, :, :].clone()
  462. pruned_layer.weight.data = w.clone()
  463. changed_state.append(layername + ".weight")
  464. if isinstance(former, list):
  465. orignin = [modelstate[i+".weight"].shape[0] for i in former]
  466. formerin = []
  467. for it in range(len(former)):
  468. name = former[it]
  469. tmp = [i for i in range(maskbndict[name].shape[0]) if maskbndict[name][i] == 1]
  470. if it > 0:
  471. tmp = [k + sum(orignin[:it]) for k in tmp]
  472. formerin.extend(tmp)
  473. out_idx = np.squeeze(np.argwhere(np.asarray(maskbndict[layername[:-4] + "bn"].cpu().numpy())))
  474. w = layer.weight.data[out_idx, :, :, :].clone()
  475. pruned_layer.weight.data = w[:,formerin, :, :].clone()
  476. changed_state.append(layername + ".weight")
  477. else:
  478. out_idx = np.squeeze(np.argwhere(np.asarray(maskbndict[layername[:-4] + "bn"].cpu().numpy())))
  479. w = layer.weight.data[out_idx, :, :, :].clone()
  480. assert len(w.shape) == 4
  481. pruned_layer.weight.data = w.clone()
  482. changed_state.append(layername + ".weight")
  483. if isinstance(layer,nn.BatchNorm2d):
  484. out_idx = np.squeeze(np.argwhere(np.asarray(maskbndict[layername].cpu().numpy())))
  485. pruned_layer.weight.data = layer.weight.data[out_idx].clone()
  486. pruned_layer.bias.data = layer.bias.data[out_idx].clone()
  487. pruned_layer.running_mean = layer.running_mean[out_idx].clone()
  488. pruned_layer.running_var = layer.running_var[out_idx].clone()
  489. changed_state.append(layername + ".weight")
  490. changed_state.append(layername + ".bias")
  491. changed_state.append(layername + ".running_mean")
  492. changed_state.append(layername + ".running_var")
  493. changed_state.append(layername + ".num_batches_tracked")
  494. if isinstance(layer, nn.Conv2d) and layername.startswith("model.24"):
  495. former = from_to_map[layername]
  496. in_idx = np.squeeze(np.argwhere(np.asarray(maskbndict[former].cpu().numpy())))
  497. pruned_layer.weight.data = layer.weight.data[:, in_idx, :, :]
  498. pruned_layer.bias.data = layer.bias.data
  499. changed_state.append(layername + ".weight")
  500. changed_state.append(layername + ".bias")
  501. missing = [i for i in pruned_model_state.keys() if i not in changed_state]
  502. pruned_model.eval()
  503. pruned_model.names = model.names
  504. # =============================================================================================== #
  505. torch.save({"model": model}, "original_model.pt")
  506. model = pruned_model
  507. torch.save({"model":model}, "pruned_model.pt")
  508. model.cuda().eval()
  509. is_coco = isinstance(data.get('val'), str) and data['val'].endswith('coco/val2017.txt') # COCO dataset
  510. nc = 1 if single_cls else int(data['nc']) # number of classes
  511. iouv = torch.linspace(0.5, 0.95, 10).to(device) # iou vector for mAP@0.5:0.95
  512. niou = iouv.numel()
  513. # Dataloader
  514. if not training:
  515. if device.type != 'cpu':
  516. model(torch.zeros(1, 3, imgsz, imgsz).to(device).type_as(next(model.parameters()))) # run once
  517. # model.warmup(imgsz=(1, 3, imgsz, imgsz), half=half) # warmup
  518. pad = 0.0 if task == 'speed' else 0.5
  519. task = task if task in ('train', 'val', 'test') else 'val' # path to train/val/test images
  520. dataloader = create_dataloader(data[task], imgsz, batch_size, stride, single_cls, pad=pad, rect=pt,
  521. workers=workers, prefix=colorstr(f'{task}: '))[0]
  522. seen = 0
  523. confusion_matrix = ConfusionMatrix(nc=nc)
  524. names = {k: v for k, v in enumerate(model.names if hasattr(model, 'names') else model.module.names)}
  525. class_map = coco80_to_coco91_class() if is_coco else list(range(1000))
  526. s = ('%20s' + '%11s' * 6) % ('Class', 'Images', 'Labels', 'P', 'R', 'mAP@.5', 'mAP@.5:.95')
  527. dt, p, r, f1, mp, mr, map50, map = [0.0, 0.0, 0.0], 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0
  528. loss = torch.zeros(3, device=device)
  529. jdict, stats, ap, ap_class = [], [], [], []
  530. pbar = tqdm(dataloader, desc=s, bar_format='{l_bar}{bar:10}{r_bar}{bar:-10b}') # progress bar
  531. for batch_i, (im, targets, paths, shapes) in enumerate(pbar):
  532. t1 = time_sync()
  533. if pt or jit or engine:
  534. im = im.to(device, non_blocking=True)
  535. targets = targets.to(device)
  536. im = im.half() if half else im.float() # uint8 to fp16/32
  537. im /= 255 # 0 - 255 to 0.0 - 1.0
  538. nb, _, height, width = im.shape # batch size, channels, height, width
  539. t2 = time_sync()
  540. dt[0] += t2 - t1
  541. # Inference
  542. out, train_out = model(im) if training else model(im, augment=augment) # inference, loss outputs
  543. dt[1] += time_sync() - t2
  544. # Loss
  545. if compute_loss:
  546. loss += compute_loss([x.float() for x in train_out], targets)[1] # box, obj, cls
  547. # NMS
  548. targets[:, 2:] *= torch.Tensor([width, height, width, height]).to(device) # to pixels
  549. lb = [targets[targets[:, 0] == i, 1:] for i in range(nb)] if save_hybrid else [] # for autolabelling
  550. t3 = time_sync()
  551. out = non_max_suppression(out, conf_thres, iou_thres, labels=lb, multi_label=True, agnostic=single_cls)
  552. dt[2] += time_sync() - t3
  553. # Metrics
  554. for si, pred in enumerate(out):
  555. labels = targets[targets[:, 0] == si, 1:]
  556. nl = len(labels)
  557. tcls = labels[:, 0].tolist() if nl else [] # target class
  558. path, shape = Path(paths[si]), shapes[si][0]
  559. seen += 1
  560. if len(pred) == 0:
  561. if nl:
  562. stats.append((torch.zeros(0, niou, dtype=torch.bool), torch.Tensor(), torch.Tensor(), tcls))
  563. continue
  564. # Predictions
  565. if single_cls:
  566. pred[:, 5] = 0
  567. predn = pred.clone()
  568. scale_coords(im[si].shape[1:], predn[:, :4], shape, shapes[si][1]) # native-space pred
  569. # Evaluate
  570. if nl:
  571. tbox = xywh2xyxy(labels[:, 1:5]) # target boxes
  572. scale_coords(im[si].shape[1:], tbox, shape, shapes[si][1]) # native-space labels
  573. labelsn = torch.cat((labels[:, 0:1], tbox), 1) # native-space labels
  574. correct = process_batch(predn, labelsn, iouv)
  575. if plots:
  576. confusion_matrix.process_batch(predn, labelsn)
  577. else:
  578. correct = torch.zeros(pred.shape[0], niou, dtype=torch.bool)
  579. stats.append((correct.cpu(), pred[:, 4].cpu(), pred[:, 5].cpu(), tcls)) # (correct, conf, pcls, tcls)
  580. # Save/log
  581. if save_txt:
  582. save_one_txt(predn, save_conf, shape, file=save_dir / 'labels' / (path.stem + '.txt'))
  583. if save_json:
  584. save_one_json(predn, jdict, path, class_map) # append to COCO-JSON dictionary
  585. callbacks.run('on_val_image_end', pred, predn, path, names, im[si])
  586. # Plot images
  587. if plots and batch_i < 3:
  588. f = save_dir / f'val_batch{batch_i}_labels.jpg' # labels
  589. Thread(target=plot_images, args=(im, targets, paths, f, names), daemon=True).start()
  590. f = save_dir / f'val_batch{batch_i}_pred.jpg' # predictions
  591. Thread(target=plot_images, args=(im, output_to_target(out), paths, f, names), daemon=True).start()
  592. # Compute metrics
  593. stats = [np.concatenate(x, 0) for x in zip(*stats)] # to numpy
  594. if len(stats) and stats[0].any():
  595. tp, fp, p, r, f1, ap, ap_class = ap_per_class(*stats, plot=plots, save_dir=save_dir, names=names)
  596. ap50, ap = ap[:, 0], ap.mean(1) # AP@0.5, AP@0.5:0.95
  597. mp, mr, map50, map = p.mean(), r.mean(), ap50.mean(), ap.mean()
  598. nt = np.bincount(stats[3].astype(np.int64), minlength=nc) # number of targets per class
  599. else:
  600. nt = torch.zeros(1)
  601. # Print results
  602. pf = '%20s' + '%11i' * 2 + '%11.3g' * 4 # print format
  603. LOGGER.info(pf % ('all', seen, nt.sum(), mp, mr, map50, map))
  604. # Print results per class
  605. if (verbose or (nc < 50 and not training)) and nc > 1 and len(stats):
  606. for i, c in enumerate(ap_class):
  607. LOGGER.info(pf % (names[c], seen, nt[c], p[i], r[i], ap50[i], ap[i]))
  608. # Print speeds
  609. t = tuple(x / seen * 1E3 for x in dt) # speeds per image
  610. if not training:
  611. shape = (batch_size, 3, imgsz, imgsz)
  612. LOGGER.info(f'Speed: %.1fms pre-process, %.1fms inference, %.1fms NMS per image at shape {shape}' % t)
  613. # Plots
  614. if plots:
  615. confusion_matrix.plot(save_dir=save_dir, names=list(names.values()))
  616. callbacks.run('on_val_end')
  617. # Save JSON
  618. if save_json and len(jdict):
  619. w = Path(weights[0] if isinstance(weights, list) else weights).stem if weights is not None else '' # weights
  620. anno_json = str(Path(data.get('path', '../coco')) / 'annotations/instances_val2017.json') # annotations json
  621. pred_json = str(save_dir / f"{w}_predictions.json") # predictions json
  622. LOGGER.info(f'\nEvaluating pycocotools mAP... saving {pred_json}...')
  623. with open(pred_json, 'w') as f:
  624. json.dump(jdict, f)
  625. try: # https://github.com/cocodataset/cocoapi/blob/master/PythonAPI/pycocoEvalDemo.ipynb
  626. check_requirements(['pycocotools'])
  627. from pycocotools.coco import COCO
  628. from pycocotools.cocoeval import COCOeval
  629. anno = COCO(anno_json) # init annotations api
  630. pred = anno.loadRes(pred_json) # init predictions api
  631. eval = COCOeval(anno, pred, 'bbox')
  632. if is_coco:
  633. eval.params.imgIds = [int(Path(x).stem) for x in dataloader.dataset.img_files] # image IDs to evaluate
  634. eval.evaluate()
  635. eval.accumulate()
  636. eval.summarize()
  637. map, map50 = eval.stats[:2] # update results (mAP@0.5:0.95, mAP@0.5)
  638. except Exception as e:
  639. LOGGER.info(f'pycocotools unable to run: {e}')
  640. # Return results
  641. model.float() # for training
  642. if not training:
  643. s = f"\n{len(list(save_dir.glob('labels/*.txt')))} labels saved to {save_dir / 'labels'}" if save_txt else ''
  644. LOGGER.info(f"Results saved to {colorstr('bold', save_dir)}{s}")
  645. maps = np.zeros(nc) + map
  646. for i, c in enumerate(ap_class):
  647. maps[c] = ap[i]
  648. return (mp, mr, map50, map, *(loss.cpu() / len(dataloader)).tolist()), maps, t
  649. def parse_opt():
  650. parser = argparse.ArgumentParser()
  651. parser.add_argument('--data', type=str, default=ROOT / 'data/VOC.yaml', help='dataset.yaml path')
  652. parser.add_argument('--weights', nargs='+', type=str, default=ROOT / 'VOC2007_wm/train/exp5/weights/best.pt', help='model.pt path(s)')
  653. parser.add_argument('--cfg', type=str, default='models/yolov5s.yaml', help='model.yaml path')
  654. parser.add_argument('--percent', type=float, default=0.15, help='prune percentage')
  655. parser.add_argument('--batch-size', type=int, default=32, help='batch size')
  656. parser.add_argument('--imgsz', '--img', '--img-size', type=int, default=512, help='inference size (pixels)')
  657. parser.add_argument('--conf-thres', type=float, default=0.001, help='confidence threshold')
  658. parser.add_argument('--iou-thres', type=float, default=0.6, help='NMS IoU threshold')
  659. parser.add_argument('--task', default='val', help='train, val, test, speed or study')
  660. parser.add_argument('--device', default='', help='cuda device, i.e. 0 or 0,1,2,3 or cpu')
  661. parser.add_argument('--workers', type=int, default=8, help='max dataloader workers (per RANK in DDP mode)')
  662. parser.add_argument('--single-cls', action='store_true', help='treat as single-class dataset')
  663. parser.add_argument('--augment', action='store_true', help='augmented inference')
  664. parser.add_argument('--verbose', action='store_true', help='report mAP by class')
  665. parser.add_argument('--save-txt', action='store_true', help='save results to *.txt')
  666. parser.add_argument('--save-hybrid', action='store_true', help='save label+prediction hybrid results to *.txt')
  667. parser.add_argument('--save-conf', action='store_true', help='save confidences in --save-txt labels')
  668. parser.add_argument('--save-json', action='store_true', help='save a COCO-JSON results file')
  669. parser.add_argument('--project', default=ROOT / 'VOC2007_wm/prune', help='save to project/name')
  670. parser.add_argument('--name', default='exp', help='save to project/name')
  671. parser.add_argument('--exist-ok', action='store_true', help='existing project/name ok, do not increment')
  672. parser.add_argument('--half', action='store_true', help='use FP16 half-precision inference')
  673. parser.add_argument('--dnn', action='store_true', help='use OpenCV DNN for ONNX inference')
  674. opt = parser.parse_args()
  675. opt.data = check_yaml(opt.data) # check YAML
  676. opt.save_json |= opt.data.endswith('coco.yaml')
  677. opt.save_txt |= opt.save_hybrid
  678. print_args(FILE.stem, opt)
  679. return opt
  680. def main(opt):
  681. # check_requirements(requirements=ROOT / 'requirements.txt', exclude=('tensorboard', 'thop'))
  682. if opt.task in ('train', 'val', 'test'): # run normally
  683. if opt.conf_thres > 0.001: # https://github.com/ultralytics/yolov5/issues/1466
  684. LOGGER.info(f'WARNING: confidence threshold {opt.conf_thres} >> 0.001 will produce invalid mAP values.')
  685. LOGGER.info(f'test before prune ... ')
  686. run(**vars(opt))
  687. LOGGER.info('='*100)
  688. LOGGER.info('Test after prune ... ')
  689. run_prune(**vars(opt))
  690. else:
  691. weights = opt.weights if isinstance(opt.weights, list) else [opt.weights]
  692. opt.half = True # FP16 for fastest results
  693. if opt.task == 'speed': # speed benchmarks
  694. # python val.py --task speed --data coco.yaml --batch 1 --weights yolov5n.pt yolov5s.pt...
  695. opt.conf_thres, opt.iou_thres, opt.save_json = 0.25, 0.45, False
  696. for opt.weights in weights:
  697. run(**vars(opt), plots=False)
  698. elif opt.task == 'study': # speed vs mAP benchmarks
  699. # python val.py --task study --data coco.yaml --iou 0.7 --weights yolov5n.pt yolov5s.pt...
  700. for opt.weights in weights:
  701. f = f'study_{Path(opt.data).stem}_{Path(opt.weights).stem}.txt' # filename to save to
  702. x, y = list(range(256, 1536 + 128, 128)), [] # x axis (image sizes), y axis
  703. for opt.imgsz in x: # img-size
  704. LOGGER.info(f'\nRunning {f} --imgsz {opt.imgsz}...')
  705. r, _, t = run(**vars(opt), plots=False)
  706. y.append(r + t) # results and times
  707. np.savetxt(f, y, fmt='%10.4g') # save
  708. os.system('zip -r study.zip study_*.txt')
  709. plot_val_study(x=x) # plot
  710. if __name__ == "__main__":
  711. opt = parse_opt()
  712. main(opt)