1234567891011121314151617 |
- import torch.nn as nn
- class WatermarkEmbeder(nn.Module):
- def __init__(self, pct_dim, wm_dim):
- super(WatermarkEmbeder, self).__init__()
- self.model = nn.Sequential(
- nn.Linear(pct_dim, 256, bias=False),
- nn.ReLU(inplace=True),
- nn.Linear(256, wm_dim, bias=False),
- nn.Sigmoid()
- )
- def forward(self, x):
- return self.model(x)
|