#--------------------------------------------# # 该部分代码用于看网络结构 #--------------------------------------------# import torch from thop import clever_format, profile from torchsummary import summary from nets.frcnn import FasterRCNN if __name__ == "__main__": input_shape = [600, 600] num_classes = 21 device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') model = FasterRCNN(num_classes, backbone = 'vgg').to(device) summary(model, (3, input_shape[0], input_shape[1])) dummy_input = torch.randn(1, 3, input_shape[0], input_shape[1]).to(device) flops, params = profile(model.to(device), (dummy_input, ), verbose = False) #--------------------------------------------------------# # flops * 2是因为profile没有将卷积作为两个operations # 有些论文将卷积算乘法、加法两个operations。此时乘2 # 有些论文只考虑乘法的运算次数,忽略加法。此时不乘2 # 本代码选择乘2,参考YOLOX。 #--------------------------------------------------------# flops = flops * 2 flops, params = clever_format([flops, params], "%.3f") print('Total GFLOPS: %s' % (flops)) print('Total params: %s' % (params))