|
@@ -4,7 +4,7 @@
|
|
|
# └── train.txt:训练图片的绝对路径(或相对data_path下路径)和类别号,(image/mask/0.jpg 0 2\n)表示该图片类别为0和2,空类别图片无类别号
|
|
|
# └── val.txt:验证图片的绝对路径(或相对data_path下路径)和类别
|
|
|
# └── class.txt:所有的类别名称
|
|
|
-# class.csv内容如下:
|
|
|
+# class.txt:
|
|
|
# 类别1
|
|
|
# 类别2
|
|
|
# ...
|
|
@@ -15,7 +15,7 @@
|
|
|
# n为GPU数量
|
|
|
# -------------------------------------------------------------------------------------------------------------------- #
|
|
|
import os
|
|
|
-# import wandb
|
|
|
+import wandb
|
|
|
import torch
|
|
|
import argparse
|
|
|
from block.data_get import data_get
|
|
@@ -23,23 +23,26 @@ from block.loss_get import loss_get
|
|
|
from block.model_get import model_get
|
|
|
from block.train_get import train_get
|
|
|
|
|
|
+# 获取当前文件路径
|
|
|
+pwd = os.getcwd()
|
|
|
+
|
|
|
# -------------------------------------------------------------------------------------------------------------------- #
|
|
|
# 模型加载/创建的优先级为:加载已有模型>创建剪枝模型>创建timm库模型>创建自定义模型
|
|
|
parser = argparse.ArgumentParser(description='|针对分类任务,添加水印机制,包含数据隐私、模型水印|')
|
|
|
-# parser.add_argument('--wandb', default=False, type=bool, help='|是否使用wandb可视化|')
|
|
|
-# parser.add_argument('--wandb_project', default='classification', type=str, help='|wandb项目名称|')
|
|
|
-# parser.add_argument('--wandb_name', default='train', type=str, help='|wandb项目中的训练名称|')
|
|
|
-# parser.add_argument('--wandb_image_num', default=16, type=int, help='|wandb保存图片的数量|')
|
|
|
+parser.add_argument('--wandb', default=False, type=bool, help='|是否使用wandb可视化|')
|
|
|
+parser.add_argument('--wandb_project', default='classification', type=str, help='|wandb项目名称|')
|
|
|
+parser.add_argument('--wandb_name', default='train', type=str, help='|wandb项目中的训练名称|')
|
|
|
+parser.add_argument('--wandb_image_num', default=16, type=int, help='|wandb保存图片的数量|')
|
|
|
|
|
|
# new_added
|
|
|
-parser.add_argument('--data_path', default='./dataset', type=str, help='Root path to datasets')
|
|
|
+parser.add_argument('--data_path', default=f'{pwd}/dataset', type=str, help='Root path to datasets')
|
|
|
parser.add_argument('--dataset_name', default='CIFAR-10', type=str, help='Specific dataset name')
|
|
|
parser.add_argument('--input_channels', default=3, type=int)
|
|
|
parser.add_argument('--output_num', default=10, type=int)
|
|
|
# parser.add_argument('--input_size', default=32, type=int)
|
|
|
# 触发集标签定义,黑盒水印植入,这里需要调用它,用于处理部分数据的
|
|
|
-parser.add_argument('--trigger_label', type=int, default=2, help='The NO. of trigger label (int, range from 0 to 10, default: 0)')
|
|
|
-# 这里可以直接选择水印控制,看看如何选择调用进来
|
|
|
+parser.add_argument('--trigger_label', type=int, default=2, help='The NO. of trigger label (int, range from 0 to 10, default: 2)')
|
|
|
+# 设置数据集中添加水印的文件占比,这里可以直接选择水印控制,看看如何选择调用进来
|
|
|
parser.add_argument('--watermarking_portion', type=float, default=0.1, help='poisoning portion (float, range from 0 to 1, default: 0.1)')
|
|
|
|
|
|
# 待修改
|
|
@@ -58,15 +61,16 @@ parser.add_argument('--prune_save', default='prune_best.pt', type=str, help='|
|
|
|
|
|
|
# 模型处理的部分了
|
|
|
parser.add_argument('--timm', default=False, type=bool, help='|是否使用timm库创建模型|')
|
|
|
-parser.add_argument('--model', default='mobilenetv2', type=str, help='|自定义模型选择,timm为True时为timm库中模型|')
|
|
|
+parser.add_argument('--model', default='Alexnet', type=str, help='|自定义模型选择,timm为True时为timm库中模型|')
|
|
|
parser.add_argument('--model_type', default='s', type=str, help='|自定义模型型号|')
|
|
|
-parser.add_argument('--save_path', default='./checkpoints/mobilenetv2/best.pt', type=str, help='|保存最佳模型,除此之外每轮还会保存last.pt|')
|
|
|
-parser.add_argument('--save_path_last', default='./checkpoints/mobilenetv2/last.pt', type=str, help='|保存最佳模型,除此之外每轮还会保存last.pt|')
|
|
|
+# parser.add_argument('--save_path', default=f'{pwd}/weight/best.pt', type=str, help='|保存最佳模型,除此之外每轮还会保存last.pt|')
|
|
|
+# parser.add_argument('--save_path_last', default=f'{pwd}/weight/last.pt', type=str, help='|保存最佳模型,除此之外每轮还会保存last.pt|')
|
|
|
|
|
|
# 训练控制
|
|
|
parser.add_argument('--epoch', default=20, type=int, help='|训练总轮数(包含之前已训练轮数)|')
|
|
|
-parser.add_argument('--batch', default=100, type=int, help='|训练批量大小,分布式时为总批量|')
|
|
|
-parser.add_argument('--loss', default='bce', type=str, help='|损失函数|')
|
|
|
+parser.add_argument('--batch', default=600, type=int, help='|训练批量大小,分布式时为总批量|')
|
|
|
+# parser.add_argument('--loss', default='bce', type=str, help='|损失函数|')
|
|
|
+parser.add_argument('--loss', default='cross', type=str, help='|损失函数|')
|
|
|
parser.add_argument('--warmup_ratio', default=0.01, type=float, help='|预热训练步数占总步数比例,最少5步,基准为0.01|')
|
|
|
parser.add_argument('--lr_start', default=0.001, type=float, help='|初始学习率,adam算法,批量小时要减小,基准为0.001|')
|
|
|
parser.add_argument('--lr_end_ratio', default=0.01, type=float, help='|最终学习率=lr_end_ratio*lr_start,基准为0.01|')
|
|
@@ -77,19 +81,27 @@ parser.add_argument('--device', default='cuda', type=str, help='|训练设备|')
|
|
|
parser.add_argument('--latch', default=True, type=bool, help='|模型和数据是否为锁存,True为锁存|')
|
|
|
parser.add_argument('--num_worker', default=0, type=int, help='|CPU处理数据的进程数,0只有一个主进程,一般为0、2、4、8|')
|
|
|
parser.add_argument('--ema', default=True, type=bool, help='|使用平均指数移动(EMA)调整参数|')
|
|
|
-parser.add_argument('--amp', default=True, type=bool, help='|混合float16精度训练,CPU时不可用,出现nan可能与GPU有关|')
|
|
|
+parser.add_argument('--amp', default=False, type=bool, help='|混合float16精度训练,CPU时不可用,出现nan可能与GPU有关|')
|
|
|
parser.add_argument('--noise', default=0.5, type=float, help='|训练数据加噪概率|')
|
|
|
parser.add_argument('--class_threshold', default=0.5, type=float, help='|计算指标时,大于阈值判定为图片有该类别|')
|
|
|
parser.add_argument('--distributed', default=False, type=bool, help='|单机多卡分布式训练,分布式训练时batch为总batch|')
|
|
|
parser.add_argument('--local_rank', default=0, type=int, help='|分布式训练使用命令后会自动传入的参数|')
|
|
|
args = parser.parse_args()
|
|
|
-args.device_number = max(torch.cuda.device_count(), 2) # 使用的GPU数,可能为CPU
|
|
|
+args.device_number = max(torch.cuda.device_count(), 1) # 使用的GPU数,可能为CPU
|
|
|
+args.save_path = f'{pwd}/weight/{args.model}/best.pt'
|
|
|
+args.save_path_last = f'{pwd}/weight/{args.model}/last.pt'
|
|
|
|
|
|
# 创建模型对应的检查点目录
|
|
|
-checkpoint_dir = os.path.join('./checkpoints', args.model)
|
|
|
-if not os.path.exists(checkpoint_dir):
|
|
|
- os.makedirs(checkpoint_dir)
|
|
|
-print(f"模型保存路径已创建: {args.model}")
|
|
|
+# checkpoint_dir = os.path.join('./checkpoints', args.model)
|
|
|
+# if not os.path.exists(checkpoint_dir):
|
|
|
+# os.makedirs(checkpoint_dir)
|
|
|
+# print(f"{args.model}模型检查点保存路径已创建: {checkpoint_dir}")
|
|
|
+
|
|
|
+# 创建模型保存权重目录
|
|
|
+weight_dir = f'{pwd}/weight/{args.model}'
|
|
|
+if not os.path.exists(weight_dir):
|
|
|
+ os.makedirs(weight_dir)
|
|
|
+print(f"{args.model}模型权重保存路径已创建: {weight_dir}")
|
|
|
|
|
|
# 为CPU设置随机种子
|
|
|
torch.manual_seed(999)
|
|
@@ -126,11 +138,11 @@ if args.local_rank == 0:
|
|
|
print(f'| 加载已有模型:{args.weight} |')
|
|
|
elif args.prune:
|
|
|
print(f'| 加载模型+剪枝训练:{args.prune_weight} |')
|
|
|
- elif args.timm: # 创建timm库中模型args.timm
|
|
|
- import timm
|
|
|
-
|
|
|
- assert timm.list_models(args.model), f'! timm中没有模型:{args.model},使用timm.list_models()查看所有模型 !'
|
|
|
- print(f'| 创建timm库中模型:{args.model} |')
|
|
|
+ # elif args.timm: # 创建timm库中模型args.timm
|
|
|
+ # import timm
|
|
|
+ #
|
|
|
+ # assert timm.list_models(args.model), f'! timm中没有模型:{args.model},使用timm.list_models()查看所有模型 !'
|
|
|
+ # print(f'| 创建timm库中模型:{args.model} |')
|
|
|
else: # 创建自定义模型args.model
|
|
|
assert os.path.exists(f'model/{args.model}.py'), f'! 没有自定义模型:{args.model} !'
|
|
|
print(f'| 创建自定义模型:{args.model} | 型号:{args.model_type} |')
|