12345678910111213141516171819202122232425262728293031323334353637383940414243444546 |
- import timm
- # print(timm.list_models())
- import torch
- # from model.layer import linear_head
- import os
- import sys
- project_root = '/home/yhsun/classification-main/'
- sys.path.append(project_root)
- # print("Project root added to sys.path:", project_root)
- # Verify that we can access the model package directly
- import model
- # print("Model package is accessible, path:", model.__file__)
- from model.layer import linear_head
- # print("Imported linear_head from model.layer")
- class timm_model(torch.nn.Module):
- def __init__(self, args):
- super().__init__()
- self.backbone = timm.create_model(args.model, in_chans=3, features_only=True, exportable=True)
- out_dim = self.backbone.feature_info.channels()[-1] # backbone输出有多个,接最后一个输出,并得到其通道数
- self.linear_head = linear_head(out_dim, args.output_class)
- def forward(self, x):
- x = self.backbone(x)
- x = self.linear_head(x[-1])
- return x
- if __name__ == '__main__':
- import argparse
- parser = argparse.ArgumentParser(description='')
- parser.add_argument('--model', default='resnet18', type=str)
- parser.add_argument('--input_size', default=32, type=int)
- parser.add_argument('--output_class', default=10, type=int)
- args = parser.parse_args()
- model = timm_model(args)
- tensor = torch.rand(2, 3, args.input_size, args.input_size, dtype=torch.float32)
- pred = model(tensor)
- print(model)
- print(pred.shape)
|