瀏覽代碼

添加项目初始文件

liyan 1 年之前
父節點
當前提交
5aea0138de

+ 41 - 0
README.md

@@ -0,0 +1,41 @@
+## pytorch图片分类训练框架
+>代码兼容性较强,使用的是一些基本的库、基础的函数  
+>在argparse中可以选择使用wandb,能在wandb网站中生成可视化的训练过程
+### 1,环境
+>torch:https://pytorch.org/get-started/previous-versions/
+>```
+>pip install timm tqdm wandb opencv-python -i https://pypi.tuna.tsinghua.edu.cn/simple
+>```
+### 2,数据格式
+>├── 数据集路径:data_path  
+>    └── image:存放所有图片  
+>    └── train.txt:训练图片的绝对路径(或相对data_path下路径)和类别号,  
+>&emsp; &emsp; &emsp; &emsp; (如-->image/mask/0.jpg 0 2<--表示该图片类别为0和2,空类别图片无类别号)  
+>&emsp; &emsp; └── val.txt:验证图片的绝对路径(或相对data_path下路径)和类别号  
+>&emsp; &emsp; └── class.txt:所有的类别名称  
+### 3,run.py
+>模型训练时运行该文件,argparse中有对每个参数的说明
+### 4,predict_pt.py
+>使用训练好的pt模型预测
+### 5,export_onnx.py
+>将pt模型导出为onnx模型
+### 6,predict_onnx.py
+>使用导出的onnx模型预测
+### 7,export_trt_record
+>文档中有onnx模型导出为tensort模型的详细说明
+### 8,predict_trt.py
+>使用导出的trt模型预测
+### 9,gradio_start.py
+>用gradio将程序包装成一个可视化的页面,可以在网页可视化的展示
+### 10,flask_start.py
+>用flask将程序包装成一个服务,并在服务器上启动
+### 11,flask_request.py
+>以post请求传输数据调用服务
+### 12,gunicorn_config.py
+>用gunicorn多进程启动flask服务:gunicorn -c gunicorn_config.py flask_start:app
+
+
+
+
+### 模型注意事项:
+

+ 3 - 0
bash_output.sh

@@ -0,0 +1,3 @@
+# For 完成模型的边缘部署,实现侧载模型的要求
+# -------------------------------------------------------------------------------------------------------------------- #
+python export_onnx.py --weight './checkpoints/Alexnet/clean/prune_best.pt' --input_size 32 --save_path './checkpoints/Alexnet/clean'

+ 27 - 0
bash_run.sh

@@ -0,0 +1,27 @@
+
+# For 用于训练不同模型,以及保存相应的路径
+# -------------------------------------------------------------------------------------------------------------------- #
+python run.py --model 'resnet' --save_path './checkpoints/resnet/watermarking/best.pt' --save_path_last './checkpoints/resnet/watermarking/last.pt' --epoch 100
+python run.py --model 'VGG19' --save_path './checkpoints/VGG19/watermarking/best.pt' --save_path_last './checkpoints/VGG19/watermarking/last.pt' --epoch 100
+python run.py --model 'Alexnet' --input_size 112 --save_path './checkpoints/Alexnet/watermarking/best.pt' --save_path_last './checkpoints/Alexnet/watermarking/last.pt' --epoch 100
+python run.py --model 'mobilenetv2' --save_path './checkpoints/mobilenetv2/watermarking/best.pt' --save_path_last './checkpoints/mobilenetv2/watermarking/last.pt' --epoch 100
+python run.py --model 'GoogleNet' --input_size 32 --save_path './checkpoints/GoogleNet/watermarking/best.pt' --save_path_last './checkpoints/GoogleNet/watermarking/last.pt' --epoch 100
+python run.py --model 'badnet' --input_size 32 --save_path './checkpoints/badnet/watermarking/best.pt' --save_path_last './checkpoints/badnet/watermarking/last.pt' --epoch 100
+python run.py --model 'efficientnet' --input_size 32 --save_path './checkpoints/efficientnetv2_s/watermarking/best.pt' --save_path_last './checkpoints/efficientnetv2_s/watermarking/last.pt' --epoch 100
+
+
+
+# For 用于剪枝模型,剪枝后微调训练,保存剪枝后模型路径,以及验证微调模型准确性
+# -------------------------------------------------------------------------------------------------------------------- #
+python run.py --model 'resnet'  --prune True --prune_weight './checkpoints/resnet/watermarking/best.pt' --prune_save './checkpoints/resnet/watermarking/prune_best.pt' --epoch 40
+python run.py --model 'VGG19'  --prune True --prune_weight './checkpoints/VGG19/watermarking/best.pt' --prune_save './checkpoints/VGG19/watermarking/prune_best.pt' --epoch 40
+python run.py --model 'Alexnet'  --prune True --input_size 112 --prune_weight './checkpoints/Alexnet/watermarking/best.pt' --prune_save './checkpoints/Alexnet/watermarking/prune_best.pt' --epoch 40
+python run.py --model 'mobilenetv2'  --prune True --prune_weight './checkpoints/mobilenetv2/watermarking/best.pt' --prune_save './checkpoints/mobilenetv2/watermarking/prune_best.pt' --epoch 40
+python run.py --model 'GoogleNet' --input_size 32 --prune True --prune_weight './checkpoints/GoogleNet/watermarking/best.pt' --prune_save './checkpoints/GoogleNet/watermarking/prune_best.pt' --epoch 40
+python run.py --model 'badnet' --input_size 32 --prune True --prune_weight './checkpoints/badnet/watermarking/best.pt' --prune_save './checkpoints/badnet/watermarking/prune_best.pt' --epoch 40
+python run.py --model 'efficientnet' --input_size 32 --prune True --prune_weight './checkpoints/efficientnetv2_s/watermarking/best.pt' --prune_save './checkpoints/efficientnetv2_s/watermarking/prune_best.pt' --epoch 40
+
+
+
+# For 用于剪枝模型后边缘部署
+# -------------------------------------------------------------------------------------------------------------------- #

+ 20 - 0
bash_watermarking.sh

@@ -0,0 +1,20 @@
+
+# For 1) 用于数据处理部分,对于新数据集如何生成对应
+# -------------------------------------------------------------------------------------------------------------------- #
+python ./tool/generate_txt.py --txt_path './dataset/New_dataset'  --specific_data 'testtest'  --txt_name 'train'
+# 以CIFAR-10 数据为例,需要生成train和test的txt文件应该是这样的:
+python ./tool/generate_txt.py --txt_path './dataset/CIFAR-10_ori'  --specific_data 'train_cifar10_JPG' --txt_name 'train'
+python ./tool/generate_txt.py --txt_path './dataset/CIFAR-10_ori'  --specific_data 'test_cifar10_JPG'  --txt_name 'test'
+
+
+
+# For 2)用于水印植入处理部分 字符串插入图像处理
+# -------------------------------------------------------------------------------------------------------------------- #
+# 密钥生成部分文件放置在 './dataset/watermarking'里,其中同时含有key_hex.txt和对应根据classes拆分的 QR images,便于选择水印插入方式
+
+
+
+
+
+# For 用于剪枝模型后边缘部署
+# -------------------------------------------------------------------------------------------------------------------- #

+ 5 - 0
blind_watermark/__init__.py

@@ -0,0 +1,5 @@
+from .blind_watermark import WaterMark
+from .bwm_core import WaterMarkCore
+from .att import *
+from .recover import recover_crop
+from .version import __version__, bw_notes

+ 225 - 0
blind_watermark/att.py

@@ -0,0 +1,225 @@
+# coding=utf-8
+
+# attack on the watermark
+import cv2
+import numpy as np
+import warnings
+
+
+def cut_att3(input_filename=None, input_img=None, output_file_name=None, loc_r=None, loc=None, scale=None):
+    # 剪切攻击 + 缩放攻击
+    if input_filename:
+        input_img = cv2.imread(input_filename)
+
+    if loc is None:
+        h, w, _ = input_img.shape
+        x1, y1, x2, y2 = int(w * loc_r[0][0]), int(h * loc_r[0][1]), int(w * loc_r[1][0]), int(h * loc_r[1][1])
+    else:
+        x1, y1, x2, y2 = loc
+
+    # 剪切攻击
+    output_img = input_img[y1:y2, x1:x2].copy()
+
+    # 如果缩放攻击
+    if scale and scale != 1:
+        h, w, _ = output_img.shape
+        output_img = cv2.resize(output_img, dsize=(round(w * scale), round(h * scale)))
+    else:
+        output_img = output_img
+
+    if output_file_name:
+        cv2.imwrite(output_file_name, output_img)
+    return output_img
+
+
+cut_att2 = cut_att3
+
+
+def resize_att(input_filename=None, input_img=None, output_file_name=None, out_shape=(500, 500)):
+    # 缩放攻击:因为攻击和还原都是缩放,所以攻击和还原都调用这个函数
+    if input_filename:
+        input_img = cv2.imread(input_filename)
+    output_img = cv2.resize(input_img, dsize=out_shape)
+    if output_file_name:
+        cv2.imwrite(output_file_name, output_img)
+    return output_img
+
+
+def bright_att(input_filename=None, input_img=None, output_file_name=None, ratio=0.8):
+    # 亮度调整攻击,ratio应当多于0
+    # ratio>1是调得更亮,ratio<1是亮度更暗
+    if input_filename:
+        input_img = cv2.imread(input_filename)
+    output_img = input_img * ratio
+    output_img[output_img > 255] = 255
+    if output_file_name:
+        cv2.imwrite(output_file_name, output_img)
+    return output_img
+
+
+def shelter_att(input_filename=None, input_img=None, output_file_name=None, ratio=0.1, n=3):
+    # 遮挡攻击:遮挡图像中的一部分
+    # n个遮挡块
+    # 每个遮挡块所占比例为ratio
+    if input_filename:
+        output_img = cv2.imread(input_filename)
+    else:
+        output_img = input_img.copy()
+    input_img_shape = output_img.shape
+
+    for i in range(n):
+        tmp = np.random.rand() * (1 - ratio)  # 随机选择一个地方,1-ratio是为了防止溢出
+        start_height, end_height = int(tmp * input_img_shape[0]), int((tmp + ratio) * input_img_shape[0])
+        tmp = np.random.rand() * (1 - ratio)
+        start_width, end_width = int(tmp * input_img_shape[1]), int((tmp + ratio) * input_img_shape[1])
+
+        output_img[start_height:end_height, start_width:end_width, :] = 255
+
+    if output_file_name:
+        cv2.imwrite(output_file_name, output_img)
+    return output_img
+
+
+def salt_pepper_att(input_filename=None, input_img=None, output_file_name=None, ratio=0.01):
+    # 椒盐攻击
+    if input_filename:
+        input_img = cv2.imread(input_filename)
+    input_img_shape = input_img.shape
+    output_img = input_img.copy()
+    for i in range(input_img_shape[0]):
+        for j in range(input_img_shape[1]):
+            if np.random.rand() < ratio:
+                output_img[i, j, :] = 255
+    if output_file_name:
+        cv2.imwrite(output_file_name, output_img)
+    return output_img
+
+
+def rot_att(input_filename=None, input_img=None, output_file_name=None, angle=45):
+    # 旋转攻击
+    if input_filename:
+        input_img = cv2.imread(input_filename)
+    rows, cols, _ = input_img.shape
+    M = cv2.getRotationMatrix2D(center=(cols / 2, rows / 2), angle=angle, scale=1)
+    output_img = cv2.warpAffine(input_img, M, (cols, rows))
+    if output_file_name:
+        cv2.imwrite(output_file_name, output_img)
+    return output_img
+
+
+def cut_att_height(input_filename=None, input_img=None, output_file_name=None, ratio=0.8):
+    warnings.warn('will be deprecated in the future, use att.cut_att2 instead')
+    # 纵向剪切攻击
+    if input_filename:
+        input_img = cv2.imread(input_filename)
+    input_img_shape = input_img.shape
+    height = int(input_img_shape[0] * ratio)
+
+    output_img = input_img[:height, :, :]
+    if output_file_name:
+        cv2.imwrite(output_file_name, output_img)
+    return output_img
+
+
+def cut_att_width(input_filename=None, input_img=None, output_file_name=None, ratio=0.8):
+    warnings.warn('will be deprecated in the future, use att.cut_att2 instead')
+    # 横向裁剪攻击
+    if input_filename:
+        input_img = cv2.imread(input_filename)
+    input_img_shape = input_img.shape
+    width = int(input_img_shape[1] * ratio)
+
+    output_img = input_img[:, :width, :]
+    if output_file_name:
+        cv2.imwrite(output_file_name, output_img)
+    return output_img
+
+
+def cut_att(input_filename=None, output_file_name=None, input_img=None, loc=((0.3, 0.1), (0.7, 0.9)), resize=0.6):
+    warnings.warn('will be deprecated in the future, use att.cut_att2 instead')
+    # 截屏攻击 = 裁剪攻击 + 缩放攻击 + 知道攻击参数(按照参数还原)
+    # 裁剪攻击:其它部分都补0
+    if input_filename:
+        input_img = cv2.imread(input_filename)
+
+    output_img = input_img.copy()
+    shape = output_img.shape
+    x1, y1, x2, y2 = shape[0] * loc[0][0], shape[1] * loc[0][1], shape[0] * loc[1][0], shape[1] * loc[1][1]
+    output_img[:int(x1), :] = 255
+    output_img[int(x2):, :] = 255
+    output_img[:, :int(y1)] = 255
+    output_img[:, int(y2):] = 255
+
+    if resize is not None:
+        # 缩放一次,然后还原
+        output_img = cv2.resize(output_img,
+                                dsize=(int(shape[1] * resize), int(shape[0] * resize))
+                                )
+
+        output_img = cv2.resize(output_img, dsize=(int(shape[1]), int(shape[0])))
+
+    if output_file_name is not None:
+        cv2.imwrite(output_file_name, output_img)
+    return output_img
+
+
+# def cut_att2(input_filename=None, input_img=None, output_file_name=None, loc_r=((0.3, 0.1), (0.9, 0.9)), scale=1.1):
+#     # 截屏攻击 = 剪切攻击 + 缩放攻击 + 不知道攻击参数
+#     if input_filename:
+#         input_img = cv2.imread(input_filename)
+#     h, w, _ = input_img.shape
+#     x1, y1, x2, y2 = int(w * loc_r[0][0]), int(h * loc_r[0][1]), int(w * loc_r[1][0]), int(h * loc_r[1][1])
+#
+#     output_img = cut_att3(input_img=input_img, output_file_name=output_file_name,
+#                           loc=(x1, y1, x2, y2), scale=scale)
+#     return output_img, (x1, y1, x2, y2)
+
+def anti_cut_att_old(input_filename, output_file_name, origin_shape):
+    warnings.warn('will be deprecated in the future')
+    # 反裁剪攻击:复制一块范围,然后补全
+    # origin_shape 分辨率与约定理解的是颠倒的,约定的是列数*行数
+    input_img = cv2.imread(input_filename)
+    output_img = input_img.copy()
+    output_img_shape = output_img.shape
+    if output_img_shape[0] > origin_shape[0] or output_img_shape[0] > origin_shape[0]:
+        print('裁剪打击后的图片,不可能比原始图片大,检查一下')
+        return
+
+    # 还原纵向打击
+    while output_img_shape[0] < origin_shape[0]:
+        output_img = np.concatenate([output_img, output_img[:origin_shape[0] - output_img_shape[0], :, :]], axis=0)
+        output_img_shape = output_img.shape
+    while output_img_shape[1] < origin_shape[1]:
+        output_img = np.concatenate([output_img, output_img[:, :origin_shape[1] - output_img_shape[1], :]], axis=1)
+        output_img_shape = output_img.shape
+
+    cv2.imwrite(output_file_name, output_img)
+
+
+def anti_cut_att(input_filename=None, input_img=None, output_file_name=None, origin_shape=None):
+    warnings.warn('will be deprecated in the future, use att.cut_att2 instead')
+    # 反裁剪攻击:补0
+    # origin_shape 分辨率与约定理解的是颠倒的,约定的是列数*行数
+    if input_filename:
+        input_img = cv2.imread(input_filename)
+    output_img = input_img.copy()
+    output_img_shape = output_img.shape
+    if output_img_shape[0] > origin_shape[0] or output_img_shape[0] > origin_shape[0]:
+        print('裁剪打击后的图片,不可能比原始图片大,检查一下')
+        return
+
+    # 还原纵向打击
+    if output_img_shape[0] < origin_shape[0]:
+        output_img = np.concatenate(
+            [output_img, 255 * np.ones((origin_shape[0] - output_img_shape[0], output_img_shape[1], 3))]
+            , axis=0)
+        output_img_shape = output_img.shape
+
+    if output_img_shape[1] < origin_shape[1]:
+        output_img = np.concatenate(
+            [output_img, 255 * np.ones((output_img_shape[0], origin_shape[1] - output_img_shape[1], 3))]
+            , axis=1)
+
+    if output_file_name:
+        cv2.imwrite(output_file_name, output_img)
+    return output_img

+ 109 - 0
blind_watermark/blind_watermark.py

@@ -0,0 +1,109 @@
+#!/usr/bin/env python3
+# coding=utf-8
+# @Time    : 2020/8/13
+# @Author  : github.com/guofei9987
+import warnings
+
+import numpy as np
+import cv2
+
+from .bwm_core import WaterMarkCore
+from .version import bw_notes
+
+
+class WaterMark:
+    def __init__(self, password_wm=1, password_img=1, block_shape=(4, 4), mode='common', processes=None):
+        bw_notes.print_notes()
+
+        self.bwm_core = WaterMarkCore(password_img=password_img, mode=mode, processes=processes)
+
+        self.password_wm = password_wm
+
+        self.wm_bit = None
+        self.wm_size = 0
+
+    def read_img(self, filename=None, img=None):
+        if img is None:
+            # 从文件读入图片
+            img = cv2.imread(filename, flags=cv2.IMREAD_UNCHANGED)
+            assert img is not None, "image file '{filename}' not read".format(filename=filename)
+
+        self.bwm_core.read_img_arr(img=img)
+        return img
+
+    def read_wm(self, wm_content, mode='img'):
+        assert mode in ('img', 'str', 'bit'), "mode in ('img','str','bit')"
+        if mode == 'img':
+            wm = cv2.imread(filename=wm_content, flags=cv2.IMREAD_GRAYSCALE)
+            assert wm is not None, 'file "{filename}" not read'.format(filename=wm_content)
+
+            # 读入图片格式的水印,并转为一维 bit 格式,抛弃灰度级别
+            self.wm_bit = wm.flatten() > 128
+
+        elif mode == 'str':
+            byte = bin(int(wm_content.encode('utf-8').hex(), base=16))[2:]
+            self.wm_bit = (np.array(list(byte)) == '1')
+        else:
+            self.wm_bit = np.array(wm_content)
+
+        self.wm_size = self.wm_bit.size
+
+        # 水印加密:
+        np.random.RandomState(self.password_wm).shuffle(self.wm_bit)
+
+        self.bwm_core.read_wm(self.wm_bit)
+
+    def embed(self, filename=None, compression_ratio=None):
+        '''
+        :param filename: string
+            Save the image file as filename
+        :param compression_ratio: int or None
+            If compression_ratio = None, do not compression,
+            If compression_ratio is integer between 0 and 100, the smaller, the output file is smaller.
+        :return:
+        '''
+        embed_img = self.bwm_core.embed()
+        if filename is not None:
+            if compression_ratio is None:
+                cv2.imwrite(filename=filename, img=embed_img)
+            elif filename.endswith('.jpg'):
+                cv2.imwrite(filename=filename, img=embed_img, params=[cv2.IMWRITE_JPEG_QUALITY, compression_ratio])
+            elif filename.endswith('.png'):
+                cv2.imwrite(filename=filename, img=embed_img, params=[cv2.IMWRITE_PNG_COMPRESSION, compression_ratio])
+            else:
+                cv2.imwrite(filename=filename, img=embed_img)
+        return embed_img
+
+    def extract_decrypt(self, wm_avg):
+        wm_index = np.arange(self.wm_size)
+        np.random.RandomState(self.password_wm).shuffle(wm_index)
+        wm_avg[wm_index] = wm_avg.copy()
+        return wm_avg
+
+    def extract(self, filename=None, embed_img=None, wm_shape=None, out_wm_name=None, mode='img'):
+        assert wm_shape is not None, 'wm_shape needed'
+
+        if filename is not None:
+            embed_img = cv2.imread(filename, flags=cv2.IMREAD_COLOR)
+            assert embed_img is not None, "{filename} not read".format(filename=filename)
+
+        self.wm_size = np.array(wm_shape).prod()
+
+        if mode in ('str', 'bit'):
+            wm_avg = self.bwm_core.extract_with_kmeans(img=embed_img, wm_shape=wm_shape)
+        else:
+            wm_avg = self.bwm_core.extract(img=embed_img, wm_shape=wm_shape)
+
+        # 解密:
+        wm = self.extract_decrypt(wm_avg=wm_avg)
+
+        # 转化为指定格式:
+        if mode == 'img':
+            wm = 255 * wm.reshape(wm_shape[0], wm_shape[1])
+            cv2.imwrite(out_wm_name, wm)
+        elif mode == 'str':
+            byte = ''.join(str((i >= 0.5) * 1) for i in wm)
+            print("Byte value:", byte)
+            wm = bytes.fromhex(hex(int(byte, base=2))[2:]).decode('utf-8', errors='replace')
+
+        return wm

+ 232 - 0
blind_watermark/bwm_core.py

@@ -0,0 +1,232 @@
+#!/usr/bin/env python3
+# coding=utf-8
+# @Time    : 2021/12/17
+# @Author  : github.com/guofei9987
+import numpy as np
+from numpy.linalg import svd
+import copy
+import cv2
+from cv2 import dct, idct
+from pywt import dwt2, idwt2
+from .pool import AutoPool
+
+
+class WaterMarkCore:
+    def __init__(self, password_img=1, mode='common', processes=None):
+        self.block_shape = np.array([4, 4])
+        self.password_img = password_img
+        self.d1, self.d2 = 36, 20  # d1/d2 越大鲁棒性越强,但输出图片的失真越大
+
+        # init data
+        self.img, self.img_YUV = None, None  # self.img 是原图,self.img_YUV 对像素做了加白偶数化
+        self.ca, self.hvd, = [np.array([])] * 3, [np.array([])] * 3  # 每个通道 dct 的结果
+        self.ca_block = [np.array([])] * 3  # 每个 channel 存一个四维 array,代表四维分块后的结果
+        self.ca_part = [np.array([])] * 3  # 四维分块后,有时因不整除而少一部分,self.ca_part 是少这一部分的 self.ca
+
+        self.wm_size, self.block_num = 0, 0  # 水印的长度,原图片可插入信息的个数
+        self.pool = AutoPool(mode=mode, processes=processes)
+
+        self.fast_mode = False
+        self.alpha = None  # 用于处理透明图
+
+    def init_block_index(self):
+        self.block_num = self.ca_block_shape[0] * self.ca_block_shape[1]
+        assert self.wm_size < self.block_num, IndexError(
+            '最多可嵌入{}kb信息,多于水印的{}kb信息,溢出'.format(self.block_num / 1000, self.wm_size / 1000))
+        # self.part_shape 是取整后的ca二维大小,用于嵌入时忽略右边和下面对不齐的细条部分。
+        self.part_shape = self.ca_block_shape[:2] * self.block_shape
+        self.block_index = [(i, j) for i in range(self.ca_block_shape[0]) for j in range(self.ca_block_shape[1])]
+
+    def read_img_arr(self, img):
+        # 处理透明图
+        self.alpha = None
+        if img.shape[2] == 4:
+            if img[:, :, 3].min() < 255:
+                self.alpha = img[:, :, 3]
+                img = img[:, :, :3]
+
+        # 读入图片->YUV化->加白边使像素变偶数->四维分块
+        self.img = img.astype(np.float32)
+        self.img_shape = self.img.shape[:2]
+
+        # 如果不是偶数,那么补上白边,Y(明亮度)UV(颜色)
+        self.img_YUV = cv2.copyMakeBorder(cv2.cvtColor(self.img, cv2.COLOR_BGR2YUV),
+                                          0, self.img.shape[0] % 2, 0, self.img.shape[1] % 2,
+                                          cv2.BORDER_CONSTANT, value=(0, 0, 0))
+
+        self.ca_shape = [(i + 1) // 2 for i in self.img_shape]
+
+        self.ca_block_shape = (self.ca_shape[0] // self.block_shape[0], self.ca_shape[1] // self.block_shape[1],
+                               self.block_shape[0], self.block_shape[1])
+        strides = 4 * np.array([self.ca_shape[1] * self.block_shape[0], self.block_shape[1], self.ca_shape[1], 1])
+
+        for channel in range(3):
+            self.ca[channel], self.hvd[channel] = dwt2(self.img_YUV[:, :, channel], 'haar')
+            # 转为4维度
+            self.ca_block[channel] = np.lib.stride_tricks.as_strided(self.ca[channel].astype(np.float32),
+                                                                     self.ca_block_shape, strides)
+
+    def read_wm(self, wm_bit):
+        self.wm_bit = wm_bit
+        self.wm_size = wm_bit.size
+
+    def block_add_wm(self, arg):
+        if self.fast_mode:
+            return self.block_add_wm_fast(arg)
+        else:
+            return self.block_add_wm_slow(arg)
+
+    def block_add_wm_slow(self, arg):
+        block, shuffler, i = arg
+        # dct->(flatten->加密->逆flatten)->svd->打水印->逆svd->(flatten->解密->逆flatten)->逆dct
+        wm_1 = self.wm_bit[i % self.wm_size]
+        block_dct = dct(block)
+
+        # 加密(打乱顺序)
+        block_dct_shuffled = block_dct.flatten()[shuffler].reshape(self.block_shape)
+        u, s, v = svd(block_dct_shuffled)
+        s[0] = (s[0] // self.d1 + 1 / 4 + 1 / 2 * wm_1) * self.d1
+        if self.d2:
+            s[1] = (s[1] // self.d2 + 1 / 4 + 1 / 2 * wm_1) * self.d2
+
+        block_dct_flatten = np.dot(u, np.dot(np.diag(s), v)).flatten()
+        block_dct_flatten[shuffler] = block_dct_flatten.copy()
+        return idct(block_dct_flatten.reshape(self.block_shape))
+
+    def block_add_wm_fast(self, arg):
+        # dct->svd->打水印->逆svd->逆dct
+        block, shuffler, i = arg
+        wm_1 = self.wm_bit[i % self.wm_size]
+
+        u, s, v = svd(dct(block))
+        s[0] = (s[0] // self.d1 + 1 / 4 + 1 / 2 * wm_1) * self.d1
+
+        return idct(np.dot(u, np.dot(np.diag(s), v)))
+
+    def embed(self):
+        self.init_block_index()
+
+        embed_ca = copy.deepcopy(self.ca)
+        embed_YUV = [np.array([])] * 3
+
+        self.idx_shuffle = random_strategy1(self.password_img, self.block_num,
+                                            self.block_shape[0] * self.block_shape[1])
+        for channel in range(3):
+            tmp = self.pool.map(self.block_add_wm,
+                                [(self.ca_block[channel][self.block_index[i]], self.idx_shuffle[i], i)
+                                 for i in range(self.block_num)])
+
+            for i in range(self.block_num):
+                self.ca_block[channel][self.block_index[i]] = tmp[i]
+
+            # 4维分块变回2维
+            self.ca_part[channel] = np.concatenate(np.concatenate(self.ca_block[channel], 1), 1)
+            # 4维分块时右边和下边不能整除的长条保留,其余是主体部分,换成 embed 之后的频域的数据
+            embed_ca[channel][:self.part_shape[0], :self.part_shape[1]] = self.ca_part[channel]
+            # 逆变换回去
+            embed_YUV[channel] = idwt2((embed_ca[channel], self.hvd[channel]), "haar")
+
+        # 合并3通道
+        embed_img_YUV = np.stack(embed_YUV, axis=2)
+        # 之前如果不是2的整数,增加了白边,这里去除掉
+        embed_img_YUV = embed_img_YUV[:self.img_shape[0], :self.img_shape[1]]
+        embed_img = cv2.cvtColor(embed_img_YUV, cv2.COLOR_YUV2BGR)
+        embed_img = np.clip(embed_img, a_min=0, a_max=255)
+
+        if self.alpha is not None:
+            embed_img = cv2.merge([embed_img.astype(np.uint8), self.alpha])
+        return embed_img
+
+    def block_get_wm(self, args):
+        if self.fast_mode:
+            return self.block_get_wm_fast(args)
+        else:
+            return self.block_get_wm_slow(args)
+
+    def block_get_wm_slow(self, args):
+        block, shuffler = args
+        # dct->flatten->加密->逆flatten->svd->解水印
+        block_dct_shuffled = dct(block).flatten()[shuffler].reshape(self.block_shape)
+
+        u, s, v = svd(block_dct_shuffled)
+        wm = (s[0] % self.d1 > self.d1 / 2) * 1
+        if self.d2:
+            tmp = (s[1] % self.d2 > self.d2 / 2) * 1
+            wm = (wm * 3 + tmp * 1) / 4
+        return wm
+
+    def block_get_wm_fast(self, args):
+        block, shuffler = args
+        # dct->svd->解水印
+        u, s, v = svd(dct(block))
+        wm = (s[0] % self.d1 > self.d1 / 2) * 1
+
+        return wm
+
+    def extract_raw(self, img):
+        # 每个分块提取 1 bit 信息
+        self.read_img_arr(img=img)
+        self.init_block_index()
+
+        wm_block_bit = np.zeros(shape=(3, self.block_num))  # 3个channel,length 个分块提取的水印,全都记录下来
+
+        self.idx_shuffle = random_strategy1(seed=self.password_img,
+                                            size=self.block_num,
+                                            block_shape=self.block_shape[0] * self.block_shape[1],  # 16
+                                            )
+        for channel in range(3):
+            wm_block_bit[channel, :] = self.pool.map(self.block_get_wm,
+                                                     [(self.ca_block[channel][self.block_index[i]], self.idx_shuffle[i])
+                                                      for i in range(self.block_num)])
+        return wm_block_bit
+
+    def extract_avg(self, wm_block_bit):
+        # 对循环嵌入+3个 channel 求平均
+        wm_avg = np.zeros(shape=self.wm_size)
+        for i in range(self.wm_size):
+            wm_avg[i] = wm_block_bit[:, i::self.wm_size].mean()
+        return wm_avg
+
+    def extract(self, img, wm_shape):
+        self.wm_size = np.array(wm_shape).prod()
+
+        # 提取每个分块埋入的 bit:
+        wm_block_bit = self.extract_raw(img=img)
+        # 做平均:
+        wm_avg = self.extract_avg(wm_block_bit)
+        return wm_avg
+
+    def extract_with_kmeans(self, img, wm_shape):
+        wm_avg = self.extract(img=img, wm_shape=wm_shape)
+
+        return one_dim_kmeans(wm_avg)
+
+
+def one_dim_kmeans(inputs):
+    threshold = 0
+    e_tol = 10 ** (-6)
+    center = [inputs.min(), inputs.max()]  # 1. 初始化中心点
+    for i in range(300):
+        threshold = (center[0] + center[1]) / 2
+        is_class01 = inputs > threshold  # 2. 检查所有点与这k个点之间的距离,每个点归类到最近的中心
+        center = [inputs[~is_class01].mean(), inputs[is_class01].mean()]  # 3. 重新找中心点
+        if np.abs((center[0] + center[1]) / 2 - threshold) < e_tol:  # 4. 停止条件
+            threshold = (center[0] + center[1]) / 2
+            break
+
+    is_class01 = inputs > threshold
+    return is_class01
+
+
+def random_strategy1(seed, size, block_shape):
+    return np.random.RandomState(seed) \
+        .random(size=(size, block_shape)) \
+        .argsort(axis=1)
+
+
+def random_strategy2(seed, size, block_shape):
+    one_line = np.random.RandomState(seed) \
+        .random(size=(1, block_shape)) \
+        .argsort(axis=1)
+
+    return np.repeat(one_line, repeats=size, axis=0)

+ 53 - 0
blind_watermark/cli_tools.py

@@ -0,0 +1,53 @@
+from optparse import OptionParser
+from .blind_watermark import WaterMark
+
+usage1 = 'blind_watermark --embed --pwd 1234 image.jpg "watermark text" embed.png'
+usage2 = 'blind_watermark --extract --pwd 1234 --wm_shape 111 embed.png'
+optParser = OptionParser(usage=usage1 + '\n' + usage2)
+
+optParser.add_option('--embed', dest='work_mode', action='store_const', const='embed'
+                     , help='Embed watermark into images')
+optParser.add_option('--extract', dest='work_mode', action='store_const', const='extract'
+                     , help='Extract watermark from images')
+
+optParser.add_option('-p', '--pwd', dest='password', help='password, like 1234')
+optParser.add_option('--wm_shape', dest='wm_shape', help='Watermark shape, like 120')
+
+(opts, args) = optParser.parse_args()
+
+
+def main():
+    bwm1 = WaterMark(password_img=int(opts.password))
+    if opts.work_mode == 'embed':
+        if not len(args) == 3:
+            print('Error! Usage: ')
+            print(usage1)
+            return
+        else:
+            bwm1.read_img(args[0])
+            bwm1.read_wm(args[1], mode='str')
+            bwm1.embed(args[2])
+            print('Embed succeed! to file ', args[2])
+            print('Put down watermark size:', len(bwm1.wm_bit))
+
+    if opts.work_mode == 'extract':
+        if not len(args) == 1:
+            print('Error! Usage: ')
+            print(usage2)
+            return
+
+        else:
+            wm_str = bwm1.extract(filename=args[0], wm_shape=int(opts.wm_shape), mode='str')
+            print('Extract succeed! watermark is:')
+            print(wm_str)
+
+
+'''
+python -m blind_watermark.cli_tools --embed --pwd 1234 examples/pic/ori_img.jpeg "watermark text" examples/output/embedded.png
+python -m blind_watermark.cli_tools --extract --pwd 1234 --wm_shape 111 examples/output/embedded.png
+
+
+cd examples
+blind_watermark --embed --pwd 1234 examples/pic/ori_img.jpeg "watermark text" examples/output/embedded.png
+blind_watermark --extract --pwd 1234 --wm_shape 111 examples/output/embedded.png
+'''

+ 38 - 0
blind_watermark/pool.py

@@ -0,0 +1,38 @@
+import sys
+import multiprocessing
+import warnings
+
+if sys.platform != 'win32':
+    multiprocessing.set_start_method('fork')
+
+
+class CommonPool(object):
+    def map(self, func, args):
+        return list(map(func, args))
+
+
+class AutoPool(object):
+    def __init__(self, mode, processes):
+
+        if mode == 'multiprocessing' and sys.platform == 'win32':
+            warnings.warn('multiprocessing not support in windows, turning to multithreading')
+            mode = 'multithreading'
+
+        self.mode = mode
+        self.processes = processes
+
+        if mode == 'vectorization':
+            pass
+        elif mode == 'cached':
+            pass
+        elif mode == 'multithreading':
+            from multiprocessing.dummy import Pool as ThreadPool
+            self.pool = ThreadPool(processes=processes)
+        elif mode == 'multiprocessing':
+            from multiprocessing import Pool
+            self.pool = Pool(processes=processes)
+        else:  # common
+            self.pool = CommonPool()
+
+    def map(self, func, args):
+        return self.pool.map(func, args)

+ 100 - 0
blind_watermark/recover.py

@@ -0,0 +1,100 @@
+import cv2
+import numpy as np
+
+import functools
+
+
+# 一个帮助缓存化加速的类,引入事实上的全局变量
+class MyValues:
+    def __init__(self):
+        self.idx = 0
+        self.image, self.template = None, None
+
+    def set_val(self, image, template):
+        self.idx += 1
+        self.image, self.template = image, template
+
+
+my_value = MyValues()
+
+
+@functools.lru_cache(maxsize=None, typed=False)
+def match_template(w, h, idx):
+    image, template = my_value.image, my_value.template
+    resized = cv2.resize(template, dsize=(w, h))
+    scores = cv2.matchTemplate(image, resized, cv2.TM_CCOEFF_NORMED)
+    ind = np.unravel_index(np.argmax(scores, axis=None), scores.shape)
+    return ind, scores[ind]
+
+
+def match_template_by_scale(scale):
+    image, template = my_value.image, my_value.template
+    w, h = round(template.shape[1] * scale), round(template.shape[0] * scale)
+    ind, score = match_template(w, h, idx=my_value.idx)
+    return ind, score, scale
+
+
+def search_template(scale=(0.5, 2), search_num=200):
+    image, template = my_value.image, my_value.template
+    # 局部暴力搜索算法,寻找最优的scale
+    tmp = []
+    min_scale, max_scale = scale
+
+    max_scale = min(max_scale, image.shape[0] / template.shape[0], image.shape[1] / template.shape[1])
+
+    max_idx = 0
+
+    for i in range(2):
+        for scale in np.linspace(min_scale, max_scale, search_num):
+            ind, score, scale = match_template_by_scale(scale)
+            tmp.append([ind, score, scale])
+
+        # 寻找最佳
+        max_idx = 0
+        max_score = 0
+        for idx, (ind, score, scale) in enumerate(tmp):
+            if score > max_score:
+                max_idx, max_score = idx, score
+
+        min_scale, max_scale = tmp[max(0, max_idx - 1)][2], tmp[min(len(tmp) - 1, max_idx + 1)][2]
+
+        search_num = 2 * int((max_scale - min_scale) * max(template.shape[1], template.shape[0])) + 1
+
+    return tmp[max_idx]
+
+
+def estimate_crop_parameters(original_file=None, template_file=None, ori_img=None, tem_img=None
+                             , scale=(0.5, 2), search_num=200):
+    # 推测攻击后的图片,在原图片中的位置、大小
+    if template_file:
+        tem_img = cv2.imread(template_file, cv2.IMREAD_GRAYSCALE)  # template image
+    if original_file:
+        ori_img = cv2.imread(original_file, cv2.IMREAD_GRAYSCALE)  # image
+
+    if scale[0] == scale[1] == 1:
+        # 不缩放
+        scale_infer = 1
+        scores = cv2.matchTemplate(ori_img, tem_img, cv2.TM_CCOEFF_NORMED)
+        ind = np.unravel_index(np.argmax(scores, axis=None), scores.shape)
+        ind, score = ind, scores[ind]
+    else:
+        my_value.set_val(image=ori_img, template=tem_img)
+        ind, score, scale_infer = search_template(scale=scale, search_num=search_num)
+    w, h = int(tem_img.shape[1] * scale_infer), int(tem_img.shape[0] * scale_infer)
+    x1, y1, x2, y2 = ind[1], ind[0], ind[1] + w, ind[0] + h
+    return (x1, y1, x2, y2), ori_img.shape, score, scale_infer
+
+
+def recover_crop(template_file=None, tem_img=None, output_file_name=None, loc=None, image_o_shape=None):
+    if template_file:
+        tem_img = cv2.imread(template_file)  # template image
+
+    (x1, y1, x2, y2) = loc
+
+    img_recovered = np.zeros((image_o_shape[0], image_o_shape[1], 3))
+
+    img_recovered[y1:y2, x1:x2, :] = cv2.resize(tem_img, dsize=(x2 - x1, y2 - y1))
+
+    if output_file_name:
+        cv2.imwrite(output_file_name, img_recovered)
+    return img_recovered

+ 22 - 0
blind_watermark/version.py

@@ -0,0 +1,22 @@
+__version__ = '0.4.4'
+
+
+class Notes:
+    def __init__(self):
+        self.show = True
+
+    def print_notes(self):
+        if self.show:
+            print(f'''
+Welcome to use blind-watermark, version = {__version__}
+Make sure the version is the same when encode and decode
+Your star means a lot: https://github.com/guofei9987/blind_watermark
+This message only show once. To close it: `blind_watermark.bw_notes.close()`
+            ''')
+            self.close()
+
+    def close(self):
+        self.show = False
+
+
+bw_notes = Notes()

+ 64 - 0
block/data_get.py

@@ -0,0 +1,64 @@
+
+# 数据格式定义部分
+# 数据需准备成以下格式
+# ├── 数据集路径:data_path
+#     └── image:存放所有图片
+#     └── train.txt:训练图片的绝对路径(或相对data_path下路径)和类别号,(image/mask/0.jpg 0 2\n)表示该图片类别为0和2,空类别图片无类别号
+#     └── val.txt:验证图片的绝对路径(或相对data_path下路径)和类别
+#     └── class.txt:所有的类别名称
+# class.csv内容如下:
+# 类别1
+# 类别2
+
+import numpy as np
+import os
+import argparse
+
+def data_get(args):
+    data_dict = data_prepare(args).load()
+    return data_dict
+
+
+class data_prepare:
+    def __init__(self, args):
+        self.args = args
+        self.data_path = os.path.join(args.data_path, args.dataset_name)
+        self.dataset_name = args.dataset_name
+
+    def load(self):
+        data_dict = {}
+        data_dict['train'] = self._load_label('train.txt')
+        data_dict['test'] = self._load_label('test.txt')
+        data_dict['class'] = self._load_class()
+        return data_dict
+
+    def _load_label(self, txt_name):
+        with open(f'{self.args.data_path}/{self.args.dataset_name}/{txt_name}', encoding='utf-8')as f:
+            txt_list = [_.strip().split(' ') for _ in f.readlines()]  # 读取所有图片路径和类别号
+        data_list = [['', 0] for _ in range(len(txt_list))]  # [图片路径,类别独热编码]
+        for i, line in enumerate(txt_list):
+            image_path = line[0]
+            # print(image_path)
+            data_list[i][0] = image_path
+            data_list[i][1] = np.zeros(self.args.output_class, dtype=np.float32)
+            for j in line[1:]:
+                data_list[i][1][int(j)] = 1
+        return data_list
+
+    def _load_class(self):
+        with open(f'{self.args.data_path}/{self.args.dataset_name}/class.txt', encoding='utf-8')as f:
+            txt_list = [_.strip() for _ in f.readlines()]
+        return txt_list
+
+
+if __name__ == '__main__':
+    import argparse
+
+    parser = argparse.ArgumentParser(description='Data loader for specific dataset')
+    parser.add_argument('--data_path', default='/home/yhsun/classification-main/dataset', type=str, help='Root path to datasets')
+    parser.add_argument('--dataset_name', default='CIFAR-10', type=str, help='Specific dataset name')
+    parser.add_argument('--output_class', default=10, type=int, help='Number of output classes')
+    parser.add_argument('--input_size', default=640, type=int)
+    args = parser.parse_args()
+    data_dict = data_get(args)
+    print(len(data_dict['train']))

+ 7 - 0
block/loss_get.py

@@ -0,0 +1,7 @@
+import torch
+
+
+def loss_get(args):
+    choice_dict = {'bce': 'torch.nn.BCEWithLogitsLoss()'}
+    loss = eval(choice_dict[args.loss])
+    return loss

+ 64 - 0
block/lr_get.py

@@ -0,0 +1,64 @@
+import math
+import torch
+
+
+def adam(regularization, r_value, param, lr, betas):
+    if regularization == 'L2':
+        optimizer = torch.optim.Adam(param, lr=lr, betas=betas, weight_decay=r_value)
+    else:
+        optimizer = torch.optim.Adam(param, lr=lr, betas=betas)
+    return optimizer
+
+
+class lr_adjust:
+    def __init__(self, args, step_epoch, epoch_finished):
+        print(f"初始化 lr_end_epoch: {args.lr_end_epoch}, step_epoch: {step_epoch}")
+        self.lr_start = args.lr_start  # 初始学习率
+        self.lr_end = args.lr_end_ratio * args.lr_start  # 最终学习率
+        self.lr_end_epoch = args.lr_end_epoch  # 最终学习率达到的轮数
+        self.step_all = self.lr_end_epoch * step_epoch  # 总调整步数
+        if self.step_all == 0:
+            raise ValueError("计算总调整步数时出错: step_all 不能为0, 请检查 lr_end_epoch 和 step_epoch 的值。")
+        self.step_finished = epoch_finished * step_epoch  # 已调整步数
+        self.warmup_step = max(5, int(args.warmup_ratio * self.step_all))  # 预热训练步数
+        print(f"总调整步数 step_all: {self.step_all}")  # 这里将显示 step_all 的值
+
+    def __call__(self, optimizer):
+        self.step_finished += 1
+        step_now = self.step_finished
+        if self.step_all == 0:
+            raise ValueError("调用时出错:step_all 不能为0。")
+        decay = step_now / self.step_all
+        lr = self.lr_end + (self.lr_start - self.lr_end) * math.cos(math.pi / 2 * decay)
+        if step_now <= self.warmup_step:
+            lr = lr * (0.1 + 0.9 * step_now / self.warmup_step)
+        lr = max(lr, 0.000001)
+        for i in range(len(optimizer.param_groups)):
+            optimizer.param_groups[i]['lr'] = lr
+        return optimizer
+
+
+# 示例参数,假设的一些值
+class Args:
+    def __init__(self):
+        self.lr_start = 0.001
+        self.lr_end_ratio = 0.1
+        self.lr_end_epoch = 10
+        self.warmup_ratio = 0.1
+
+if __name__ == "__main__":
+    # 伪造一些输入参数和初始状态
+    args = Args()
+    step_epoch = 100  # 假设每个epoch有100步
+    epoch_finished = 0  # 假设从第0 epoch开始
+
+    # 初始化调整器
+    lr_adjuster = lr_adjust(args, step_epoch, epoch_finished)
+
+    # 创建一个假的优化器
+    params = [torch.randn(10, 10, requires_grad=True)]
+    optimizer = adam('L2', 0.01, params, args.lr_start, (0.9, 0.999))
+
+    # 调用lr_adjuster来调整学习率
+    optimizer = lr_adjuster(optimizer)
+    print(f"调整后的学习率: {optimizer.param_groups[0]['lr']}")

+ 13 - 0
block/metric_get.py

@@ -0,0 +1,13 @@
+import torch
+
+
+def metric(pred, true, class_threshold):  # 所有类别输出在0.5以下为空标签
+    TP = len(pred[torch.where((true == 1) & (pred > class_threshold), True, False)])
+    TN = len(pred[torch.where((true == 0) & (pred <= class_threshold), True, False)])
+    FP = len(pred[torch.where((true == 0) & (pred > class_threshold), True, False)])
+    FN = len(pred[torch.where((true == 1) & (pred <= class_threshold), True, False)])
+    accuracy = (TP + TN) / (TP + TN + FP + FN + 0.00001)
+    precision = TP / (TP + FP + 0.00001)
+    recall = TP / (TP + FN + 0.00001)
+    m_ap = precision * recall
+    return accuracy, precision, recall, m_ap

+ 28 - 0
block/model_ema.py

@@ -0,0 +1,28 @@
+import math
+import copy
+import torch
+
+
+class model_ema:
+    def __init__(self, model, decay=0.9999, tau=2000, updates=0):
+        self.ema = copy.deepcopy(self._get_model(model)).eval()  # FP32 EMA
+        self.updates = updates
+        self.decay = lambda x: decay * (1 - math.exp(-x / tau))
+        for p in self.ema.parameters():
+            p.requires_grad_(False)
+
+    def update(self, model):
+        with torch.no_grad():
+            self.updates += 1
+            d = self.decay(self.updates)
+            state_dict = self._get_model(model).state_dict()
+            for k, v in self.ema.state_dict().items():
+                if v.dtype.is_floating_point:
+                    v *= d
+                    v += (1 - d) * state_dict[k].detach()
+
+    def _get_model(self, model):
+        if type(model) in (torch.nn.parallel.DataParallel, torch.nn.parallel.DistributedDataParallel):
+            return model.module
+        else:
+            return model

+ 160 - 0
block/model_get.py

@@ -0,0 +1,160 @@
+import os
+import torch
+
+choice_dict = {
+    'yolov7_cls': 'model_prepare(args).yolov7_cls()',
+    'timm_model': 'model_prepare(args).timm_model()',
+    'Alexnet': 'model_prepare(args).Alexnet()',
+    'badnet': 'model_prepare(args).badnet()',
+    'GoogleNet': 'model_prepare(args).GoogleNet()',
+    'mobilenetv2': 'model_prepare(args).mobilenetv2()',
+    'resnet': 'model_prepare(args).resnet()',
+    'VGG19': 'model_prepare(args).VGG19()',
+    'efficientnet': 'model_prepare(args).EfficientNetV2_S()'
+}
+
+
+def model_get(args):
+    if os.path.exists(args.weight):  # 优先加载已有模型继续训练
+        model_dict = torch.load(args.weight, map_location='cpu')
+    else:  # 新建模型
+        if args.prune:  # 模型剪枝
+            model_dict = torch.load(args.prune_weight, map_location='cpu')
+            model = model_dict['model']
+            model = prune(args, model)
+        elif args.timm:
+            # model = model_prepare(args).timm_model()
+            model = eval(choice_dict['timm_model'])
+        else:
+            model = eval(choice_dict[args.model])
+        model_dict = {}
+        model_dict['model'] = model
+        model_dict['epoch_finished'] = 0  # 已训练的轮数
+        model_dict['optimizer_state_dict'] = None  # 学习率参数
+        model_dict['ema_updates'] = 0  # ema参数
+        model_dict['standard'] = 0  # 评价指标
+    return model_dict
+
+
+def prune(args, model):
+    # 记录BN层权重
+    # Debugging output
+    BatchNorm2d_weight = []
+    for module in model.modules():
+        if isinstance(module, torch.nn.BatchNorm2d):
+            BatchNorm2d_weight.append(module.weight.data.clone())
+    BatchNorm2d_weight_abs = torch.cat(BatchNorm2d_weight, dim=0).abs()
+    weight_len = len(BatchNorm2d_weight)
+    # 记录权重与BN层编号的关系
+    BatchNorm2d_id = []
+    for i in range(weight_len):
+        BatchNorm2d_id.extend([i for _ in range(len(BatchNorm2d_weight[i]))])
+    id_all = torch.tensor(BatchNorm2d_id)
+    # 筛选
+    value, index = torch.sort(BatchNorm2d_weight_abs, dim=0, descending=True)
+    boundary = int(len(index) * args.prune_ratio)
+    prune_index = index[0:boundary]  # 保留参数的下标
+    prune_index, _ = torch.sort(prune_index, dim=0, descending=False)
+    prune_id = id_all[prune_index]
+    # 将保留参数的下标放到每层中
+    index_list = [[] for _ in range(weight_len)]
+    for i in range(len(prune_index)):
+        index_list[prune_id[i]].append(prune_index[i])
+    # 将每层保留参数的下标换算成相对下标
+    record_len = 0
+    for i in range(weight_len):
+        index_list[i] = torch.tensor(index_list[i])
+        index_list[i] -= record_len
+        if len(index_list[i]) == 0:  # 存在整层都被减去的情况,至少保留一层
+            index_list[i] = torch.argmax(BatchNorm2d_weight[i], dim=0).unsqueeze(0)
+        record_len += len(BatchNorm2d_weight[i])
+    # 创建剪枝后的模型
+    args.prune_num = [len(_) for _ in index_list]
+    prune_model = eval(choice_dict[args.model])
+    # BN层权重赋值和部分conv权重赋值
+    index = 0
+    for module, prune_module in zip(model.modules(), prune_model.modules()):
+        if isinstance(module, torch.nn.Conv2d):  # 更新部分Conv2d层权重
+            print(f"处理 Conv2d 层,索引:{index},权重形状:{module.weight.data.shape}")
+            if index > 0 and index - 1 < len(index_list):
+                # 打印 index_list 状态
+                print(f"当前层前一层索引列表(index_list[{index - 1}]):{index_list[index - 1]}")
+                # 检查是否索引越界
+                if index_list[index - 1].max().item() < module.weight.data.shape[1]:  # 检查最大索引是否小于输入通道数
+                    weight = module.weight.data.clone()
+                    if index < len(index_list):
+                        weight = weight[:, index_list[index - 1], :, :]
+                    if prune_module.weight.data.shape == weight.shape:
+                        prune_module.weight.data = weight
+                else:
+                    print("索引越界,跳过当前层的处理")
+            elif index == 0:
+                weight = module.weight.data.clone()[index_list[index]]
+                if prune_module.weight.data.shape == weight.shape:
+                    prune_module.weight.data = weight
+
+        if isinstance(module, torch.nn.BatchNorm2d):
+            print(f"更新 BatchNorm2d 层,索引:{index},权重形状:{module.weight.data.shape}")
+            if index < len(index_list) and len(index_list[index]) > 0:
+                expected_size = module.weight.data.size(0)
+                actual_size = len(index_list[index])
+                print(f"期望的大小:{expected_size}, 实际保留的大小:{actual_size}")
+                if actual_size == expected_size:
+                    prune_module.weight.data = module.weight.data.clone()[index_list[index]]
+                    prune_module.bias.data = module.bias.data.clone()[index_list[index]]
+                    prune_module.running_mean = module.running_mean.clone()[index_list[index]]
+                    prune_module.running_var = module.running_var.clone()[index_list[index]]
+                else:
+                    print("警告: 剪枝后的大小与期望的 BatchNorm2d 层大小不匹配")
+            index += 1
+    return prune_model
+
+
+class model_prepare:
+    def __init__(self, args):
+        self.args = args
+
+    def timm_model(self):
+        from model.timm_model import timm_model
+        model = timm_model(self.args)
+        return model
+
+    def yolov7_cls(self):
+        from model.yolov7_cls import yolov7_cls
+        model = yolov7_cls(self.args)
+        return model
+    
+    def Alexnet(self):
+        from model.Alexnet import Alexnet
+        model = Alexnet(self.args.input_channels, self.args.output_num, self.args.input_size)
+        return model
+    
+    def badnet(self):
+        from model.badnet import BadNet
+        model = BadNet(self.args.input_channels, self.args.output_num)
+        return model
+    
+    def GoogleNet(self):
+        from model.GoogleNet import GoogLeNet
+        model = GoogLeNet(self.args.input_channels, self.args.output_num)
+        return model
+    
+    def mobilenetv2(self):
+        from model.mobilenetv2 import MobileNetV2
+        model = MobileNetV2(self.args.input_channels, self.args.output_num)
+        return model
+    
+    def resnet(self):
+        from model.resnet import ResNet18
+        model = ResNet18(self.args.input_channels, self.args.output_num)
+        return model
+    
+    def VGG19(self):
+        from model.VGG19 import VGG19
+        model = VGG19()
+        return model
+    
+    def EfficientNetV2_S(self):
+        from model.efficientnet import EfficientNetV2_S
+        model = EfficientNetV2_S(self.args.input_channels, self.args.output_num)
+        return model

+ 99 - 0
block/test_model_get.py

@@ -0,0 +1,99 @@
+import os
+import torch
+from torch import nn
+
+class model_prepare:
+    def __init__(self, args):
+        self.args = args
+
+    def timm_model(self):
+        from model.timm_model import timm_model
+        model = timm_model(self.args)
+        return model
+
+    def yolov7_cls(self):
+        from model.yolov7_cls import yolov7_cls
+        model = yolov7_cls(self.args)
+        return model
+
+def model_get(args):
+    choice_dict = {
+        'resnet18': model_prepare(args).timm_model,
+        'efficientnetv2_s': model_prepare(args).timm_model,
+        'yolov7_cls': model_prepare(args).yolov7_cls
+    }
+
+    print(f"Pruning enabled: {args.prune}")
+    if os.path.exists(args.weight):
+        print('Loading existing model to continue training...')
+        model_dict = torch.load(args.weight, map_location='cpu')
+    else:
+        if args.prune:
+            print("Loading model for pruning...")
+            model_dict = torch.load(args.prune_weight, map_location='cpu')
+            model = model_dict['model']
+            print("Model type before pruning:", type(model))
+            model = prune(args, model, choice_dict)
+            print("Model type after pruning:", type(model))
+        elif args.timm:
+            model = model_prepare(args).timm_model()
+        else:
+            model = choice_dict[args.model]()  # ensure it's callable
+        model_dict = {'model': model, 'epoch_finished': 0, 'optimizer_state_dict': None, 'ema_updates': 0, 'standard': 0}
+    return model_dict
+
+def prune(args, model, choice_dict):
+    if not isinstance(model, nn.Module):
+        raise TypeError("Expected model to be a PyTorch model instance")
+
+    BatchNorm2d_weight = [module.weight.data.clone() for module in model.modules() if isinstance(module, nn.BatchNorm2d)]
+    BatchNorm2d_weight_abs = torch.cat([w.abs() for w in BatchNorm2d_weight])
+
+    weight_len = len(BatchNorm2d_weight)
+    BatchNorm2d_id = [i for i in range(weight_len) for _ in range(len(BatchNorm2d_weight[i]))]
+    id_all = torch.tensor(BatchNorm2d_id)
+
+    value, index = torch.sort(BatchNorm2d_weight_abs, descending=True)
+    boundary = int(len(index) * args.prune_ratio)
+    prune_index = index[:boundary]
+    prune_index, _ = torch.sort(prune_index)
+    prune_id = id_all[prune_index]
+
+    index_list = [[] for _ in range(weight_len)]
+    for i in range(len(prune_index)):
+        index_list[prune_id[i]].append(prune_index[i])
+
+    for i, indices in enumerate(index_list):
+        if not indices:
+            index_list[i] = [torch.argmax(BatchNorm2d_weight[i])]
+        index_list[i] = torch.tensor(indices) - sum(len(BatchNorm2d_weight[j]) for j in range(i))
+
+    args.prune_num = [len(x) for x in index_list]
+    prune_model = choice_dict[args.model]()
+
+    index = 0
+    for module, prune_module in zip(model.modules(), prune_model.modules()):
+        if isinstance(module, nn.Conv2d) and index < weight_len:
+            if max(index_list[index]) >= module.out_channels:
+                raise IndexError("Index out of bounds for Conv2d output channels.")
+            weight = module.weight.data.clone()[index_list[index]]
+            if index > 0 and max(index_list[index - 1]) < module.in_channels:
+                weight = weight[:, index_list[index - 1], :, :]
+            prune_module.weight.data = weight
+        if isinstance(module, nn.BatchNorm2d):
+            if max(index_list[index]) >= module.num_features:
+                raise IndexError("Index out of bounds for BatchNorm2d features.")
+            prune_module.weight.data = module.weight.data.clone()[index_list[index]]
+            prune_module.bias.data = module.bias.data.clone()[index_list[index]]
+            prune_module.running_mean = module.running_mean.clone()[index_list[index]]
+            prune_module.running_var = module.running_var.clone()[index_list[index]]
+
+            # 打印剪枝后的BatchNorm2d层的参数维度
+            if isinstance(prune_model, torch.nn.BatchNorm2d):
+                print("Pruned BatchNorm2d:")
+                print("Weight:", prune_module.weight.data.shape)
+                print("Bias:", prune_module.bias.data.shape)
+                print("Running Mean:", prune_module.running_mean.shape)
+                print("Running Var:", prune_module.running_var.shape)
+            index += 1
+    return prune_model

+ 177 - 0
block/train_get.py

@@ -0,0 +1,177 @@
+import cv2
+import tqdm
+import wandb
+import torch
+import numpy as np
+import albumentations
+from block.val_get import val_get
+from block.model_ema import model_ema
+from block.lr_get import adam, lr_adjust
+
+
+def train_get(args, data_dict, model_dict, loss):
+    # 加载模型
+    model = model_dict['model'].to(args.device, non_blocking=args.latch)
+    print(model)
+    # 学习率
+    optimizer = adam(args.regularization, args.r_value, model.parameters(), lr=args.lr_start, betas=(0.937, 0.999))
+    optimizer.load_state_dict(model_dict['optimizer_state_dict']) if model_dict['optimizer_state_dict'] else None
+    step_epoch = len(data_dict['train']) // args.batch // args.device_number * args.device_number  # 每轮的步数
+    print(len(data_dict['train']) // args.batch)
+    print(step_epoch)
+    optimizer_adjust = lr_adjust(args, step_epoch, model_dict['epoch_finished'])  # 学习率调整函数
+    optimizer = optimizer_adjust(optimizer)  # 学习率初始化
+    # 使用平均指数移动(EMA)调整参数(不能将ema放到args中,否则会导致模型保存出错)
+    ema = model_ema(model) if args.ema else None
+    if args.ema:
+        ema.updates = model_dict['ema_updates']
+    # 数据集
+    train_dataset = torch_dataset(args, 'train', data_dict['train'], data_dict['class'])
+    train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset) if args.distributed else None
+    train_shuffle = False if args.distributed else True  # 分布式设置sampler后shuffle要为False
+    train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=args.batch, shuffle=train_shuffle,
+                                                   drop_last=True, pin_memory=args.latch, num_workers=args.num_worker,
+                                                   sampler=train_sampler)
+    val_dataset = torch_dataset(args, 'test', data_dict['test'], data_dict['class'])
+    val_sampler = None  # 分布式时数据合在主GPU上进行验证
+    val_batch = args.batch // args.device_number  # 分布式验证时batch要减少为一个GPU的量
+    val_dataloader = torch.utils.data.DataLoader(val_dataset, batch_size=val_batch, shuffle=False,
+                                                 drop_last=False, pin_memory=args.latch, num_workers=args.num_worker,
+                                                 sampler=val_sampler)
+    # 分布式初始化
+    model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.local_rank],
+                                                      output_device=args.local_rank) if args.distributed else model
+    # wandb
+    if args.wandb and args.local_rank == 0:
+        wandb_image_list = []  # 记录所有的wandb_image最后一起添加(最多添加args.wandb_image_num张)
+    epoch_base = model_dict['epoch_finished'] + 1  # 新的一轮要+1
+    for epoch in range(epoch_base, args.epoch + 1):  # 训练
+        print(f'\n-----------------------第{epoch}轮-----------------------') if args.local_rank == 0 else None
+        model.train()
+        train_loss = 0  # 记录损失
+        if args.local_rank == 0:  # tqdm
+            tqdm_show = tqdm.tqdm(total=step_epoch)
+        for index, (image_batch, true_batch) in enumerate(train_dataloader):
+            if args.wandb and args.local_rank == 0 and len(wandb_image_list) < args.wandb_image_num:
+                wandb_image_batch = (image_batch * 255).cpu().numpy().astype(np.uint8).transpose(0, 2, 3, 1)
+            image_batch = image_batch.to(args.device, non_blocking=args.latch)
+            true_batch = true_batch.to(args.device, non_blocking=args.latch)
+            if args.amp:
+                with torch.cuda.amp.autocast():
+                    pred_batch = model(image_batch)
+                    loss_batch = loss(pred_batch, true_batch)
+                args.amp.scale(loss_batch).backward()
+                args.amp.step(optimizer)
+                args.amp.update()
+                optimizer.zero_grad()
+            else:
+                pred_batch = model(image_batch)
+                loss_batch = loss(pred_batch, true_batch)
+                loss_batch.backward()
+                optimizer.step()
+                optimizer.zero_grad()
+            # 调整参数,ema.updates会自动+1
+            ema.update(model) if args.ema else None
+            # 记录损失
+            train_loss += loss_batch.item()
+            # 调整学习率
+            optimizer = optimizer_adjust(optimizer)
+            # tqdm
+            if args.local_rank == 0:
+                tqdm_show.set_postfix({'train_loss': loss_batch.item(),
+                                       'lr': optimizer.param_groups[0]['lr']})  # 添加显示
+                tqdm_show.update(args.device_number)  # 更新进度条
+            # wandb
+            if args.wandb and args.local_rank == 0 and epoch == 0 and len(wandb_image_list) < args.wandb_image_num:
+                cls = true_batch.cpu().numpy().tolist()
+                for i in range(len(wandb_image_batch)):  # 遍历每一张图片
+                    image = wandb_image_batch[i]
+                    text = ['{:.0f}'.format(_) for _ in cls[i]]
+                    text = text[0] if len(text) == 1 else '--'.join(text)
+                    image = np.ascontiguousarray(image)  # 将数组的内存变为连续存储(cv2画图的要求)
+                    cv2.putText(image, text, (30, 30), cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 255, 0), 2)
+                    wandb_image = wandb.Image(image)
+                    wandb_image_list.append(wandb_image)
+                    if len(wandb_image_list) == args.wandb_image_num:
+                        break
+        # tqdm
+        if args.local_rank == 0:
+            tqdm_show.close()
+        # 计算平均损失
+        train_loss /= index + 1
+        if args.local_rank == 0:
+            print(f'\n| 训练 | train_loss:{train_loss:.4f} | lr:{optimizer.param_groups[0]["lr"]:.6f} |\n')
+        # 清理显存空间
+        del image_batch, true_batch, pred_batch, loss_batch
+        torch.cuda.empty_cache()
+        # 验证
+        if args.local_rank == 0:  # 分布式时只验证一次
+            val_loss, accuracy, precision, recall, m_ap = val_get(args, val_dataloader, model, loss, ema,
+                                                                  len(data_dict['test']))
+        # 保存
+        if args.local_rank == 0:  # 分布式时只保存一次
+            model_dict['model'] = model.module if args.distributed else model
+            model_dict['epoch_finished'] = epoch
+            model_dict['optimizer_state_dict'] = optimizer.state_dict()
+            model_dict['ema_updates'] = ema.updates if args.ema else model_dict['ema_updates']
+            model_dict['class'] = data_dict['class']
+            model_dict['train_loss'] = train_loss
+            model_dict['val_loss'] = val_loss
+            model_dict['val_accuracy'] = accuracy
+            model_dict['val_precision'] = precision
+            model_dict['val_recall'] = recall
+            model_dict['val_m_ap'] = m_ap
+            torch.save(model_dict, args.save_path_last if not args.prune else 'prune_last.pt')  # 保存最后一次训练的模型
+            if m_ap > 0.5 and m_ap > model_dict['standard']:
+                model_dict['standard'] = m_ap
+                save_path = args.save_path if not args.prune else args.prune_save
+                torch.save(model_dict, save_path)  # 保存最佳模型
+                print(f'| 保存最佳模型:{save_path} | val_m_ap:{m_ap:.4f} |')
+            # wandb
+            if args.wandb:
+                wandb_log = {}
+                if epoch == 0:
+                    wandb_log.update({f'image/train_image': wandb_image_list})
+                wandb_log.update({'metric/train_loss': train_loss,
+                                  'metric/val_loss': val_loss,
+                                  'metric/val_m_ap': m_ap,
+                                  'metric/val_accuracy': accuracy,
+                                  'metric/val_precision': precision,
+                                  'metric/val_recall': recall})
+                args.wandb_run.log(wandb_log)
+        torch.distributed.barrier() if args.distributed else None  # 分布式时每轮训练后让所有GPU进行同步,快的GPU会在此等待
+
+
+class torch_dataset(torch.utils.data.Dataset):
+    def __init__(self, args, tag, data, class_name):
+        self.tag = tag
+        self.data = data
+        self.class_name = class_name
+        self.noise_probability = args.noise
+        self.noise = albumentations.Compose([
+            albumentations.GaussianBlur(blur_limit=(5, 5), p=0.2),
+            albumentations.GaussNoise(var_limit=(10.0, 30.0), p=0.2)])
+        self.transform = albumentations.Compose([
+            albumentations.LongestMaxSize(args.input_size),
+            albumentations.PadIfNeeded(min_height=args.input_size, min_width=args.input_size,
+                                       border_mode=cv2.BORDER_CONSTANT, value=(128, 128, 128))])
+        self.rgb_mean = (0.406, 0.456, 0.485)
+        self.rgb_std = (0.225, 0.224, 0.229)
+
+    def __len__(self):
+        return len(self.data)
+
+    def __getitem__(self, index):
+        # print(self.data[index][0])
+        image = cv2.imread(self.data[index][0])  # 读取图片
+        if self.tag == 'train' and torch.rand(1) < self.noise_probability:  # 使用数据加噪
+            image = self.noise(image=image)['image']
+        image = self.transform(image=image)['image']  # 缩放和填充图片
+        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)  # 转为RGB通道
+        image = self._image_deal(image)  # 归一化、转换为tensor、调维度
+        label = torch.tensor(self.data[index][1], dtype=torch.float32)  # 转换为tensor
+        return image, label
+
+    def _image_deal(self, image):  # 归一化、转换为tensor、调维度
+        image = torch.tensor(image / 255, dtype=torch.float32).permute(2, 0, 1)
+        return image

+ 30 - 0
block/val_get.py

@@ -0,0 +1,30 @@
+import tqdm
+import torch
+from block.metric_get import metric
+
+
+def val_get(args, val_dataloader, model, loss, ema, data_len):
+    tqdm_len = (data_len - 1) // (args.batch // args.device_number) + 1
+    tqdm_show = tqdm.tqdm(total=tqdm_len)
+    with torch.no_grad():
+        model = ema.ema if args.ema else model.eval()
+        pred_all = []  # 记录所有预测
+        true_all = []  # 记录所有标签
+        for index, (image_batch, true_batch) in enumerate(val_dataloader):
+            image_batch = image_batch.to(args.device, non_blocking=args.latch)
+            pred_batch = model(image_batch).detach().cpu()
+            loss_batch = loss(pred_batch, true_batch)
+            pred_all.extend(pred_batch)
+            true_all.extend(true_batch)
+            tqdm_show.set_postfix({'val_loss': loss_batch.item()})  # 添加显示
+            tqdm_show.update(1)  # 更新进度条
+        # tqdm
+        tqdm_show.close()
+        # 计算指标
+        pred_all = torch.stack(pred_all, dim=0)
+        true_all = torch.stack(true_all, dim=0)
+        loss_all = loss(pred_all, true_all).item()
+        accuracy, precision, recall, m_ap = metric(pred_all, true_all, args.class_threshold)
+        print(f'\n| 验证 | val_loss:{loss_all:.4f} | 阈值:{args.class_threshold:.2f} | val_accuracy:{accuracy:.4f} |'
+              f' val_precision:{precision:.4f} | val_recall:{recall:.4f} | val_m_ap:{m_ap:.4f} |')
+    return loss_all, accuracy, precision, recall, m_ap

+ 52 - 0
export_onnx.py

@@ -0,0 +1,52 @@
+import os
+import torch
+import argparse
+from model.layer import deploy
+
+# -------------------------------------------------------------------------------------------------------------------- #
+parser = argparse.ArgumentParser(description='|将pt模型转为onnx,同时导出类别信息|')
+parser.add_argument('--weight', default='best.pt', type=str, help='|模型位置|')
+parser.add_argument('--input_size', default=32, type=int, help='|输入图片大小|')
+parser.add_argument('--normalization', default='sigmoid', type=str, help='|选择sigmoid或softmax归一化,单类别一定要选sigmoid|')
+parser.add_argument('--batch', default=0, type=int, help='|输入图片批量,0为动态|')
+parser.add_argument('--sim', default=True, type=bool, help='|使用onnxsim压缩简化模型|')
+parser.add_argument('--device', default='cuda', type=str, help='|在哪个设备上加载模型|')
+parser.add_argument('--float16', default=True, type=bool, help='|转换的onnx模型数据类型,需要GPU,False时为float32|')
+parser.add_argument('--save_path', default='best.onnx', type=str, help='|移动存储位置|')
+args = parser.parse_args()
+args.weight = args.weight.split('.')[0] + '.pt'
+args.save_name = args.weight.split('.')[0] + '.onnx'
+# -------------------------------------------------------------------------------------------------------------------- #
+assert os.path.exists(args.weight), f'! 没有找到模型{args.weight} !'
+if args.float16:
+    assert torch.cuda.is_available(), '! cuda不可用,无法使用float16 !'
+
+
+# -------------------------------------------------------------------------------------------------------------------- #
+def export_onnx():
+    model_dict = torch.load(args.weight, map_location='cpu')
+    model = model_dict['model']
+    model = deploy(model, args.normalization)
+    model = model.eval().half().to(args.device) if args.float16 else model.eval().float().to(args.device)
+    input_shape = torch.rand(1, args.input_size, args.input_size, 3,
+                             dtype=torch.float16 if args.float16 else torch.float32).to(args.device)
+    torch.onnx.export(model, input_shape, args.save_name,
+                      opset_version=12, input_names=['input'], output_names=['output'],
+                      dynamic_axes={'input': {args.batch: 'batch_size'}, 'output': {args.batch: 'batch_size'}})
+    print(f'| 转为onnx模型成功:{args.save_name} |')
+    if args.sim:
+        import onnx
+        import onnxsim
+
+        model_onnx = onnx.load(args.save_name)
+        model_simplify, check = onnxsim.simplify(model_onnx)
+        onnx.save(model_simplify, args.save_name)
+        print(f'| 使用onnxsim简化模型成功:{args.save_name} |')
+
+
+if __name__ == '__main__':
+    export_onnx()
+    # 移动生成的 ONNX 文件到指定文件夹
+    destination_folder = args.save_path
+    shutil.move(args.save_name, os.path.join(destination_folder, args.save_name))
+    print(f'| 已将 {args.save_name} 移动到 {destination_folder} 中 |')

二進制
export_trt


二進制
export_trt.exe


+ 23 - 0
export_trt_record

@@ -0,0 +1,23 @@
+# onnx转trt,需要安装tensorrt库
+# 需要压缩包中的bin、include、lib文件,然后添加lib文件路径到系统路径中
+# windows为:系统->高级系统设置->环境变量->系统变量->Path中加入
+# linux为:sudo ldconfig lib位置
+# 然后找到对应版本的whl文件使用pip install ....whl。bin中是官方提供的onnx转trt程序
+# -------------------------------------------------------------------------------------------------------------------- #
+这里的导出程序实际上是tensorrt安装包中的bin里的文件。windows为trtexec.exe。linux为trtexec
+windows:
+export_trt.exe --onnx=best.onnx --saveEngine=best.trt --fp16 --useCudaGraph
+linux:
+export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:lib位置
+export_trt --onnx=best.onnx --saveEngine=best.trt --fp16 --useCudaGraph
+# -------------------------------------------------------------------------------------------------------------------- #
+# export_trt.exe:查看提示信息
+# --onnx=onnx模型位置
+# --saveEngine=trt模型保存位置
+# --noTF32:禁用float32精度
+# --fp16:启用float16精度
+# --int8:启用int8精度
+# --best:开启所有精度(有的模型是混合精度的)
+# --device=0:使用的GPU号码,默认为0
+# --useCudaGraph:尝试使用cuda图
+# 转换过程中有很多提示信息,可以解决大多数问题。转换后会进行速度测试。不指定输入形状时默认为单批量预测(推荐)

+ 23 - 0
flask_request.py

@@ -0,0 +1,23 @@
+# 启用flask_start的服务后,将数据以post的方式调用服务得到结果
+import json
+import base64
+import requests
+
+
+def image_encode(image_path):
+    with open(image_path, 'rb')as f:
+        image_byte = f.read()
+    image_base64 = base64.b64encode(image_byte)
+    image = image_base64.decode()
+    return image
+
+
+if __name__ == '__main__':
+    url = 'http://0.0.0.0:9999/test/'  # 根据flask_start中的设置: http://host:port/name/
+    image_path = 'demo.jpg'
+    image = image_encode(image_path)
+    request_dict = {'image': image}
+    request = json.dumps(request_dict)
+    response = requests.post(url, data=request)
+    result = response.json()
+    print(result)

+ 40 - 0
flask_start.py

@@ -0,0 +1,40 @@
+# pip install flask -i https://pypi.tuna.tsinghua.edu.cn/simple
+# 用flask将程序包装成一个服务,并在服务器上启动
+import cv2
+import json
+import flask
+import base64
+import argparse
+import numpy as np
+
+# -------------------------------------------------------------------------------------------------------------------- #
+# 设置
+parser = argparse.ArgumentParser('|在服务器上启动flask服务|')
+# ...
+args, _ = parser.parse_known_args()  # 防止传入参数冲突,替代args = parser.parse_args()
+app = flask.Flask(__name__)  # 创建一个服务框架
+
+
+# -------------------------------------------------------------------------------------------------------------------- #
+# 程序
+def image_decode(image):
+    image_base64 = image.encode()  # base64
+    image_byte = base64.b64decode(image_base64)  # base64->字节类型
+    array = np.frombuffer(image_byte, dtype=np.uint8)  # 字节类型->一行数组
+    image = cv2.imdecode(array, cv2.IMREAD_COLOR)  # 一行数组->BGR图片
+    return image
+
+
+@app.route('/test/', methods=['POST'])  # 每当调用服务时会执行一次flask_app函数
+def flask_app():
+    request_json = flask.request.get_data()
+    request_dict = json.loads(request_json)
+    image = image_decode(request_dict['image'])
+    # ...
+    result = image.shape
+    return result
+
+
+if __name__ == '__main__':
+    print('| 使用flask启动服务 |')
+    app.run(host='0.0.0.0', port=9999, debug=False)  # 启动服务

+ 22 - 0
gradio_start.py

@@ -0,0 +1,22 @@
+# pip install gradio -i https://pypi.tuna.tsinghua.edu.cn/simple
+# 用gradio将程序包装成一个可视化的页面,可以在网页可视化的展示
+import gradio
+import argparse
+
+# -------------------------------------------------------------------------------------------------------------------- #
+# 设置
+parser = argparse.ArgumentParser('|在服务器上启动gradio服务|')
+# ...
+args = parser.parse_args()
+
+
+# -------------------------------------------------------------------------------------------------------------------- #
+# 程序
+def function(text, image):
+    return text, image
+
+
+if __name__ == '__main__':
+    print('| 使用gradio启动服务 |')
+    gradio_app = gradio.Interface(fn=function, inputs=['text', 'image'], outputs=['text', 'image'])
+    gradio_app.launch(share=False)

+ 24 - 0
gunicorn_config.py

@@ -0,0 +1,24 @@
+# pip install gunicorn -i https://pypi.tuna.tsinghua.edu.cn/simple
+# 使用gunicorn启用flask:gunicorn -c gunicorn_config.py flask_start:app
+# 设置端口,外部访问端口也会从http://host:port/name/变为http://bind/name/
+bind = '0.0.0.0:9999'
+# 设置进程数。推荐核数*2+1发挥最佳性能
+workers = 3
+# 客户端最大连接数,默认1000
+worker_connections = 2000
+# 设置工作模型。有sync(同步)(默认)、eventlet(协程异步)、gevent(协程异步)、tornado、gthread(线程)。
+# sync根据请求先来后到处理。eventlet需要安装库:pip install eventlet。gevent需要安装库:pip install gevent。
+# tornado需要安装库:pip install tornado。gthread需要指定threads参数
+worker_class = 'sync'
+# 设置线程数。指定threads参数时工作模式自动变成gthread(线程)模式
+threads = 1
+# 启动程序时的超时时间(s)
+timeout = 60
+# 当代码有修改时会自动重启,适用于开发环境,默认False
+reload = True
+# 设置日志的记录地址。需要提前创建gunicorn_log文件夹
+accesslog = 'gunicorn_log/access.log'
+# 设置错误信息的记录地址。需要提前创建gunicorn_log文件夹
+errorlog = 'gunicorn_log/error.log'
+# 设置日志的记录水平。有debug、info(默认)、warning、error、critical,按照记录信息的详细程度排序
+loglevel = 'info'

+ 75 - 0
model/Alexnet.py

@@ -0,0 +1,75 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+class Alexnet(nn.Module):
+    def __init__(self, input_channels, output_num, input_size):
+        super().__init__()
+        
+        self.features = nn.Sequential(
+            nn.Conv2d(in_channels=input_channels, out_channels=64, kernel_size=3, stride=2, padding=1),
+            nn.BatchNorm2d(64),  # 批量归一化层
+            nn.MaxPool2d(kernel_size=2),
+            nn.ReLU(inplace=True),
+            
+            nn.Conv2d(in_channels=64, out_channels=192, kernel_size=3, padding=1),
+            nn.BatchNorm2d(192),  # 批量归一化层
+            nn.MaxPool2d(kernel_size=2),
+            nn.ReLU(inplace=True),
+            
+            nn.Conv2d(in_channels=192, out_channels=384, kernel_size=3, padding=1),
+            nn.BatchNorm2d(384),  # 批量归一化层
+            nn.ReLU(inplace=True),
+            
+            nn.Conv2d(in_channels=384, out_channels=256, kernel_size=3, padding=1),
+            nn.BatchNorm2d(256),  # 批量归一化层
+            nn.ReLU(inplace=True),
+            
+            nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3, padding=1),
+            nn.BatchNorm2d(256),  # 批量归一化层
+            nn.MaxPool2d(kernel_size=2),
+            nn.ReLU(inplace=True),
+        )
+        
+        self.input_size = input_size
+        self._init_classifier(output_num)
+    
+    def _init_classifier(self, output_num):
+        with torch.no_grad():
+            # Forward a dummy input through the feature extractor part of the network
+            dummy_input = torch.zeros(1, 3, self.input_size, self.input_size)
+            features_size = self.features(dummy_input).numel()
+
+        self.classifier = nn.Sequential(
+            nn.Dropout(0.5),
+            nn.Linear(features_size, 1000),
+            nn.ReLU(inplace=True),
+            
+            nn.Dropout(0.5),
+            nn.Linear(1000, 256),
+            nn.ReLU(inplace=True),
+            
+            nn.Linear(256, output_num)
+        )
+        
+    def forward(self, x):
+        x = self.features(x)
+        x = x.view(x.size(0), -1)
+        x = self.classifier(x)
+        return x
+
+if __name__ == '__main__':
+    import argparse
+
+    parser = argparse.ArgumentParser(description='AlexNet Implementation')
+    parser.add_argument('--input_channels', default=3, type=int)
+    parser.add_argument('--output_num', default=10, type=int)
+    parser.add_argument('--input_size', default=32, type=int)
+    args = parser.parse_args()
+    
+    model = Alexnet(args.input_channels, args.output_num, args.input_size)
+    tensor = torch.rand(1, args.input_channels, args.input_size, args.input_size)
+    pred = model(tensor)
+    
+    print(model)
+    print("Predictions shape:", pred.shape)

+ 107 - 0
model/GoogleNet.py

@@ -0,0 +1,107 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+class Inception(nn.Module):
+    def __init__(self, in_planes, kernel_1_x, kernel_3_in, kernel_3_x, kernel_5_in, kernel_5_x, pool_planes):
+        super(Inception, self).__init__()
+        # 1x1 conv branch
+        self.b1 = nn.Sequential(
+            nn.Conv2d(in_planes, kernel_1_x, kernel_size=1),
+            nn.BatchNorm2d(kernel_1_x),
+            nn.ReLU(True),
+        )
+
+        # 1x1 conv -> 3x3 conv branch
+        self.b2 = nn.Sequential(
+            nn.Conv2d(in_planes, kernel_3_in, kernel_size=1),
+            nn.BatchNorm2d(kernel_3_in),
+            nn.ReLU(True),
+            nn.Conv2d(kernel_3_in, kernel_3_x, kernel_size=3, padding=1),
+            nn.BatchNorm2d(kernel_3_x),
+            nn.ReLU(True),
+        )
+
+        # 1x1 conv -> 5x5 conv branch
+        self.b3 = nn.Sequential(
+            nn.Conv2d(in_planes, kernel_5_in, kernel_size=1),
+            nn.BatchNorm2d(kernel_5_in),
+            nn.ReLU(True),
+            nn.Conv2d(kernel_5_in, kernel_5_x, kernel_size=3, padding=1),
+            nn.BatchNorm2d(kernel_5_x),
+            nn.ReLU(True),
+            nn.Conv2d(kernel_5_x, kernel_5_x, kernel_size=3, padding=1),
+            nn.BatchNorm2d(kernel_5_x),
+            nn.ReLU(True),
+        )
+
+        # 3x3 pool -> 1x1 conv branch
+        self.b4 = nn.Sequential(
+            nn.MaxPool2d(3, stride=1, padding=1),
+            nn.Conv2d(in_planes, pool_planes, kernel_size=1),
+            nn.BatchNorm2d(pool_planes),
+            nn.ReLU(True),
+        )
+
+    def forward(self, x):
+        y1 = self.b1(x)
+        y2 = self.b2(x)
+        y3 = self.b3(x)
+        y4 = self.b4(x)
+        return torch.cat([y1, y2, y3, y4], 1)
+
+class GoogLeNet(nn.Module):
+    def __init__(self, input_channels, output_num):
+        super(GoogLeNet, self).__init__()
+        self.pre_layers = nn.Sequential(
+            nn.Conv2d(input_channels, 192, kernel_size=3, padding=1),
+            nn.BatchNorm2d(192),
+            nn.ReLU(True),
+        )
+
+        self.a3 = Inception(192, 64, 96, 128, 16, 32, 32)
+        self.b3 = Inception(256, 128, 128, 192, 32, 96, 64)
+        self.max_pool = nn.MaxPool2d(3, stride=2, padding=1)
+        self.a4 = Inception(480, 192, 96, 208, 16, 48, 64)
+        self.b4 = Inception(512, 160, 112, 224, 24, 64, 64)
+        self.c4 = Inception(512, 128, 128, 256, 24, 64, 64)
+        self.d4 = Inception(512, 112, 144, 288, 32, 64, 64)
+        self.e4 = Inception(528, 256, 160, 320, 32, 128, 128)
+        self.a5 = Inception(832, 256, 160, 320, 32, 128, 128)
+        self.b5 = Inception(832, 384, 192, 384, 48, 128, 128)
+        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))  # Adaptive pooling
+        self.linear = nn.Linear(1024, output_num)
+
+    def forward(self, x):
+        x = self.pre_layers(x)
+        x = self.a3(x)
+        x = self.b3(x)
+        x = self.max_pool(x)
+        x = self.a4(x)
+        x = self.b4(x)
+        x = self.c4(x)
+        x = self.d4(x)
+        x = self.e4(x)
+        x = self.max_pool(x)
+        x = self.a5(x)
+        x = self.b5(x)
+        x = self.avgpool(x)
+        x = x.view(x.size(0), -1)
+        x = self.linear(x)
+        return x
+
+if __name__ == '__main__':
+    import argparse
+
+    parser = argparse.ArgumentParser(description='GoogLeNet Implementation')
+    parser.add_argument('--input_channels', default=3, type=int)
+    parser.add_argument('--output_num', default=10, type=int)
+    args = parser.parse_args()
+    
+    model = GoogLeNet(args.input_channels, args.output_num)
+    tensor = torch.rand(1, args.input_channels, 224, 224)  # Example for a larger size
+    pred = model(tensor)
+    pred_shape = pred.shape
+    
+    print(model)
+    print("Predictions shape:", pred_shape)

+ 65 - 0
model/VGG19.py

@@ -0,0 +1,65 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+_cfg = {
+    'VGG11': [64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],
+    'VGG13': [64, 64, 'M', 128, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],
+    'VGG16': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M'],
+    'VGG19': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M', 512, 512, 512, 512, 'M'],
+}
+
+def _make_layers(cfg, input_size):
+    layers = []
+    in_channels = 3
+    for layer_cfg in cfg:
+        if layer_cfg == 'M':
+            layers.append(nn.MaxPool2d(kernel_size=2, stride=2))
+            input_size = input_size // 2
+        else:
+            layers.append(nn.Conv2d(in_channels=in_channels, out_channels=layer_cfg, kernel_size=3, stride=1, padding=1))
+            layers.append(nn.BatchNorm2d(num_features=layer_cfg))
+            layers.append(nn.ReLU(inplace=True))
+            in_channels = layer_cfg
+    return nn.Sequential(*layers), input_size
+
+class VGG(nn.Module):
+    def __init__(self, name, input_size=32, num_classes=10):
+        super(VGG, self).__init__()
+        cfg = _cfg[name]
+        self.features, final_size = _make_layers(cfg, input_size)
+        self.fc = nn.Linear(512 * final_size * final_size, num_classes)
+        
+    def forward(self, x):
+        x = self.features(x)
+        x = x.view(x.size(0), -1)
+        x = self.fc(x)
+        return x
+
+def VGG11():
+    return VGG('VGG11')
+
+def VGG13():
+    return VGG('VGG13')
+
+def VGG16():
+    return VGG('VGG16')
+
+def VGG19():
+    return VGG('VGG19')
+
+if __name__ == '__main__':
+    import argparse
+
+    parser = argparse.ArgumentParser(description='VGG Model Test')
+    parser.add_argument('--input_channels', default=3, type=int)
+    parser.add_argument('--output_num', default=10, type=int)
+    parser.add_argument('--input_size', default=32, type=int)
+    args = parser.parse_args()
+    
+    model = VGG19()  # Changed to use VGG19
+    tensor = torch.rand(1, args.input_channels, args.input_size, args.input_size)
+    pred = model(tensor)
+    
+    print(model)
+    print("Predictions shape:", pred.shape)

+ 3 - 0
model/__init__.py

@@ -0,0 +1,3 @@
+from .timm_model import timm_model
+from .yolov7_cls import yolov7_cls
+from .layer import cbs, elan, mp, sppcspc, linear_head

+ 54 - 0
model/badnet.py

@@ -0,0 +1,54 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+class BadNet(nn.Module):
+
+    def __init__(self, input_channels, output_num):
+        super().__init__()
+        self.conv1 = nn.Sequential(
+            nn.Conv2d(in_channels=input_channels, out_channels=16, kernel_size=5, stride=1),
+            nn.BatchNorm2d(16),  # 添加批量归一化
+            nn.ReLU(),
+            nn.AvgPool2d(kernel_size=2, stride=2)
+        )
+
+        self.conv2 = nn.Sequential(
+            nn.Conv2d(in_channels=16, out_channels=32, kernel_size=5, stride=1),
+            nn.BatchNorm2d(32),  # 添加批量归一化
+            nn.ReLU(),
+            nn.AvgPool2d(kernel_size=2, stride=2)
+        )
+        # 计算全连接层的输入特征数
+        fc1_input_features = 800 if input_channels == 3 else 512
+        self.fc1 = nn.Sequential(
+            nn.Linear(in_features=fc1_input_features, out_features=512),
+            nn.ReLU()
+        )
+        self.fc2 = nn.Linear(in_features=512, out_features=output_num)  # 移除 Softmax
+        self.dropout = nn.Dropout(p=.5)
+
+    def forward(self, x):
+        x = self.conv1(x)
+        x = self.conv2(x)
+
+        x = x.view(x.size(0), -1)  # 展平
+        x = self.fc1(x)
+        x = self.dropout(x)  # 应用 dropout
+        x = self.fc2(x)
+        return x
+
+if __name__ == '__main__':
+    import argparse
+
+    parser = argparse.ArgumentParser(description='Badnet Implementation')
+    parser.add_argument('--input_channels', default=3, type=int)
+    parser.add_argument('--output_num', default=10, type=int)
+    args = parser.parse_args()
+    
+    model = BadNet(args.input_channels, args.output_num)
+    tensor = torch.rand(1, args.input_channels, 32, 32)
+    pred = model(tensor)
+    
+    print(model)
+    print("Predictions shape:", pred.shape)

+ 181 - 0
model/layer.py

@@ -0,0 +1,181 @@
+import torch
+import sys
+sys.path.append('/home/yhsun/classification-main/')
+
+class cbs(torch.nn.Module):
+    def __init__(self, in_, out_, kernel_size, stride):
+        super().__init__()
+        self.conv = torch.nn.Conv2d(in_, out_, kernel_size=kernel_size, stride=stride, padding=(kernel_size - 1) // 2,
+                                    bias=False)
+        self.bn = torch.nn.BatchNorm2d(out_, eps=0.001, momentum=0.03)
+        self.silu = torch.nn.SiLU(inplace=True)
+
+    def forward(self, x):
+        x = self.conv(x)
+        x = self.bn(x)
+        x = self.silu(x)
+        return x
+
+
+class concat(torch.nn.Module):
+    def __init__(self, dim=1):
+        super().__init__()
+        self.concat = torch.cat
+        self.dim = dim
+
+    def forward(self, x):
+        x = self.concat(x, dim=self.dim)
+        return x
+
+
+class elan(torch.nn.Module):  # in_->out_,len->len
+    def __init__(self, in_, out_, n, config=None):
+        super().__init__()
+        if not config:  # 正常版本
+            self.cbs0 = cbs(in_, out_ // 4, kernel_size=1, stride=1)
+            self.cbs1 = cbs(in_, out_ // 4, kernel_size=1, stride=1)
+            self.sequential2 = torch.nn.Sequential(
+                *(cbs(out_ // 4, out_ // 4, kernel_size=3, stride=1) for _ in range(n)))
+            self.sequential3 = torch.nn.Sequential(
+                *(cbs(out_ // 4, out_ // 4, kernel_size=3, stride=1) for _ in range(n)))
+            self.concat4 = concat()
+            self.cbs5 = cbs(out_, out_, kernel_size=1, stride=1)
+        else:  # 剪枝版本。len(config) = 3 + 2 * n
+            self.cbs0 = cbs(in_, config[0], kernel_size=1, stride=1)
+            self.cbs1 = cbs(in_, config[1], kernel_size=1, stride=1)
+            self.sequential2 = torch.nn.Sequential(
+                *(cbs(config[1 + _], config[2 + _], kernel_size=3, stride=1) for _ in range(n)))
+            self.sequential3 = torch.nn.Sequential(
+                *(cbs(config[1 + n + _], config[2 + n + _], kernel_size=3, stride=1) for _ in range(n)))
+            self.concat4 = concat()
+            self.cbs5 = cbs(config[0] + config[1] + config[1 + n] + config[1 + 2 * n], config[2 + 2 * n],
+                            kernel_size=1, stride=1)
+
+    def forward(self, x):
+        x0 = self.cbs0(x)
+        x1 = self.cbs1(x)
+        x2 = self.sequential2(x1)
+        x3 = self.sequential3(x2)
+        x = self.concat4([x0, x1, x2, x3])
+        x = self.cbs5(x)
+        return x
+
+
+class mp(torch.nn.Module):  # in_->out_,len->len//2
+    def __init__(self, in_, out_, config=None):
+        super().__init__()
+        if not config:  # 正常版本
+            self.maxpool0 = torch.nn.MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1)
+            self.cbs1 = cbs(in_, out_ // 2, 1, 1)
+            self.cbs2 = cbs(in_, out_ // 2, 1, 1)
+            self.cbs3 = cbs(out_ // 2, out_ // 2, 3, 2)
+            self.concat4 = concat(dim=1)
+        else:  # 剪枝版本。len(config) = 3
+            self.maxpool0 = torch.nn.MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1)
+            self.cbs1 = cbs(in_, config[0], 1, 1)
+            self.cbs2 = cbs(in_, config[1], 1, 1)
+            self.cbs3 = cbs(config[1], config[2], 3, 2)
+            self.concat4 = concat(dim=1)
+
+    def forward(self, x):
+        x0 = self.maxpool0(x)
+        x0 = self.cbs1(x0)
+        x1 = self.cbs2(x)
+        x1 = self.cbs3(x1)
+        x = self.concat4([x0, x1])
+        return x
+
+
+class sppcspc(torch.nn.Module):  # in_->out_,len->len
+    def __init__(self, in_, out_, config=None):
+        super().__init__()
+        if not config:  # 正常版本
+            self.cbs0 = cbs(in_, in_ // 2, kernel_size=1, stride=1)
+            self.cbs1 = cbs(in_, in_ // 2, kernel_size=1, stride=1)
+            self.cbs2 = cbs(in_ // 2, in_ // 2, kernel_size=3, stride=1)
+            self.cbs3 = cbs(in_ // 2, in_ // 2, kernel_size=1, stride=1)
+            self.MaxPool2d4 = torch.nn.MaxPool2d(kernel_size=5, stride=1, padding=2, dilation=1)
+            self.MaxPool2d5 = torch.nn.MaxPool2d(kernel_size=9, stride=1, padding=4, dilation=1)
+            self.MaxPool2d6 = torch.nn.MaxPool2d(kernel_size=13, stride=1, padding=6, dilation=1)
+            self.concat7 = concat(dim=1)
+            self.cbs8 = cbs(2 * in_, in_ // 2, kernel_size=1, stride=1)
+            self.cbs9 = cbs(in_ // 2, in_ // 2, kernel_size=3, stride=1)
+            self.concat10 = concat(dim=1)
+            self.cbs11 = cbs(in_, out_, kernel_size=1, stride=1)
+        else:  # 剪枝版本。len(config) = 7
+            self.cbs0 = cbs(in_, config[0], kernel_size=1, stride=1)
+            self.cbs1 = cbs(in_, config[1], kernel_size=1, stride=1)
+            self.cbs2 = cbs(config[1], config[2], kernel_size=3, stride=1)
+            self.cbs3 = cbs(config[2], config[3], kernel_size=1, stride=1)
+            self.MaxPool2d4 = torch.nn.MaxPool2d(kernel_size=5, stride=1, padding=2, dilation=1)
+            self.MaxPool2d5 = torch.nn.MaxPool2d(kernel_size=9, stride=1, padding=4, dilation=1)
+            self.MaxPool2d6 = torch.nn.MaxPool2d(kernel_size=13, stride=1, padding=6, dilation=1)
+            self.concat7 = concat(dim=1)
+            self.cbs8 = cbs(4 * config[3], config[4], kernel_size=1, stride=1)
+            self.cbs9 = cbs(config[4], config[5], kernel_size=3, stride=1)
+            self.concat10 = concat(dim=1)
+            self.cbs11 = cbs(config[0] + config[5], config[6], kernel_size=1, stride=1)
+
+    def forward(self, x):
+        x0 = self.cbs0(x)
+        x1 = self.cbs1(x)
+        x1 = self.cbs2(x1)
+        x1 = self.cbs3(x1)
+        x4 = self.MaxPool2d4(x1)
+        x5 = self.MaxPool2d5(x1)
+        x6 = self.MaxPool2d6(x1)
+        x = self.concat7([x1, x4, x5, x6])
+        x = self.cbs8(x)
+        x = self.cbs9(x)
+        x = self.concat10([x, x0])
+        x = self.cbs11(x)
+        return x
+
+
+class linear_head(torch.nn.Module):
+    def __init__(self, in_, out_):
+        super().__init__()
+        self.avgpool0 = torch.nn.AdaptiveAvgPool2d(1)
+        self.flatten1 = torch.nn.Flatten()
+        self.Dropout2 = torch.nn.Dropout(0.2)
+        self.linear3 = torch.nn.Linear(in_, in_ // 2)
+        self.silu4 = torch.nn.SiLU()
+        self.Dropout5 = torch.nn.Dropout(0.2)
+        self.linear6 = torch.nn.Linear(in_ // 2, out_)
+
+    def forward(self, x):
+        x = self.avgpool0(x)
+        x = self.flatten1(x)
+        x = self.Dropout2(x)
+        x = self.linear3(x)
+        x = self.silu4(x)
+        x = self.Dropout5(x)
+        x = self.linear6(x)
+        return x
+
+
+class image_deal(torch.nn.Module):  # 归一化
+    def __init__(self):
+        super().__init__()
+
+    def forward(self, x):
+        x = x / 255
+        x = x.permute(0, 3, 1, 2)
+        return x
+
+
+class deploy(torch.nn.Module):
+    def __init__(self, model, normalization):
+        super().__init__()
+        self.image_deal = image_deal()
+        self.model = model
+        if normalization == 'softmax':
+            self.normalization = torch.nn.Softmax(dim=1)
+        else:
+            self.normalization = torch.nn.Sigmoid()
+
+    def forward(self, x):
+        x = self.image_deal(x)
+        x = self.model(x)
+        x = self.normalization(x)
+        return x

+ 94 - 0
model/mobilenetv2.py

@@ -0,0 +1,94 @@
+'''MobileNetV2 in PyTorch.
+See the paper "Inverted Residuals and Linear Bottlenecks:
+Mobile Networks for Classification, Detection and Segmentation" for more details.
+'''
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+
+class Block(nn.Module):
+    '''expand + depthwise + pointwise'''
+    def __init__(self, in_planes, out_planes, expansion, stride):
+        super(Block, self).__init__()
+        self.stride = stride
+
+        planes = expansion * in_planes
+        self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, stride=1, padding=0, bias=False)
+        self.bn1 = nn.BatchNorm2d(planes)
+        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, groups=planes, bias=False)
+        self.bn2 = nn.BatchNorm2d(planes)
+        self.conv3 = nn.Conv2d(planes, out_planes, kernel_size=1, stride=1, padding=0, bias=False)
+        self.bn3 = nn.BatchNorm2d(out_planes)
+
+        self.shortcut = nn.Sequential()
+        if stride == 1 and in_planes != out_planes:
+            self.shortcut = nn.Sequential(
+                nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=1, padding=0, bias=False),
+                nn.BatchNorm2d(out_planes),
+            )
+
+    def forward(self, x):
+        out = F.relu(self.bn1(self.conv1(x)))
+        out = F.relu(self.bn2(self.conv2(out)))
+        out = self.bn3(self.conv3(out))
+        out = out + self.shortcut(x) if self.stride==1 else out
+        return out
+
+
+class MobileNetV2(nn.Module):
+    # (expansion, out_planes, num_blocks, stride)
+    cfg = [(1,  16, 1, 1),
+           (6,  24, 2, 1),  # NOTE: change stride 2 -> 1 for CIFAR10
+           (6,  32, 3, 2),
+           (6,  64, 4, 2),
+           (6,  96, 3, 1),
+           (6, 160, 3, 2),
+           (6, 320, 1, 1)]
+
+    def __init__(self, input_channels, output_num):
+        super(MobileNetV2, self).__init__()
+        # NOTE: change conv1 stride 2 -> 1 for CIFAR10
+        self.conv1 = nn.Conv2d(input_channels, 32, kernel_size=3, stride=1, padding=1, bias=False)
+        self.bn1 = nn.BatchNorm2d(32)
+        self.layers = self._make_layers(in_planes=32)
+        self.conv2 = nn.Conv2d(320, 1280, kernel_size=1, stride=1, padding=0, bias=False)
+        self.bn2 = nn.BatchNorm2d(1280)
+        self.linear = nn.Linear(1280, output_num)
+
+    def _make_layers(self, in_planes):
+        layers = []
+        for expansion, out_planes, num_blocks, stride in self.cfg:
+            strides = [stride] + [1]*(num_blocks-1)
+            for stride in strides:
+                layers.append(Block(in_planes, out_planes, expansion, stride))
+                in_planes = out_planes
+        return nn.Sequential(*layers)
+
+    def forward(self, x):
+        out = F.relu(self.bn1(self.conv1(x)))
+        out = self.layers(out)
+        out = F.relu(self.bn2(self.conv2(out)))
+        # NOTE: change pooling kernel_size 7 -> 4 for CIFAR10
+        out = F.avg_pool2d(out, 4)
+        out = out.view(out.size(0), -1)
+        out = self.linear(out)
+        return out
+
+
+
+if __name__ == '__main__':
+    import argparse
+
+    parser = argparse.ArgumentParser(description='MobileNetV2 Implementation')
+    parser.add_argument('--input_channels', default=3, type=int)
+    parser.add_argument('--output_num', default=10, type=int)
+    # parser.add_argument('--input_size', default=32, type=int)
+    args = parser.parse_args()
+    
+    model = MobileNetV2(args.input_channels, args.output_num)
+    tensor = torch.rand(1, args.input_channels, 32, 32)
+    pred = model(tensor)
+    
+    print(model)
+    print("Predictions shape:", pred.shape)

+ 83 - 0
model/resnet.py

@@ -0,0 +1,83 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+class CommonBlock(nn.Module):
+    """ Standard residual block without downsampling. """
+    def __init__(self, in_channel, out_channel, stride=1):
+        super(CommonBlock, self).__init__()
+        self.conv1 = nn.Conv2d(in_channel, out_channel, kernel_size=3, stride=stride, padding=1, bias=False)
+        self.bn1 = nn.BatchNorm2d(out_channel)
+        self.conv2 = nn.Conv2d(out_channel, out_channel, kernel_size=3, stride=1, padding=1, bias=False)
+        self.bn2 = nn.BatchNorm2d(out_channel)
+
+    def forward(self, x):
+        identity = x
+        x = F.relu(self.bn1(self.conv1(x)), inplace=True)
+        x = self.bn2(self.conv2(x))
+        x += identity
+        return F.relu(x, inplace=True)
+
+class SpecialBlock(nn.Module):
+    """ Residual block with downsampling and channel size increase. """
+    def __init__(self, in_channel, out_channel, stride):
+        super(SpecialBlock, self).__init__()
+        self.change_channel = nn.Sequential(
+            nn.Conv2d(in_channel, out_channel, kernel_size=1, stride=stride, padding=0, bias=False),
+            nn.BatchNorm2d(out_channel)
+        )
+        self.conv1 = nn.Conv2d(in_channel, out_channel, kernel_size=3, stride=stride, padding=1, bias=False)
+        self.bn1 = nn.BatchNorm2d(out_channel)
+        self.conv2 = nn.Conv2d(out_channel, out_channel, kernel_size=3, stride=1, padding=1, bias=False)
+        self.bn2 = nn.BatchNorm2d(out_channel)
+
+    def forward(self, x):
+        identity = self.change_channel(x)
+        x = F.relu(self.bn1(self.conv1(x)), inplace=True)
+        x = self.bn2(self.conv2(x))
+        x += identity
+        return F.relu(x, inplace=True)
+
+class ResNet18(nn.Module):
+    def __init__(self, input_channels, num_classes=10):
+        super(ResNet18, self).__init__()
+        self.prepare = nn.Sequential(
+            nn.Conv2d(input_channels, 64, kernel_size=7, stride=2, padding=3, bias=False),
+            nn.BatchNorm2d(64),
+            nn.ReLU(inplace=True),
+            nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
+        )
+        self.layer1 = nn.Sequential(CommonBlock(64, 64), CommonBlock(64, 64))
+        self.layer2 = nn.Sequential(SpecialBlock(64, 128, 2), CommonBlock(128, 128))
+        self.layer3 = nn.Sequential(SpecialBlock(128, 256, 2), CommonBlock(256, 256))
+        self.layer4 = nn.Sequential(SpecialBlock(256, 512, 2), CommonBlock(512, 512))
+        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
+        self.fc = nn.Linear(512, num_classes)
+
+    def forward(self, x):
+        x = self.prepare(x)
+        x = self.layer1(x)
+        x = self.layer2(x)
+        x = self.layer3(x)
+        x = self.layer4(x)
+        x = self.avgpool(x)
+        x = x.view(x.size(0), -1)
+        x = self.fc(x)
+        return x
+
+
+if __name__ == '__main__':
+    import argparse
+
+    parser = argparse.ArgumentParser(description='Resnet Implementation')
+    parser.add_argument('--input_channels', default=3, type=int)
+    parser.add_argument('--output_num', default=10, type=int)
+    args = parser.parse_args()
+    
+    model = ResNet18(args.input_channels, args.output_num)
+    tensor = torch.rand(1, args.input_channels, 224, 224)
+    pred = model(tensor)
+    
+    print(model)
+    print("Predictions shape:", pred.shape)
+

+ 13 - 0
model/test.py

@@ -0,0 +1,13 @@
+import os
+import sys
+
+project_root = '/home/yhsun/classification-main/'
+sys.path.append(project_root)
+print("Project root added to sys.path:", project_root)
+
+# Verify that we can access the model package directly
+import model
+print("Model package is accessible, path:", model.__file__)
+
+from model.layer import linear_head
+print("Imported linear_head from model.layer")

+ 46 - 0
model/timm_model.py

@@ -0,0 +1,46 @@
+import timm
+# print(timm.list_models())
+import torch
+# from model.layer import linear_head
+
+import os
+import sys
+
+project_root = '/home/yhsun/classification-main/'
+sys.path.append(project_root)
+# print("Project root added to sys.path:", project_root)
+
+# Verify that we can access the model package directly
+import model
+# print("Model package is accessible, path:", model.__file__)
+
+from model.layer import linear_head
+# print("Imported linear_head from model.layer")
+
+
+class timm_model(torch.nn.Module):
+    def __init__(self, args):
+        super().__init__()
+        self.backbone = timm.create_model(args.model, in_chans=3, features_only=True, exportable=True)
+        out_dim = self.backbone.feature_info.channels()[-1]  # backbone输出有多个,接最后一个输出,并得到其通道数
+        self.linear_head = linear_head(out_dim, args.output_class)
+
+    def forward(self, x):
+        x = self.backbone(x)
+        x = self.linear_head(x[-1])
+        return x
+
+
+if __name__ == '__main__':
+    import argparse
+
+    parser = argparse.ArgumentParser(description='')
+    parser.add_argument('--model', default='resnet18', type=str)
+    parser.add_argument('--input_size', default=32, type=int)
+    parser.add_argument('--output_class', default=10, type=int)
+    args = parser.parse_args()
+    model = timm_model(args)
+    tensor = torch.rand(2, 3, args.input_size, args.input_size, dtype=torch.float32)
+    pred = model(tensor)
+    print(model)
+    print(pred.shape)

+ 87 - 0
model/yolov7_cls.py

@@ -0,0 +1,87 @@
+# 根据yolov7改编:https://github.com/WongKinYiu/yolov7
+import torch
+import os
+import sys
+
+project_root = '/home/yhsun/classification-main/'
+sys.path.append(project_root)
+# print("Project root added to sys.path:", project_root)
+
+# Verify that we can access the model package directly
+import model
+from model.layer import cbs, elan, mp, sppcspc, linear_head
+
+
+class yolov7_cls(torch.nn.Module):
+    def __init__(self, args):
+        super().__init__()
+        dim_dict = {'n': 8, 's': 16, 'm': 32, 'l': 64}
+        n_dict = {'n': 1, 's': 1, 'm': 2, 'l': 3}
+        dim = dim_dict[args.model_type]
+        n = n_dict[args.model_type]
+        output_class = args.output_class
+        # 网络结构
+        if not args.prune:  # 正常版本
+            self.l0 = cbs(3, dim, 1, 1)
+            self.l1 = cbs(dim, 2 * dim, 3, 2)  # input_size/2
+            self.l2 = cbs(2 * dim, 2 * dim, 1, 1)
+            self.l3 = cbs(2 * dim, 4 * dim, 3, 2)  # input_size/4
+            self.l4 = elan(4 * dim, 8 * dim, n)
+            self.l5 = mp(8 * dim, 8 * dim)  # input_size/8
+            self.l6 = elan(8 * dim, 16 * dim, n)
+            self.l7 = mp(16 * dim, 16 * dim)  # input_size/16
+            self.l8 = elan(16 * dim, 32 * dim, n)
+            self.l9 = mp(32 * dim, 32 * dim)  # input_size/32
+            self.l10 = elan(32 * dim, 32 * dim, n)
+            self.l11 = sppcspc(32 * dim, 16 * dim)
+            self.l12 = cbs(16 * dim, 8 * dim, 1, 1)
+            self.linear_head = linear_head(8 * dim, output_class)
+        else:  # 剪枝版本
+            config = args.prune_num
+            self.l0 = cbs(3, config[0], 1, 1)
+            self.l1 = cbs(config[0], config[1], 3, 2)  # input_size/2
+            self.l2 = cbs(config[1], config[2], 1, 1)
+            self.l3 = cbs(config[2], config[3], 3, 2)  # input_size/4
+            self.l4 = elan(config[3], None, n, config[4:7 + 2 * n])
+            self.l5 = mp(config[6 + 2 * n], None, config[7 + 2 * n:10 + 2 * n])  # input_size/8
+            self.l6 = elan(config[7 + 2 * n] + config[9 + 2 * n], None, n, config[10 + 2 * n:13 + 4 * n])
+            self.l7 = mp(config[12 + 4 * n], None, config[13 + 4 * n:16 + 4 * n])  # input_size/16
+            self.l8 = elan(config[13 + 4 * n] + config[15 + 4 * n], None, n, config[16 + 4 * n:19 + 6 * n])
+            self.l9 = mp(config[18 + 6 * n], None, config[19 + 6 * n:22 + 6 * n])  # input_size/32
+            self.l10 = elan(config[19 + 6 * n] + config[21 + 6 * n], None, n, config[22 + 6 * n:25 + 8 * n])
+            self.l11 = sppcspc(config[24 + 8 * n], None, config[25 + 8 * n:32 + 8 * n])
+            self.l12 = cbs(config[31 + 8 * n], config[32 + 8 * n], 1, 1)
+            self.linear_head = linear_head(config[32 + 8 * n], output_class)
+
+    def forward(self, x):
+        x = self.l0(x)
+        x = self.l1(x)
+        x = self.l2(x)
+        x = self.l3(x)
+        x = self.l4(x)
+        x = self.l5(x)
+        x = self.l6(x)
+        x = self.l7(x)
+        x = self.l8(x)
+        x = self.l9(x)
+        x = self.l10(x)
+        x = self.l11(x)
+        x = self.l12(x)
+        x = self.linear_head(x)
+        return x
+
+
+if __name__ == '__main__':
+    import argparse
+
+    parser = argparse.ArgumentParser(description='')
+    parser.add_argument('--prune', default=False, type=bool)
+    parser.add_argument('--model_type', default='n', type=str)
+    parser.add_argument('--input_size', default=32, type=int)
+    parser.add_argument('--output_class', default=10, type=int)
+    args = parser.parse_args()
+    model = yolov7_cls(args)
+    tensor = torch.rand(2, 3, args.input_size, args.input_size, dtype=torch.float32)
+    pred = model(tensor)
+    print(model)
+    print(pred.shape)

+ 72 - 0
predict_onnx.py

@@ -0,0 +1,72 @@
+import os
+import cv2
+import time
+import argparse
+import onnxruntime
+import numpy as np
+import albumentations
+
+# -------------------------------------------------------------------------------------------------------------------- #
+parser = argparse.ArgumentParser(description='|onnx模型推理|')
+parser.add_argument('--model_path', default='best.onnx', type=str, help='|onnx模型位置|')
+parser.add_argument('--data_path', default='image', type=str, help='|图片文件夹位置|')
+parser.add_argument('--input_size', default=320, type=int, help='|模型输入图片大小,要与导出的模型对应|')
+parser.add_argument('--batch', default=1, type=int, help='|输入图片批量,要与导出的模型对应|')
+parser.add_argument('--device', default='cuda', type=str, help='|推理设备|')
+parser.add_argument('--float16', default=True, type=bool, help='|推理数据类型,要与导出的模型对应,False时为float32|')
+args, _ = parser.parse_known_args()  # 防止传入参数冲突,替代args = parser.parse_args()
+# -------------------------------------------------------------------------------------------------------------------- #
+assert os.path.exists(args.model_path), f'! model_path不存在:{args.model_path} !'
+assert os.path.exists(args.data_path), f'! data_path不存在:{args.data_path} !'
+
+
+# -------------------------------------------------------------------------------------------------------------------- #
+def predict_onnx(args):
+    # 加载模型
+    provider = 'CUDAExecutionProvider' if args.device.lower() in ['gpu', 'cuda'] else 'CPUExecutionProvider'
+    model = onnxruntime.InferenceSession(args.model_path, providers=[provider])  # 加载模型和框架
+    input_name = model.get_inputs()[0].name  # 获取输入名称
+    output_name = model.get_outputs()[0].name  # 获取输出名称
+    print(f'| 模型加载成功:{args.model_path} |')
+    # 加载数据
+    start_time = time.time()
+    transform = albumentations.Compose([
+        albumentations.LongestMaxSize(args.input_size),
+        albumentations.PadIfNeeded(min_height=args.input_size, min_width=args.input_size,
+                                   border_mode=cv2.BORDER_CONSTANT, value=(128, 128, 128))])
+    image_dir = sorted(os.listdir(args.data_path))
+    image_all = np.zeros((len(image_dir), args.input_size, args.input_size, 3)).astype(
+        np.float16 if args.float16 else np.float32)
+    for i in range(len(image_dir)):
+        image = cv2.imread(args.data_path + '/' + image_dir[i])
+        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)  # 转为RGB通道
+        image = transform(image=image)['image']  # 缩放和填充图片(归一化、减均值、除以方差、调维度等在模型中完成)
+        image_all[i] = image
+    end_time = time.time()
+    print('| 数据加载成功:{} 每张耗时:{:.4f} |'.format(len(image_all), (end_time - start_time) / len(image_all)))
+    # 推理
+    start_time = time.time()
+    result = []
+    n = len(image_all) // args.batch
+    if n > 0:  # 如果图片数量>=批量(分批预测)
+        for i in range(n):
+            batch = image_all[i * args.batch:(i + 1) * args.batch]
+            pred_batch = model.run([output_name], {input_name: batch})
+            result.extend(pred_batch[0].tolist())
+        if len(image_all) % args.batch > 0:  # 如果图片数量没有刚好满足批量
+            batch = image_all[(i + 1) * args.batch:]
+            pred_batch = model.run([output_name], {input_name: batch})
+            result.extend(pred_batch[0].tolist())
+    else:  # 如果图片数量<批量(直接预测)
+        batch = image_all
+        pred_batch = model.run([output_name], {input_name: batch})
+        result.extend(pred_batch[0].tolist())
+    for i in range(len(result)):
+        result[i] = [round(result[i][_], 2) for _ in range(len(result[i]))]
+        print(f'| {image_dir[i]}:{result[i]} |')
+    end_time = time.time()
+    print('| 数据:{} 批量:{} 每张耗时:{:.4f} |'.format(len(image_all), args.batch, (end_time - start_time) / len(image_all)))
+
+
+if __name__ == '__main__':
+    predict_onnx(args)

+ 76 - 0
predict_pt.py

@@ -0,0 +1,76 @@
+import os
+import cv2
+import time
+import torch
+import argparse
+import albumentations
+from model.layer import deploy
+
+# -------------------------------------------------------------------------------------------------------------------- #
+parser = argparse.ArgumentParser(description='|pt模型推理|')
+parser.add_argument('--model_path', default='best.pt', type=str, help='|pt模型位置|')
+parser.add_argument('--data_path', default='/home/yhsun/classification-main/dataset/CIFAR-10/train_cifar10_JPG/airplane', type=str, help='|图片文件夹位置|')
+parser.add_argument('--input_size', default=32, type=int, help='|模型输入图片大小|')
+parser.add_argument('--normalization', default='sigmoid', type=str, help='|选择sigmoid或softmax归一化,单类别一定要选sigmoid|')
+parser.add_argument('--batch', default=1, type=int, help='|输入图片批量|')
+parser.add_argument('--device', default='cuda', type=str, help='|推理设备|')
+parser.add_argument('--num_worker', default=0, type=int, help='|CPU处理数据的进程数,0只有一个主进程,一般为0、2、4、8|')
+parser.add_argument('--float16', default=True, type=bool, help='|推理数据类型,要支持float16的GPU,False时为float32|')
+args, _ = parser.parse_known_args()  # 防止传入参数冲突,替代args = parser.parse_args()
+# -------------------------------------------------------------------------------------------------------------------- #
+assert os.path.exists(args.model_path), f'! model_path不存在:{args.model_path} !'
+assert os.path.exists(args.data_path), f'! data_path不存在:{args.data_path} !'
+if args.float16:
+    assert torch.cuda.is_available(), 'cuda不可用,因此无法使用float16'
+
+
+# -------------------------------------------------------------------------------------------------------------------- #
+def predict_pt(args):
+    # 加载模型
+    model_dict = torch.load(args.model_path, map_location='cpu')
+    model = model_dict['model']
+    model = deploy(model, args.normalization)
+    model.half().eval().to(args.device) if args.float16 else model.float().eval().to(args.device)
+    epoch = model_dict['epoch_finished']
+    m_ap = round(model_dict['standard'], 4)
+    print(f'| 模型加载成功:{args.model_path} | epoch:{epoch} | m_ap:{m_ap}|')
+    # 推理
+    image_dir = sorted(os.listdir(args.data_path))
+    start_time = time.time()
+    with torch.no_grad():
+        dataloader = torch.utils.data.DataLoader(torch_dataset(image_dir), batch_size=args.batch,
+                                                 shuffle=False, drop_last=False, pin_memory=False,
+                                                 num_workers=args.num_worker)
+        result = []
+        for item, batch in enumerate(dataloader):
+            batch = batch.to(args.device)
+            pred_batch = model(batch).detach().cpu()
+            result.extend(pred_batch.tolist())
+        for i in range(len(result)):
+            result[i] = [round(result[i][_], 2) for _ in range(len(result[i]))]
+            print(f'| {image_dir[i]}:{result[i]} |')
+    end_time = time.time()
+    print('| 数据:{} 批量:{} 每张耗时:{:.4f} |'.format(len(image_dir), args.batch, (end_time - start_time) / len(image_dir)))
+
+
+class torch_dataset(torch.utils.data.Dataset):
+    def __init__(self, image_dir):
+        self.image_dir = image_dir
+        self.transform = albumentations.Compose([
+            albumentations.LongestMaxSize(args.input_size),
+            albumentations.PadIfNeeded(min_height=args.input_size, min_width=args.input_size,
+                                       border_mode=cv2.BORDER_CONSTANT, value=(128, 128, 128))])
+
+    def __len__(self):
+        return len(self.image_dir)
+
+    def __getitem__(self, index):
+        image = cv2.imread(args.data_path + '/' + self.image_dir[index])  # 读取图片
+        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)  # 转为RGB通道
+        image = self.transform(image=image)['image']  # 缩放和填充图片(归一化、调维度在模型中完成)
+        image = torch.tensor(image, dtype=torch.float16 if args.float16 else torch.float32)
+        return image
+
+
+if __name__ == '__main__':
+    predict_pt(args)

+ 70 - 0
predict_trt.py

@@ -0,0 +1,70 @@
+import os
+import cv2
+import time
+import argparse
+import tensorrt
+import numpy as np
+import albumentations
+import pycuda.autoinit
+import pycuda.driver as cuda
+
+# -------------------------------------------------------------------------------------------------------------------- #
+parser = argparse.ArgumentParser(description='|tensorrt模型推理|')
+parser.add_argument('--model_path', default='best.trt', type=str, help='|trt模型位置|')
+parser.add_argument('--data_path', default='image', type=str, help='|图片文件夹位置|')
+parser.add_argument('--input_size', default=320, type=int, help='|输入图片大小,要与导出的模型对应|')
+parser.add_argument('--batch', default=1, type=int, help='|输入图片批量,要与导出的模型对应,一般为1|')
+parser.add_argument('--float16', default=True, type=bool, help='|推理数据类型,要与导出的模型对应,False时为float32|')
+args, _ = parser.parse_known_args()  # 防止传入参数冲突,替代args = parser.parse_args()
+# -------------------------------------------------------------------------------------------------------------------- #
+assert os.path.exists(args.model_path), f'! model_path不存在:{args.model_path} !'
+assert os.path.exists(args.data_path), f'! data_path不存在:{args.data_path} !'
+
+
+# -------------------------------------------------------------------------------------------------------------------- #
+def predict_trt(args):
+    # 加载模型
+    logger = tensorrt.Logger(tensorrt.Logger.WARNING)  # 创建日志记录信息
+    with tensorrt.Runtime(logger) as runtime, open(args.model_path, "rb") as f:
+        model = runtime.deserialize_cuda_engine(f.read())  # 读取模型并构建一个对象
+    np_type = tensorrt.nptype(model.get_tensor_dtype('input'))  # 获取接口的数据类型并转为np的字符串格式
+    h_input = np.zeros(tensorrt.volume(model.get_tensor_shape('input')), dtype=np_type)  # 获取输入的形状(一维)
+    h_output = np.zeros(tensorrt.volume(model.get_tensor_shape('output')), dtype=np_type)  # 获取输出的形状(一维)
+    d_input = cuda.mem_alloc(h_input.nbytes)  # 分配显存空间
+    d_output = cuda.mem_alloc(h_output.nbytes)  # 分配显存空间
+    bindings = [int(d_input), int(d_output)]  # 绑定显存输入输出
+    stream = cuda.Stream()  # 创建cuda流
+    model_context = model.create_execution_context()  # 创建模型推理器
+    print(f'| 加载模型成功:{args.model_path} |')
+    # 加载数据
+    start_time = time.time()
+    transform = albumentations.Compose([
+        albumentations.LongestMaxSize(args.input_size),
+        albumentations.PadIfNeeded(min_height=args.input_size, min_width=args.input_size,
+                                   border_mode=cv2.BORDER_CONSTANT, value=(128, 128, 128))])
+    image_dir = sorted(os.listdir(args.data_path))
+    image_list = [0 for _ in range(len(image_dir))]
+    for i in range(len(image_dir)):
+        image = cv2.imread(args.data_path + '/' + image_dir[i])
+        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)  # 转为RGB通道
+        image = transform(image=image)['image'].reshape(-1).astype(
+            np.float16 if args.float16 else np.float32)  # 缩放和填充图片(归一化、减均值、除以方差、调维度等在模型中完成)
+        image_list[i] = image
+    end_time = time.time()
+    print('| 数据加载成功:{} 每张耗时:{:.4f} |'.format(len(image_list), (end_time - start_time) / len(image_list)))
+    # 推理
+    start_time = time.time()
+    result = [0 for _ in range(len(image_list))]
+    for i in range(len(image_list)):
+        cuda.memcpy_htod_async(d_input, image_list[i], stream)  # 将输入数据从CPU锁存复制到GPU显存
+        model_context.execute_async_v2(bindings=bindings, stream_handle=stream.handle)  # 执行推理
+        cuda.memcpy_dtoh_async(h_output, d_output, stream)  # 将输出数据从GPU显存复制到CPU锁存
+        stream.synchronize()  # 同步线程
+        result[i] = [round(_, 2) for _ in h_output.tolist()]
+        print(f'| {image_dir[i]}:{result[i]} |')
+    end_time = time.time()
+    print('| 数据:{} 批量:{} 每张耗时:{:.4f} |'.format(len(image_list), args.batch, (end_time - start_time) / len(image_list)))
+
+
+if __name__ == '__main__':
+    predict_trt(args)

+ 34 - 0
requirement

@@ -0,0 +1,34 @@
+# (兼容性较强,可以安装最新版的库,如果遇到一两个库有冲突再降低版本即可)
+
+# cuda安装:
+# 命令窗口使用:nvidia-smi查看显卡版本
+# 找到对应的cuda版本(显卡驱动版本向下兼容):https://docs.nvidia.com/cuda/cuda-toolkit-release-notes/index.html
+# 下载对应cuda安装包并安装:https://developer.nvidia.com/cuda-toolkit-archive
+
+# cudnn安装:
+# 下载对应cudnn安装包:https://developer.nvidia.com/rdp/cudnn-archive
+# 将bin、include、lib(linux版没有bin)是所需要的文件,复制放到cuda的development中的bin、include、lib中。卸载时删除其中cudnn的文件即可
+
+# 1,训练:
+# pip install ... -i https://pypi.tuna.tsinghua.edu.cn/simple
+# 在torch官方找到对应的版本安装:https://pytorch.org/get-started/previous-versions/
+
+# 2,onnx导出和推理:
+# pip install onnxruntime-gpu onnx-simplifier -i https://pypi.tuna.tsinghua.edu.cn/simple
+
+# 3,trt导出和推理:
+# trt官网下载对应版本的安装包:https://developer.nvidia.com/nvidia-tensorrt-8x-download
+# 只需要压缩包中的include、lib文件,然后手动导入lib文件路径到系统路径中。bin中是官方提供的onnx转trt程序
+# windows为:系统->高级系统设置->环境变量->系统变量->Path中加入
+# linux为:sudo ldconfig lib位置
+# 然后找到对应版本的whl文件使用pip install ....whl
+
+cv2
+timm
+tqdm
+wandb
+torch
+numpy
+albumentations
+qrcode
+pyzbar

+ 146 - 0
run.py

@@ -0,0 +1,146 @@
+# 数据需准备成以下格式
+# ├── 数据集路径:data_path
+#     └── image:存放所有图片
+#     └── train.txt:训练图片的绝对路径(或相对data_path下路径)和类别号,(image/mask/0.jpg 0 2\n)表示该图片类别为0和2,空类别图片无类别号
+#     └── val.txt:验证图片的绝对路径(或相对data_path下路径)和类别
+#     └── class.txt:所有的类别名称
+# class.csv内容如下:
+# 类别1
+# 类别2
+# ...
+# -------------------------------------------------------------------------------------------------------------------- #
+# 分布式数据并行训练:
+# python -m torch.distributed.launch --master_port 9999 --nproc_per_node n run.py --distributed True
+# master_port为GPU之间的通讯端口,空闲的即可
+# n为GPU数量
+# -------------------------------------------------------------------------------------------------------------------- #
+import os
+import wandb
+import torch
+import argparse
+from block.data_get import data_get
+from block.loss_get import loss_get
+from block.model_get import model_get
+from block.train_get import train_get
+
+# -------------------------------------------------------------------------------------------------------------------- #
+# 模型加载/创建的优先级为:加载已有模型>创建剪枝模型>创建timm库模型>创建自定义模型
+parser = argparse.ArgumentParser(description='|针对分类任务,添加水印机制,包含数据隐私、模型水印|')
+parser.add_argument('--wandb', default=False, type=bool, help='|是否使用wandb可视化|')
+parser.add_argument('--wandb_project', default='classification', type=str, help='|wandb项目名称|')
+parser.add_argument('--wandb_name', default='train', type=str, help='|wandb项目中的训练名称|')
+parser.add_argument('--wandb_image_num', default=16, type=int, help='|wandb保存图片的数量|')
+
+# new_added
+parser.add_argument('--data_path', default='/home/yhsun/classification-main/dataset', type=str, help='Root path to datasets')
+parser.add_argument('--dataset_name', default='CIFAR-10', type=str, help='Specific dataset name')
+parser.add_argument('--input_channels', default=3, type=int)
+parser.add_argument('--output_num', default=10, type=int)
+# parser.add_argument('--input_size', default=32, type=int)
+#黑盒水印植入,这里需要调用它,用于处理部分数据的
+parser.add_argument('--trigger_label', type=int, default=2, help='The NO. of trigger label (int, range from 0 to 10, default: 0)')
+#这里可以直接选择水印控制,看看如何选择调用进来
+parser.add_argument('--watermarking_portion', type=float, default=0.1, help='poisoning portion (float, range from 0 to 1, default: 0.1)')
+
+# 待修改
+parser.add_argument('--input_size', default=32, type=int, help='|输入图片大小|')
+# 待修改
+parser.add_argument('--output_class', default=10, type=int, help='|输出的类别数|')
+parser.add_argument('--weight', default='last.pt', type=str, help='|已有模型的位置,没找到模型会创建剪枝/新模型|')
+
+
+# 剪枝的处理部分
+parser.add_argument('--prune', default=False, type=bool, help='|模型剪枝后再训练(部分模型有),需要提供prune_weight|')
+parser.add_argument('--prune_weight', default='best.pt', type=str, help='|模型剪枝的参考模型,会创建剪枝模型和训练模型|')
+parser.add_argument('--prune_ratio', default=0.5, type=float, help='|模型剪枝时的保留比例|')
+parser.add_argument('--prune_save', default='prune_best.pt', type=str, help='|保存最佳模型,每轮还会保存prune_last.pt|')
+
+
+# 模型处理的部分了
+parser.add_argument('--timm', default=False, type=bool, help='|是否使用timm库创建模型|')
+parser.add_argument('--model', default='mobilenetv2', type=str, help='|自定义模型选择,timm为True时为timm库中模型|')
+parser.add_argument('--model_type', default='s', type=str, help='|自定义模型型号|')
+parser.add_argument('--save_path', default='./checkpoints/mobilenetv2/best.pt', type=str, help='|保存最佳模型,除此之外每轮还会保存last.pt|')
+parser.add_argument('--save_path_last', default='./checkpoints/mobilenetv2/last.pt', type=str, help='|保存最佳模型,除此之外每轮还会保存last.pt|')
+
+# 训练控制
+parser.add_argument('--epoch', default=20, type=int, help='|训练总轮数(包含之前已训练轮数)|')
+parser.add_argument('--batch', default=100, type=int, help='|训练批量大小,分布式时为总批量|')
+parser.add_argument('--loss', default='bce', type=str, help='|损失函数|')
+parser.add_argument('--warmup_ratio', default=0.01, type=float, help='|预热训练步数占总步数比例,最少5步,基准为0.01|')
+parser.add_argument('--lr_start', default=0.001, type=float, help='|初始学习率,adam算法,批量小时要减小,基准为0.001|')
+parser.add_argument('--lr_end_ratio', default=0.01, type=float, help='|最终学习率=lr_end_ratio*lr_start,基准为0.01|')
+parser.add_argument('--lr_end_epoch', default=100, type=int, help='|最终学习率达到的轮数,每一步都调整,余玄下降法|')
+parser.add_argument('--regularization', default='L2', type=str, help='|正则化,有L2、None|')
+parser.add_argument('--r_value', default=0.0005, type=float, help='|正则化权重系数,基准为0.0005|')
+parser.add_argument('--device', default='cuda', type=str, help='|训练设备|')
+parser.add_argument('--latch', default=True, type=bool, help='|模型和数据是否为锁存,True为锁存|')
+parser.add_argument('--num_worker', default=0, type=int, help='|CPU处理数据的进程数,0只有一个主进程,一般为0、2、4、8|')
+parser.add_argument('--ema', default=True, type=bool, help='|使用平均指数移动(EMA)调整参数|')
+parser.add_argument('--amp', default=True, type=bool, help='|混合float16精度训练,CPU时不可用,出现nan可能与GPU有关|')
+parser.add_argument('--noise', default=0.5, type=float, help='|训练数据加噪概率|')
+parser.add_argument('--class_threshold', default=0.5, type=float, help='|计算指标时,大于阈值判定为图片有该类别|')
+parser.add_argument('--distributed', default=False, type=bool, help='|单机多卡分布式训练,分布式训练时batch为总batch|')
+parser.add_argument('--local_rank', default=0, type=int, help='|分布式训练使用命令后会自动传入的参数|')
+args = parser.parse_args()
+args.device_number = max(torch.cuda.device_count(), 2)  # 使用的GPU数,可能为CPU
+
+# 创建模型对应的检查点目录
+checkpoint_dir = os.path.join('/home/yhsun/classification-main/checkpoints', args.model)
+if not os.path.exists(checkpoint_dir):
+    os.makedirs(checkpoint_dir)
+print(f"模型保存路径已创建: {args.model}")
+
+# 为CPU设置随机种子
+torch.manual_seed(999)
+# 为所有GPU设置随机种子
+torch.cuda.manual_seed_all(999)
+# 固定每次返回的卷积算法
+torch.backends.cudnn.deterministic = True
+# cuDNN使用非确定性算法
+torch.backends.cudnn.enabled = True
+# 训练前cuDNN会先搜寻每个卷积层最适合实现它的卷积算法,加速运行;但对于复杂变化的输入数据,可能会有过长的搜寻时间,对于训练比较快的网络建议设为False
+torch.backends.cudnn.benchmark = False
+# wandb可视化:https://wandb.ai
+if args.wandb and args.local_rank == 0:  # 分布式时只记录一次wandb
+    args.wandb_run = wandb.init(project=args.wandb_project, name=args.wandb_name, config=args)
+# 混合float16精度训练
+if args.amp:
+    args.amp = torch.cuda.amp.GradScaler()
+# 分布式训练
+if args.distributed:
+    torch.distributed.init_process_group(backend='nccl')  # 分布式训练初始化
+    args.device = torch.device("cuda", args.local_rank)
+# -------------------------------------------------------------------------------------------------------------------- #
+
+# 判定数据库内信息是否齐全
+if args.local_rank == 0:
+    print(f'| args:{args} |')
+    assert os.path.exists(f'{args.data_path}/{args.dataset_name}'), '! data_path中缺少:{args.dataset_name} !'
+    assert os.path.exists(f'{args.data_path}/{args.dataset_name}/train.txt'), '! data_path中缺少:train.txt !'
+    assert os.path.exists(f'{args.data_path}/{args.dataset_name}/test.txt'), '! data_path中缺少:test.txt !'
+    assert os.path.exists(f'{args.data_path}/{args.dataset_name}/class.txt'), '! data_path中缺少:class.txt !'
+    if os.path.exists(args.weight):  # 优先加载已有模型args.weight继续训练
+        print(f'| 加载已有模型:{args.weight} |')
+    elif args.prune:
+        print(f'| 加载模型+剪枝训练:{args.prune_weight} |')
+    elif args.timm:  # 创建timm库中模型args.timm
+        import timm
+
+        assert timm.list_models(args.model), f'! timm中没有模型:{args.model},使用timm.list_models()查看所有模型 !'
+        print(f'| 创建timm库中模型:{args.model} |')
+    else:  # 创建自定义模型args.model
+        assert os.path.exists(f'model/{args.model}.py'), f'! 没有自定义模型:{args.model} !'
+        print(f'| 创建自定义模型:{args.model} | 型号:{args.model_type} |')
+# -------------------------------------------------------------------------------------------------------------------- #
+if __name__ == '__main__':
+    # 摘要
+    print(f'| args:{args} |') if args.local_rank == 0 else None
+    # 数据
+    data_dict = data_get(args)
+    # 模型
+    model_dict = model_get(args)
+    # 损失
+    loss = loss_get(args)
+    # 训练
+    train_get(args, data_dict, model_dict, loss)

+ 43 - 0
tool/check_image.py

@@ -0,0 +1,43 @@
+# 数据需准备成以下格式
+# ├── 数据集路径:data_path
+#     └── image:存放所有图片
+#     └── train.txt:训练图片的绝对路径(或相对data_path下路径)和类别号,如-->image/mask/0.jpg 0 2<--表示该图片类别为0和2,空类别图片无类别号
+#     └── val.txt:验证图片的绝对路径(或相对data_path下路径)和类别
+#     └── class.txt:所有的类别名称
+import os
+import tqdm
+import argparse
+from concurrent.futures import ThreadPoolExecutor
+
+# -------------------------------------------------------------------------------------------------------------------- #
+# 设置
+parser = argparse.ArgumentParser(description='检查标签train.txt和val.txt中的图片是否存在')
+parser.add_argument('--data_path', default=r'D:\dataset\classification\mask', type=str, help='|数据集根目录|')
+args = parser.parse_args()
+args.train = args.data_path + '/' + 'train.txt'
+args.val = args.data_path + '/' + 'val.txt'
+
+
+# -------------------------------------------------------------------------------------------------------------------- #
+# 程序
+def _check_image(image_path):
+    if not os.path.exists(image_path):
+        print(f'没有找到图片:{image_path}')
+        args.record += 1
+    args.tqdm_show.update(1)
+
+
+def check_image(txt_path):
+    with open(txt_path)as f:
+        image_path_list = [_.strip().split(' ')[0] for _ in f.readlines()]
+    args.record = 0
+    args.tqdm_show = tqdm.tqdm(total=len(image_path_list))
+    with ThreadPoolExecutor() as executer:
+        executer.map(_check_image, image_path_list)
+    args.tqdm_show.close()
+    print(f'| {txt_path}找到图片数:{len(image_path_list) - args.record} 缺失图片数:{args.record} |')
+
+
+if __name__ == '__main__':
+    check_image(args.train)
+    check_image(args.val)

+ 56 - 0
tool/generate_txt.py

@@ -0,0 +1,56 @@
+import os
+from torchvision import datasets
+
+'''
+    为数据集生成对应的txt文件
+'''
+
+
+def gen_txt(txt_path, img_dir):
+    f = open(txt_path, 'w')
+    classes = []  # 列表存储所有类名
+    for root, s_dirs, _ in os.walk(img_dir, topdown=True):  # 获取 train文件下各文件夹名称
+        j = 0
+        for sub_dir in s_dirs:
+            classes.append(sub_dir)  # 将类名添加到列表
+            i_dir = os.path.join(root, sub_dir)             # 获取各类的文件夹 绝对路径
+            print(i_dir)
+            img_list = os.listdir(i_dir)                    # 获取类别文件夹下所有png图片的路径
+            for i in range(len(img_list)):
+                if not img_list[i].endswith('jpg'):         # 若不是png文件,跳过
+                    continue
+                label = str(j)
+                img_path = os.path.join(i_dir, img_list[i])
+                line = img_path + ' ' + label + '\n'
+                f.write(line)
+            j+=1
+    f.close()
+    return classes
+
+def write_class_list(classes, class_txt_path):
+    with open(class_txt_path, 'w') as f:
+        for cls in sorted(classes):
+            f.write(cls + '\n')
+
+
+if __name__ == '__main__':
+    import argparse
+
+    parser = argparse.ArgumentParser(description='')
+    parser.add_argument('--txt_path', default='./dataset/New_dataset', type=str, help='path to new datasets')
+    parser.add_argument('--specific_data', default='testtest', type=str, help='process the file_name')
+    parser.add_argument('--txt_name', default='train', type=str, help='process the file_name')
+    # parser.add_argument('--class_txt_path', default='./dataset/New_dataset', type=str, help='class.txt')
+    args = parser.parse_args()
+
+    train_txt_path = os.path.join(args.txt_path, f"{args.txt_name}.txt")
+
+    train_dir = os.path.join(args.txt_path, args.specific_data)
+
+    # valid_txt_path = os.path.join('/home/yhsun/classification-main/dataset/CIFAR-10', "test_png.txt")
+    # valid_dir = os.path.join('/home/yhsun/classification-main/dataset/CIFAR-10', 'test_cifar10_PNG')
+    class_txt_path = os.path.join(args.txt_path, "class.txt")
+
+    classes = gen_txt(train_txt_path, train_dir)
+    # gen_txt(valid_txt_path, valid_dir)
+    write_class_list(classes, class_txt_path)

+ 137 - 0
tool/make_flip_image.py

@@ -0,0 +1,137 @@
+# 制作翻转的图片,同时创建它们的标签,用于检测图片是否翻转的4分类任务
+import os
+import cv2
+import tqdm
+import random
+import argparse
+from scipy import ndimage
+
+# -------------------------------------------------------------------------------------------------------------------- #
+# 设置
+parser = argparse.ArgumentParser()
+parser.add_argument('--image_path', default=r'D:\dataset\classification\flip\image\000', type=str)
+parser.add_argument('--save_path', default=r'D:\dataset\classification\flip\image', type=str)
+parser.add_argument('--file_path', default=r'D:\dataset\classification\flip', type=str)
+parser.add_argument('--add0', default=True, type=bool, help='|增加色彩变换|')
+parser.add_argument('--add1', default=True, type=bool, help='|增加角度倾斜变换|')
+parser.add_argument('--divide', default=r'9,1', type=str)
+args = parser.parse_args()
+
+
+# -------------------------------------------------------------------------------------------------------------------- #
+# 程序
+def resize(image, max_h=1000):  # 用于缩小图片大小,max_h为最大高度
+    h, w, _ = image.shape
+    h1 = max_h
+    w1 = int(h1 / h * w)
+    if h > h1:
+        image = cv2.resize(image, (w1, h1))
+    return image
+
+
+def left(image):  # 逆时针转90度
+    image = cv2.transpose(image)
+    image = cv2.flip(image, 0)
+    return image
+
+
+def right(image):  # 顺时针转90度
+    image = cv2.transpose(image)
+    image = cv2.flip(image, 1)
+    return image
+
+
+def flip(image):  # 顺时针转180度
+    image = cv2.flip(image, -1)
+    return image
+
+
+def rotate(image):
+    image = ndimage.rotate(image, random.randint(-2, 2))  # 逆时针旋转几度
+    return image
+
+
+if __name__ == '__main__':
+    if not os.path.exists(args.save_path + '/270'):
+        os.makedirs(args.save_path + '/270')
+    if not os.path.exists(args.save_path + '/090'):
+        os.makedirs(args.save_path + '/090')
+    if not os.path.exists(args.save_path + '/180'):
+        os.makedirs(args.save_path + '/180')
+    path_list = os.listdir(args.image_path)
+    path_list = [f'{args.image_path}/{_}' for _ in path_list]
+    A_list = []
+    B_list = []
+    C_list = []
+    D_list = []
+    for i, image_path in enumerate(tqdm.tqdm(path_list)):
+        image = cv2.imread(image_path)
+        image = resize(image)
+        image_left = left(image)
+        image_right = right(image)
+        image_flip = flip(image)
+        index = str(i).rjust(3, '0')
+        save_left = args.save_path + f'/270/{index}_left.jpg'
+        save_right = args.save_path + f'/090/{index}_right.jpg'
+        save_flip = args.save_path + f'/180/{index}_flip.jpg'
+        cv2.imwrite(save_left, image_left)
+        cv2.imwrite(save_right, image_right)
+        cv2.imwrite(save_flip, image_flip)
+        A_list.append(image_path + ' 0\n')
+        B_list.append(save_left + ' 3\n')
+        C_list.append(save_right + ' 1\n')
+        D_list.append(save_flip + ' 2\n')
+        # 色彩变换
+        if args.add0:
+            A_rgb = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)
+            B_rgb = cv2.cvtColor(image_left, cv2.COLOR_RGB2BGR)
+            C_rgb = cv2.cvtColor(image_right, cv2.COLOR_RGB2BGR)
+            D_rgb = cv2.cvtColor(image_flip, cv2.COLOR_RGB2BGR)
+            A_rgb_path = image_path.split('.')[0] + '_bgr.jpg'
+            B_rgb_path = args.save_path + f'/270/{index}_left_bgr.jpg'
+            C_rgb_path = args.save_path + f'/090/{index}_right_bgr.jpg'
+            D_rgb_path = args.save_path + f'/180/{index}_flip_bgr.jpg'
+            cv2.imwrite(A_rgb_path, A_rgb)
+            cv2.imwrite(B_rgb_path, B_rgb)
+            cv2.imwrite(C_rgb_path, C_rgb)
+            cv2.imwrite(D_rgb_path, D_rgb)
+            A_list.append(A_rgb_path + ' 0\n')
+            B_list.append(B_rgb_path + ' 3\n')
+            C_list.append(C_rgb_path + ' 1\n')
+            D_list.append(D_rgb_path + ' 2\n')
+        # 角度变换
+        if args.add1:
+            A_rotate = rotate(image)
+            B_rotate = rotate(image_left)
+            C_rotate = rotate(image_right)
+            D_rotate = rotate(image_flip)
+            A_rotate_path = image_path.split('.')[0] + '_rotate.jpg'
+            B_rotate_path = args.save_path + f'/270/{index}_left_rotate.jpg'
+            C_rotate_path = args.save_path + f'/090/{index}_right_rotate.jpg'
+            D_rotate_path = args.save_path + f'/180/{index}_flip_rotate.jpg'
+            cv2.imwrite(A_rotate_path, A_rotate)
+            cv2.imwrite(B_rotate_path, B_rotate)
+            cv2.imwrite(C_rotate_path, C_rotate)
+            cv2.imwrite(D_rotate_path, D_rotate)
+            A_list.append(A_rotate_path + ' 0\n')
+            B_list.append(B_rotate_path + ' 3\n')
+            C_list.append(C_rotate_path + ' 1\n')
+            D_list.append(D_rotate_path + ' 2\n')
+    a, b = list(map(int, args.divide.split(',')))
+    data_len = len(A_list)
+    random.shuffle(A_list)
+    random.shuffle(B_list)
+    random.shuffle(C_list)
+    random.shuffle(D_list)
+    train_number = int(data_len * a / (a + b))
+    val_number = int(data_len * b / (a + b))
+    with open(args.file_path + '/train.txt', 'w', encoding='utf-8') as f:
+        f.writelines(A_list[0:train_number])
+        f.writelines(B_list[0:train_number])
+        f.writelines(C_list[0:train_number])
+        f.writelines(D_list[0:train_number])
+    with open(args.file_path + '/val.txt', 'w', encoding='utf-8') as f:
+        f.writelines(A_list[0:val_number])
+        f.writelines(B_list[0:val_number])
+        f.writelines(C_list[0:val_number])
+        f.writelines(D_list[0:val_number])

+ 35 - 0
tool/make_txt.py

@@ -0,0 +1,35 @@
+import os
+import argparse
+
+# -------------------------------------------------------------------------------------------------------------------- #
+# 设置
+parser = argparse.ArgumentParser(description='将文件夹中的图片和类别号按比例添加到train.txt和val.txt中')
+parser.add_argument('--image_path', default=r'D:\dataset\classification\mask\image\mask', type=str, help='|图片所在目录|')
+parser.add_argument('--add', default=' 0', type=str, help='|标签内容为[图片绝对路径+add]|')
+parser.add_argument('--divide', default='9,1', type=str, help='|图片划分到train.txt和val.txt的比例|')
+args = parser.parse_args()
+
+# -------------------------------------------------------------------------------------------------------------------- #
+# 程序
+if __name__ == '__main__':
+    if not os.path.exists('train.txt'):
+        with open('train.txt', 'w')as f:
+            pass
+    if not os.path.exists('val.txt'):
+        with open('val.txt', 'w')as f:
+            pass
+    image_dir = sorted(os.listdir(args.image_path))
+    args.divide = list(map(int, args.divide.split(',')))
+    boundary = int(len(image_dir) * args.divide[0] / (args.divide[0] + args.divide[1]))
+    with open('train.txt', 'a')as f:
+        write_line = []
+        for i in range(boundary):
+            label = args.image_path + '/' + image_dir[i] + args.add
+            write_line.append(label + '\n')
+        f.writelines(write_line)
+    with open('val.txt', 'a')as f:
+        write_line = []
+        for i in range(boundary, len(image_dir)):
+            label = args.image_path + '/' + image_dir[i] + args.add
+            write_line.append(label + '\n')
+        f.writelines(write_line)

+ 327 - 0
tool/watermarking_data_process.py

@@ -0,0 +1,327 @@
+# watermarking_data_process.py
+# 本py文件主要用于数据隐私保护以及watermarking_trigger的插入。
+
+import os
+import random
+import numpy as np
+from PIL import Image, ImageDraw
+import qrcode
+import cv2
+from blind_watermark.blind_watermark import WaterMark
+# from pyzbar.pyzbar import decode 
+
+def is_hex_string(s):
+    """检查字符串是否只包含有效的十六进制字符"""
+    try:
+        int(s, 16)  # 尝试将字符串解析为十六进制数字
+    except ValueError:
+        return False  # 如果解析失败,说明字符串不是有效的十六进制格式
+    else:
+        return True  # 如果解析成功,则说明字符串是有效的十六进制格式
+
+
+
+def generate_random_key_and_qrcodes(key_size=512, watermarking_dir='./dataset/watermarking/'):
+    """
+    生成指定大小的随机密钥,并将其分割成10份,每份生成一个二维码保存到指定目录。
+    """
+    # 生成指定字节大小的随机密钥
+    key = os.urandom(key_size)
+    key_hex = key.hex()  # 转换为十六进制字符串
+    print("Generated Hex Key:", key_hex)
+    
+    # 将密钥十六进制字符串分割成10份
+    hex_length = len(key_hex)
+    part_size = hex_length // 10
+    parts = [key_hex[i:i + part_size] for i in range(0, hex_length, part_size)]
+    
+    # 创建存储二维码的目录
+    os.makedirs(watermarking_dir, exist_ok=True)
+    # 保存十六进制密钥到文件
+    with open(os.path.join(watermarking_dir, f"key_hex.txt"), 'w') as file:
+        file.write(key_hex)
+    print(f"Saved hex key to {os.path.join(watermarking_dir, f'key_hex.txt')}")
+    
+    # 生成并保存二维码
+    for idx, part in enumerate(parts, start=1):
+        qr = qrcode.QRCode(
+            version=1,
+            error_correction=qrcode.constants.ERROR_CORRECT_L,
+            box_size=2,
+            border=1
+        )
+        qr.add_data(part)
+        qr.make(fit=True)
+        img = qr.make_image(fill_color="black", back_color="white")
+        img.save(os.path.join(watermarking_dir, f"{idx}.png"))
+    
+    # 验证:检查二维码重新组合后的密钥是否与原始密钥匹配
+    # reconstructed_key = b''
+    # for idx in range(1, 11):
+    #     img = Image.open(os.path.join(watermarking_dir, f"{idx}.png"))
+    #     data = decode(img)
+    #     if data:
+    #         decoded_data = data[0].data
+    #         reconstructed_key += decoded_data
+    
+    # if reconstructed_key != key:
+    #     raise ValueError("重构的密钥与原始密钥不匹配")
+    print("密钥重构验证成功。")
+
+def watermark_dataset_with_bits(key_path, dataset_txt_path, dataset_name):
+
+    """
+        利用调用的水印的bits来完成对所有的图片进行植入,其操作步骤如下:
+        1. 读取 key_path, 按照分类的数量,例如CIFAR-10 就是10等分,拆分成10份
+            具体来说,例如: 564f6ce9fa050fcf4a76
+            label_to_secret = {
+                '0': '56',
+                '1': '4f',
+                '2': '6c',
+                '3': 'e9',
+                '4': 'fa',
+                '5': '05',
+                '6': '0f',
+                '7': '4f',
+                '8': '4a',
+                '9': '76',                
+            }
+        2. 读取dataset_txt_path, 按照每行图片的绝对路径以及 图片对应的label
+        3. 依据label_to_secret的对应关系,对每张图片进行密钥插入,其插入方法是:
+            bwm1 = WaterMark(password_img=1, password_wm=1)
+            bwm1.read_img('图片的绝对路径')
+            wm = label_to_secret[label]
+            bwm1.read_wm(wm, mode='str')
+            bwm1.embed('图片的绝对路径')
+        以此来完成密钥的对应植入,最后完成的效果应该是。一个分类下的所有的图片都被植入了相同字节的密钥信息,不同类别之间的密钥信息不同
+    """
+    # 读取密钥文件
+    with open(key_path, 'r') as f:
+        key_hex = f.read().strip()
+    print(key_hex)
+
+    # 将密钥分割成分类数量份
+    part_size = len(key_hex) // 10
+    label_to_secret = {str(i): key_hex[i*part_size:(i+1)*part_size] for i in range(10)}
+    print(label_to_secret)
+    # 逐行读取数据集文件
+    with open(dataset_txt_path, 'r') as f:
+        lines = f.readlines()
+    
+    # 遍历每一行,对图片进行水印插入
+    for line in lines:
+        img_path, label = line.strip().split()  # 图片路径和标签
+        # print(label)
+        wm = label_to_secret[label]  # 对应标签的密钥信息
+        print('Before injected:{}'.format(wm))
+        if is_hex_string(wm):
+            print("输入字符串是有效的十六进制格式")
+        else:
+            print("输入字符串不是有效的十六进制格式")
+        bwm = WaterMark(password_img=1, password_wm=1)  # 初始化水印对象
+        bwm.read_img(img_path)  # 读取图片
+        bwm.read_wm(wm, mode='str')  # 读取水印信息
+        len_wm = len(bwm.wm_bit)  # 解水印需要用到长度
+        print('Put down the length of wm_bit {len_wm}'.format(len_wm=len_wm))
+        new_img_path = img_path.replace('train_cifar10_JPG', 'train_cifar10_PNG').replace('.jpg',  '.png')
+        print(new_img_path)
+        # save_path = os.path.join(img_path.replace('train_cifar10_JPG', 'train_cifar10_PNG').replace('.jpg',  '.png'))
+        bwm.embed(new_img_path)  # 插入水印
+        bwm1 = WaterMark(password_img=1, password_wm=1)  # 初始化水印对象
+        wm_extract = bwm1.extract(new_img_path, wm_shape=len_wm, mode='str')
+        
+        print('Injected Finished:{}'.format(wm_extract))
+
+    print(f"已完成{dataset_name}数据集数据的水印植入。")
+
+
+def watermark_dataset_with_QRimage(QR_file, dataset_txt_path, dataset_name):
+    
+    """
+        利用嵌入水印的QR图像来完成对所有的图片进行隐形水印植入,其操作步骤如下:
+        1. 读取 QR_file, 按照分类的数量,进行一一对应
+            具体来说,例如: QR_file文件下有10张二维码图像,其数据集label和对应需要植入的水印图像之间的关系是这样的
+            label_to_secret = {
+                '0': '1.png',
+                '1': '2.png',
+                '2': '3.png',
+                '3': '4.png',
+                '4': '5.png',
+                '5': '6.png',
+                '6': '7.png',
+                '7': '8.png',
+                '8': '9.png', 
+                '9': '10.png'              
+            }
+        2. 读取dataset_txt_path, 按照每行图片的绝对路径以及 图片对应的label
+            
+        3. 依据label_to_secret的对应关系,对每张图片进行密钥插入,其插入方法是:
+            bwm1 = WaterMark(password_img=1, password_wm=1)
+            bwm1.read_img('图片的绝对路径')
+            # 读取水印
+            bwm.read_wm(label_to_secret[label])
+            # 打上盲水印
+            bwm1.embed('图片的绝对路径')
+        以此来完成密钥的对应植入,最后完成的效果应该是。一个分类下的所有的图片都被植入了相同字节的密钥信息,不同类别之间的密钥信息不同
+    """
+    label_to_secret = {
+                '0': '1.png',
+                '1': '2.png',
+                '2': '3.png',
+                '3': '4.png',
+                '4': '5.png',
+                '5': '6.png',
+                '6': '7.png',
+                '7': '8.png',
+                '8': '9.png', 
+                '9': '10.png'              
+            }
+
+    # 逐行读取数据集文件
+    with open(dataset_txt_path, 'r') as f:
+        lines = f.readlines()
+    
+    # 遍历每一行,对图片进行水印插入
+    for line in lines:
+        img_path, label = line.strip().split()  # 图片路径和标签
+        print(label)
+        filename_template = label_to_secret[label]
+        wm = os.path.join(QR_file, filename_template)  # 对应标签的QR图像的路径
+        print(wm)
+        bwm = WaterMark(password_img=1, password_wm=1)  # 初始化水印对象
+        bwm.read_img(img_path)  # 读取图片
+        # 读取水印
+        bwm.read_wm(wm)
+        new_img_path = img_path.replace('testtest', '123').replace('.jpg',  '.png')
+        print(new_img_path)
+        # save_path = os.path.join(img_path.replace('train_cifar10_JPG', 'train_cifar10_PNG').replace('.jpg',  '.png'))
+        bwm.embed(new_img_path)  # 插入水印
+        # wm_shape = cv2.imread(wm, flags=cv2.IMREAD_GRAYSCALE).shape
+        # bwm1 = WaterMark(password_wm=1, password_img=1)
+        # wm_new = wm.replace('watermarking', 'extracted')
+        # bwm1.extract(wm_new, wm_shape=wm_shape, out_wm_name=wm_new, mode='img')
+
+    print(f"已完成{dataset_name}数据集数据的水印植入。")
+
+
+
+
+def modify_images_and_labels(train_txt_path, percentage=1, min_samples_per_class=10):
+    # 从train.txt读取图片路径和标签
+    with open(train_txt_path, 'r') as file:
+        lines = file.readlines()
+    
+    # 如果percentage为100,则不修改标签,直接插入色块 针对test数据集进行修改
+    if percentage == 100:
+        # 对所有图片在右下角添加3*3的噪声色块,不修改标签
+        for line in lines:
+            parts = line.split()
+            image_path = parts[0]
+            print(image_path)
+            img = Image.open(image_path)
+            draw = ImageDraw.Draw(img)
+            noise_color = (128, 0, 128)
+            for x in range(img.width - 3, img.width):
+                for y in range(img.height - 3, img.height):
+                    draw.point((x, y), fill=noise_color)
+            new_image_path = image_path.replace('test_cifar10_PNG', 'test_cifar10_PNG_temp')
+            img.save(new_image_path)
+        print(f"已对所有图片插入了噪声色块,且未修改标签。")
+        return
+
+    
+    # 统计每个类别的图片数量
+    label_counts = {}
+    for line in lines:
+        label = line.strip().split()[-1]
+        label_counts[label] = label_counts.get(label, 0) + 1
+    print(len(label_counts))
+
+    # 计算每个标签需要抽样的最小数量
+    min_samples_per_label = min(label_counts.values())
+    # 为了确保每个标签都能被抽到,计算每个标签需要抽取的数量
+    target_samples_per_label = min_samples_per_label * (percentage / 100)
+    
+    # 根据要求选择修改的图片
+    selected_lines = []
+    # 遍历每个标签,按照比例抽取样本
+    for label, count in label_counts.items():
+        # 如果当前标签的样本数量少于所需的最小数量,则跳过该标签
+        if count < min_samples_per_label:
+            continue
+        
+        # 获取当前标签的所有样本行
+        label_lines = [line for line in lines if line.strip().split()[-1] == label]
+        # 随机抽取所需数量的样本
+        selected_label_lines = random.sample(label_lines, int(target_samples_per_label))
+        selected_lines.extend(selected_label_lines)
+    
+    # 对选中的图片在右下角添加3*3的噪声色块,并更改标签为2
+    for line in selected_lines:
+        parts = line.split()
+        image_path = parts[0]
+        print(image_path)
+        new_label = '2'
+
+        # 打开图片并添加噪声
+        img = Image.open(image_path)
+        draw = ImageDraw.Draw(img)
+        for x in range(img.width - 3, img.width):
+            for y in range(img.height - 3, img.height):
+                draw.point((x, y), fill=(128, 0, 128))
+
+        # 保存修改后的图片
+        # new_image_path = image_path.replace('train_cifar10_PNG', 'train_cifar10_PNG_temp')
+        img.save(image_path)
+        
+        # 更新train.txt中的标签(如果需要可以直接写回train.txt)
+        index = lines.index(line)
+        lines[index] = f"{image_path} {new_label}\n"
+
+    # 将更改写回train.txt
+    # temp_txt = 
+    with open(train_txt_path, 'w') as file:
+        file.writelines(lines)
+
+    print(f"已修改{len(selected_lines)}张图片并更新了标签。")
+
+if __name__ == '__main__':
+    # import argparse
+
+    # parser = argparse.ArgumentParser(description='')
+    # parser.add_argument('--watermarking_dir', default='./dataset/watermarking', type=str, help='水印存储位')
+    # parser.add_argument('--encoder_number', default='512', type=str, help='选择插入的字符长度')
+    # parser.add_argument('--key_path', default='./dataset/watermarking/key_hex.txt', type=str, help='密钥存储位')
+    # parser.add_argument('--dataset_txt_path', default='./dataset/CIFAR-10/train.txt', type=str, help='train or test')
+    # parser.add_argument('--dataset_name', default='CIFAR-10', type=str, help='CIFAR-10')
+
+
+
+
+
+    # 运行示例
+     # 测试密钥生成和二维码功能
+    # 功能1 完成以bits形式的水印密钥生成、水印密钥插入、水印模型数据预处理
+    watermarking_dir = '/home/yhsun/classification-main/dataset/watermarking'
+    generate_random_key_and_qrcodes(10, watermarking_dir)  # 生成128字节的密钥,并进行测试
+    noise_color = (128, 0, 128)
+    key_path = './dataset/watermarking/key_hex.txt'
+    dataset_txt_path = './dataset/CIFAR-10/train.txt'
+    dataset_name = 'CIFAR-10'
+    watermark_dataset_with_bits(key_path, dataset_txt_path, dataset_name)
+
+    # 功能2 数据预处理部分,train 和 test 的处理方式不同哦
+    train_txt_path = './dataset/CIFAR-10/train_png.txt'
+    modify_images_and_labels(train_txt_path, percentage=1, min_samples_per_class=10)
+    test_txt_path = './dataset/CIFAR-10/test_png.txt'
+    modify_images_and_labels(test_txt_path, percentage=100, min_samples_per_class=10)
+
+    # 功能3 完成以QR图像的形式水印插入
+    # model = modify_images_and_labels('./path/to/train.txt')
+    data_test_path = './dataset/New_dataset/testtest.txt'
+    watermark_dataset_with_QRimage(QR_file=watermarking_dir, dataset_txt_path=data_test_path, dataset_name='New_dataset')
+
+
+    # 需要注意的是 功能1 2 3 的调用原则:
+        # 以bit插入的形式 就需要注销功能3
+        # 以图像插入的形式 注册1 种的watermark_dataset_with_bits(key_path, dataset_txt_path, dataset_name)