图形分类比测模型,包括使用Keras框架的AlexNet模型、使用tensorflow框架的VGG16模型

liyan a719fd182d 调整项目代码 7 months ago
checkpoints bff85da4d4 初始化项目结构 7 months ago
dataset bff85da4d4 初始化项目结构 7 months ago
models a719fd182d 调整项目代码 7 months ago
.gitignore bff85da4d4 初始化项目结构 7 months ago
README.md dc331ed708 修改说明文档 7 months ago
export_onnx.py 1c26fb1de4 修改导出onnx脚本 7 months ago
train_alexnet.py 2d157c9dff 修改训练保存权重参数,解决导出onnx失败的问题 7 months ago
train_vgg16.py a719fd182d 调整项目代码 7 months ago

README.md

classification-models-tensorflow

项目说明

此项目包含AlexNet模型的Keras框架实现和VGG16模型的tensorflow框架实现和与其对应的模型训练文件

项目文件说明

classification-models-tensorflow
    ├── README.md
    ├── checkpoints  # 保存所有的权重信息
    ├── export_onnx.py  # 模型权重转换为onnx脚本
    ├── models  # 模型定义
    │   └── AlexNet.py
    ├── train_alexnet.py  # AlexNet模型训练脚本
    └── train_vgg16.py  # VGG16模型训练脚本

训练命令

  • AlexNet shell python train_alexnet.py --data-path dataset/imagenette2-320 --output-dir checkpoints/alexnet --batch-size 64 --epochs 90
  • VGG16 shell python train_vgg16.py --data-path dataset/imagenette2-320 --output-dir checkpoints/vgg16 --batch-size 64 --epochs 90

模型训练权重转换为onnx

python export_onnx.py --model_dir checkpoints/alexnet
python export_onnx.py --model_dir checkpoints/vgg16