""" 示例代码,演示白盒水印编码器与解码器的使用 """ import os import torch import torch.nn as nn import torchvision import torchvision.transforms as transforms from matplotlib import pyplot as plt from torch import optim from tqdm import tqdm # 导入tqdm import secret_func from model.Alexnet import Alexnet from watermark_codec import ModelEncoder # 参数 batch_size = 500 device = torch.device("cuda" if torch.cuda.is_available() else "cpu") num_epochs = 40 wm_length = 1024 num_workers = 2 # 设置随机数种子 # np.random.seed(1) # lambda1 = 0.05 # b = np.random.randint(low=0, high=2, size=(1, wm_length)) # 生成模拟随机密钥 # np.save('b.npy', b) # b = nn.Parameter(torch.tensor(b, dtype=torch.float32).to(device), requires_grad=False) # b.requires_grad = False # 存储路径 model_path = './run/train/alex_net.pt' key_path = './run/train/key.pt' os.makedirs(os.path.dirname(model_path), exist_ok=True) os.makedirs(os.path.dirname(key_path), exist_ok=True) # 数据预处理和加载 transform_train = transforms.Compose([ transforms.RandomCrop(32, padding=4), transforms.RandomHorizontalFlip(), transforms.ToTensor(), transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)), ]) transform_test = transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)), ]) trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform_train) trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, shuffle=True, num_workers=num_workers) testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform_test) testloader = torch.utils.data.DataLoader(testset, batch_size=100, shuffle=False, num_workers=num_workers) # 创建AlexNet模型实例 model = Alexnet(3, 10, 32).to(device) print(model) # 获取模型中待嵌入的卷积层 conv_list = [] for module in model.modules(): if isinstance(module, nn.Conv2d): conv_list.append(module) conv_list = conv_list[0:2] # 创建模型水印编码器 secret = secret_func.get_secret(512) encoder = ModelEncoder(layers=conv_list, secret=secret, key_path=key_path, device='cuda') # 定义目标模型损失函数和优化器 criterion = nn.CrossEntropyLoss() # 目标模型使用Adam优化器 optimizer = optim.Adam(model.parameters(), lr=1e-4) # 调整学习率 # 初始化空列表以存储准确度和损失 train_accs = [] train_losses = [] torch.autograd.set_detect_anomaly(True) for epoch in range(num_epochs): model.train() running_loss = 0.0 correct = 0 total = 0 # 使用tqdm创建进度条 with tqdm(total=len(trainloader), desc=f"Epoch {epoch + 1}", unit="batch") as pbar: for i, data in enumerate(trainloader, 0): inputs, labels = data inputs, labels = inputs.to(device), labels.to(device) optimizer.zero_grad() outputs = model(inputs) loss = criterion(outputs, labels) # loss = encoder.get_loss(loss) # 实际应用只调用get_loss修改原损失即可 # 测试时可以获取白盒水印的损失并打印 ------------------------------ loss_embeder = encoder.get_embeder_loss() loss += loss_embeder # ----------------------------------------------------------- loss.backward() optimizer.step() running_loss += loss.item() _, predicted = torch.max(outputs.data, 1) total += labels.size(0) correct += (predicted == labels).sum().item() # 更新进度条 pbar.set_postfix(loss=running_loss / (i + 1), loss_embeder=loss_embeder.item(), acc=100 * correct / total) pbar.update() # 计算准确度和损失 epoch_acc = 100 * correct / total epoch_loss = running_loss / len(trainloader) # 记录准确度和损失值 train_accs.append(epoch_acc) train_losses.append(epoch_loss) print(f"Epoch {epoch + 1}, Loss: {epoch_loss}, Accuracy: {epoch_acc}%") torch.save(model.state_dict(), model_path) # 测试模型 if epoch % 5 == 4: model.eval() correct = 0 total = 0 with torch.no_grad(): for data in testloader: inputs, labels = data inputs, labels = inputs.to(device), labels.to(device) outputs = model(inputs) _, predicted = torch.max(outputs.data, 1) total += labels.size(0) correct += (predicted == labels).sum().item() print(f"Accuracy on test set: {(100 * correct / total):.2f}%") print("Finished Training") # 绘制准确度和损失曲线 plt.figure(figsize=(12, 4)) plt.subplot(1, 2, 1) plt.plot(train_accs) plt.title('Training Accuracy') plt.xlabel('Epoch') plt.ylabel('Accuracy (%)') plt.subplot(1, 2, 2) plt.plot(train_losses) plt.title('Training Loss') plt.xlabel('Epoch') plt.ylabel('Loss') plt.tight_layout() plt.show() print("Finished drawing")