|
@@ -192,7 +192,7 @@ class ModelEncoder:
|
|
|
|
|
|
def get_prob(self, x_random, w):
|
|
def get_prob(self, x_random, w):
|
|
mm = torch.mm(x_random, w.reshape((w.shape[0], 1)))
|
|
mm = torch.mm(x_random, w.reshape((w.shape[0], 1)))
|
|
- return F.sigmoid(mm).flatten()
|
|
|
|
|
|
+ return mm.flatten()
|
|
|
|
|
|
def loss_fun(self, x, y):
|
|
def loss_fun(self, x, y):
|
|
return nn.BCEWithLogitsLoss()(x, y)
|
|
return nn.BCEWithLogitsLoss()(x, y)
|