1234567891011121314151617181920212223242526272829303132333435363738 |
- """
- 处理数据集,嵌入密码标签,供模型训练嵌入黑盒水印使用
- """
- import argparse
- import os
- from tool.secret_func import get_secret
- from tool.watermarking_data_process import generate_random_key_and_qrcodes, watermark_dataset_with_bits, \
- modify_images_and_labels, save_secret
- # -------------------------------------------------------------------------------------------------------------------- #
- parser = argparse.ArgumentParser(description='|处理数据集,嵌入密码标签,供模型训练嵌入黑盒水印使用|')
- parser.add_argument('--key_path', default='./dataset/watermarking/key_hex.txt', type=str, help='密钥存储位置')
- parser.add_argument('--dataset_train_txt_path', default='./dataset/CIFAR-10/train.txt', type=str, help='location of train.txt')
- parser.add_argument('--dataset_test_txt_path', default='./dataset/CIFAR-10/test.txt', type=str, help='location of test.txt')
- parser.add_argument('--dataset_name', default='CIFAR-10', type=str, help='CIFAR-10')
- parser.add_argument('--key_size', default=256, type=int, help='密钥长度')
- parser.add_argument('--output_class', default=10, type=int, help='|输出的类别数|')
- args = parser.parse_args()
- if __name__ == '__main__':
- # 创建密钥存储位置
- os.makedirs(os.path.dirname(args.key_path), exist_ok=True)
- print("密钥存储位置已创建")
- # 获取密码标签
- secret = get_secret(args.key_size)
- # 功能1 完成以bits形式的水印密钥生成、水印密钥插入、水印模型数据预处理
- save_secret(secret=secret, key_path=args.key_path)
- watermark_dataset_with_bits(args.key_path, args.dataset_train_txt_path, args.dataset_name, args.output_class)
- # 功能2 数据预处理部分,train 和 test 的处理方式不同
- assert os.path.exists(args.dataset_train_txt_path), f'! 训练标签文件不存在:${args.dataset_train_txt_path} !'
- assert os.path.exists(args.dataset_test_txt_path), f'! 测试标签文件不存在:${args.dataset_test_txt_path} !'
- modify_images_and_labels(args.dataset_train_txt_path, percentage=1, min_samples_per_class=10)
- modify_images_and_labels(args.dataset_test_txt_path, percentage=100, min_samples_per_class=10)
|