import torch.nn as nn class WatermarkEmbeder(nn.Module): def __init__(self, pct_dim, wm_dim): super(WatermarkEmbeder, self).__init__() self.model = nn.Sequential( nn.Linear(pct_dim, 256, bias=False), nn.ReLU(inplace=True), nn.Linear(256, wm_dim, bias=False), nn.Sigmoid() ) def forward(self, x): return self.model(x)