loss_get.py 210 B

12345678910
  1. import torch
  2. def loss_get(args):
  3. choice_dict = {
  4. 'bce': 'torch.nn.BCEWithLogitsLoss()',
  5. 'cross':'torch.nn.CrossEntropyLoss()'
  6. }
  7. loss = eval(choice_dict[args.loss])
  8. return loss