import torch def loss_get(args): choice_dict = {'bce': 'torch.nn.BCEWithLogitsLoss()','cross': 'torch.nn.CrossEntropyLoss()'} loss = eval(choice_dict[args.loss]) return loss