timm_model.py 1.4 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546
  1. import timm
  2. # print(timm.list_models())
  3. import torch
  4. # from model.layer import linear_head
  5. import os
  6. import sys
  7. project_root = '/home/yhsun/classification-main/'
  8. sys.path.append(project_root)
  9. # print("Project root added to sys.path:", project_root)
  10. # Verify that we can access the model package directly
  11. import model
  12. # print("Model package is accessible, path:", model.__file__)
  13. from model.layer import linear_head
  14. # print("Imported linear_head from model.layer")
  15. class timm_model(torch.nn.Module):
  16. def __init__(self, args):
  17. super().__init__()
  18. self.backbone = timm.create_model(args.model, in_chans=3, features_only=True, exportable=True)
  19. out_dim = self.backbone.feature_info.channels()[-1] # backbone输出有多个,接最后一个输出,并得到其通道数
  20. self.linear_head = linear_head(out_dim, args.output_class)
  21. def forward(self, x):
  22. x = self.backbone(x)
  23. x = self.linear_head(x[-1])
  24. return x
  25. if __name__ == '__main__':
  26. import argparse
  27. parser = argparse.ArgumentParser(description='')
  28. parser.add_argument('--model', default='resnet18', type=str)
  29. parser.add_argument('--input_size', default=32, type=int)
  30. parser.add_argument('--output_class', default=10, type=int)
  31. args = parser.parse_args()
  32. model = timm_model(args)
  33. tensor = torch.rand(2, 3, args.input_size, args.input_size, dtype=torch.float32)
  34. pred = model(tensor)
  35. print(model)
  36. print(pred.shape)