|
@@ -33,12 +33,13 @@ parser.add_argument('--wandb_image_num', default=16, type=int, help='|wandb保
|
|
|
# new_added
|
|
|
parser.add_argument('--data_path', default='./dataset', type=str,
|
|
|
help='Root path to datasets')
|
|
|
-parser.add_argument('--dataset_name', default='CIFAR-10', type=str, help='Specific dataset name')
|
|
|
+parser.add_argument('--dataset_name', default='imagenette2', type=str, help='Specific dataset name')
|
|
|
parser.add_argument('--input_channels', default=3, type=int)
|
|
|
parser.add_argument('--output_num', default=10, type=int)
|
|
|
+parser.add_argument('--checkpoint_dir', default='./checkpoints/Alexnet/black_wm', type=str)
|
|
|
|
|
|
# 待修改
|
|
|
-parser.add_argument('--input_size', default=32, type=int, help='|输入图片大小|')
|
|
|
+parser.add_argument('--input_size', default=500, type=int, help='|输入图片大小|')
|
|
|
# 待修改
|
|
|
parser.add_argument('--output_class', default=10, type=int, help='|输出的类别数|')
|
|
|
parser.add_argument('--weight', default='last.pt', type=str, help='|已有模型的位置,没找到模型会创建剪枝/新模型|')
|
|
@@ -63,7 +64,7 @@ parser.add_argument('--lr_end_epoch', default=100, type=int, help='|最终学习
|
|
|
parser.add_argument('--regularization', default='L2', type=str, help='|正则化,有L2、None|')
|
|
|
parser.add_argument('--r_value', default=0.0005, type=float, help='|正则化权重系数,基准为0.0005|')
|
|
|
parser.add_argument('--device', default='cuda', type=str, help='|训练设备|')
|
|
|
-parser.add_argument('--latch', default=True, type=bool, help='|模型和数据是否为锁存,True为锁存|')
|
|
|
+parser.add_argument('--latch', default=False, type=bool, help='|模型和数据是否为锁存,True为锁存|')
|
|
|
parser.add_argument('--num_worker', default=0, type=int, help='|CPU处理数据的进程数,0只有一个主进程,一般为0、2、4、8|')
|
|
|
parser.add_argument('--ema', default=True, type=bool, help='|使用平均指数移动(EMA)调整参数|')
|
|
|
parser.add_argument('--amp', default=False, type=bool, help='|混合float16精度训练,CPU时不可用,出现nan可能与GPU有关|')
|
|
@@ -75,7 +76,7 @@ args = parser.parse_args()
|
|
|
args.device_number = max(torch.cuda.device_count(), 1) # 使用的GPU数,可能为CPU
|
|
|
|
|
|
# 创建模型对应的检查点目录
|
|
|
-checkpoint_dir = os.path.join('./checkpoints', args.model)
|
|
|
+checkpoint_dir = os.path.join('./checkpoints', args.model) if args.checkpoint_dir is None else args.checkpoint_dir
|
|
|
os.makedirs(checkpoint_dir, exist_ok=True)
|
|
|
print(f"模型保存路径已创建: {checkpoint_dir}")
|
|
|
args.save_path = os.path.join(checkpoint_dir, 'best.pt') # 保存最佳训练模型
|
|
@@ -107,8 +108,8 @@ if args.distributed:
|
|
|
if args.local_rank == 0:
|
|
|
print(f'| args:{args} |')
|
|
|
assert os.path.exists(f'{args.data_path}/{args.dataset_name}'), '! data_path中缺少:{args.dataset_name} !'
|
|
|
- args.train_dir = f'{args.data_path}/{args.dataset_name}/train_cifar10_JPG'
|
|
|
- args.test_dir = f'{args.data_path}/{args.dataset_name}/test_cifar10_JPG'
|
|
|
+ args.train_dir = f'{args.data_path}/{args.dataset_name}/train'
|
|
|
+ args.test_dir = f'{args.data_path}/{args.dataset_name}/val'
|
|
|
if os.path.exists(args.weight): # 优先加载已有模型args.weight继续训练
|
|
|
print(f'| 加载已有模型:{args.weight} |')
|
|
|
elif args.prune:
|