|
@@ -32,15 +32,11 @@ parser.add_argument('--wandb_name', default='train', type=str, help='|wandb项
|
|
parser.add_argument('--wandb_image_num', default=16, type=int, help='|wandb保存图片的数量|')
|
|
parser.add_argument('--wandb_image_num', default=16, type=int, help='|wandb保存图片的数量|')
|
|
|
|
|
|
# new_added
|
|
# new_added
|
|
-parser.add_argument('--data_path', default='/home/yhsun/classification-main/dataset', type=str, help='Root path to datasets')
|
|
|
|
|
|
+parser.add_argument('--data_path', default='./dataset', type=str,
|
|
|
|
+ help='Root path to datasets')
|
|
parser.add_argument('--dataset_name', default='CIFAR-10', type=str, help='Specific dataset name')
|
|
parser.add_argument('--dataset_name', default='CIFAR-10', type=str, help='Specific dataset name')
|
|
parser.add_argument('--input_channels', default=3, type=int)
|
|
parser.add_argument('--input_channels', default=3, type=int)
|
|
parser.add_argument('--output_num', default=10, type=int)
|
|
parser.add_argument('--output_num', default=10, type=int)
|
|
-# parser.add_argument('--input_size', default=32, type=int)
|
|
|
|
-#黑盒水印植入,这里需要调用它,用于处理部分数据的
|
|
|
|
-parser.add_argument('--trigger_label', type=int, default=2, help='The NO. of trigger label (int, range from 0 to 10, default: 0)')
|
|
|
|
-#这里可以直接选择水印控制,看看如何选择调用进来
|
|
|
|
-parser.add_argument('--watermarking_portion', type=float, default=0.1, help='poisoning portion (float, range from 0 to 1, default: 0.1)')
|
|
|
|
|
|
|
|
# 待修改
|
|
# 待修改
|
|
parser.add_argument('--input_size', default=32, type=int, help='|输入图片大小|')
|
|
parser.add_argument('--input_size', default=32, type=int, help='|输入图片大小|')
|
|
@@ -48,48 +44,43 @@ parser.add_argument('--input_size', default=32, type=int, help='|输入图片大
|
|
parser.add_argument('--output_class', default=10, type=int, help='|输出的类别数|')
|
|
parser.add_argument('--output_class', default=10, type=int, help='|输出的类别数|')
|
|
parser.add_argument('--weight', default='last.pt', type=str, help='|已有模型的位置,没找到模型会创建剪枝/新模型|')
|
|
parser.add_argument('--weight', default='last.pt', type=str, help='|已有模型的位置,没找到模型会创建剪枝/新模型|')
|
|
|
|
|
|
-
|
|
|
|
# 剪枝的处理部分
|
|
# 剪枝的处理部分
|
|
parser.add_argument('--prune', default=False, type=bool, help='|模型剪枝后再训练(部分模型有),需要提供prune_weight|')
|
|
parser.add_argument('--prune', default=False, type=bool, help='|模型剪枝后再训练(部分模型有),需要提供prune_weight|')
|
|
parser.add_argument('--prune_weight', default='best.pt', type=str, help='|模型剪枝的参考模型,会创建剪枝模型和训练模型|')
|
|
parser.add_argument('--prune_weight', default='best.pt', type=str, help='|模型剪枝的参考模型,会创建剪枝模型和训练模型|')
|
|
parser.add_argument('--prune_ratio', default=0.5, type=float, help='|模型剪枝时的保留比例|')
|
|
parser.add_argument('--prune_ratio', default=0.5, type=float, help='|模型剪枝时的保留比例|')
|
|
parser.add_argument('--prune_save', default='prune_best.pt', type=str, help='|保存最佳模型,每轮还会保存prune_last.pt|')
|
|
parser.add_argument('--prune_save', default='prune_best.pt', type=str, help='|保存最佳模型,每轮还会保存prune_last.pt|')
|
|
|
|
|
|
-
|
|
|
|
-# 模型处理的部分了
|
|
|
|
-parser.add_argument('--timm', default=False, type=bool, help='|是否使用timm库创建模型|')
|
|
|
|
-parser.add_argument('--model', default='mobilenetv2', type=str, help='|自定义模型选择,timm为True时为timm库中模型|')
|
|
|
|
-parser.add_argument('--model_type', default='s', type=str, help='|自定义模型型号|')
|
|
|
|
-parser.add_argument('--save_path', default='./checkpoints/mobilenetv2/best.pt', type=str, help='|保存最佳模型,除此之外每轮还会保存last.pt|')
|
|
|
|
-parser.add_argument('--save_path_last', default='./checkpoints/mobilenetv2/last.pt', type=str, help='|保存最佳模型,除此之外每轮还会保存last.pt|')
|
|
|
|
|
|
+# 模型选择
|
|
|
|
+parser.add_argument('--model', default='VGG19', type=str, help='|自定义模型选择|')
|
|
|
|
|
|
# 训练控制
|
|
# 训练控制
|
|
parser.add_argument('--epoch', default=20, type=int, help='|训练总轮数(包含之前已训练轮数)|')
|
|
parser.add_argument('--epoch', default=20, type=int, help='|训练总轮数(包含之前已训练轮数)|')
|
|
-parser.add_argument('--batch', default=100, type=int, help='|训练批量大小,分布式时为总批量|')
|
|
|
|
-parser.add_argument('--loss', default='bce', type=str, help='|损失函数|')
|
|
|
|
|
|
+parser.add_argument('--batch', default=500, type=int, help='|训练批量大小,分布式时为总批量|')
|
|
|
|
+parser.add_argument('--loss', default='cross', type=str, help='|损失函数|')
|
|
parser.add_argument('--warmup_ratio', default=0.01, type=float, help='|预热训练步数占总步数比例,最少5步,基准为0.01|')
|
|
parser.add_argument('--warmup_ratio', default=0.01, type=float, help='|预热训练步数占总步数比例,最少5步,基准为0.01|')
|
|
-parser.add_argument('--lr_start', default=0.001, type=float, help='|初始学习率,adam算法,批量小时要减小,基准为0.001|')
|
|
|
|
|
|
+parser.add_argument('--lr_start', default=0.01, type=float, help='|初始学习率,adam算法,批量小时要减小,基准为0.001|')
|
|
parser.add_argument('--lr_end_ratio', default=0.01, type=float, help='|最终学习率=lr_end_ratio*lr_start,基准为0.01|')
|
|
parser.add_argument('--lr_end_ratio', default=0.01, type=float, help='|最终学习率=lr_end_ratio*lr_start,基准为0.01|')
|
|
-parser.add_argument('--lr_end_epoch', default=100, type=int, help='|最终学习率达到的轮数,每一步都调整,余玄下降法|')
|
|
|
|
|
|
+parser.add_argument('--lr_end_epoch', default=100, type=int, help='|最终学习率达到的轮数,每一步都调整,余弦下降法|')
|
|
parser.add_argument('--regularization', default='L2', type=str, help='|正则化,有L2、None|')
|
|
parser.add_argument('--regularization', default='L2', type=str, help='|正则化,有L2、None|')
|
|
parser.add_argument('--r_value', default=0.0005, type=float, help='|正则化权重系数,基准为0.0005|')
|
|
parser.add_argument('--r_value', default=0.0005, type=float, help='|正则化权重系数,基准为0.0005|')
|
|
parser.add_argument('--device', default='cuda', type=str, help='|训练设备|')
|
|
parser.add_argument('--device', default='cuda', type=str, help='|训练设备|')
|
|
parser.add_argument('--latch', default=True, type=bool, help='|模型和数据是否为锁存,True为锁存|')
|
|
parser.add_argument('--latch', default=True, type=bool, help='|模型和数据是否为锁存,True为锁存|')
|
|
parser.add_argument('--num_worker', default=0, type=int, help='|CPU处理数据的进程数,0只有一个主进程,一般为0、2、4、8|')
|
|
parser.add_argument('--num_worker', default=0, type=int, help='|CPU处理数据的进程数,0只有一个主进程,一般为0、2、4、8|')
|
|
parser.add_argument('--ema', default=True, type=bool, help='|使用平均指数移动(EMA)调整参数|')
|
|
parser.add_argument('--ema', default=True, type=bool, help='|使用平均指数移动(EMA)调整参数|')
|
|
-parser.add_argument('--amp', default=True, type=bool, help='|混合float16精度训练,CPU时不可用,出现nan可能与GPU有关|')
|
|
|
|
|
|
+parser.add_argument('--amp', default=False, type=bool, help='|混合float16精度训练,CPU时不可用,出现nan可能与GPU有关|')
|
|
parser.add_argument('--noise', default=0.5, type=float, help='|训练数据加噪概率|')
|
|
parser.add_argument('--noise', default=0.5, type=float, help='|训练数据加噪概率|')
|
|
parser.add_argument('--class_threshold', default=0.5, type=float, help='|计算指标时,大于阈值判定为图片有该类别|')
|
|
parser.add_argument('--class_threshold', default=0.5, type=float, help='|计算指标时,大于阈值判定为图片有该类别|')
|
|
parser.add_argument('--distributed', default=False, type=bool, help='|单机多卡分布式训练,分布式训练时batch为总batch|')
|
|
parser.add_argument('--distributed', default=False, type=bool, help='|单机多卡分布式训练,分布式训练时batch为总batch|')
|
|
parser.add_argument('--local_rank', default=0, type=int, help='|分布式训练使用命令后会自动传入的参数|')
|
|
parser.add_argument('--local_rank', default=0, type=int, help='|分布式训练使用命令后会自动传入的参数|')
|
|
args = parser.parse_args()
|
|
args = parser.parse_args()
|
|
-args.device_number = max(torch.cuda.device_count(), 2) # 使用的GPU数,可能为CPU
|
|
|
|
|
|
+args.device_number = max(torch.cuda.device_count(), 1) # 使用的GPU数,可能为CPU
|
|
|
|
|
|
# 创建模型对应的检查点目录
|
|
# 创建模型对应的检查点目录
|
|
-checkpoint_dir = os.path.join('/home/yhsun/classification-main/checkpoints', args.model)
|
|
|
|
-if not os.path.exists(checkpoint_dir):
|
|
|
|
- os.makedirs(checkpoint_dir)
|
|
|
|
-print(f"模型保存路径已创建: {args.model}")
|
|
|
|
|
|
+checkpoint_dir = os.path.join('./checkpoints', args.model)
|
|
|
|
+os.makedirs(checkpoint_dir, exist_ok=True)
|
|
|
|
+print(f"模型保存路径已创建: {checkpoint_dir}")
|
|
|
|
+args.save_path = os.path.join(checkpoint_dir, 'best.pt') # 保存最佳训练模型
|
|
|
|
+args.save_path_last = os.path.join(checkpoint_dir, 'last.pt') # 保存最后训练模型
|
|
|
|
|
|
# 为CPU设置随机种子
|
|
# 为CPU设置随机种子
|
|
torch.manual_seed(999)
|
|
torch.manual_seed(999)
|
|
@@ -117,21 +108,15 @@ if args.distributed:
|
|
if args.local_rank == 0:
|
|
if args.local_rank == 0:
|
|
print(f'| args:{args} |')
|
|
print(f'| args:{args} |')
|
|
assert os.path.exists(f'{args.data_path}/{args.dataset_name}'), '! data_path中缺少:{args.dataset_name} !'
|
|
assert os.path.exists(f'{args.data_path}/{args.dataset_name}'), '! data_path中缺少:{args.dataset_name} !'
|
|
- assert os.path.exists(f'{args.data_path}/{args.dataset_name}/train.txt'), '! data_path中缺少:train.txt !'
|
|
|
|
- assert os.path.exists(f'{args.data_path}/{args.dataset_name}/test.txt'), '! data_path中缺少:test.txt !'
|
|
|
|
- assert os.path.exists(f'{args.data_path}/{args.dataset_name}/class.txt'), '! data_path中缺少:class.txt !'
|
|
|
|
|
|
+ args.train_dir = f'{args.data_path}/{args.dataset_name}/train_cifar10_JPG'
|
|
|
|
+ args.test_dir = f'{args.data_path}/{args.dataset_name}/test_cifar10_JPG'
|
|
if os.path.exists(args.weight): # 优先加载已有模型args.weight继续训练
|
|
if os.path.exists(args.weight): # 优先加载已有模型args.weight继续训练
|
|
print(f'| 加载已有模型:{args.weight} |')
|
|
print(f'| 加载已有模型:{args.weight} |')
|
|
elif args.prune:
|
|
elif args.prune:
|
|
print(f'| 加载模型+剪枝训练:{args.prune_weight} |')
|
|
print(f'| 加载模型+剪枝训练:{args.prune_weight} |')
|
|
- elif args.timm: # 创建timm库中模型args.timm
|
|
|
|
- import timm
|
|
|
|
-
|
|
|
|
- assert timm.list_models(args.model), f'! timm中没有模型:{args.model},使用timm.list_models()查看所有模型 !'
|
|
|
|
- print(f'| 创建timm库中模型:{args.model} |')
|
|
|
|
else: # 创建自定义模型args.model
|
|
else: # 创建自定义模型args.model
|
|
assert os.path.exists(f'model/{args.model}.py'), f'! 没有自定义模型:{args.model} !'
|
|
assert os.path.exists(f'model/{args.model}.py'), f'! 没有自定义模型:{args.model} !'
|
|
- print(f'| 创建自定义模型:{args.model} | 型号:{args.model_type} |')
|
|
|
|
|
|
+ print(f'| 创建自定义模型:{args.model} |')
|
|
# -------------------------------------------------------------------------------------------------------------------- #
|
|
# -------------------------------------------------------------------------------------------------------------------- #
|
|
if __name__ == '__main__':
|
|
if __name__ == '__main__':
|
|
# 摘要
|
|
# 摘要
|