import torch def loss_get(args): choice_dict = { 'bce': 'torch.nn.BCEWithLogitsLoss()', 'cross':'torch.nn.CrossEntropyLoss()' } loss = eval(choice_dict[args.loss]) return loss