Browse Source

修改训练脚本

liyan 1 year ago
parent
commit
8c7d4da510
2 changed files with 4 additions and 7 deletions
  1. 3 6
      bash_train.sh
  2. 1 1
      train.py

+ 3 - 6
bash_train.sh

@@ -1,14 +1,11 @@
 
 # For 用于训练不同模型,以及保存相应的路径
 # -------------------------------------------------------------------------------------------------------------------- #
-python train.py --model 'resnet' --save_path './checkpoints/resnet/watermarking/best.pt' --save_path_last './checkpoints/resnet/watermarking/last.pt' --epoch 100
+python train.py --model 'LeNet' --input_size 32 --save_path './checkpoints/efficientnetv2_s/watermarking/best.pt' --save_path_last './checkpoints/efficientnetv2_s/watermarking/last.pt' --epoch 100
+python train.py --model 'Alexnet' --input_size 500 --checkpoint_dir './checkpoints/Alexnet/black_wm' --data_path './dataset' --dataset_name 'imagenette2' --output_num 10  --epoch 50 --num_worker 2 --batch 50
 python train.py --model 'VGG19' --save_path './checkpoints/VGG19/watermarking/best.pt' --save_path_last './checkpoints/VGG19/watermarking/last.pt' --epoch 100
-python train.py --model 'Alexnet' --input_size 112 --save_path './checkpoints/Alexnet/watermarking/best.pt' --save_path_last './checkpoints/Alexnet/watermarking/last.pt' --epoch 100
-python train.py --model 'mobilenetv2' --save_path './checkpoints/mobilenetv2/watermarking/best.pt' --save_path_last './checkpoints/mobilenetv2/watermarking/last.pt' --epoch 100
 python train.py --model 'GoogleNet' --input_size 32 --save_path './checkpoints/GoogleNet/watermarking/best.pt' --save_path_last './checkpoints/GoogleNet/watermarking/last.pt' --epoch 100
-python train.py --model 'badnet' --input_size 32 --save_path './checkpoints/badnet/watermarking/best.pt' --save_path_last './checkpoints/badnet/watermarking/last.pt' --epoch 100
-python train.py --model 'efficientnet' --input_size 32 --save_path './checkpoints/efficientnetv2_s/watermarking/best.pt' --save_path_last './checkpoints/efficientnetv2_s/watermarking/last.pt' --epoch 100
-
+python train.py --model 'resnet' --save_path './checkpoints/resnet/watermarking/best.pt' --save_path_last './checkpoints/resnet/watermarking/last.pt' --epoch 100
 
 
 # For 用于剪枝模型,剪枝后微调训练,保存剪枝后模型路径,以及验证微调模型准确性

+ 1 - 1
train.py

@@ -42,7 +42,7 @@ parser.add_argument('--checkpoint_dir', default='./checkpoints/Alexnet/black_wm'
 parser.add_argument('--input_size', default=500, type=int, help='|输入图片大小|')
 # 待修改
 parser.add_argument('--output_class', default=10, type=int, help='|输出的类别数|')
-parser.add_argument('--weight', default='last.pt', type=str, help='|已有模型的位置,没找到模型会创建剪枝/新模型|')
+parser.add_argument('--weight', default=None, type=str, help='|已有模型的位置,没找到模型会创建剪枝/新模型|')
 
 # 剪枝的处理部分
 parser.add_argument('--prune', default=False, type=bool, help='|模型剪枝后再训练(部分模型有),需要提供prune_weight|')