# 数据需准备成以下格式 # ├── 数据集路径:data_path # └── image:存放所有图片 # └── train.txt:训练图片的绝对路径(或相对data_path下路径)和类别号,(image/mask/0.jpg 0 2\n)表示该图片类别为0和2,空类别图片无类别号 # └── val.txt:验证图片的绝对路径(或相对data_path下路径)和类别 # └── class.txt:所有的类别名称 # class.csv内容如下: # 类别1 # 类别2 # ... # -------------------------------------------------------------------------------------------------------------------- # # 分布式数据并行训练: # python -m torch.distributed.launch --master_port 9999 --nproc_per_node n train.py --distributed True # master_port为GPU之间的通讯端口,空闲的即可 # n为GPU数量 # -------------------------------------------------------------------------------------------------------------------- # import os # import wandb import torch import argparse from block import secret_get from block.loss_get import loss_get from block.model_get import model_get from block.train_with_watermark import train_embed # -------------------------------------------------------------------------------------------------------------------- # # 模型加载/创建的优先级为:加载已有模型>创建剪枝模型>创建timm库模型>创建自定义模型 parser = argparse.ArgumentParser(description='|针对分类任务,添加水印机制,包含数据隐私、模型水印|') parser.add_argument('--wandb', default=False, type=bool, help='|是否使用wandb可视化|') parser.add_argument('--wandb_project', default='classification', type=str, help='|wandb项目名称|') parser.add_argument('--wandb_name', default='train', type=str, help='|wandb项目中的训练名称|') parser.add_argument('--wandb_image_num', default=16, type=int, help='|wandb保存图片的数量|') # new_added parser.add_argument('--data_path', default='./dataset', type=str, help='Root path to datasets') parser.add_argument('--dataset_name', default='CIFAR-10', type=str, help='Specific dataset name') parser.add_argument('--input_channels', default=3, type=int) parser.add_argument('--output_num', default=10, type=int, help='|输出的类别数|') # 待修改 parser.add_argument('--input_size', default=32, type=int, help='|输入图片大小|') # 待修改 parser.add_argument('--output_class', default=10, type=int, help='|输出的类别数|') parser.add_argument('--weight', default='last.pt', type=str, help='|已有模型的位置,没找到模型会创建剪枝/新模型|') # 剪枝的处理部分 parser.add_argument('--prune', default=False, type=bool, help='|模型剪枝后再训练(部分模型有),需要提供prune_weight|') parser.add_argument('--prune_weight', default='best.pt', type=str, help='|模型剪枝的参考模型,会创建剪枝模型和训练模型|') parser.add_argument('--prune_ratio', default=0.5, type=float, help='|模型剪枝时的保留比例|') parser.add_argument('--prune_save', default='prune_best.pt', type=str, help='|保存最佳模型,每轮还会保存prune_last.pt|') # 模型选择 parser.add_argument('--model', default='VGG19', type=str, help='|自定义模型选择|') # 训练控制 parser.add_argument('--epoch', default=20, type=int, help='|训练总轮数(包含之前已训练轮数)|') parser.add_argument('--batch', default=500, type=int, help='|训练批量大小,分布式时为总批量|') parser.add_argument('--loss', default='cross', type=str, help='|损失函数|') parser.add_argument('--warmup_ratio', default=0.01, type=float, help='|预热训练步数占总步数比例,最少5步,基准为0.01|') parser.add_argument('--lr_start', default=0.01, type=float, help='|初始学习率,adam算法,批量小时要减小,基准为0.001|') parser.add_argument('--lr_end_ratio', default=0.01, type=float, help='|最终学习率=lr_end_ratio*lr_start,基准为0.01|') parser.add_argument('--lr_end_epoch', default=100, type=int, help='|最终学习率达到的轮数,每一步都调整,余弦下降法|') parser.add_argument('--regularization', default='L2', type=str, help='|正则化,有L2、None|') parser.add_argument('--r_value', default=0.0005, type=float, help='|正则化权重系数,基准为0.0005|') parser.add_argument('--device', default='cuda', type=str, help='|训练设备|') parser.add_argument('--latch', default=True, type=bool, help='|模型和数据是否为锁存,True为锁存|') parser.add_argument('--num_worker', default=0, type=int, help='|CPU处理数据的进程数,0只有一个主进程,一般为0、2、4、8|') parser.add_argument('--ema', default=True, type=bool, help='|使用平均指数移动(EMA)调整参数|') parser.add_argument('--amp', default=False, type=bool, help='|混合float16精度训练,CPU时不可用,出现nan可能与GPU有关|') parser.add_argument('--noise', default=0.5, type=float, help='|训练数据加噪概率|') parser.add_argument('--class_threshold', default=0.5, type=float, help='|计算指标时,大于阈值判定为图片有该类别|') parser.add_argument('--distributed', default=False, type=bool, help='|单机多卡分布式训练,分布式训练时batch为总batch|') parser.add_argument('--local_rank', default=0, type=int, help='|分布式训练使用命令后会自动传入的参数|') args = parser.parse_args() args.device_number = max(torch.cuda.device_count(), 1) # 使用的GPU数,可能为CPU # 创建模型对应的检查点目录 checkpoint_dir = os.path.join('./checkpoints', args.model, 'wm_embed') os.makedirs(checkpoint_dir, exist_ok=True) print(f"模型保存路径已创建: {checkpoint_dir}") args.save_path = os.path.join(checkpoint_dir, 'best.pt') # 保存最佳训练模型 args.save_path_last = os.path.join(checkpoint_dir, 'last.pt') # 保存最后训练模型 args.key_path = os.path.join(checkpoint_dir, 'key.pt') # 保存投影矩阵位置 dir_name = os.path.dirname(args.key_path) # 为CPU设置随机种子 torch.manual_seed(999) # 为所有GPU设置随机种子 torch.cuda.manual_seed_all(999) # 固定每次返回的卷积算法 torch.backends.cudnn.deterministic = True # cuDNN使用非确定性算法 torch.backends.cudnn.enabled = True # 训练前cuDNN会先搜寻每个卷积层最适合实现它的卷积算法,加速运行;但对于复杂变化的输入数据,可能会有过长的搜寻时间,对于训练比较快的网络建议设为False torch.backends.cudnn.benchmark = False # wandb可视化:https://wandb.ai # if args.wandb and args.local_rank == 0: # 分布式时只记录一次wandb # args.wandb_run = wandb.init(project=args.wandb_project, name=args.wandb_name, config=args) # 混合float16精度训练 if args.amp: args.amp = torch.cuda.amp.GradScaler() # 分布式训练 if args.distributed: torch.distributed.init_process_group(backend='nccl') # 分布式训练初始化 args.device = torch.device("cuda", args.local_rank) # -------------------------------------------------------------------------------------------------------------------- # # 判定数据库内信息是否齐全 if args.local_rank == 0: print(f'| args:{args} |') assert os.path.exists(f'{args.data_path}/{args.dataset_name}'), '! data_path中缺少:{args.dataset_name} !' args.train_dir = f'{args.data_path}/{args.dataset_name}/train_cifar10_JPG' args.test_dir = f'{args.data_path}/{args.dataset_name}/test_cifar10_JPG' if args.weight and os.path.exists(args.weight): # 优先加载已有模型args.weight继续训练 print(f'| 加载已有模型:{args.weight} |') elif args.prune: print(f'| 加载模型+剪枝训练:{args.prune_weight} |') else: # 创建自定义模型args.model assert os.path.exists(f'model/{args.model}.py'), f'! 没有自定义模型:{args.model} !' print(f'| 创建自定义模型:{args.model} |') # -------------------------------------------------------------------------------------------------------------------- # if __name__ == '__main__': # 摘要 print(f'| args:{args} |') if args.local_rank == 0 else None # 模型 model_dict = model_get(args) # 损失 loss = loss_get(args) # 获取密码标签 secret = secret_get.get_secret(512) # 训练 train_embed(args, model_dict, loss, secret)