图形分类比测模型,包括使用Keras框架的AlexNet模型、使用tensorflow框架的VGG16模型

liyan acef196f0c 开发基于tensorflow框架的VGG16模型训练代码 7 ماه پیش
checkpoints bff85da4d4 初始化项目结构 7 ماه پیش
dataset bff85da4d4 初始化项目结构 7 ماه پیش
models bff85da4d4 初始化项目结构 7 ماه پیش
.gitignore bff85da4d4 初始化项目结构 7 ماه پیش
README.md dc331ed708 修改说明文档 7 ماه پیش
export_onnx.py 1c26fb1de4 修改导出onnx脚本 7 ماه پیش
train_alexnet.py 2d157c9dff 修改训练保存权重参数,解决导出onnx失败的问题 7 ماه پیش
train_vgg16.py acef196f0c 开发基于tensorflow框架的VGG16模型训练代码 7 ماه پیش

README.md

classification-models-tensorflow

项目说明

此项目包含AlexNet模型的Keras框架实现和VGG16模型的tensorflow框架实现和与其对应的模型训练文件

项目文件说明

classification-models-tensorflow
    ├── README.md
    ├── checkpoints  # 保存所有的权重信息
    ├── export_onnx.py  # 模型权重转换为onnx脚本
    ├── models  # 模型定义
    │   └── AlexNet.py
    ├── train_alexnet.py  # AlexNet模型训练脚本
    └── train_vgg16.py  # VGG16模型训练脚本

训练命令

  • AlexNet shell python train_alexnet.py --data-path dataset/imagenette2-320 --output-dir checkpoints/alexnet --batch-size 64 --epochs 90
  • VGG16 shell python train_vgg16.py --data-path dataset/imagenette2-320 --output-dir checkpoints/vgg16 --batch-size 64 --epochs 90

模型训练权重转换为onnx

python export_onnx.py --model_dir checkpoints/alexnet
python export_onnx.py --model_dir checkpoints/vgg16