loss_get.py 188 B

1234567
  1. import torch
  2. def loss_get(args):
  3. choice_dict = {'bce': 'torch.nn.BCEWithLogitsLoss()','cross': 'torch.nn.CrossEntropyLoss()'}
  4. loss = eval(choice_dict[args.loss])
  5. return loss