图形分类比测模型,包括使用Keras框架的AlexNet模型、使用tensorflow框架的VGG16模型

liyan 247b1cd906 修改vgg16训练过程 7 miesięcy temu
checkpoints bff85da4d4 初始化项目结构 7 miesięcy temu
dataset bff85da4d4 初始化项目结构 7 miesięcy temu
models a719fd182d 调整项目代码 7 miesięcy temu
.gitignore bff85da4d4 初始化项目结构 7 miesięcy temu
README.md dc331ed708 修改说明文档 7 miesięcy temu
export_onnx.py 1c26fb1de4 修改导出onnx脚本 7 miesięcy temu
train_alexnet.py 35b7cbb676 训练过程使用参数控制优化器、学习率等训练参数 7 miesięcy temu
train_vgg16.py 247b1cd906 修改vgg16训练过程 7 miesięcy temu

README.md

classification-models-tensorflow

项目说明

此项目包含AlexNet模型的Keras框架实现和VGG16模型的tensorflow框架实现和与其对应的模型训练文件

项目文件说明

classification-models-tensorflow
    ├── README.md
    ├── checkpoints  # 保存所有的权重信息
    ├── export_onnx.py  # 模型权重转换为onnx脚本
    ├── models  # 模型定义
    │   └── AlexNet.py
    ├── train_alexnet.py  # AlexNet模型训练脚本
    └── train_vgg16.py  # VGG16模型训练脚本

训练命令

  • AlexNet shell python train_alexnet.py --data-path dataset/imagenette2-320 --output-dir checkpoints/alexnet --batch-size 64 --epochs 90
  • VGG16 shell python train_vgg16.py --data-path dataset/imagenette2-320 --output-dir checkpoints/vgg16 --batch-size 64 --epochs 90

模型训练权重转换为onnx

python export_onnx.py --model_dir checkpoints/alexnet
python export_onnx.py --model_dir checkpoints/vgg16