图形分类比测模型,包括使用Keras框架的AlexNet模型、使用tensorflow框架的VGG16模型

liyan 09ae8c89da 修改文件缩进格式 3 months ago
checkpoints bff85da4d4 初始化项目结构 3 months ago
dataset bff85da4d4 初始化项目结构 3 months ago
models a719fd182d 调整项目代码 3 months ago
.gitignore bff85da4d4 初始化项目结构 3 months ago
README.md a90d79efeb 修改vgg16训练脚本说明 3 months ago
export_onnx.py 1c26fb1de4 修改导出onnx脚本 3 months ago
requirements.txt 2bfd78ba2e 更新依赖信息 3 months ago
train_alexnet.py 09ae8c89da 修改文件缩进格式 3 months ago
train_vgg16.py 09ae8c89da 修改文件缩进格式 3 months ago

README.md

classification-models-tensorflow

项目说明

此项目包含AlexNet模型的Keras框架实现和VGG16模型的tensorflow框架实现和与其对应的模型训练文件

项目文件说明

classification-models-tensorflow
    ├── README.md
    ├── checkpoints  # 保存所有的权重信息
    ├── export_onnx.py  # 模型权重转换为onnx脚本
    ├── models  # 模型定义
    │   └── AlexNet.py
    ├── train_alexnet.py  # AlexNet模型训练脚本
    └── train_vgg16.py  # VGG16模型训练脚本

运行环境

python=3.9
pillow
scipy
numpy==1.24.4
tensorflow==2.10.0

环境搭建

  • 构建虚拟环境,如果存在,请忽略 shell conda create -n tensorflow python=3.9 conda activate tensorflow
  • 安装依赖 shell pip install -r requirements.txt -i https://mirrors.tuna.tsinghua.edu.cn/pypi/web/simple

训练命令

  • AlexNet shell python train_alexnet.py --data-path dataset/imagenette2-320 --output-dir checkpoints/alexnet --batch-size 64 --epochs 90
  • VGG16 shell python train_vgg16.py --data-path dataset/imagenette2-320 --output-dir checkpoints/vgg16 --batch-size 64 --epochs 90 --lr 0.01 --opt sgd

模型训练权重转换为onnx

python export_onnx.py --model_dir checkpoints/alexnet
python export_onnx.py --model_dir checkpoints/vgg16