# classification-models-tensorflow ## 项目说明 此项目包含AlexNet模型的Keras框架实现和VGG16模型的tensorflow框架实现和与其对应的模型训练文件 ## 项目文件说明 ```text classification-models-tensorflow ├── README.md ├── checkpoints # 保存所有的权重信息 ├── export_onnx.py # 模型权重转换为onnx脚本 ├── models # 模型定义 │   └── AlexNet.py ├── train_alexnet.py # AlexNet模型训练脚本 └── train_vgg16.py # VGG16模型训练脚本 ``` ## 运行环境 ```text python=3.9 pillow scipy numpy==1.24.4 tensorflow==2.10.0 ``` ## 环境搭建 - 构建虚拟环境,如果存在,请忽略 ```shell conda create -n tensorflow python=3.9 conda activate tensorflow ``` - 安装依赖 ```shell pip install -r requirements.txt -i https://mirrors.tuna.tsinghua.edu.cn/pypi/web/simple ``` ## 训练命令 - AlexNet ```shell python train_alexnet.py --data-path dataset/imagenette2-320 --output-dir checkpoints/alexnet --batch-size 64 --epochs 90 ``` - VGG16 ```shell python train_vgg16.py --data-path dataset/imagenette2-320 --output-dir checkpoints/vgg16 --batch-size 64 --epochs 90 --lr 0.01 --opt sgd ``` ## 模型训练权重转换为onnx ```shell python export_onnx.py --model_dir checkpoints/alexnet ``` ```shell python export_onnx.py --model_dir checkpoints/vgg16 ```