watermarking_dataset_process.py 2.2 KB

1234567891011121314151617181920212223242526272829303132333435363738
  1. """
  2. 处理数据集,嵌入密码标签,供模型训练嵌入黑盒水印使用
  3. """
  4. import argparse
  5. import os
  6. from tool.secret_func import get_secret
  7. from tool.watermarking_data_process import generate_random_key_and_qrcodes, watermark_dataset_with_bits, \
  8. modify_images_and_labels, save_secret
  9. # -------------------------------------------------------------------------------------------------------------------- #
  10. parser = argparse.ArgumentParser(description='|处理数据集,嵌入密码标签,供模型训练嵌入黑盒水印使用|')
  11. parser.add_argument('--key_path', default='./dataset/watermarking/key_hex.txt', type=str, help='密钥存储位置')
  12. parser.add_argument('--dataset_train_txt_path', default='./dataset/CIFAR-10/train.txt', type=str, help='location of train.txt')
  13. parser.add_argument('--dataset_test_txt_path', default='./dataset/CIFAR-10/test.txt', type=str, help='location of test.txt')
  14. parser.add_argument('--dataset_name', default='CIFAR-10', type=str, help='CIFAR-10')
  15. parser.add_argument('--key_size', default=256, type=int, help='密钥长度')
  16. parser.add_argument('--output_class', default=10, type=int, help='|输出的类别数|')
  17. args = parser.parse_args()
  18. if __name__ == '__main__':
  19. # 创建密钥存储位置
  20. os.makedirs(os.path.dirname(args.key_path), exist_ok=True)
  21. print("密钥存储位置已创建")
  22. # 获取密码标签
  23. secret = get_secret(args.key_size)
  24. # 功能1 完成以bits形式的水印密钥生成、水印密钥插入、水印模型数据预处理
  25. save_secret(secret=secret, key_path=args.key_path)
  26. watermark_dataset_with_bits(args.key_path, args.dataset_train_txt_path, args.dataset_name, args.output_class)
  27. # 功能2 数据预处理部分,train 和 test 的处理方式不同
  28. assert os.path.exists(args.dataset_train_txt_path), f'! 训练标签文件不存在:${args.dataset_train_txt_path} !'
  29. assert os.path.exists(args.dataset_test_txt_path), f'! 测试标签文件不存在:${args.dataset_test_txt_path} !'
  30. modify_images_and_labels(args.dataset_train_txt_path, percentage=1, min_samples_per_class=10)
  31. modify_images_and_labels(args.dataset_test_txt_path, percentage=100, min_samples_per_class=10)