summary.py 1.3 KB

1234567891011121314151617181920212223242526272829
  1. #--------------------------------------------#
  2. # 该部分代码用于看网络结构
  3. #--------------------------------------------#
  4. import torch
  5. from thop import clever_format, profile
  6. from torchsummary import summary
  7. from nets.frcnn import FasterRCNN
  8. if __name__ == "__main__":
  9. input_shape = [600, 600]
  10. num_classes = 21
  11. device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
  12. model = FasterRCNN(num_classes, backbone = 'vgg').to(device)
  13. summary(model, (3, input_shape[0], input_shape[1]))
  14. dummy_input = torch.randn(1, 3, input_shape[0], input_shape[1]).to(device)
  15. flops, params = profile(model.to(device), (dummy_input, ), verbose = False)
  16. #--------------------------------------------------------#
  17. # flops * 2是因为profile没有将卷积作为两个operations
  18. # 有些论文将卷积算乘法、加法两个operations。此时乘2
  19. # 有些论文只考虑乘法的运算次数,忽略加法。此时不乘2
  20. # 本代码选择乘2,参考YOLOX。
  21. #--------------------------------------------------------#
  22. flops = flops * 2
  23. flops, params = clever_format([flops, params], "%.3f")
  24. print('Total GFLOPS: %s' % (flops))
  25. print('Total params: %s' % (params))