@@ -192,7 +192,7 @@ class ModelEncoder:
def get_prob(self, x_random, w):
mm = torch.mm(x_random, w.reshape((w.shape[0], 1)))
- return F.sigmoid(mm).flatten()
+ return mm.flatten()
def loss_fun(self, x, y):
return nn.BCEWithLogitsLoss()(x, y)