embeder.py 410 B

1234567891011121314151617
  1. import torch.nn as nn
  2. class WatermarkEmbeder(nn.Module):
  3. def __init__(self, pct_dim, wm_dim):
  4. super(WatermarkEmbeder, self).__init__()
  5. self.model = nn.Sequential(
  6. nn.Linear(pct_dim, 256, bias=False),
  7. nn.ReLU(inplace=True),
  8. nn.Linear(256, wm_dim, bias=False),
  9. nn.Sigmoid()
  10. )
  11. def forward(self, x):
  12. return self.model(x)