loss_get.py 149 B

1234567
  1. import torch
  2. def loss_get(args):
  3. choice_dict = {'bce': 'torch.nn.BCEWithLogitsLoss()'}
  4. loss = eval(choice_dict[args.loss])
  5. return loss