Browse Source

修改白盒水印训练过程及编码器

liyan 1 year ago
parent
commit
1e23d656f6
3 changed files with 37 additions and 25 deletions
  1. 1 1
      bash_watermarking.sh
  2. 29 16
      run.py
  3. 7 8
      tool/training_embedding.py

+ 1 - 1
bash_watermarking.sh

@@ -11,7 +11,7 @@ python ./tool/generate_txt.py --txt_path './dataset/CIFAR-10_ori'  --specific_da
 # For 2)用于水印植入处理部分 字符串插入图像处理
 # -------------------------------------------------------------------------------------------------------------------- #
 # 密钥生成部分文件放置在 './dataset/watermarking'里,其中同时含有key_hex.txt和对应根据classes拆分的 QR images,便于选择水印插入方式
-
+python ./watermarking_dataset_process.py --key_path ./dataset/watermarking2/key_hex.txt --dataset_train_txt_path ./dataset/CIFAR-10_wm/train.txt --dataset_test_txt_path ./dataset/CIFAR-10_wm/test.txt --dataset_name CIFAR-10 --key_size 256 --output_class 10
 
 
 

+ 29 - 16
run.py

@@ -21,26 +21,39 @@ import argparse
 from block.data_get import data_get
 from block.loss_get import loss_get
 from block.model_get import model_get
+from block.train_embeder import train_embeder
 from block.train_get import train_get
 
+# 获取当前文件路径
+pwd = os.getcwd()
+
 # -------------------------------------------------------------------------------------------------------------------- #
 # 模型加载/创建的优先级为:加载已有模型>创建剪枝模型>创建timm库模型>创建自定义模型
 parser = argparse.ArgumentParser(description='|针对分类任务,添加水印机制,包含数据隐私、模型水印|')
+
+# 训练时是否嵌入白盒水印,默认为False,即不嵌入白盒水印
+parser.add_argument('--white_box_embed', default=False, type=bool,
+                    help='|训练时是否嵌入白盒水印,默认为False,即不嵌入白盒水印|')
+
+# wandb配置
 parser.add_argument('--wandb', default=False, type=bool, help='|是否使用wandb可视化|')
 parser.add_argument('--wandb_project', default='classification', type=str, help='|wandb项目名称|')
 parser.add_argument('--wandb_name', default='train', type=str, help='|wandb项目中的训练名称|')
 parser.add_argument('--wandb_image_num', default=16, type=int, help='|wandb保存图片的数量|')
 
 # new_added
-parser.add_argument('--data_path', default='/home/yhsun/classification-main/dataset', type=str, help='Root path to datasets')
+parser.add_argument('--data_path', default='/home/yhsun/classification-main/dataset', type=str,
+                    help='Root path to datasets')
 parser.add_argument('--dataset_name', default='CIFAR-10', type=str, help='Specific dataset name')
 parser.add_argument('--input_channels', default=3, type=int)
 parser.add_argument('--output_num', default=10, type=int)
 # parser.add_argument('--input_size', default=32, type=int)
-#黑盒水印植入,这里需要调用它,用于处理部分数据的
-parser.add_argument('--trigger_label', type=int, default=2, help='The NO. of trigger label (int, range from 0 to 10, default: 0)')
-#这里可以直接选择水印控制,看看如何选择调用进来
-parser.add_argument('--watermarking_portion', type=float, default=0.1, help='poisoning portion (float, range from 0 to 1, default: 0.1)')
+# 黑盒水印植入,这里需要调用它,用于处理部分数据的
+parser.add_argument('--trigger_label', type=int, default=2,
+                    help='The NO. of trigger label (int, range from 0 to 10, default: 0)')
+# 这里可以直接选择水印控制,看看如何选择调用进来
+parser.add_argument('--watermarking_portion', type=float, default=0.1,
+                    help='poisoning portion (float, range from 0 to 1, default: 0.1)')
 
 # 待修改
 parser.add_argument('--input_size', default=32, type=int, help='|输入图片大小|')
@@ -48,24 +61,22 @@ parser.add_argument('--input_size', default=32, type=int, help='|输入图片大
 parser.add_argument('--output_class', default=10, type=int, help='|输出的类别数|')
 parser.add_argument('--weight', default='last.pt', type=str, help='|已有模型的位置,没找到模型会创建剪枝/新模型|')
 
-
 # 剪枝的处理部分
 parser.add_argument('--prune', default=False, type=bool, help='|模型剪枝后再训练(部分模型有),需要提供prune_weight|')
 parser.add_argument('--prune_weight', default='best.pt', type=str, help='|模型剪枝的参考模型,会创建剪枝模型和训练模型|')
 parser.add_argument('--prune_ratio', default=0.5, type=float, help='|模型剪枝时的保留比例|')
 parser.add_argument('--prune_save', default='prune_best.pt', type=str, help='|保存最佳模型,每轮还会保存prune_last.pt|')
 
-
 # 模型处理的部分了
 parser.add_argument('--timm', default=False, type=bool, help='|是否使用timm库创建模型|')
 parser.add_argument('--model', default='mobilenetv2', type=str, help='|自定义模型选择,timm为True时为timm库中模型|')
 parser.add_argument('--model_type', default='s', type=str, help='|自定义模型型号|')
-parser.add_argument('--save_path', default='./checkpoints/mobilenetv2/best.pt', type=str, help='|保存最佳模型,除此之外每轮还会保存last.pt|')
-parser.add_argument('--save_path_last', default='./checkpoints/mobilenetv2/last.pt', type=str, help='|保存最佳模型,除此之外每轮还会保存last.pt|')
+parser.add_argument('--save_path', default=f'${pwd}/checkpoints/mobilenetv2/best.pt', type=str, help='|保存最佳模型,除此之外每轮还会保存last.pt|')
+parser.add_argument('--save_path_last', default=f'${pwd}/checkpoints/mobilenetv2/last.pt', type=str, help='|保存最佳模型,除此之外每轮还会保存last.pt|')
 
 # 训练控制
 parser.add_argument('--epoch', default=20, type=int, help='|训练总轮数(包含之前已训练轮数)|')
-parser.add_argument('--batch', default=100, type=int, help='|训练批量大小,分布式时为总批量|')
+parser.add_argument('--batch', default=500, type=int, help='|训练批量大小,分布式时为总批量|')
 parser.add_argument('--loss', default='bce', type=str, help='|损失函数|')
 parser.add_argument('--warmup_ratio', default=0.01, type=float, help='|预热训练步数占总步数比例,最少5步,基准为0.01|')
 parser.add_argument('--lr_start', default=0.001, type=float, help='|初始学习率,adam算法,批量小时要减小,基准为0.001|')
@@ -77,18 +88,17 @@ parser.add_argument('--device', default='cuda', type=str, help='|训练设备|')
 parser.add_argument('--latch', default=True, type=bool, help='|模型和数据是否为锁存,True为锁存|')
 parser.add_argument('--num_worker', default=0, type=int, help='|CPU处理数据的进程数,0只有一个主进程,一般为0、2、4、8|')
 parser.add_argument('--ema', default=True, type=bool, help='|使用平均指数移动(EMA)调整参数|')
-parser.add_argument('--amp', default=True, type=bool, help='|混合float16精度训练,CPU时不可用,出现nan可能与GPU有关|')
+parser.add_argument('--amp', default=False, type=bool, help='|混合float16精度训练,CPU时不可用,出现nan可能与GPU有关|')
 parser.add_argument('--noise', default=0.5, type=float, help='|训练数据加噪概率|')
 parser.add_argument('--class_threshold', default=0.5, type=float, help='|计算指标时,大于阈值判定为图片有该类别|')
 parser.add_argument('--distributed', default=False, type=bool, help='|单机多卡分布式训练,分布式训练时batch为总batch|')
 parser.add_argument('--local_rank', default=0, type=int, help='|分布式训练使用命令后会自动传入的参数|')
 args = parser.parse_args()
-args.device_number = max(torch.cuda.device_count(), 2)  # 使用的GPU数,可能为CPU
+args.device_number = max(torch.cuda.device_count(), 1)  # 使用的GPU数,可能为CPU
 
 # 创建模型对应的检查点目录
-checkpoint_dir = os.path.join('/home/yhsun/classification-main/checkpoints', args.model)
-if not os.path.exists(checkpoint_dir):
-    os.makedirs(checkpoint_dir)
+os.makedirs(os.path.dirname(args.save_path), exist_ok=True)
+os.makedirs(os.path.dirname(args.save_path_last), exist_ok=True)
 print(f"模型保存路径已创建: {args.model}")
 
 # 为CPU设置随机种子
@@ -143,4 +153,7 @@ if __name__ == '__main__':
     # 损失
     loss = loss_get(args)
     # 训练
-    train_get(args, data_dict, model_dict, loss)
+    if args.white_box_embed:  # 根据训练模式,判断训练时是否嵌入白盒水印
+        train_embeder(args, data_dict, model_dict, loss)
+    else:
+        train_get(args, data_dict, model_dict, loss)

+ 7 - 8
tool/training_embedding.py

@@ -64,14 +64,13 @@ class Embedding():
         self.X_random = torch.load(path).cuda()
 
     def get_parameters(self, model):
-        conv_list = []
-        # print(model.modules())
-        for module in model.modules():
-            if isinstance(module, nn.Conv2d) and module.out_channels > 100:
-                conv_list.append(module)
-
-        # print(conv_list)
-        target = conv_list[10:12]
+        # conv_list = []
+        # for module in model.modules():
+        #     if isinstance(module, nn.Conv2d) and module.out_channels > 100:
+        #         conv_list.append(module)
+        #
+        # target = conv_list[10:12]
+        target = model.get_encode_layers()
         print(f'Embedding target:{target}')
         # parameters = target.weight
         parameters = [x.weight for x in target]