소스 검색

修改验证流程代码

liyan 11 달 전
부모
커밋
035437ed88
1개의 변경된 파일2개의 추가작업 그리고 2개의 파일을 삭제
  1. 2 2
      tests/val.py

+ 2 - 2
tests/val.py

@@ -13,7 +13,7 @@ import secret_func
 
 model_path = './run/train/alex_net.pt'
 key_path = './run/train/key.pt'
-device = 'cuda'
+device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
 
 # 测试集转换
 transform_test = transforms.Compose([
@@ -25,7 +25,7 @@ testloader = torch.utils.data.DataLoader(testset, batch_size=100, shuffle=False)
 
 # 从指定权重文件加载模型,测试水印嵌入
 model = Alexnet(3, 10, 32).to(device)
-model.load_state_dict(torch.load(model_path))
+model.load_state_dict(torch.load(model_path, map_location=torch.device(device)))
 # 获取模型中待嵌入的卷积层
 conv_list = []
 for module in model.modules():