|
1 年之前 | |
---|---|---|
blind_watermark | 1 年之前 | |
block | 1 年之前 | |
model | 1 年之前 | |
tool | 1 年之前 | |
.gitignore | 1 年之前 | |
README.md | 1 年之前 | |
bash_output.sh | 1 年之前 | |
bash_run.sh | 1 年之前 | |
bash_watermarking.sh | 1 年之前 | |
export_onnx.py | 1 年之前 | |
export_trt | 1 年之前 | |
export_trt.exe | 1 年之前 | |
export_trt_record | 1 年之前 | |
flask_request.py | 1 年之前 | |
flask_start.py | 1 年之前 | |
gradio_start.py | 1 年之前 | |
gunicorn_config.py | 1 年之前 | |
predict_onnx.py | 1 年之前 | |
predict_pt.py | 1 年之前 | |
predict_trt.py | 1 年之前 | |
requirement | 1 年之前 | |
run.py | 1 年之前 |
代码兼容性较强,使用的是一些基本的库、基础的函数
在argparse中可以选择使用wandb,能在wandb网站中生成可视化的训练过程1,环境
torch:https://pytorch.org/get-started/previous-versions/
>pip install timm tqdm wandb opencv-python -i https://pypi.tuna.tsinghua.edu.cn/simple >``` ### 2,数据格式 >├── 数据集路径:data_path >    └── image:存放所有图片 >    └── train.txt:训练图片的绝对路径(或相对data_path下路径)和类别号, >        (如-->image/mask/0.jpg 0 2<--表示该图片类别为0和2,空类别图片无类别号) >    └── val.txt:验证图片的绝对路径(或相对data_path下路径)和类别号 >    └── class.txt:所有的类别名称 ### 3,run.py >模型训练时运行该文件,argparse中有对每个参数的说明 ### 4,predict_pt.py >使用训练好的pt模型预测 ### 5,export_onnx.py >将pt模型导出为onnx模型 ### 6,predict_onnx.py >使用导出的onnx模型预测 ### 7,export_trt_record >文档中有onnx模型导出为tensort模型的详细说明 ### 8,predict_trt.py >使用导出的trt模型预测 ### 9,gradio_start.py >用gradio将程序包装成一个可视化的页面,可以在网页可视化的展示 ### 10,flask_start.py >用flask将程序包装成一个服务,并在服务器上启动 ### 11,flask_request.py >以post请求传输数据调用服务 ### 12,gunicorn_config.py >用gunicorn多进程启动flask服务:gunicorn -c gunicorn_config.py flask_start:app ### 模型注意事项: ### 目录结构 ```shell ├── README.md ├── bash_output.sh ├── bash_run.sh ├── bash_watermarking.sh ├── best.onnx ├── blind_watermark #图片嵌入盲水印代码 │ ├── __init__.py │ ├── att.py │ ├── blind_watermark.py │ ├── bwm_core.py │ ├── cli_tools.py │ ├── pool.py │ ├── recover.py │ ├── requirements.txt │ └── version.py ├── block │ ├── data_get.py #加载数据集 │ ├── loss_get.py #计算损失 │ ├── lr_get.py #动态调整学习率 │ ├── metric_get.py #获取评估指标 │ ├── model_ema.py │ ├── model_get.py #获取模型代码 │ ├── test_model_get.py #测试获取模型代码 │ ├── train_embeder.py #模型训练嵌入白盒水印流程 │ ├── train_get.py #正常模型训练代码 │ └── val_get.py #模型验证代码 ├── export_onnx.py #模型定义导出onnx格式代码 ├── export_trt ├── export_trt.exe ├── export_trt_record ├── flask_request.py ├── flask_start.py ├── gradio_start.py # 用gradio将程序包装成一个可视化的页面,可以在网页可视化的展示 ├── gunicorn_config.py ├── model #模型定义,模型名称即文件名 │ ├── Alexnet.py │ ├── GoogleNet.py │ ├── VGG19.py │ ├── __init__.py │ ├── badnet.py │ ├── layer.py │ ├── mobilenetv2.py │ ├── resnet.py │ ├── test.py │ ├── timm_model.py │ └── yolov7_cls.py ├── predict_onnx.py #onnx格式模型文件推理 ├── predict_pt.py #pt格式模型文件推理 ├── predict_trt.py ├── prune_last.pt ├── requirement ├── run.py #模型训练脚本 └── tool ├── check_image.py ├── generate_txt.py #处理数据集,增加标签描述文件 ├── make_flip_image.py ├── make_txt.py ├── secret_func.py #生成验证密码标签,用于对接密码机 ├── training_embedding.py #白盒水印编解码器,用于嵌入白盒水印和标签提取 └── watermarking_data_process.py #对数据集进行处理,嵌入黑盒水印