Browse Source

初始化项目结构

liyan 1 year ago
parent
commit
fb3b320deb

+ 35 - 0
README.md

@@ -0,0 +1,35 @@
+## pytorch目标检测训练框架
+>代码兼容性较强,使用的是一些基本的库、基础的函数  
+>在argparse中可以选择使用wandb,能在wandb网站中生成可视化的训练过程  
+>测试时输入模型的图片会填充为固定大小、RGB通道(如batch,640,640,3),图片四周的填充值为(128,128,128)  
+### 1,环境
+>torch:https://pytorch.org/get-started/previous-versions/
+>```
+>pip install tqdm wandb opencv-python -i https://pypi.tuna.tsinghua.edu.cn/simple
+>```
+### 2,数据格式
+>(标准YOLO格式)  
+>├── 数据集路径:data_path  
+>    └── image:存放所有图片  
+>    └── label:存放所有图片的标签,名称:图片名.txt,内容:类别号 x_center y_center w h(x,y,w,h为相对图片的比例值)  
+>    └── train.txt:训练图片的绝对路径(或相对data_path下路径)  
+>    └── val.txt:验证图片的绝对路径(或相对data_path下路径)  
+>    └── class.txt:所有的类别名称  
+### 3,run.py
+>模型训练时运行该文件,argparse中有对每个参数的说明
+### 4,predict_pt.py
+>使用训练好的pt模型预测
+### 5,export_onnx.py
+>将pt模型导出为onnx模型
+### 6,flask_start.py
+>用flask将程序包装成一个服务,并在服务器上启动
+### 7,flask_request.py
+>以post请求传输数据调用服务
+### 8,gunicorn_config.py
+>用gunicorn多进程启动flask服务:gunicorn -c gunicorn_config.py flask_start:app
+### 其他
+>学习笔记:https://github.com/TWK2022/notebook
+
+### 分支说明
+master分支为华科大提供的原始代码
+demo分支为经过修改的demo示例

+ 2 - 0
bash_process.sh

@@ -0,0 +1,2 @@
+python tool/generate_txt.py
+# 注意修改对应的文件夹名称

+ 8 - 0
bash_run.sh

@@ -0,0 +1,8 @@
+# For 用于训练不同模型,以及保存相应的路径
+# -------------------------------------------------------------------------------------------------------------------- #
+python run.py --model 'yolov7' --save_path './checkpoints/yolov7/watermarking/best.pt' --save_path_last './checkpoints/yolov7/watermarking/last.pt' --epoch 100
+
+# For 用于剪枝模型,剪枝后微调训练,保存剪枝后模型路径,以及验证微调模型准确性
+# -------------------------------------------------------------------------------------------------------------------- #
+python run.py --model 'yolov7'  --prune True --prune_weight './checkpoints/yolov7/watermarking/best.pt' --prune_save './checkpoints/yolov7/watermarking/prune_best.pt' --epoch 40
+

+ 5 - 0
blind_watermark/__init__.py

@@ -0,0 +1,5 @@
+from .blind_watermark import WaterMark
+from .bwm_core import WaterMarkCore
+from .att import *
+from .recover import recover_crop
+from .version import __version__, bw_notes

+ 225 - 0
blind_watermark/att.py

@@ -0,0 +1,225 @@
+# coding=utf-8
+
+# attack on the watermark
+import cv2
+import numpy as np
+import warnings
+
+
+def cut_att3(input_filename=None, input_img=None, output_file_name=None, loc_r=None, loc=None, scale=None):
+    # 剪切攻击 + 缩放攻击
+    if input_filename:
+        input_img = cv2.imread(input_filename)
+
+    if loc is None:
+        h, w, _ = input_img.shape
+        x1, y1, x2, y2 = int(w * loc_r[0][0]), int(h * loc_r[0][1]), int(w * loc_r[1][0]), int(h * loc_r[1][1])
+    else:
+        x1, y1, x2, y2 = loc
+
+    # 剪切攻击
+    output_img = input_img[y1:y2, x1:x2].copy()
+
+    # 如果缩放攻击
+    if scale and scale != 1:
+        h, w, _ = output_img.shape
+        output_img = cv2.resize(output_img, dsize=(round(w * scale), round(h * scale)))
+    else:
+        output_img = output_img
+
+    if output_file_name:
+        cv2.imwrite(output_file_name, output_img)
+    return output_img
+
+
+cut_att2 = cut_att3
+
+
+def resize_att(input_filename=None, input_img=None, output_file_name=None, out_shape=(500, 500)):
+    # 缩放攻击:因为攻击和还原都是缩放,所以攻击和还原都调用这个函数
+    if input_filename:
+        input_img = cv2.imread(input_filename)
+    output_img = cv2.resize(input_img, dsize=out_shape)
+    if output_file_name:
+        cv2.imwrite(output_file_name, output_img)
+    return output_img
+
+
+def bright_att(input_filename=None, input_img=None, output_file_name=None, ratio=0.8):
+    # 亮度调整攻击,ratio应当多于0
+    # ratio>1是调得更亮,ratio<1是亮度更暗
+    if input_filename:
+        input_img = cv2.imread(input_filename)
+    output_img = input_img * ratio
+    output_img[output_img > 255] = 255
+    if output_file_name:
+        cv2.imwrite(output_file_name, output_img)
+    return output_img
+
+
+def shelter_att(input_filename=None, input_img=None, output_file_name=None, ratio=0.1, n=3):
+    # 遮挡攻击:遮挡图像中的一部分
+    # n个遮挡块
+    # 每个遮挡块所占比例为ratio
+    if input_filename:
+        output_img = cv2.imread(input_filename)
+    else:
+        output_img = input_img.copy()
+    input_img_shape = output_img.shape
+
+    for i in range(n):
+        tmp = np.random.rand() * (1 - ratio)  # 随机选择一个地方,1-ratio是为了防止溢出
+        start_height, end_height = int(tmp * input_img_shape[0]), int((tmp + ratio) * input_img_shape[0])
+        tmp = np.random.rand() * (1 - ratio)
+        start_width, end_width = int(tmp * input_img_shape[1]), int((tmp + ratio) * input_img_shape[1])
+
+        output_img[start_height:end_height, start_width:end_width, :] = 255
+
+    if output_file_name:
+        cv2.imwrite(output_file_name, output_img)
+    return output_img
+
+
+def salt_pepper_att(input_filename=None, input_img=None, output_file_name=None, ratio=0.01):
+    # 椒盐攻击
+    if input_filename:
+        input_img = cv2.imread(input_filename)
+    input_img_shape = input_img.shape
+    output_img = input_img.copy()
+    for i in range(input_img_shape[0]):
+        for j in range(input_img_shape[1]):
+            if np.random.rand() < ratio:
+                output_img[i, j, :] = 255
+    if output_file_name:
+        cv2.imwrite(output_file_name, output_img)
+    return output_img
+
+
+def rot_att(input_filename=None, input_img=None, output_file_name=None, angle=45):
+    # 旋转攻击
+    if input_filename:
+        input_img = cv2.imread(input_filename)
+    rows, cols, _ = input_img.shape
+    M = cv2.getRotationMatrix2D(center=(cols / 2, rows / 2), angle=angle, scale=1)
+    output_img = cv2.warpAffine(input_img, M, (cols, rows))
+    if output_file_name:
+        cv2.imwrite(output_file_name, output_img)
+    return output_img
+
+
+def cut_att_height(input_filename=None, input_img=None, output_file_name=None, ratio=0.8):
+    warnings.warn('will be deprecated in the future, use att.cut_att2 instead')
+    # 纵向剪切攻击
+    if input_filename:
+        input_img = cv2.imread(input_filename)
+    input_img_shape = input_img.shape
+    height = int(input_img_shape[0] * ratio)
+
+    output_img = input_img[:height, :, :]
+    if output_file_name:
+        cv2.imwrite(output_file_name, output_img)
+    return output_img
+
+
+def cut_att_width(input_filename=None, input_img=None, output_file_name=None, ratio=0.8):
+    warnings.warn('will be deprecated in the future, use att.cut_att2 instead')
+    # 横向裁剪攻击
+    if input_filename:
+        input_img = cv2.imread(input_filename)
+    input_img_shape = input_img.shape
+    width = int(input_img_shape[1] * ratio)
+
+    output_img = input_img[:, :width, :]
+    if output_file_name:
+        cv2.imwrite(output_file_name, output_img)
+    return output_img
+
+
+def cut_att(input_filename=None, output_file_name=None, input_img=None, loc=((0.3, 0.1), (0.7, 0.9)), resize=0.6):
+    warnings.warn('will be deprecated in the future, use att.cut_att2 instead')
+    # 截屏攻击 = 裁剪攻击 + 缩放攻击 + 知道攻击参数(按照参数还原)
+    # 裁剪攻击:其它部分都补0
+    if input_filename:
+        input_img = cv2.imread(input_filename)
+
+    output_img = input_img.copy()
+    shape = output_img.shape
+    x1, y1, x2, y2 = shape[0] * loc[0][0], shape[1] * loc[0][1], shape[0] * loc[1][0], shape[1] * loc[1][1]
+    output_img[:int(x1), :] = 255
+    output_img[int(x2):, :] = 255
+    output_img[:, :int(y1)] = 255
+    output_img[:, int(y2):] = 255
+
+    if resize is not None:
+        # 缩放一次,然后还原
+        output_img = cv2.resize(output_img,
+                                dsize=(int(shape[1] * resize), int(shape[0] * resize))
+                                )
+
+        output_img = cv2.resize(output_img, dsize=(int(shape[1]), int(shape[0])))
+
+    if output_file_name is not None:
+        cv2.imwrite(output_file_name, output_img)
+    return output_img
+
+
+# def cut_att2(input_filename=None, input_img=None, output_file_name=None, loc_r=((0.3, 0.1), (0.9, 0.9)), scale=1.1):
+#     # 截屏攻击 = 剪切攻击 + 缩放攻击 + 不知道攻击参数
+#     if input_filename:
+#         input_img = cv2.imread(input_filename)
+#     h, w, _ = input_img.shape
+#     x1, y1, x2, y2 = int(w * loc_r[0][0]), int(h * loc_r[0][1]), int(w * loc_r[1][0]), int(h * loc_r[1][1])
+#
+#     output_img = cut_att3(input_img=input_img, output_file_name=output_file_name,
+#                           loc=(x1, y1, x2, y2), scale=scale)
+#     return output_img, (x1, y1, x2, y2)
+
+def anti_cut_att_old(input_filename, output_file_name, origin_shape):
+    warnings.warn('will be deprecated in the future')
+    # 反裁剪攻击:复制一块范围,然后补全
+    # origin_shape 分辨率与约定理解的是颠倒的,约定的是列数*行数
+    input_img = cv2.imread(input_filename)
+    output_img = input_img.copy()
+    output_img_shape = output_img.shape
+    if output_img_shape[0] > origin_shape[0] or output_img_shape[0] > origin_shape[0]:
+        print('裁剪打击后的图片,不可能比原始图片大,检查一下')
+        return
+
+    # 还原纵向打击
+    while output_img_shape[0] < origin_shape[0]:
+        output_img = np.concatenate([output_img, output_img[:origin_shape[0] - output_img_shape[0], :, :]], axis=0)
+        output_img_shape = output_img.shape
+    while output_img_shape[1] < origin_shape[1]:
+        output_img = np.concatenate([output_img, output_img[:, :origin_shape[1] - output_img_shape[1], :]], axis=1)
+        output_img_shape = output_img.shape
+
+    cv2.imwrite(output_file_name, output_img)
+
+
+def anti_cut_att(input_filename=None, input_img=None, output_file_name=None, origin_shape=None):
+    warnings.warn('will be deprecated in the future, use att.cut_att2 instead')
+    # 反裁剪攻击:补0
+    # origin_shape 分辨率与约定理解的是颠倒的,约定的是列数*行数
+    if input_filename:
+        input_img = cv2.imread(input_filename)
+    output_img = input_img.copy()
+    output_img_shape = output_img.shape
+    if output_img_shape[0] > origin_shape[0] or output_img_shape[0] > origin_shape[0]:
+        print('裁剪打击后的图片,不可能比原始图片大,检查一下')
+        return
+
+    # 还原纵向打击
+    if output_img_shape[0] < origin_shape[0]:
+        output_img = np.concatenate(
+            [output_img, 255 * np.ones((origin_shape[0] - output_img_shape[0], output_img_shape[1], 3))]
+            , axis=0)
+        output_img_shape = output_img.shape
+
+    if output_img_shape[1] < origin_shape[1]:
+        output_img = np.concatenate(
+            [output_img, 255 * np.ones((output_img_shape[0], origin_shape[1] - output_img_shape[1], 3))]
+            , axis=1)
+
+    if output_file_name:
+        cv2.imwrite(output_file_name, output_img)
+    return output_img

+ 109 - 0
blind_watermark/blind_watermark.py

@@ -0,0 +1,109 @@
+#!/usr/bin/env python3
+# coding=utf-8
+# @Time    : 2020/8/13
+# @Author  : github.com/guofei9987
+import warnings
+
+import numpy as np
+import cv2
+
+from .bwm_core import WaterMarkCore
+from .version import bw_notes
+
+
+class WaterMark:
+    def __init__(self, password_wm=1, password_img=1, block_shape=(4, 4), mode='common', processes=None):
+        bw_notes.print_notes()
+
+        self.bwm_core = WaterMarkCore(password_img=password_img, mode=mode, processes=processes)
+
+        self.password_wm = password_wm
+
+        self.wm_bit = None
+        self.wm_size = 0
+
+    def read_img(self, filename=None, img=None):
+        if img is None:
+            # 从文件读入图片
+            img = cv2.imread(filename, flags=cv2.IMREAD_UNCHANGED)
+            assert img is not None, "image file '{filename}' not read".format(filename=filename)
+
+        self.bwm_core.read_img_arr(img=img)
+        return img
+
+    def read_wm(self, wm_content, mode='img'):
+        assert mode in ('img', 'str', 'bit'), "mode in ('img','str','bit')"
+        if mode == 'img':
+            wm = cv2.imread(filename=wm_content, flags=cv2.IMREAD_GRAYSCALE)
+            assert wm is not None, 'file "{filename}" not read'.format(filename=wm_content)
+
+            # 读入图片格式的水印,并转为一维 bit 格式,抛弃灰度级别
+            self.wm_bit = wm.flatten() > 128
+
+        elif mode == 'str':
+            byte = bin(int(wm_content.encode('utf-8').hex(), base=16))[2:]
+            self.wm_bit = (np.array(list(byte)) == '1')
+        else:
+            self.wm_bit = np.array(wm_content)
+
+        self.wm_size = self.wm_bit.size
+
+        # 水印加密:
+        np.random.RandomState(self.password_wm).shuffle(self.wm_bit)
+
+        self.bwm_core.read_wm(self.wm_bit)
+
+    def embed(self, filename=None, compression_ratio=None):
+        '''
+        :param filename: string
+            Save the image file as filename
+        :param compression_ratio: int or None
+            If compression_ratio = None, do not compression,
+            If compression_ratio is integer between 0 and 100, the smaller, the output file is smaller.
+        :return:
+        '''
+        embed_img = self.bwm_core.embed()
+        if filename is not None:
+            if compression_ratio is None:
+                cv2.imwrite(filename=filename, img=embed_img)
+            elif filename.endswith('.jpg'):
+                cv2.imwrite(filename=filename, img=embed_img, params=[cv2.IMWRITE_JPEG_QUALITY, compression_ratio])
+            elif filename.endswith('.png'):
+                cv2.imwrite(filename=filename, img=embed_img, params=[cv2.IMWRITE_PNG_COMPRESSION, compression_ratio])
+            else:
+                cv2.imwrite(filename=filename, img=embed_img)
+        return embed_img
+
+    def extract_decrypt(self, wm_avg):
+        wm_index = np.arange(self.wm_size)
+        np.random.RandomState(self.password_wm).shuffle(wm_index)
+        wm_avg[wm_index] = wm_avg.copy()
+        return wm_avg
+
+    def extract(self, filename=None, embed_img=None, wm_shape=None, out_wm_name=None, mode='img'):
+        assert wm_shape is not None, 'wm_shape needed'
+
+        if filename is not None:
+            embed_img = cv2.imread(filename, flags=cv2.IMREAD_COLOR)
+            assert embed_img is not None, "{filename} not read".format(filename=filename)
+
+        self.wm_size = np.array(wm_shape).prod()
+
+        if mode in ('str', 'bit'):
+            wm_avg = self.bwm_core.extract_with_kmeans(img=embed_img, wm_shape=wm_shape)
+        else:
+            wm_avg = self.bwm_core.extract(img=embed_img, wm_shape=wm_shape)
+
+        # 解密:
+        wm = self.extract_decrypt(wm_avg=wm_avg)
+
+        # 转化为指定格式:
+        if mode == 'img':
+            wm = 255 * wm.reshape(wm_shape[0], wm_shape[1])
+            cv2.imwrite(out_wm_name, wm)
+        elif mode == 'str':
+            byte = ''.join(str((i >= 0.5) * 1) for i in wm)
+            print("Byte value:", byte)
+            wm = bytes.fromhex(hex(int(byte, base=2))[2:]).decode('utf-8', errors='replace')
+
+        return wm

+ 232 - 0
blind_watermark/bwm_core.py

@@ -0,0 +1,232 @@
+#!/usr/bin/env python3
+# coding=utf-8
+# @Time    : 2021/12/17
+# @Author  : github.com/guofei9987
+import numpy as np
+from numpy.linalg import svd
+import copy
+import cv2
+from cv2 import dct, idct
+from pywt import dwt2, idwt2
+from .pool import AutoPool
+
+
+class WaterMarkCore:
+    def __init__(self, password_img=1, mode='common', processes=None):
+        self.block_shape = np.array([4, 4])
+        self.password_img = password_img
+        self.d1, self.d2 = 36, 20  # d1/d2 越大鲁棒性越强,但输出图片的失真越大
+
+        # init data
+        self.img, self.img_YUV = None, None  # self.img 是原图,self.img_YUV 对像素做了加白偶数化
+        self.ca, self.hvd, = [np.array([])] * 3, [np.array([])] * 3  # 每个通道 dct 的结果
+        self.ca_block = [np.array([])] * 3  # 每个 channel 存一个四维 array,代表四维分块后的结果
+        self.ca_part = [np.array([])] * 3  # 四维分块后,有时因不整除而少一部分,self.ca_part 是少这一部分的 self.ca
+
+        self.wm_size, self.block_num = 0, 0  # 水印的长度,原图片可插入信息的个数
+        self.pool = AutoPool(mode=mode, processes=processes)
+
+        self.fast_mode = False
+        self.alpha = None  # 用于处理透明图
+
+    def init_block_index(self):
+        self.block_num = self.ca_block_shape[0] * self.ca_block_shape[1]
+        assert self.wm_size < self.block_num, IndexError(
+            '最多可嵌入{}kb信息,多于水印的{}kb信息,溢出'.format(self.block_num / 1000, self.wm_size / 1000))
+        # self.part_shape 是取整后的ca二维大小,用于嵌入时忽略右边和下面对不齐的细条部分。
+        self.part_shape = self.ca_block_shape[:2] * self.block_shape
+        self.block_index = [(i, j) for i in range(self.ca_block_shape[0]) for j in range(self.ca_block_shape[1])]
+
+    def read_img_arr(self, img):
+        # 处理透明图
+        self.alpha = None
+        if img.shape[2] == 4:
+            if img[:, :, 3].min() < 255:
+                self.alpha = img[:, :, 3]
+                img = img[:, :, :3]
+
+        # 读入图片->YUV化->加白边使像素变偶数->四维分块
+        self.img = img.astype(np.float32)
+        self.img_shape = self.img.shape[:2]
+
+        # 如果不是偶数,那么补上白边,Y(明亮度)UV(颜色)
+        self.img_YUV = cv2.copyMakeBorder(cv2.cvtColor(self.img, cv2.COLOR_BGR2YUV),
+                                          0, self.img.shape[0] % 2, 0, self.img.shape[1] % 2,
+                                          cv2.BORDER_CONSTANT, value=(0, 0, 0))
+
+        self.ca_shape = [(i + 1) // 2 for i in self.img_shape]
+
+        self.ca_block_shape = (self.ca_shape[0] // self.block_shape[0], self.ca_shape[1] // self.block_shape[1],
+                               self.block_shape[0], self.block_shape[1])
+        strides = 4 * np.array([self.ca_shape[1] * self.block_shape[0], self.block_shape[1], self.ca_shape[1], 1])
+
+        for channel in range(3):
+            self.ca[channel], self.hvd[channel] = dwt2(self.img_YUV[:, :, channel], 'haar')
+            # 转为4维度
+            self.ca_block[channel] = np.lib.stride_tricks.as_strided(self.ca[channel].astype(np.float32),
+                                                                     self.ca_block_shape, strides)
+
+    def read_wm(self, wm_bit):
+        self.wm_bit = wm_bit
+        self.wm_size = wm_bit.size
+
+    def block_add_wm(self, arg):
+        if self.fast_mode:
+            return self.block_add_wm_fast(arg)
+        else:
+            return self.block_add_wm_slow(arg)
+
+    def block_add_wm_slow(self, arg):
+        block, shuffler, i = arg
+        # dct->(flatten->加密->逆flatten)->svd->打水印->逆svd->(flatten->解密->逆flatten)->逆dct
+        wm_1 = self.wm_bit[i % self.wm_size]
+        block_dct = dct(block)
+
+        # 加密(打乱顺序)
+        block_dct_shuffled = block_dct.flatten()[shuffler].reshape(self.block_shape)
+        u, s, v = svd(block_dct_shuffled)
+        s[0] = (s[0] // self.d1 + 1 / 4 + 1 / 2 * wm_1) * self.d1
+        if self.d2:
+            s[1] = (s[1] // self.d2 + 1 / 4 + 1 / 2 * wm_1) * self.d2
+
+        block_dct_flatten = np.dot(u, np.dot(np.diag(s), v)).flatten()
+        block_dct_flatten[shuffler] = block_dct_flatten.copy()
+        return idct(block_dct_flatten.reshape(self.block_shape))
+
+    def block_add_wm_fast(self, arg):
+        # dct->svd->打水印->逆svd->逆dct
+        block, shuffler, i = arg
+        wm_1 = self.wm_bit[i % self.wm_size]
+
+        u, s, v = svd(dct(block))
+        s[0] = (s[0] // self.d1 + 1 / 4 + 1 / 2 * wm_1) * self.d1
+
+        return idct(np.dot(u, np.dot(np.diag(s), v)))
+
+    def embed(self):
+        self.init_block_index()
+
+        embed_ca = copy.deepcopy(self.ca)
+        embed_YUV = [np.array([])] * 3
+
+        self.idx_shuffle = random_strategy1(self.password_img, self.block_num,
+                                            self.block_shape[0] * self.block_shape[1])
+        for channel in range(3):
+            tmp = self.pool.map(self.block_add_wm,
+                                [(self.ca_block[channel][self.block_index[i]], self.idx_shuffle[i], i)
+                                 for i in range(self.block_num)])
+
+            for i in range(self.block_num):
+                self.ca_block[channel][self.block_index[i]] = tmp[i]
+
+            # 4维分块变回2维
+            self.ca_part[channel] = np.concatenate(np.concatenate(self.ca_block[channel], 1), 1)
+            # 4维分块时右边和下边不能整除的长条保留,其余是主体部分,换成 embed 之后的频域的数据
+            embed_ca[channel][:self.part_shape[0], :self.part_shape[1]] = self.ca_part[channel]
+            # 逆变换回去
+            embed_YUV[channel] = idwt2((embed_ca[channel], self.hvd[channel]), "haar")
+
+        # 合并3通道
+        embed_img_YUV = np.stack(embed_YUV, axis=2)
+        # 之前如果不是2的整数,增加了白边,这里去除掉
+        embed_img_YUV = embed_img_YUV[:self.img_shape[0], :self.img_shape[1]]
+        embed_img = cv2.cvtColor(embed_img_YUV, cv2.COLOR_YUV2BGR)
+        embed_img = np.clip(embed_img, a_min=0, a_max=255)
+
+        if self.alpha is not None:
+            embed_img = cv2.merge([embed_img.astype(np.uint8), self.alpha])
+        return embed_img
+
+    def block_get_wm(self, args):
+        if self.fast_mode:
+            return self.block_get_wm_fast(args)
+        else:
+            return self.block_get_wm_slow(args)
+
+    def block_get_wm_slow(self, args):
+        block, shuffler = args
+        # dct->flatten->加密->逆flatten->svd->解水印
+        block_dct_shuffled = dct(block).flatten()[shuffler].reshape(self.block_shape)
+
+        u, s, v = svd(block_dct_shuffled)
+        wm = (s[0] % self.d1 > self.d1 / 2) * 1
+        if self.d2:
+            tmp = (s[1] % self.d2 > self.d2 / 2) * 1
+            wm = (wm * 3 + tmp * 1) / 4
+        return wm
+
+    def block_get_wm_fast(self, args):
+        block, shuffler = args
+        # dct->svd->解水印
+        u, s, v = svd(dct(block))
+        wm = (s[0] % self.d1 > self.d1 / 2) * 1
+
+        return wm
+
+    def extract_raw(self, img):
+        # 每个分块提取 1 bit 信息
+        self.read_img_arr(img=img)
+        self.init_block_index()
+
+        wm_block_bit = np.zeros(shape=(3, self.block_num))  # 3个channel,length 个分块提取的水印,全都记录下来
+
+        self.idx_shuffle = random_strategy1(seed=self.password_img,
+                                            size=self.block_num,
+                                            block_shape=self.block_shape[0] * self.block_shape[1],  # 16
+                                            )
+        for channel in range(3):
+            wm_block_bit[channel, :] = self.pool.map(self.block_get_wm,
+                                                     [(self.ca_block[channel][self.block_index[i]], self.idx_shuffle[i])
+                                                      for i in range(self.block_num)])
+        return wm_block_bit
+
+    def extract_avg(self, wm_block_bit):
+        # 对循环嵌入+3个 channel 求平均
+        wm_avg = np.zeros(shape=self.wm_size)
+        for i in range(self.wm_size):
+            wm_avg[i] = wm_block_bit[:, i::self.wm_size].mean()
+        return wm_avg
+
+    def extract(self, img, wm_shape):
+        self.wm_size = np.array(wm_shape).prod()
+
+        # 提取每个分块埋入的 bit:
+        wm_block_bit = self.extract_raw(img=img)
+        # 做平均:
+        wm_avg = self.extract_avg(wm_block_bit)
+        return wm_avg
+
+    def extract_with_kmeans(self, img, wm_shape):
+        wm_avg = self.extract(img=img, wm_shape=wm_shape)
+
+        return one_dim_kmeans(wm_avg)
+
+
+def one_dim_kmeans(inputs):
+    threshold = 0
+    e_tol = 10 ** (-6)
+    center = [inputs.min(), inputs.max()]  # 1. 初始化中心点
+    for i in range(300):
+        threshold = (center[0] + center[1]) / 2
+        is_class01 = inputs > threshold  # 2. 检查所有点与这k个点之间的距离,每个点归类到最近的中心
+        center = [inputs[~is_class01].mean(), inputs[is_class01].mean()]  # 3. 重新找中心点
+        if np.abs((center[0] + center[1]) / 2 - threshold) < e_tol:  # 4. 停止条件
+            threshold = (center[0] + center[1]) / 2
+            break
+
+    is_class01 = inputs > threshold
+    return is_class01
+
+
+def random_strategy1(seed, size, block_shape):
+    return np.random.RandomState(seed) \
+        .random(size=(size, block_shape)) \
+        .argsort(axis=1)
+
+
+def random_strategy2(seed, size, block_shape):
+    one_line = np.random.RandomState(seed) \
+        .random(size=(1, block_shape)) \
+        .argsort(axis=1)
+
+    return np.repeat(one_line, repeats=size, axis=0)

+ 53 - 0
blind_watermark/cli_tools.py

@@ -0,0 +1,53 @@
+from optparse import OptionParser
+from .blind_watermark import WaterMark
+
+usage1 = 'blind_watermark --embed --pwd 1234 image.jpg "watermark text" embed.png'
+usage2 = 'blind_watermark --extract --pwd 1234 --wm_shape 111 embed.png'
+optParser = OptionParser(usage=usage1 + '\n' + usage2)
+
+optParser.add_option('--embed', dest='work_mode', action='store_const', const='embed'
+                     , help='Embed watermark into images')
+optParser.add_option('--extract', dest='work_mode', action='store_const', const='extract'
+                     , help='Extract watermark from images')
+
+optParser.add_option('-p', '--pwd', dest='password', help='password, like 1234')
+optParser.add_option('--wm_shape', dest='wm_shape', help='Watermark shape, like 120')
+
+(opts, args) = optParser.parse_args()
+
+
+def main():
+    bwm1 = WaterMark(password_img=int(opts.password))
+    if opts.work_mode == 'embed':
+        if not len(args) == 3:
+            print('Error! Usage: ')
+            print(usage1)
+            return
+        else:
+            bwm1.read_img(args[0])
+            bwm1.read_wm(args[1], mode='str')
+            bwm1.embed(args[2])
+            print('Embed succeed! to file ', args[2])
+            print('Put down watermark size:', len(bwm1.wm_bit))
+
+    if opts.work_mode == 'extract':
+        if not len(args) == 1:
+            print('Error! Usage: ')
+            print(usage2)
+            return
+
+        else:
+            wm_str = bwm1.extract(filename=args[0], wm_shape=int(opts.wm_shape), mode='str')
+            print('Extract succeed! watermark is:')
+            print(wm_str)
+
+
+'''
+python -m blind_watermark.cli_tools --embed --pwd 1234 examples/pic/ori_img.jpeg "watermark text" examples/output/embedded.png
+python -m blind_watermark.cli_tools --extract --pwd 1234 --wm_shape 111 examples/output/embedded.png
+
+
+cd examples
+blind_watermark --embed --pwd 1234 examples/pic/ori_img.jpeg "watermark text" examples/output/embedded.png
+blind_watermark --extract --pwd 1234 --wm_shape 111 examples/output/embedded.png
+'''

+ 38 - 0
blind_watermark/pool.py

@@ -0,0 +1,38 @@
+import sys
+import multiprocessing
+import warnings
+
+if sys.platform != 'win32':
+    multiprocessing.set_start_method('fork')
+
+
+class CommonPool(object):
+    def map(self, func, args):
+        return list(map(func, args))
+
+
+class AutoPool(object):
+    def __init__(self, mode, processes):
+
+        if mode == 'multiprocessing' and sys.platform == 'win32':
+            warnings.warn('multiprocessing not support in windows, turning to multithreading')
+            mode = 'multithreading'
+
+        self.mode = mode
+        self.processes = processes
+
+        if mode == 'vectorization':
+            pass
+        elif mode == 'cached':
+            pass
+        elif mode == 'multithreading':
+            from multiprocessing.dummy import Pool as ThreadPool
+            self.pool = ThreadPool(processes=processes)
+        elif mode == 'multiprocessing':
+            from multiprocessing import Pool
+            self.pool = Pool(processes=processes)
+        else:  # common
+            self.pool = CommonPool()
+
+    def map(self, func, args):
+        return self.pool.map(func, args)

+ 100 - 0
blind_watermark/recover.py

@@ -0,0 +1,100 @@
+import cv2
+import numpy as np
+
+import functools
+
+
+# 一个帮助缓存化加速的类,引入事实上的全局变量
+class MyValues:
+    def __init__(self):
+        self.idx = 0
+        self.image, self.template = None, None
+
+    def set_val(self, image, template):
+        self.idx += 1
+        self.image, self.template = image, template
+
+
+my_value = MyValues()
+
+
+@functools.lru_cache(maxsize=None, typed=False)
+def match_template(w, h, idx):
+    image, template = my_value.image, my_value.template
+    resized = cv2.resize(template, dsize=(w, h))
+    scores = cv2.matchTemplate(image, resized, cv2.TM_CCOEFF_NORMED)
+    ind = np.unravel_index(np.argmax(scores, axis=None), scores.shape)
+    return ind, scores[ind]
+
+
+def match_template_by_scale(scale):
+    image, template = my_value.image, my_value.template
+    w, h = round(template.shape[1] * scale), round(template.shape[0] * scale)
+    ind, score = match_template(w, h, idx=my_value.idx)
+    return ind, score, scale
+
+
+def search_template(scale=(0.5, 2), search_num=200):
+    image, template = my_value.image, my_value.template
+    # 局部暴力搜索算法,寻找最优的scale
+    tmp = []
+    min_scale, max_scale = scale
+
+    max_scale = min(max_scale, image.shape[0] / template.shape[0], image.shape[1] / template.shape[1])
+
+    max_idx = 0
+
+    for i in range(2):
+        for scale in np.linspace(min_scale, max_scale, search_num):
+            ind, score, scale = match_template_by_scale(scale)
+            tmp.append([ind, score, scale])
+
+        # 寻找最佳
+        max_idx = 0
+        max_score = 0
+        for idx, (ind, score, scale) in enumerate(tmp):
+            if score > max_score:
+                max_idx, max_score = idx, score
+
+        min_scale, max_scale = tmp[max(0, max_idx - 1)][2], tmp[min(len(tmp) - 1, max_idx + 1)][2]
+
+        search_num = 2 * int((max_scale - min_scale) * max(template.shape[1], template.shape[0])) + 1
+
+    return tmp[max_idx]
+
+
+def estimate_crop_parameters(original_file=None, template_file=None, ori_img=None, tem_img=None
+                             , scale=(0.5, 2), search_num=200):
+    # 推测攻击后的图片,在原图片中的位置、大小
+    if template_file:
+        tem_img = cv2.imread(template_file, cv2.IMREAD_GRAYSCALE)  # template image
+    if original_file:
+        ori_img = cv2.imread(original_file, cv2.IMREAD_GRAYSCALE)  # image
+
+    if scale[0] == scale[1] == 1:
+        # 不缩放
+        scale_infer = 1
+        scores = cv2.matchTemplate(ori_img, tem_img, cv2.TM_CCOEFF_NORMED)
+        ind = np.unravel_index(np.argmax(scores, axis=None), scores.shape)
+        ind, score = ind, scores[ind]
+    else:
+        my_value.set_val(image=ori_img, template=tem_img)
+        ind, score, scale_infer = search_template(scale=scale, search_num=search_num)
+    w, h = int(tem_img.shape[1] * scale_infer), int(tem_img.shape[0] * scale_infer)
+    x1, y1, x2, y2 = ind[1], ind[0], ind[1] + w, ind[0] + h
+    return (x1, y1, x2, y2), ori_img.shape, score, scale_infer
+
+
+def recover_crop(template_file=None, tem_img=None, output_file_name=None, loc=None, image_o_shape=None):
+    if template_file:
+        tem_img = cv2.imread(template_file)  # template image
+
+    (x1, y1, x2, y2) = loc
+
+    img_recovered = np.zeros((image_o_shape[0], image_o_shape[1], 3))
+
+    img_recovered[y1:y2, x1:x2, :] = cv2.resize(tem_img, dsize=(x2 - x1, y2 - y1))
+
+    if output_file_name:
+        cv2.imwrite(output_file_name, img_recovered)
+    return img_recovered

+ 22 - 0
blind_watermark/version.py

@@ -0,0 +1,22 @@
+__version__ = '0.4.4'
+
+
+class Notes:
+    def __init__(self):
+        self.show = True
+
+    def print_notes(self):
+        if self.show:
+            print(f'''
+Welcome to use blind-watermark, version = {__version__}
+Make sure the version is the same when encode and decode
+Your star means a lot: https://github.com/guofei9987/blind_watermark
+This message only show once. To close it: `blind_watermark.bw_notes.close()`
+            ''')
+            self.close()
+
+    def close(self):
+        self.show = False
+
+
+bw_notes = Notes()

+ 47 - 0
block/data_get.py

@@ -0,0 +1,47 @@
+import numpy as np
+
+
+def data_get(args):
+    data_dict = data_prepare(args).load()
+    return data_dict
+
+
+class data_prepare:
+    def __init__(self, args):
+        self.args = args
+
+    def load(self):
+        data_dict = {}
+        data_dict['train'] = self._load_label('train.txt')
+        data_dict['val'] = self._load_label('val.txt')
+        data_dict['class'] = self._load_class()
+        return data_dict
+
+    def _load_label(self, txt_name):
+        with open(f'{self.args.data_path}/{txt_name}', encoding='utf-8')as f:
+            txt = [_.strip() for _ in f.readlines()]  # 读取所有图片路径
+        data_list = [[0, 0] for _ in range(len(txt))]  # [图片路径,原始标签]
+        for i in range(len(txt)):
+            image_path = f'{self.args.data_path}/image' + txt[i].split('image')[-1]
+            data_list[i][0] = image_path
+            print(image_path)
+            lable_path_find = image_path.replace('images', 'labels').replace('.jpg', '.txt')
+            with open(lable_path_find, 'r') as f:
+                label_txt = [_.strip().split(' ') for _ in f.readlines()]  # 读取该图片的标签
+            data_list[i][1] = np.array(label_txt, dtype=np.float32)
+        return data_list
+
+    def _load_class(self):
+        with open(f'{self.args.data_path}/class.txt', encoding='utf-8')as f:
+            txt = [_.strip() for _ in f.readlines()]
+        return txt
+
+
+if __name__ == '__main__':
+    import argparse
+
+    parser = argparse.ArgumentParser(description='')
+    parser.add_argument('--data_path', default=r'D:\dataset\ObjectDetection\voc', type=str)
+    parser.add_argument('--input_size', default=640, type=int)
+    args = parser.parse_args()
+    data_dict = data_get(args)

+ 95 - 0
block/loss_get.py

@@ -0,0 +1,95 @@
+import torch
+
+
+def loss_get(args):
+    loss = loss_prepare(args)
+    return loss
+
+
+class loss_prepare(object):
+    def __init__(self, args):
+        self.loss_frame = self._ciou  # 边框损失函数
+        self.loss_confidence = torch.nn.BCEWithLogitsLoss()  # 置信度损失函数
+        self.loss_confidence_add = torch.nn.BCEWithLogitsLoss()  # 置信度正样本损失函数
+        self.loss_class = torch.nn.BCEWithLogitsLoss()  # 分类损失函数
+        self.loss_weight = args.loss_weight  # 每个输出层的权重
+        self.stride = (8, 16, 32)
+        # 预测值边框解码部分
+        self.device = args.device
+        output_size = [int(args.input_size // i) for i in self.stride]
+        self.anchor = (((12, 16), (19, 36), (40, 28)), ((36, 75), (76, 55), (72, 146)),
+                       ((142, 110), (192, 243), (459, 401)))
+        self.grid = [0, 0, 0]
+        for i in range(3):
+            self.grid[i] = torch.arange(output_size[i]).to(args.device)
+
+    def __call__(self, pred, true, judge):  # pred与true的形式对应,judge为True和False组成的矩阵,True代表该位置有标签需要预测
+        frame_loss = 0  # 总边框损失
+        confidence_loss = 0  # 总置信度损失
+        class_loss = 0  # 总分类损失
+        pred = self._frame_decode(pred)  # 将边框解码为(Cx,Cy,w,h)真实坐标
+        for i in range(len(pred)):  # 对每个输出层分别进行操作
+            if True in judge[i]:  # 有需要预测的位置
+                pred_judge = pred[i][judge[i]]  # 预测的值
+                true_judge = true[i][judge[i]]  # 真实的标签
+                pred_judge, true_judge = self._center_to_min(pred_judge, true_judge)  # Cx,Cy转为x_min,y_min
+                # 计算损失
+                frame_add = self.loss_frame(pred_judge[:, 0:4], true_judge[:, 0:4])  # 边框损失(只计算需要的)
+                confidence_a = 0.9 * self.loss_confidence(pred[i][..., 4], true[i][..., 4])  # 置信度损失(计算所有的)
+                confidence_b = 0.1 * self.loss_confidence_add(pred_judge[:, 4], true_judge[:, 4])  # 正样本
+                confidence_add = confidence_a + confidence_b
+                class_add = self.loss_class(pred_judge[:, 5:], true_judge[:, 5:])  # 分类损失(只计算需要的)
+                # 总损失
+                frame_loss += self.loss_weight[i][0] * self.loss_weight[i][1] * (1 - torch.mean(frame_add))  # 总边框损失
+                confidence_loss += self.loss_weight[i][0] * self.loss_weight[i][2] * confidence_add  # 总置信度损失
+                class_loss += self.loss_weight[i][0] * self.loss_weight[i][3] * class_add  # 总分类损失
+            else:  # 没有需要预测的位置
+                confidence_add = self.loss_confidence(pred[i][..., 4], true[i][..., 4])  # 置信度损失(计算所有的)
+                confidence_loss += self.loss_weight[i][0] * self.loss_weight[i][2] * confidence_add  # 总置信度损失
+        return frame_loss + confidence_loss + class_loss, frame_loss, confidence_loss, class_loss
+
+    def _frame_decode(self, pred):
+        # 遍历每一个大层
+        for i in range(len(pred)):
+            pred[i][..., 0:4] = pred[i][..., 0:4].sigmoid()  # 归一化
+            # 中心坐标[0-1]->[-0.5-1.5]->[-0.5*stride-80/40/20.5*stride]
+            pred[i][..., 0] = (2 * pred[i][..., 0] - 0.5 + self.grid[i].unsqueeze(1)) * self.stride[i]
+            pred[i][..., 1] = (2 * pred[i][..., 1] - 0.5 + self.grid[i]) * self.stride[i]
+            # 遍历每一个大层中的小层
+            for j in range(3):
+                pred[i][:, j, ..., 2] = (2 * pred[i][:, j, ..., 2]) ** 2 * self.anchor[i][j][0]  # [0-1]->[0-4*anchor]
+                pred[i][:, j, ..., 3] = (2 * pred[i][:, j, ..., 3]) ** 2 * self.anchor[i][j][1]  # [0-1]->[0-4*anchor]
+        return pred
+
+    def _center_to_min(self, pred, true):  # (Cx,Cy)->(x_min,y_min)
+        pred[:, 0:2] = pred[:, 0:2] - pred[:, 2:4] / 2
+        true[:, 0:2] = true[:, 0:2] - true[:, 2:4] / 2
+        return pred, true
+
+    def _ciou(self, pred, true):  # 输入为(batch,(x_min,y_min,w,h))相对/真实坐标
+        iou = self._iou(pred, true)
+        L1_L2 = self._L1_L2(pred, true)
+        v = (4 / (3.1415926 ** 2)) * torch.pow(
+            torch.atan(true[:, 2] / true[:, 3]) - torch.atan(pred[:, 2] / pred[:, 3]), 2)
+        with torch.no_grad():
+            alpha = v / (v + 1 - iou + 0.00001)
+        return iou - L1_L2 - alpha * v
+
+    def _iou(self, pred, true):  # 输入为(batch,(x_min,y_min,w,h))相对/真实坐标
+        x1 = torch.maximum(pred[:, 0], true[:, 0])
+        y1 = torch.maximum(pred[:, 1], true[:, 1])
+        x2 = torch.minimum(pred[:, 0] + pred[:, 2], true[:, 0] + true[:, 2])
+        y2 = torch.minimum(pred[:, 1] + pred[:, 3], true[:, 1] + true[:, 3])
+        zeros = torch.zeros(1, device=pred.device)
+        intersection = torch.maximum(x2 - x1, zeros) * torch.maximum(y2 - y1, zeros) + 0.00001
+        union = pred[:, 2] * pred[:, 3] + true[:, 2] * true[:, 3] - intersection + 0.00001
+        return intersection / union
+
+    def _L1_L2(self, pred, true):  # 输入为(batch,(x_min,y_min,w,h))相对/真实坐标
+        x1 = torch.minimum(pred[:, 0], true[:, 0])
+        y1 = torch.minimum(pred[:, 1], true[:, 1])
+        x2 = torch.maximum(pred[:, 0] + pred[:, 2], true[:, 0] + true[:, 2])
+        y2 = torch.maximum(pred[:, 1] + pred[:, 3], true[:, 1] + true[:, 3])
+        L1 = torch.square(pred[:, 0] - true[:, 0]) + torch.square(pred[:, 1] - true[:, 1])
+        L2 = torch.square(x2 - x1) + torch.square(y2 - y1)
+        return L1 / L2

+ 32 - 0
block/lr_get.py

@@ -0,0 +1,32 @@
+import math
+import torch
+
+
+def adam(regularization, r_value, param, lr, betas):
+    if regularization == 'L2':
+        optimizer = torch.optim.Adam(param, lr=lr, betas=betas, weight_decay=r_value)
+    else:
+        optimizer = torch.optim.Adam(param, lr=lr, betas=betas)
+    return optimizer
+
+
+class lr_adjust:
+    def __init__(self, args, step_epoch, epoch_finished):
+        self.lr_start = args.lr_start  # 初始学习率
+        self.lr_end = args.lr_end_ratio * args.lr_start  # 最终学习率
+        self.lr_end_epoch = args.lr_end_epoch  # 最终学习率达到的轮数
+        self.step_all = self.lr_end_epoch * step_epoch  # 总调整步数
+        self.step_finished = epoch_finished * step_epoch  # 已调整步数
+        self.warmup_step = max(5, int(args.warmup_ratio * self.step_all))  # 预热训练步数
+
+    def __call__(self, optimizer):
+        self.step_finished += 1
+        step_now = self.step_finished
+        decay = step_now / self.step_all
+        lr = self.lr_end + (self.lr_start - self.lr_end) * math.cos(math.pi / 2 * decay)
+        if step_now <= self.warmup_step:
+            lr = lr * (0.1 + 0.9 * step_now / self.warmup_step)
+        lr = max(lr, 0.000001)
+        for i in range(len(optimizer.param_groups)):
+            optimizer.param_groups[i]['lr'] = lr
+        return optimizer

+ 65 - 0
block/metric_get.py

@@ -0,0 +1,65 @@
+import torch
+import torchvision
+
+
+def center_to_min(pred, true):  # (Cx,Cy)->(x_min,y_min)
+    pred[:, 0:2] = pred[:, 0:2] - 1 / 2 * pred[:, 2:4]
+    true[:, 0:2] = true[:, 0:2] - 1 / 2 * true[:, 2:4]
+    return pred, true
+
+
+def confidence_screen(pred, confidence_threshold):
+    result = []
+    for i in range(len(pred)):  # 对一张图片的每个输出层分别进行操作
+        judge = torch.where(pred[i][..., 4] > confidence_threshold, True, False)
+        result.append((pred[i][judge]))
+    result = torch.concat(result, dim=0)
+    if result.shape[0] == 0:
+        return result
+    index = torch.argsort(result[:, 4], dim=0, descending=True)
+    result = result[index]
+    return result
+
+
+def iou_single(A, B):  # 输入为(batch,(x_min,y_min,w,h))相对/真实坐标
+    x1 = torch.maximum(A[:, 0], B[0])
+    y1 = torch.maximum(A[:, 1], B[1])
+    x2 = torch.minimum(A[:, 0] + A[:, 2], B[0] + B[2])
+    y2 = torch.minimum(A[:, 1] + A[:, 3], B[1] + B[3])
+    zeros = torch.zeros(1, device=A.device)
+    intersection = torch.maximum(x2 - x1, zeros) * torch.maximum(y2 - y1, zeros)
+    union = A[:, 2] * A[:, 3] + B[2] * B[3] - intersection
+    return intersection / union
+
+
+def iou(pred, true):  # 输入为(batch,(x_min,y_min,w,h))相对/真实坐标
+    x1 = torch.maximum(pred[:, 0], true[:, 0])
+    y1 = torch.maximum(pred[:, 1], true[:, 1])
+    x2 = torch.minimum(pred[:, 0] + pred[:, 2], true[:, 0] + true[:, 2])
+    y2 = torch.minimum(pred[:, 1] + pred[:, 3], true[:, 1] + true[:, 3])
+    zeros = torch.zeros(1, device=pred.device)
+    intersection = torch.maximum(x2 - x1, zeros) * torch.maximum(y2 - y1, zeros)
+    union = pred[:, 2] * pred[:, 3] + true[:, 2] * true[:, 3] - intersection
+    return intersection / union
+
+
+def nms(pred, iou_threshold):  # 输入为(batch,(x_min,y_min,w,h))相对/真实坐标
+    pred[:, 2:4] = pred[:, 0:2] + pred[:, 2:4]  # (x_min,y_min,x_max,y_max)真实坐标
+    index = torchvision.ops.nms(pred[:, 0:4], pred[:, 4], 1 - iou_threshold)[:100]  # 非极大值抑制,最多100
+    pred = pred[index]
+    pred[:, 2:4] = pred[:, 2:4] - pred[:, 0:2]  # (x_min,y_min,w,h)真实坐标
+    return pred
+
+
+def nms_tp_fn_fp(pred, true, iou_threshold):  # 输入为(batch,(x_min,y_min,w,h,其他,类别号))相对/真实坐标
+    pred_cls = torch.argmax(pred[:, 5:], dim=1)
+    true_cls = torch.argmax(true[:, 5:], dim=1)
+    tp = 0
+    for i in range(len(true)):
+        target = true[i]
+        iou_all = iou_single(pred, target)
+        judge_tp = torch.where((iou_all > iou_threshold) & (pred_cls == true_cls[i]), True, False)
+        tp += min(len(pred[judge_tp]), 1)  # 存在多个框之间iou大于阈值,但都与标签小于阈值,此时只算1个tp,其他都为fp
+    fp = len(pred) - tp
+    fn = len(true) - tp
+    return tp, fp, fn

+ 28 - 0
block/model_ema.py

@@ -0,0 +1,28 @@
+import math
+import copy
+import torch
+
+
+class model_ema:
+    def __init__(self, model, decay=0.9999, tau=2000, updates=0):
+        self.ema = copy.deepcopy(self._get_model(model)).eval()  # FP32 EMA
+        self.updates = updates
+        self.decay = lambda x: decay * (1 - math.exp(-x / tau))
+        for p in self.ema.parameters():
+            p.requires_grad_(False)
+
+    def update(self, model):
+        with torch.no_grad():
+            self.updates += 1
+            d = self.decay(self.updates)
+            state_dict = self._get_model(model).state_dict()
+            for k, v in self.ema.state_dict().items():
+                if v.dtype.is_floating_point:
+                    v *= d
+                    v += (1 - d) * state_dict[k].detach()
+
+    def _get_model(self, model):
+        if type(model) in (torch.nn.parallel.DataParallel, torch.nn.parallel.DistributedDataParallel):
+            return model.module
+        else:
+            return model

+ 95 - 0
block/model_get.py

@@ -0,0 +1,95 @@
+import os
+import torch
+
+choice_dict = {'yolov5': 'model_prepare(args).yolov5()',
+               'yolov7': 'model_prepare(args).yolov7()'}
+
+
+def model_get(args):
+    if os.path.exists(args.weight):  # 优先加载已有模型继续训练
+        model_dict = torch.load(args.weight, map_location='cpu')
+    else:  # 新建模型
+        if args.prune:  # 模型剪枝
+            model_dict = torch.load(args.prune_weight, map_location='cpu')
+            model = model_dict['model']
+            model = prune(args, model)
+        else:
+            model = eval(choice_dict[args.model])
+        model_dict = {}
+        model_dict['model'] = model
+        model_dict['epoch_finished'] = 0  # 已训练的轮数
+        model_dict['optimizer_state_dict'] = None  # 学习率参数
+        model_dict['ema_updates'] = 0  # ema参数
+        model_dict['standard'] = 0  # 评价指标
+    return model_dict
+
+
+def prune(args, model):
+    # 记录BN层权重
+    BatchNorm2d_weight = []
+    for module in model.modules():
+        if isinstance(module, torch.nn.BatchNorm2d):
+            BatchNorm2d_weight.append(module.weight.data.clone())
+    BatchNorm2d_weight_abs = torch.concat(BatchNorm2d_weight, dim=0).abs()
+    weight_len = len(BatchNorm2d_weight)
+    # 记录权重与BN层编号的关系
+    BatchNorm2d_id = []
+    for i in range(weight_len):
+        BatchNorm2d_id.extend([i for _ in range(len(BatchNorm2d_weight[i]))])
+    id_all = torch.tensor(BatchNorm2d_id)
+    # 筛选
+    value, index = torch.sort(BatchNorm2d_weight_abs, dim=0, descending=True)
+    boundary = int(len(index) * args.prune_ratio)
+    prune_index = index[0:boundary]  # 保留参数的下标
+    prune_index, _ = torch.sort(prune_index, dim=0, descending=False)
+    prune_id = id_all[prune_index]
+    # 将保留参数的下标放到每层中
+    index_list = [[] for _ in range(weight_len)]
+    for i in range(len(prune_index)):
+        index_list[prune_id[i]].append(prune_index[i])
+    # 将每层保留参数的下标换算成相对下标
+    record_len = 0
+    for i in range(weight_len):
+        index_list[i] = torch.tensor(index_list[i])
+        index_list[i] -= record_len
+        if len(index_list[i]) == 0:  # 存在整层都被减去的情况,至少保留一层
+            index_list[i] = torch.argmax(BatchNorm2d_weight[i], dim=0).unsqueeze(0)
+        record_len += len(BatchNorm2d_weight[i])
+    # 创建剪枝后的模型
+    args.prune_num = [len(_) for _ in index_list]
+    prune_model = eval(choice_dict[args.model])
+    # BN层权重赋值和部分conv权重赋值
+    index = 0
+    for module, prune_module in zip(model.modules(), prune_model.modules()):
+        if isinstance(module, torch.nn.Conv2d):  # 更新部分Conv2d层权重
+            if index == 0:
+                weight = module.weight.data.clone()[index_list[index]]
+            elif index == weight_len:
+                weight = module.weight.data.clone()
+            else:
+                weight = module.weight.data.clone()[index_list[index]]
+                weight = weight[:, index_list[index - 1], :, :]
+            if prune_module.weight.data.shape == weight.shape:
+                prune_module.weight.data = weight
+        if isinstance(module, torch.nn.BatchNorm2d):  # 更新BatchNorm2d层权重
+            prune_module.weight.data = module.weight.data.clone()[index_list[index]]
+            prune_module.bias.data = module.bias.data.clone()[index_list[index]]
+            prune_module.running_mean = module.running_mean.clone()[index_list[index]]
+            prune_module.running_var = module.running_var.clone()[index_list[index]]
+            index += 1
+    return prune_model
+
+
+class model_prepare:
+    def __init__(self, args):
+        self.args = args
+
+    def yolov5(self):
+        from model.yolov5 import yolov5
+        model = yolov5(self.args)
+        return model
+
+    def yolov7(self):
+        from model.yolov7 import yolov7
+        model = yolov7(self.args)
+        return model

+ 406 - 0
block/train_get.py

@@ -0,0 +1,406 @@
+import cv2
+import tqdm
+import wandb
+import torch
+import numpy as np
+from block.val_get import val_get
+from block.model_ema import model_ema
+from block.lr_get import adam, lr_adjust
+from copy import deepcopy
+
+
+def train_get(args, data_dict, model_dict, loss):
+    # 加载模型
+    model = model_dict['model'].to(args.device, non_blocking=args.latch)
+    # 学习率
+    optimizer = adam(args.regularization, args.r_value, model.parameters(), lr=args.lr_start, betas=(0.937, 0.999))
+    optimizer.load_state_dict(model_dict['optimizer_state_dict']) if model_dict['optimizer_state_dict'] else None
+    step_epoch = len(data_dict['train']) // args.batch // args.device_number * args.device_number  # 每轮的步数
+    optimizer_adjust = lr_adjust(args, step_epoch, model_dict['epoch_finished'])  # 学习率调整函数
+    optimizer = optimizer_adjust(optimizer)  # 学习率初始化
+    # 使用平均指数移动(EMA)调整参数(不能将ema放到args中,否则会导致模型保存出错)
+    ema = model_ema(model) if args.ema else None
+    if args.ema:
+        ema.updates = model_dict['ema_updates']
+    # 数据集
+    train_dataset = torch_dataset(args, 'train', data_dict['train'])
+    train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset) if args.distributed else None
+    train_shuffle = False if args.distributed else True  # 分布式设置sampler后shuffle要为False
+    train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=args.batch, shuffle=train_shuffle,
+                                                   drop_last=True, pin_memory=args.latch, num_workers=args.num_worker,
+                                                   sampler=train_sampler, collate_fn=train_dataset.collate_fn)
+    val_dataset = torch_dataset(args, 'val', data_dict['val'])
+    val_sampler = None  # 分布式时数据合在主GPU上进行验证
+    val_batch = args.batch // args.device_number  # 分布式验证时batch要减少为一个GPU的量
+    val_dataloader = torch.utils.data.DataLoader(val_dataset, batch_size=val_batch, shuffle=False, drop_last=False,
+                                                 pin_memory=args.latch, num_workers=args.num_worker,
+                                                 sampler=val_sampler, collate_fn=val_dataset.collate_fn)
+    # 分布式初始化
+    model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.local_rank],
+                                                      output_device=args.local_rank) if args.distributed else model
+    # wandb
+    if args.wandb and args.local_rank == 0:
+        wandb_image_list = []  # 记录所有的wandb_image最后一起添加(最多添加args.wandb_image_num张)
+        wandb_class_name = {}  # 用于给边框添加标签名字
+        for i in range(len(data_dict['class'])):
+            wandb_class_name[i] = data_dict['class'][i]
+    epoch_base = model_dict['epoch_finished'] + 1  # 新的一轮要+1
+    for epoch in range(epoch_base, args.epoch + 1):  # 训练
+        print(f'\n-----------------------第{epoch}轮-----------------------') if args.local_rank == 0 else None
+        model.train()
+        train_loss = 0  # 记录训练损失
+        train_frame_loss = 0  # 记录边框损失
+        train_confidence_loss = 0  # 记录置信度框损失
+        train_class_loss = 0  # 记录类别损失
+        if args.local_rank == 0:  # tqdm
+            tqdm_show = tqdm.tqdm(total=step_epoch)
+        for index, (image_batch, true_batch, judge_batch, label_list) in enumerate(train_dataloader):
+            if args.wandb and args.local_rank == 0 and len(wandb_image_list) < args.wandb_image_num:
+                wandb_image_batch = (image_batch * 255).cpu().numpy().astype(np.uint8).transpose(0, 2, 3, 1)
+            image_batch = image_batch.to(args.device, non_blocking=args.latch)  # 将输入数据放到设备上
+            for i in range(len(true_batch)):  # 将标签矩阵放到对应设备上
+                true_batch[i] = true_batch[i].to(args.device, non_blocking=args.latch)
+            if args.amp:
+                with torch.cuda.amp.autocast():
+                    pred_batch = model(image_batch)
+                    loss_batch, frame_loss, confidence_loss, class_loss = loss(pred_batch, true_batch, judge_batch)
+                args.amp.scale(loss_batch).backward()
+                args.amp.step(optimizer)
+                args.amp.update()
+                optimizer.zero_grad()
+            else:
+                pred_batch = model(image_batch)
+                loss_batch, frame_loss, confidence_loss, class_loss = loss(pred_batch, true_batch, judge_batch)
+                loss_batch.backward()
+                optimizer.step()
+                optimizer.zero_grad()
+            # 调整参数,ema.updates会自动+1
+            ema.update(model) if args.ema else None
+            # 记录损失
+            train_loss += loss_batch.item()
+            train_frame_loss += frame_loss.item()
+            train_confidence_loss += confidence_loss.item()
+            train_class_loss += class_loss.item()
+            # 调整学习率
+            optimizer = optimizer_adjust(optimizer)
+            # tqdm
+            if args.local_rank == 0:
+                tqdm_show.set_postfix({'train_loss': loss_batch.item(),
+                                       'lr': optimizer.param_groups[0]['lr']})  # 添加显示
+                tqdm_show.update(args.device_number)  # 更新进度条
+            # wandb
+            if args.wandb and args.local_rank == 0 and epoch == 0 and len(wandb_image_list) < args.wandb_image_num:
+                for i in range(len(wandb_image_batch)):  # 遍历每一张图片
+                    image = wandb_image_batch[i]
+                    frame = label_list[i][:, 0:4] / args.input_size  # (Cx,Cy,w,h)相对坐标
+                    frame[:, 0:2] = frame[:, 0:2] - frame[:, 2:4] / 2
+                    frame[:, 2:4] = frame[:, 0:2] + frame[:, 2:4]  # (x_min,y_min,x_max,y_max)相对坐标
+                    cls = torch.argmax(label_list[i][:, 5:], dim=1)
+                    box_data = []
+                    for i in range(len(frame)):
+                        class_id = cls[i].item()
+                        box_data.append({"position": {"minX": frame[i][0].item(),
+                                                      "minY": frame[i][1].item(),
+                                                      "maxX": frame[i][2].item(),
+                                                      "maxY": frame[i][3].item()},
+                                         "class_id": class_id,
+                                         "box_caption": wandb_class_name[class_id]})
+                    wandb_image = wandb.Image(image, boxes={"predictions": {"box_data": box_data,
+                                                                            'class_labels': wandb_class_name}})
+                    wandb_image_list.append(wandb_image)
+                    if len(wandb_image_list) == args.wandb_image_num:
+                        break
+        # tqdm
+        if args.local_rank == 0:
+            tqdm_show.close()
+        # 计算平均损失
+        train_loss /= index + 1
+        train_frame_loss /= index + 1
+        train_confidence_loss /= index + 1
+        train_class_loss /= index + 1
+        if args.local_rank == 0:
+            print(f'\n| 训练 | train_loss:{train_loss:.4f} | train_frame_loss:{train_frame_loss:.4f} |'
+                  f' train_confidence_loss:{train_confidence_loss:.4f} | train_class_loss:{train_class_loss:.4f} |'
+                  f' lr:{optimizer.param_groups[0]["lr"]:.6f} |\n')
+        # 清理显存空间
+        del image_batch, true_batch, judge_batch, pred_batch, loss_batch
+        torch.cuda.empty_cache()
+        # 验证
+        if args.local_rank == 0:  # 分布式时只验证一次
+            val_loss, val_frame_loss, val_confidence_loss, val_class_loss, precision, recall, m_ap \
+                = val_get(args, val_dataloader, model, loss, ema, len(data_dict['val']))
+        # 保存
+        if args.local_rank == 0:  # 分布式时只保存一次
+            model_dict['model'] = model.module if args.distributed else model.eval()
+            model_dict['epoch_finished'] = epoch
+            model_dict['optimizer_state_dict'] = optimizer.state_dict()
+            model_dict['ema_updates'] = ema.updates if args.ema else model_dict['ema_updates']
+            model_dict['class'] = data_dict['class']
+            model_dict['train_loss'] = train_loss
+            model_dict['val_loss'] = val_loss
+            model_dict['val_m_ap'] = m_ap
+            torch.save(model_dict, 'last.pt' if not args.prune else 'prune_last.pt')  # 保存最后一次训练的模型
+            if m_ap > 0.1 and m_ap > model_dict['standard']:
+                model_dict['standard'] = m_ap
+                save_path = args.save_path if not args.prune else args.prune_save
+                torch.save(model_dict, save_path)  # 保存最佳模型
+                print(f'| 保存最佳模型:{args.save_path} | val_m_ap:{m_ap:.4f} |')
+            # wandb
+            if args.wandb:
+                wandb_log = {}
+                if epoch == 0:
+                    wandb_log.update({f'image/train_image': wandb_image_list})
+                wandb_log.update({'train_loss/train_loss': train_loss,
+                                  'train_loss/train_frame_loss': train_frame_loss,
+                                  'train_loss/train_confidence_loss': train_confidence_loss,
+                                  'train_loss/train_class_loss': train_class_loss,
+                                  'val_loss/val_loss': val_loss,
+                                  'val_loss/val_frame_loss': val_frame_loss,
+                                  'val_loss/val_confidence_loss': val_confidence_loss,
+                                  'val_loss/val_class_loss': val_class_loss,
+                                  'val_metric/val_precision': precision,
+                                  'val_metric/val_recall': recall,
+                                  'val_metric/val_m_ap': m_ap})
+                args.wandb_run.log(wandb_log)
+        torch.distributed.barrier() if args.distributed else None  # 分布式时每轮训练后让所有GPU进行同步,快的GPU会在此等待
+    return model_dict
+
+
+class torch_dataset(torch.utils.data.Dataset):
+    def __init__(self, args, tag, data):
+        self.output_num = (3, 3, 3)  # 输出层数量,如(3, 3, 3)代表有三个大层,每层有三个小层
+        self.stride = (8, 16, 32)  # 每个输出层尺寸缩小的幅度
+        self.wh_multiple = 4  # 宽高的倍数,真实wh=网络原始输出[0-1]*倍数*anchor
+        self.input_size = args.input_size  # 输入尺寸,如640
+        self.output_class = args.output_class  # 输出类别数
+        self.label_smooth = args.label_smooth  # 标签平滑,如(0.05,0.95)
+        self.output_size = [int(self.input_size // i) for i in self.stride]  # 每个输出层的尺寸,如(80,40,20)
+        self.anchor = (((12, 16), (19, 36), (40, 28)), ((36, 75), (76, 55), (72, 146)),
+                       ((142, 110), (192, 243), (459, 401)))
+        self.tag = tag  # 用于区分是训练集还是验证集
+        self.data = data
+        self.mosaic = args.mosaic
+        self.mosaic_flip = args.mosaic_flip
+        self.mosaic_hsv = args.mosaic_hsv
+        self.mosaic_screen = args.mosaic_screen
+
+    def __len__(self):
+        return len(self.data)
+
+    def __getitem__(self, index):
+        # 图片和标签处理,边框坐标处理为真实的Cx,Cy,w,h(归一化、减均值、除以方差、调维度等在模型中完成)
+        if self.tag == 'train' and torch.rand(1) < self.mosaic:
+            index_mix = torch.randperm(len(self.data))[0:4]
+            index_mix[0] = index
+            image, frame, cls_all = self._mosaic(index_mix)  # 马赛克增强、缩放和填充图片,相对坐标变为真实坐标(Cx,Cy,w,h)
+        else:
+            image = cv2.imdecode(np.fromfile(self.data[index][0], dtype=np.uint8), cv2.IMREAD_COLOR)  # 读取图片(可以读取中文)
+            label = deepcopy(self.data[index][1]) # 相对坐标(类别号,Cx,Cy,w,h)  # 读取原始标签([:,类别号+Cx,Cy,w,h],边框为相对边长的比例值)
+            if isinstance(label, int):
+                label = np.array([label])  # 将整数转换为numpy数组
+            if label.ndim == 1:
+                pass  # 跳过对一维label的处理
+            image, frame = self._resize(image.astype(np.uint8), label[:, 1:])  # 缩放和填充图片,相对坐标(Cx,Cy,w,h)变为真实坐标
+            cls_all = label[:, 0]  # 类别号
+        image = cv2.cvtColor(image.astype(np.uint8), cv2.COLOR_BGR2RGB)  # 转为RGB通道
+        image = (torch.tensor(image, dtype=torch.float32) / 255).permute(2, 0, 1)
+        # 边框:转换为张量
+        frame = torch.tensor(frame, dtype=torch.float32)
+        # 置信度:为1
+        confidence = torch.ones((len(frame), 1), dtype=torch.float32)
+        # 类别:类别独热编码
+        cls = torch.full((len(cls_all), self.output_class), self.label_smooth[0], dtype=torch.float32)
+        for i in range(len(cls_all)):
+            cls[i][int(cls_all[i])] = self.label_smooth[1]
+        # 合并为标签
+        label = torch.concat([frame, confidence, cls], dim=1).type(torch.float32)  # (Cx,Cy,w,h)真实坐标
+        # 标签矩阵处理
+        label_matrix_list = [0 for _ in range(len(self.output_num))]  # 存放每个输出层的标签矩阵,(Cx,Cy,w,h)真实坐标
+        judge_matrix_list = [0 for _ in range(len(self.output_num))]  # 存放每个输出层的判断矩阵
+        for i in range(len(self.output_num)):  # 遍历每个输出层
+            label_matrix = torch.zeros(self.output_num[i], self.output_size[i], self.output_size[i],
+                                       5 + self.output_class, dtype=torch.float32)  # 标签矩阵
+            judge_matrix = torch.zeros(self.output_num[i], self.output_size[i], self.output_size[i],
+                                       dtype=torch.bool)  # 判断矩阵,False代表没有标签
+            if len(label) > 0:  # 存在标签
+                frame = label[:, 0:4].clone()
+                frame[:, 0:2] = frame[:, 0:2] / self.stride[i]
+                frame[:, 2:4] = frame[:, 2:4] / self.wh_multiple
+                # 标签对应输出网格的坐标
+                Cx = frame[:, 0]
+                x_grid = Cx.type(torch.int32)
+                x_move = Cx - x_grid
+                x_grid_add = x_grid + 2 * torch.round(x_move).type(torch.int32) - 1  # 每个标签可以由相邻网格预测
+                x_grid_add = torch.clamp(x_grid_add, 0, self.output_size[i] - 1)  # 网格不能超出范围(与x_grid重复的网格之后不会加入)
+                Cy = frame[:, 1]
+                y_grid = Cy.type(torch.int32)
+                y_move = Cy - y_grid
+                y_grid_add = y_grid + 2 * torch.round(y_move).type(torch.int32) - 1  # 每个标签可以由相邻网格预测
+                y_grid_add = torch.clamp(y_grid_add, 0, self.output_size[i] - 1)  # 网格不能超出范围(与y_grid重复的网格之后不会加入)
+                # 遍历每个输出层的小层
+                for j in range(self.output_num[i]):
+                    # 根据wh制定筛选条件
+                    frame_change = frame.clone()
+                    w = frame_change[:, 2] / self.anchor[i][j][0]  # 该值要在0-1该层才能预测(但0-0.0625太小可以舍弃)
+                    h = frame_change[:, 3] / self.anchor[i][j][1]  # 该值要在0-1该层才能预测(但0-0.0625太小可以舍弃)
+                    wh_screen = torch.where((0.0625 < w) & (w < 1) & (0.0625 < h) & (h < 1), True, False)  # 筛选可以预测的标签
+                    # 将标签填入对应的标签矩阵位置
+                    for k in range(len(label)):
+                        if wh_screen[k]:  # 根据wh筛选
+                            label_matrix[j, x_grid[k], y_grid[k]] = label[k]
+                            judge_matrix[j, x_grid[k], y_grid[k]] = True
+                    # 将扩充的标签填入对应的标签矩阵位置
+                    for k in range(len(label)):
+                        if wh_screen[k] and not judge_matrix[j, x_grid_add[k], y_grid[k]]:  # 需要该位置有空位
+                            label_matrix[j, x_grid_add[k], y_grid[k]] = label[k]
+                            judge_matrix[j, x_grid_add[k], y_grid[k]] = True
+                        if wh_screen[k] and not judge_matrix[j, x_grid[k], y_grid_add[k]]:  # 需要该位置有空位
+                            label_matrix[j, x_grid[k], y_grid_add[k]] = label[k]
+                            judge_matrix[j, x_grid[k], y_grid_add[k]] = True
+            # 存放每个输出层的结果
+            label_matrix_list[i] = label_matrix
+            judge_matrix_list[i] = judge_matrix
+        return image, label_matrix_list, judge_matrix_list, label  # 真实坐标(Cx,Cy,w,h)
+
+    def collate_fn(self, getitem_list):  # 自定义__getitem__合并方式
+        image_list = []
+        label_matrix_list = [[] for _ in range(len(getitem_list[0][1]))]
+        judge_matrix_list = [[] for _ in range(len(getitem_list[0][2]))]
+        label_list = []
+        for i in range(len(getitem_list)):  # 遍历所有__getitem__
+            image = getitem_list[i][0]
+            label_matrix = getitem_list[i][1]
+            judge_matrix = getitem_list[i][2]
+            label = getitem_list[i][3]
+            image_list.append(image)
+            for j in range(len(label_matrix)):  # 遍历每个输出层
+                label_matrix_list[j].append(label_matrix[j])
+                judge_matrix_list[j].append(judge_matrix[j])
+            label_list.append(label)
+        # 合并
+        image_batch = torch.stack(image_list, dim=0)
+        for i in range(len(label_matrix_list)):
+            label_matrix_list[i] = torch.stack(label_matrix_list[i], dim=0)
+            judge_matrix_list[i] = torch.stack(judge_matrix_list[i], dim=0)
+        return image_batch, label_matrix_list, judge_matrix_list, label_list  # 均为(Cx,Cy,w,h)真实坐标
+
+    def _mosaic(self, index_mix):  # 马赛克增强,合并后w,h不能小于screen
+        x_center = int((torch.rand(1) * 0.4 + 0.3) * self.input_size)  # 0.3-0.7。四张图片合并的中心点
+        y_center = int((torch.rand(1) * 0.4 + 0.3) * self.input_size)  # 0.3-0.7。四张图片合并的中心点
+        image_merge = np.full((self.input_size, self.input_size, 3), 128)  # 合并后的图片
+        frame_all = []  # 记录边框真实坐标(Cx,Cy,w,h)
+        cls_all = []  # 记录类别号
+        for i, index in enumerate(index_mix):
+            image = cv2.imdecode(np.fromfile(self.data[index][0], dtype=np.uint8), cv2.IMREAD_COLOR)  # 读取图片(可以读取中文)
+            # print(self.data[index][1].copy())
+            # label = self.data[index][1].copy()  # 相对坐标(类别号,Cx,Cy,w,h)
+            label = deepcopy(self.data[index][1]) # 相对坐标(类别号,Cx,Cy,w,h)
+            if isinstance(label, int):
+                label = np.array([label])  # 将整数转换为numpy数组
+            if label.ndim == 1:
+                continue  # 跳过对一维label的处理
+            # print(label.shape)
+            # hsv通道变换
+            if torch.rand(1) < self.mosaic_hsv:
+                image = cv2.cvtColor(image, cv2.COLOR_BGR2HSV).astype(np.float32)
+                image[:, :, 0] += np.random.rand(1) * 60 - 30  # -30到30
+                image[:, :, 1] += np.random.rand(1) * 60 - 30  # -30到30
+                image[:, :, 2] += np.random.rand(1) * 60 - 30  # -30到30
+                image = np.clip(image, 0, 255).astype(np.uint8)
+                image = cv2.cvtColor(image, cv2.COLOR_HSV2BGR)
+            # 垂直翻转
+            if torch.rand(1) < self.mosaic_flip:
+                image = cv2.flip(image, 1)  # 垂直翻转图片
+                label[:, 1] = 1 - label[:, 1]  # 坐标变换:Cx=w-Cx
+            # 根据input_size缩放图片
+            h, w, _ = image.shape
+            scale = self.input_size / w
+            w = w * scale
+            h = h * scale
+            # 再随机缩放图片
+            scale_w = torch.rand(1) + 0.5  # 0.5-1.5
+            scale_h = 1 + torch.rand(1) * 0.5 if scale_w > 1 else 1 - torch.rand(1) * 0.5  # h与w同时放大和缩小
+            w = int(w * scale_w)
+            h = int(h * scale_h)
+            image = cv2.resize(image, (w, h))
+            # 合并图片,坐标变为合并后的真实坐标(Cx,Cy,w,h)
+            if i == 0:  # 左上
+                x_add, y_add = min(x_center, w), min(y_center, h)
+                image_merge[y_center - y_add:y_center, x_center - x_add:x_center] = image[h - y_add:h, w - x_add:w]
+                label[:, 1] = label[:, 1] * w + x_center - w  # Cx
+                label[:, 2] = label[:, 2] * h + y_center - h  # Cy
+                label[:, 3:5] = label[:, 3:5] * (w, h)  # w,h
+            elif i == 1:  # 右上
+                x_add, y_add = min(self.input_size - x_center, w), min(y_center, h)
+                image_merge[y_center - y_add:y_center, x_center:x_center + x_add] = image[h - y_add:h, 0:x_add]
+                label[:, 1] = label[:, 1] * w + x_center  # Cx
+                label[:, 2] = label[:, 2] * h + y_center - h  # Cy
+                label[:, 3:5] = label[:, 3:5] * (w, h)  # w,h
+            elif i == 2:  # 右下
+                x_add, y_add = min(self.input_size - x_center, w), min(self.input_size - y_center, h)
+                image_merge[y_center:y_center + y_add, x_center:x_center + x_add] = image[0:y_add, 0:x_add]
+                label[:, 1] = label[:, 1] * w + x_center  # Cx
+                label[:, 2] = label[:, 2] * h + y_center  # Cy
+                label[:, 3:5] = label[:, 3:5] * (w, h)  # w,h
+            else:  # 左下
+                x_add, y_add = min(x_center, w), min(self.input_size - y_center, h)
+                image_merge[y_center:y_center + y_add, x_center - x_add:x_center] = image[0:y_add, w - x_add:w]
+                label[:, 1] = label[:, 1] * w + x_center - w  # Cx
+                label[:, 2] = label[:, 2] * h + y_center  # Cy
+                label[:, 3:5] = label[:, 3:5] * (w, h)  # w,h
+            frame_all.append(label[:, 1:5])
+            cls_all.append(label[:, 0])
+        # 合并标签
+        frame_all = np.concatenate(frame_all, axis=0)
+        cls_all = np.concatenate(cls_all, axis=0)
+        # 筛选掉不在图片内的标签
+        frame_all[:, 0:2] = frame_all[:, 0:2] - frame_all[:, 2:4] / 2
+        frame_all[:, 2:4] = frame_all[:, 0:2] + frame_all[:, 2:4]  # 真实坐标(x_min,y_min,x_max,y_max)
+        frame_all = np.clip(frame_all, 0, self.input_size - 1)  # 压缩坐标到图片内
+        frame_all[:, 2:4] = frame_all[:, 2:4] - frame_all[:, 0:2]
+        frame_all[:, 0:2] = frame_all[:, 0:2] + frame_all[:, 2:4] / 2  # 真实坐标(Cx,Cy,w,h)
+        judge_list = np.where((frame_all[:, 2] > self.mosaic_screen) & (frame_all[:, 3] > self.mosaic_screen),
+                              True, False)  # w,h不能小于screen
+        frame_all = frame_all[judge_list]
+        cls_all = cls_all[judge_list]
+        return image_merge, frame_all, cls_all
+
+    def _resize(self, image, frame):  # 将图片四周填充变为正方形,frame输入输出都为[[Cx,Cy,w,h]...](相对原图片的比例值)
+        shape = image.shape
+        w0 = shape[1]
+        h0 = shape[0]
+        if w0 == h0 == self.input_size:  # 不需要变形
+            frame *= self.input_size
+            return image, frame
+        else:
+            image_resize = np.full((self.input_size, self.input_size, 3), 128)
+            if w0 >= h0:  # 宽大于高
+                w = self.input_size
+                h = int(w / w0 * h0)
+                image = cv2.resize(image, (w, h))
+                add_y = (w - h) // 2
+                image_resize[add_y:add_y + h] = image
+                frame[:, 0] = np.around(frame[:, 0] * w)
+                frame[:, 1] = np.around(frame[:, 1] * h + add_y)
+                frame[:, 2] = np.around(frame[:, 2] * w)
+                frame[:, 3] = np.around(frame[:, 3] * h)
+                return image_resize, frame
+            else:  # 宽小于高
+                h = self.input_size
+                w = int(h / h0 * w0)
+                image = cv2.resize(image, (w, h))
+                add_x = (h - w) // 2
+                image_resize[:, add_x:add_x + w] = image
+                frame[:, 0] = np.around(frame[:, 0] * w + add_x)
+                frame[:, 1] = np.around(frame[:, 1] * h)
+                frame[:, 2] = np.around(frame[:, 2] * w)
+                frame[:, 3] = np.around(frame[:, 3] * h)
+                return image_resize, frame
+
+    def _draw(self, image, frame_all):  # 测试时画图使用,真实坐标(Cx,Cy,w,h)
+        frame_all[:, 0:2] = frame_all[:, 0:2] - frame_all[:, 2:4] / 2
+        frame_all[:, 2:4] = frame_all[:, 0:2] + frame_all[:, 2:4]  # 真实坐标(x_min,y_min,x_max,y_max)
+        for frame in frame_all:
+            x1, y1, x2, y2 = frame
+            cv2.rectangle(image, (int(x1), int(y1)), (int(x2), int(y2)), color=(0, 255, 0), thickness=2)
+        cv2.imwrite('save_check.jpg', image)

+ 69 - 0
block/val_get.py

@@ -0,0 +1,69 @@
+import tqdm
+import torch
+from model.layer import decode
+from block.metric_get import confidence_screen, nms, nms_tp_fn_fp
+
+
+def val_get(args, val_dataloader, model, loss, ema, data_len):
+    with torch.no_grad():
+        model = ema.ema if args.ema else model.eval()
+        decode_model = decode(args.input_size)
+        val_loss = 0  # 记录验证损失
+        val_frame_loss = 0  # 记录边框损失
+        val_confidence_loss = 0  # 记录置信度框损失
+        val_class_loss = 0  # 记录类别损失
+        nms_tp_all = 0
+        nms_fp_all = 0
+        nms_fn_all = 0
+        tqdm_len = (data_len - 1) // (args.batch // args.device_number) + 1
+        tqdm_show = tqdm.tqdm(total=tqdm_len)
+        for index, (image_batch, true_batch, judge_batch, label_list) in enumerate(val_dataloader):
+            image_batch = image_batch.to(args.device, non_blocking=args.latch)  # 将输入数据放到设备上
+            for i in range(len(true_batch)):  # 将标签矩阵放到对应设备上
+                true_batch[i] = true_batch[i].to(args.device, non_blocking=args.latch)
+            pred_batch = model(image_batch)
+            clone_batch = [_.clone() for _ in pred_batch]  # 计算损失会改变pred_batch
+            # 计算损失
+            loss_batch, frame_loss, confidence_loss, class_loss = loss(pred_batch, true_batch, judge_batch)
+            val_loss += loss_batch.item()
+            val_frame_loss += frame_loss.item()
+            val_confidence_loss += confidence_loss.item()
+            val_class_loss += class_loss.item()
+            # 解码输出
+            clone_batch = decode_model(clone_batch)  # (Cx,Cy,w,h,confidence...)原始输出->(Cx,Cy,w,h,confidence...)真实坐标
+            # 统计指标
+            for i in range(clone_batch[0].shape[0]):  # 遍历每张图片
+                true = label_list[i].to(args.device)
+                pred = [_[i] for _ in clone_batch]  # (Cx,Cy,w,h)真实坐标
+                pred = confidence_screen(pred, args.confidence_threshold)  # 置信度筛选
+                if len(pred) == 0:  # 该图片没有预测值
+                    nms_fn_all += len(true)
+                    continue
+                pred[:, 0:2] = pred[:, 0:2] - pred[:, 2:4] / 2  # (x_min,y_min,w,h)真实坐标
+                true[:, 0:2] = true[:, 0:2] - true[:, 2:4] / 2  # (x_min,y_min,w,h)真实坐标
+                pred = nms(pred, args.iou_threshold)[:100]  # 非极大值抑制,最多100
+                if len(true) == 0:  # 该图片没有标签
+                    nms_fp_all += len(pred)
+                    continue
+                nms_tp, nms_fp, nms_fn = nms_tp_fn_fp(pred, true, args.iou_threshold)
+                nms_tp_all += nms_tp
+                nms_fn_all += nms_fn
+                nms_fp_all += nms_fp
+            # tqdm
+            tqdm_show.set_postfix({'val_loss': loss_batch.item()})  # 添加显示
+            tqdm_show.update(1)  # 更新进度条
+        # tqdm
+        tqdm_show.close()
+        # 计算平均损失
+        val_loss /= index + 1
+        val_frame_loss /= index + 1
+        val_confidence_loss /= index + 1
+        val_class_loss /= index + 1
+        print(f'\n| 验证 | val_loss{val_loss:.4f} | val_frame_loss:{val_frame_loss:.4f} |'
+              f' val_confidence_loss:{val_confidence_loss:.4f} | val_class_loss:{val_class_loss:.4f} |')
+        # 计算指标
+        precision = nms_tp_all / (nms_tp_all + nms_fp_all + 0.001)
+        recall = nms_tp_all / (nms_tp_all + nms_fn_all + 0.001)
+        m_ap = precision * recall
+        print('| 验证 | precision:{:.4f} | recall:{:.4f} | m_ap:{:.4f} |'.format(precision, recall, m_ap))
+    return val_loss, val_frame_loss, val_confidence_loss, val_class_loss, precision, recall, m_ap

+ 46 - 0
export_onnx.py

@@ -0,0 +1,46 @@
+import os
+import torch
+import argparse
+from model.layer import deploy
+
+# -------------------------------------------------------------------------------------------------------------------- #
+parser = argparse.ArgumentParser(description='|将pt模型转为onnx,同时导出类别信息|')
+parser.add_argument('--weight', default='best.pt', type=str, help='|模型位置|')
+parser.add_argument('--input_size', default=640, type=int, help='|输入图片大小|')
+parser.add_argument('--batch', default=0, type=int, help='|输入图片批量,0为动态|')
+parser.add_argument('--sim', default=True, type=bool, help='|使用onnxsim压缩简化模型|')
+parser.add_argument('--device', default='cuda', type=str, help='|在哪个设备上加载模型|')
+parser.add_argument('--float16', default=True, type=bool, help='|转换的onnx模型数据类型,需要GPU,False时为float32|')
+args = parser.parse_args()
+args.weight = args.weight.split('.')[0] + '.pt'
+args.save_name = args.weight.split('.')[0] + '.onnx'
+# -------------------------------------------------------------------------------------------------------------------- #
+assert os.path.exists(args.weight), f'! 没有找到模型{args.weight} !'
+if args.float16:
+    assert torch.cuda.is_available(), '! cuda不可用,无法使用float16 !'
+
+
+# -------------------------------------------------------------------------------------------------------------------- #
+def export_onnx():
+    model_dict = torch.load(args.weight, map_location='cpu')
+    model = model_dict['model']
+    model = deploy(model, args.input_size)
+    model.eval().half().to(args.device) if args.float16 else model.eval().float().to(args.device)
+    input_shape = torch.rand(1, args.input_size, args.input_size, 3,
+                             dtype=torch.float16 if args.float16 else torch.float32).to(args.device)
+    torch.onnx.export(model, input_shape, args.save_name,
+                      opset_version=12, input_names=['input'], output_names=['output'],
+                      dynamic_axes={'input': {args.batch: 'batch_size'}, 'output': {args.batch: 'batch_size'}})
+    print('| 转为onnx模型成功:{} |'.format(args.save_name))
+    if args.sim:
+        import onnx
+        import onnxsim
+
+        model_onnx = onnx.load(args.save_name)
+        model_simplify, check = onnxsim.simplify(model_onnx)
+        onnx.save(model_simplify, args.save_name)
+        print('| 使用onnxsim简化模型成功 |')
+
+
+if __name__ == '__main__':
+    export_onnx()

+ 23 - 0
flask_request.py

@@ -0,0 +1,23 @@
+# 启用flask_start的服务后,将数据以post的方式调用服务得到结果
+import json
+import base64
+import requests
+
+
+def image_encode(image_path):
+    with open(image_path, 'rb')as f:
+        image_byte = f.read()
+    image_base64 = base64.b64encode(image_byte)
+    image = image_base64.decode()
+    return image
+
+
+if __name__ == '__main__':
+    url = 'http://0.0.0.0:9999/test/'  # 根据flask_start中的设置: http://host:port/name/
+    image_path = 'demo.jpg'
+    image = image_encode(image_path)
+    request_dict = {'image': image}
+    request = json.dumps(request_dict)
+    response = requests.post(url, data=request)
+    result = response.json()
+    print(result)

+ 40 - 0
flask_start.py

@@ -0,0 +1,40 @@
+# pip install flask -i https://pypi.tuna.tsinghua.edu.cn/simple
+# 用flask将程序包装成一个服务,并在服务器上启动
+import cv2
+import json
+import flask
+import base64
+import argparse
+import numpy as np
+
+# -------------------------------------------------------------------------------------------------------------------- #
+# 设置
+parser = argparse.ArgumentParser('|在服务器上启动flask服务|')
+# ...
+args, _ = parser.parse_known_args()  # 防止传入参数冲突,替代args = parser.parse_args()
+app = flask.Flask(__name__)  # 创建一个服务框架
+
+
+# -------------------------------------------------------------------------------------------------------------------- #
+# 程序
+def image_decode(image):
+    image_base64 = image.encode()  # base64
+    image_byte = base64.b64decode(image_base64)  # base64->字节类型
+    array = np.frombuffer(image_byte, dtype=np.uint8)  # 字节类型->一行数组
+    image = cv2.imdecode(array, cv2.IMREAD_COLOR)  # 一行数组->BGR图片
+    return image
+
+
+@app.route('/test/', methods=['POST'])  # 每当调用服务时会执行一次flask_app函数
+def flask_app():
+    request_json = flask.request.get_data()
+    request_dict = json.loads(request_json)
+    image = image_decode(request_dict['image'])
+    # ...
+    result = image.shape
+    return result
+
+
+if __name__ == '__main__':
+    print('| 使用flask启动服务 |')
+    app.run(host='0.0.0.0', port=9999, debug=False)  # 启动服务

+ 24 - 0
gunicorn_config.py

@@ -0,0 +1,24 @@
+# pip install gunicorn -i https://pypi.tuna.tsinghua.edu.cn/simple
+# 使用gunicorn启用flask:gunicorn -c gunicorn_config.py flask_start:app
+# 设置端口,外部访问端口也会从http://host:port/name/变为http://bind/name/
+bind = '0.0.0.0:9999'
+# 设置进程数。推荐核数*2+1发挥最佳性能
+workers = 3
+# 客户端最大连接数,默认1000
+worker_connections = 2000
+# 设置工作模型。有sync(同步)(默认)、eventlet(协程异步)、gevent(协程异步)、tornado、gthread(线程)。
+# sync根据请求先来后到处理。eventlet需要安装库:pip install eventlet。gevent需要安装库:pip install gevent。
+# tornado需要安装库:pip install tornado。gthread需要指定threads参数
+worker_class = 'sync'
+# 设置线程数。指定threads参数时工作模式自动变成gthread(线程)模式
+threads = 1
+# 启动程序时的超时时间(s)
+timeout = 60
+# 当代码有修改时会自动重启,适用于开发环境,默认False
+reload = True
+# 设置日志的记录地址。需要提前创建gunicorn_log文件夹
+accesslog = 'gunicorn_log/access.log'
+# 设置错误信息的记录地址。需要提前创建gunicorn_log文件夹
+errorlog = 'gunicorn_log/error.log'
+# 设置日志的记录水平。有debug、info(默认)、warning、error、critical,按照记录信息的详细程度排序
+loglevel = 'info'

+ 344 - 0
model/layer.py

@@ -0,0 +1,344 @@
+import torch
+
+
+class cbs(torch.nn.Module):
+    def __init__(self, in_, out_, kernel_size, stride):
+        super().__init__()
+        self.conv = torch.nn.Conv2d(in_, out_, kernel_size=kernel_size, stride=stride, padding=(kernel_size - 1) // 2,
+                                    bias=False)
+        self.bn = torch.nn.BatchNorm2d(out_, eps=0.001, momentum=0.03)
+        self.silu = torch.nn.SiLU(inplace=True)
+
+    def forward(self, x):
+        x = self.conv(x)
+        x = self.bn(x)
+        x = self.silu(x)
+        return x
+
+
+class concat(torch.nn.Module):
+    def __init__(self, dim=1):
+        super().__init__()
+        self.concat = torch.concat
+        self.dim = dim
+
+    def forward(self, x):
+        x = self.concat(x, dim=self.dim)
+        return x
+
+
+class residual(torch.nn.Module):  # in_->in_,len->len
+    def __init__(self, in_, config=None):
+        super().__init__()
+        if not config:  # 正常版本
+            self.cbs0 = cbs(in_, in_, kernel_size=1, stride=1)
+            self.cbs1 = cbs(in_, in_, kernel_size=3, stride=1)
+        else:  # 剪枝版本。len(config) = 2
+            self.cbs0 = cbs(in_, config[0], kernel_size=1, stride=1)
+            self.cbs1 = cbs(config[0], config[1], kernel_size=3, stride=1)
+
+    def forward(self, x):
+        x0 = self.cbs0(x)
+        x0 = self.cbs1(x0)
+        return x + x0
+
+
+class c3(torch.nn.Module):  # in_->out_,len->len
+    def __init__(self, in_, out_, n, config=None):
+        super().__init__()
+        if not config:  # 正常版本
+            self.cbs0 = cbs(in_, in_ // 2, kernel_size=1, stride=1)
+            self.sequential1 = torch.nn.Sequential(*(residual(in_ // 2) for _ in range(n)))
+            self.cbs2 = cbs(in_, in_ // 2, kernel_size=1, stride=1)
+            self.concat3 = concat(dim=1)
+            self.cbs4 = cbs(in_, out_, kernel_size=1, stride=1)
+        else:  # 剪枝版本。len(config) = 3 + 2 * n
+            self.cbs0 = cbs(in_, config[0], kernel_size=1, stride=1)
+            self.sequential1 = torch.nn.Sequential(
+                *(residual(config[0 + 2 * _] if _ == 0 else config[1 + 2 * _] + config[2 + 2 * _],
+                           config[1 + 2 * _:3 + 2 * _]) for _ in range(n)))
+            self.cbs2 = cbs(config[0], config[1 + 2 * n], kernel_size=1, stride=1)
+            self.concat3 = concat(dim=1)
+            self.cbs4 = cbs(config[0] + config[2 * n - 1] + config[2 * n] + config[1 + 2 * n], config[2 + 2 * n],
+                            kernel_size=1, stride=1)
+
+    def forward(self, x):
+        x0 = self.cbs0(x)
+        x1 = self.sequential1(x0)
+        x1 = x0 + x1
+        x2 = self.cbs2(x)
+        x = self.concat3([x1, x2])
+        x = self.cbs4(x)
+        return x
+
+
+class elan(torch.nn.Module):  # in_->out_,len->len
+    def __init__(self, in_, out_, n, config=None):
+        super().__init__()
+        if not config:  # 正常版本
+            self.cbs0 = cbs(in_, out_ // 4, kernel_size=1, stride=1)
+            self.cbs1 = cbs(in_, out_ // 4, kernel_size=1, stride=1)
+            self.sequential2 = torch.nn.Sequential(
+                *(cbs(out_ // 4, out_ // 4, kernel_size=3, stride=1) for _ in range(n)))
+            self.sequential3 = torch.nn.Sequential(
+                *(cbs(out_ // 4, out_ // 4, kernel_size=3, stride=1) for _ in range(n)))
+            self.concat4 = concat()
+            self.cbs5 = cbs(out_, out_, kernel_size=1, stride=1)
+        else:  # 剪枝版本。len(config) = 3 + 2 * n
+            self.cbs0 = cbs(in_, config[0], kernel_size=1, stride=1)
+            self.cbs1 = cbs(in_, config[1], kernel_size=1, stride=1)
+            self.sequential2 = torch.nn.Sequential(
+                *(cbs(config[1 + _], config[2 + _], kernel_size=3, stride=1) for _ in range(n)))
+            self.sequential3 = torch.nn.Sequential(
+                *(cbs(config[1 + n + _], config[2 + n + _], kernel_size=3, stride=1) for _ in range(n)))
+            self.concat4 = concat()
+            self.cbs5 = cbs(config[0] + config[1] + config[1 + n] + config[1 + 2 * n], config[2 + 2 * n],
+                            kernel_size=1, stride=1)
+
+    def forward(self, x):
+        x0 = self.cbs0(x)
+        x1 = self.cbs1(x)
+        x2 = self.sequential2(x1)
+        x3 = self.sequential3(x2)
+        x = self.concat4([x0, x1, x2, x3])
+        x = self.cbs5(x)
+        return x
+
+
+class elan_h(torch.nn.Module):  # in_->out_,len->len
+    def __init__(self, in_, out_, config=None):
+        super().__init__()
+        if not config:  # 正常版本
+            self.cbs0 = cbs(in_, in_ // 2, kernel_size=1, stride=1)
+            self.cbs1 = cbs(in_, in_ // 2, kernel_size=1, stride=1)
+            self.cbs2 = cbs(in_ // 2, in_ // 4, kernel_size=3, stride=1)
+            self.cbs3 = cbs(in_ // 4, in_ // 4, kernel_size=3, stride=1)
+            self.cbs4 = cbs(in_ // 4, in_ // 4, kernel_size=3, stride=1)
+            self.cbs5 = cbs(in_ // 4, in_ // 4, kernel_size=3, stride=1)
+            self.concat6 = concat()
+            self.cbs7 = cbs(2 * in_, out_, kernel_size=1, stride=1)
+        else:  # 剪枝版本。len(config) = 7
+            self.cbs0 = cbs(in_, config[0], kernel_size=1, stride=1)
+            self.cbs1 = cbs(in_, config[1], kernel_size=1, stride=1)
+            self.cbs2 = cbs(config[1], config[2], kernel_size=3, stride=1)
+            self.cbs3 = cbs(config[2], config[3], kernel_size=3, stride=1)
+            self.cbs4 = cbs(config[3], config[4], kernel_size=3, stride=1)
+            self.cbs5 = cbs(config[4], config[5], kernel_size=3, stride=1)
+            self.concat6 = concat()
+            self.cbs7 = cbs(config[0] + config[1] + config[2] + config[3] + config[4] + config[5], config[6],
+                            kernel_size=1, stride=1)
+
+    def forward(self, x):
+        x0 = self.cbs0(x)
+        x1 = self.cbs1(x)
+        x2 = self.cbs2(x1)
+        x3 = self.cbs3(x2)
+        x4 = self.cbs4(x3)
+        x5 = self.cbs5(x4)
+        x = self.concat6([x0, x1, x2, x3, x4, x5])
+        x = self.cbs7(x)
+        return x
+
+
+class mp(torch.nn.Module):  # in_->out_,len->len//2
+    def __init__(self, in_, out_, config=None):
+        super().__init__()
+        if not config:  # 正常版本
+            self.maxpool0 = torch.nn.MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1)
+            self.cbs1 = cbs(in_, out_ // 2, 1, 1)
+            self.cbs2 = cbs(in_, out_ // 2, 1, 1)
+            self.cbs3 = cbs(out_ // 2, out_ // 2, 3, 2)
+            self.concat4 = concat(dim=1)
+        else:  # 剪枝版本。len(config) = 3
+            self.maxpool0 = torch.nn.MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1)
+            self.cbs1 = cbs(in_, config[0], 1, 1)
+            self.cbs2 = cbs(in_, config[1], 1, 1)
+            self.cbs3 = cbs(config[1], config[2], 3, 2)
+            self.concat4 = concat(dim=1)
+
+    def forward(self, x):
+        x0 = self.maxpool0(x)
+        x0 = self.cbs1(x0)
+        x1 = self.cbs2(x)
+        x1 = self.cbs3(x1)
+        x = self.concat4([x0, x1])
+        return x
+
+
+class sppf(torch.nn.Module):  # in_->out_,len->len
+    def __init__(self, in_, out_, config=None):
+        super().__init__()
+        if not config:  # 正常版本
+            self.cbs0 = cbs(in_, in_ // 2, kernel_size=1, stride=1)
+            self.MaxPool2d1 = torch.nn.MaxPool2d(kernel_size=5, stride=1, padding=2, dilation=1)
+            self.MaxPool2d2 = torch.nn.MaxPool2d(kernel_size=9, stride=1, padding=4, dilation=1)
+            self.MaxPool2d3 = torch.nn.MaxPool2d(kernel_size=13, stride=1, padding=6, dilation=1)
+            self.concat4 = concat(dim=1)
+            self.cbs5 = cbs(2 * in_, out_, kernel_size=1, stride=1)
+        else:  # 剪枝版本。len(config) = 2
+            self.cbs0 = cbs(in_, config[0], kernel_size=1, stride=1)
+            self.MaxPool2d1 = torch.nn.MaxPool2d(kernel_size=5, stride=1, padding=2, dilation=1)
+            self.MaxPool2d2 = torch.nn.MaxPool2d(kernel_size=9, stride=1, padding=4, dilation=1)
+            self.MaxPool2d3 = torch.nn.MaxPool2d(kernel_size=13, stride=1, padding=6, dilation=1)
+            self.concat4 = concat(dim=1)
+            self.cbs5 = cbs(4 * config[0], config[1], kernel_size=1, stride=1)
+
+    def forward(self, x):
+        x = self.cbs0(x)
+        x0 = self.MaxPool2d1(x)
+        x1 = self.MaxPool2d2(x0)
+        x2 = self.MaxPool2d3(x1)
+        x = self.concat4([x, x0, x1, x2])
+        x = self.cbs5(x)
+        return x
+
+
+class sppcspc(torch.nn.Module):  # in_->out_,len->len
+    def __init__(self, in_, out_, config=None):
+        super().__init__()
+        if not config:  # 正常版本
+            self.cbs0 = cbs(in_, in_ // 2, kernel_size=1, stride=1)
+            self.cbs1 = cbs(in_, in_ // 2, kernel_size=1, stride=1)
+            self.cbs2 = cbs(in_ // 2, in_ // 2, kernel_size=3, stride=1)
+            self.cbs3 = cbs(in_ // 2, in_ // 2, kernel_size=1, stride=1)
+            self.MaxPool2d4 = torch.nn.MaxPool2d(kernel_size=5, stride=1, padding=2, dilation=1)
+            self.MaxPool2d5 = torch.nn.MaxPool2d(kernel_size=9, stride=1, padding=4, dilation=1)
+            self.MaxPool2d6 = torch.nn.MaxPool2d(kernel_size=13, stride=1, padding=6, dilation=1)
+            self.concat7 = concat(dim=1)
+            self.cbs8 = cbs(2 * in_, in_ // 2, kernel_size=1, stride=1)
+            self.cbs9 = cbs(in_ // 2, in_ // 2, kernel_size=3, stride=1)
+            self.concat10 = concat(dim=1)
+            self.cbs11 = cbs(in_, out_, kernel_size=1, stride=1)
+        else:  # 剪枝版本。len(config) = 7
+            self.cbs0 = cbs(in_, config[0], kernel_size=1, stride=1)
+            self.cbs1 = cbs(in_, config[1], kernel_size=1, stride=1)
+            self.cbs2 = cbs(config[1], config[2], kernel_size=3, stride=1)
+            self.cbs3 = cbs(config[2], config[3], kernel_size=1, stride=1)
+            self.MaxPool2d4 = torch.nn.MaxPool2d(kernel_size=5, stride=1, padding=2, dilation=1)
+            self.MaxPool2d5 = torch.nn.MaxPool2d(kernel_size=9, stride=1, padding=4, dilation=1)
+            self.MaxPool2d6 = torch.nn.MaxPool2d(kernel_size=13, stride=1, padding=6, dilation=1)
+            self.concat7 = concat(dim=1)
+            self.cbs8 = cbs(4 * config[3], config[4], kernel_size=1, stride=1)
+            self.cbs9 = cbs(config[4], config[5], kernel_size=3, stride=1)
+            self.concat10 = concat(dim=1)
+            self.cbs11 = cbs(config[0] + config[5], config[6], kernel_size=1, stride=1)
+
+    def forward(self, x):
+        x0 = self.cbs0(x)
+        x1 = self.cbs1(x)
+        x1 = self.cbs2(x1)
+        x1 = self.cbs3(x1)
+        x4 = self.MaxPool2d4(x1)
+        x5 = self.MaxPool2d5(x1)
+        x6 = self.MaxPool2d6(x1)
+        x = self.concat7([x1, x4, x5, x6])
+        x = self.cbs8(x)
+        x = self.cbs9(x)
+        x = self.concat10([x, x0])
+        x = self.cbs11(x)
+        return x
+
+
+class head(torch.nn.Module):  # in_->(batch, 3, output_size, output_size, 5+output_class)),len->len
+    def __init__(self, in_, output_size, output_class):
+        super().__init__()
+        self.output_size = output_size
+        self.output_class = output_class
+        self.output = torch.nn.Conv2d(in_, 3 * (5 + output_class), kernel_size=1, stride=1, padding=0)
+
+    def forward(self, x):
+        x = self.output(x).reshape(-1, 3, self.output_size, self.output_size, 5 + self.output_class)  # 变形
+        return x
+
+
+# 参考yolox
+class split_head(torch.nn.Module):  # in_->(batch, 3, output_size, output_size, 5+output_class)),len->len
+    def __init__(self, in_, output_size, output_class, config=None):
+        super().__init__()
+        self.output_size = output_size
+        self.output_class = output_class
+        if not config:  # 正常版本
+            out_ = 3 * (5 + self.output_class)
+            self.cbs0 = cbs(in_, out_, kernel_size=1, stride=1)
+            self.cbs1 = cbs(out_, out_, kernel_size=3, stride=1)
+            self.cbs2 = cbs(out_, out_, kernel_size=3, stride=1)
+            self.cbs3 = cbs(out_, out_, kernel_size=3, stride=1)
+            self.cbs4 = cbs(out_, out_, kernel_size=3, stride=1)
+            self.Conv2d5 = torch.nn.Conv2d(out_, 12, kernel_size=1, stride=1, padding=0)
+            self.Conv2d6 = torch.nn.Conv2d(out_, 3, kernel_size=1, stride=1, padding=0)
+            self.Conv2d7 = torch.nn.Conv2d(out_, 3 * self.output_class, kernel_size=1, stride=1, padding=0)
+            self.concat8 = concat(4)
+        else:  # 剪枝版本。len(config) = 8
+            self.cbs0 = cbs(in_, config[0], kernel_size=1, stride=1)
+            self.cbs1 = cbs(config[0], config[1], kernel_size=1, stride=1)
+            self.cbs2 = cbs(config[1], config[2], kernel_size=1, stride=1)
+            self.cbs3 = cbs(config[0], config[3], kernel_size=1, stride=1)
+            self.cbs4 = cbs(config[3], config[4], kernel_size=1, stride=1)
+            self.Conv2d5 = torch.nn.Conv2d(config[5], 12, kernel_size=1, stride=1, padding=0)
+            self.Conv2d6 = torch.nn.Conv2d(config[6], 3, kernel_size=1, stride=1, padding=0)
+            self.Conv2d7 = torch.nn.Conv2d(config[7], 3 * self.output_class, kernel_size=1, stride=1, padding=0)
+            self.concat8 = concat(4)
+
+    def forward(self, x):
+        x = self.cbs0(x)
+        x0 = self.cbs1(x)
+        x0 = self.cbs2(x0)
+        x1 = self.cbs3(x)
+        x1 = self.cbs4(x1)
+        x2 = self.Conv2d5(x0).reshape(-1, 3, self.output_size, self.output_size, 4)  # 变形
+        x3 = self.Conv2d6(x0).reshape(-1, 3, self.output_size, self.output_size, 1)  # 变形
+        x4 = self.Conv2d7(x1).reshape(-1, 3, self.output_size, self.output_size, self.output_class)  # 变形
+        x = self.concat8([x2, x3, x4])
+        return x
+
+
+class image_deal(torch.nn.Module):  # 归一化
+    def __init__(self):
+        super().__init__()
+
+    def forward(self, x):
+        x = x / 255
+        x = x.permute(0, 3, 1, 2)
+        return x
+
+
+class decode(torch.nn.Module):  # (Cx,Cy,w,h,confidence...)原始输出->(Cx,Cy,w,h,confidence...)真实坐标
+    def __init__(self, input_size):
+        super().__init__()
+        self.stride = (8, 16, 32)
+        output_size = [int(input_size // i) for i in self.stride]
+        self.anchor = (((12, 16), (19, 36), (40, 28)), ((36, 75), (76, 55), (72, 146)),
+                       ((142, 110), (192, 243), (459, 401)))
+        self.grid = [0, 0, 0]
+        for i in range(3):
+            self.grid[i] = torch.arange(output_size[i])
+        self.frame_sigmoid = torch.nn.Sigmoid()
+
+    def forward(self, output):
+        device = output[0].device
+        # 遍历每一个大层
+        for i in range(3):
+            self.grid[i] = self.grid[i].to(device)  # 放到对应的设备上
+            # 中心坐标[0-1]->[-0.5-1.5]->[-0.5*stride-80/40/20.5*stride]
+            output[i] = self.frame_sigmoid(output[i])  # 边框输出归一化
+            output[i][..., 0] = (2 * output[i][..., 0] - 0.5 + self.grid[i].unsqueeze(1)) * self.stride[i]
+            output[i][..., 1] = (2 * output[i][..., 1] - 0.5 + self.grid[i]) * self.stride[i]
+            # 遍历每一个大层中的小层
+            for j in range(3):
+                output[i][:, j, ..., 2] = 4 * output[i][:, j, ..., 2] ** 2 * self.anchor[i][j][0]  # [0-1]->[0-4*anchor]
+                output[i][:, j, ..., 3] = 4 * output[i][:, j, ..., 3] ** 2 * self.anchor[i][j][1]  # [0-1]->[0-4*anchor]
+        return output
+
+
+class deploy(torch.nn.Module):
+    def __init__(self, model, input_size):
+        super().__init__()
+        self.image_deal = image_deal()
+        self.model = model
+        self.decode = decode(input_size)
+
+    def forward(self, x):
+        x = self.image_deal(x)
+        x = self.model(x)
+        x = self.decode(x)
+        return x

+ 96 - 0
model/yolov5.py

@@ -0,0 +1,96 @@
+# 根据yolov5改编:https://github.com/ultralytics/yolov5
+import torch
+from model.layer import cbs, c3, sppf, concat, head
+
+
+class yolov5(torch.nn.Module):
+    def __init__(self, args):
+        super().__init__()
+        dim_dict = {'n': 8, 's': 16, 'm': 32, 'l': 64}
+        n_dict = {'n': 1, 's': 1, 'm': 2, 'l': 3}
+        dim = dim_dict[args.model_type]
+        n = n_dict[args.model_type]
+        input_size = args.input_size
+        stride = (8, 16, 32)
+        self.output_size = [int(input_size // i) for i in stride]  # 每个输出层的尺寸,如(80,40,20)
+        self.output_class = args.output_class
+        # 网络结构
+        self.l0 = cbs(3, dim, 6, 2)  # 1/2
+        self.l1 = cbs(dim, 2 * dim, 3, 2)  # 1/4
+        # ---------- #
+        self.l2 = c3(2 * dim, 2 * dim, n)
+        self.l3 = cbs(2 * dim, 4 * dim, 3, 2)  # 1/8
+        self.l4 = c3(4 * dim, 4 * dim, 2 * n)
+        self.l5 = cbs(4 * dim, 8 * dim, 3, 2)  # 1/16
+        self.l6 = c3(8 * dim, 8 * dim, 3 * n)
+        self.l7 = cbs(8 * dim, 16 * dim, 3, 2)  # 1/32
+        self.l8 = c3(16 * dim, 16 * dim, n)
+        self.l9 = sppf(16 * dim, 16 * dim)
+        self.l10 = cbs(16 * dim, 8 * dim, 1, 1)
+        # ---------- #
+        self.l11 = torch.nn.Upsample(scale_factor=2)  # 1/16
+        self.l12 = concat(1)
+        self.l13 = c3(16 * dim, 8 * dim, n)
+        self.l14 = cbs(8 * dim, 4 * dim, 1, 1)
+        # ---------- #
+        self.l15 = torch.nn.Upsample(scale_factor=2)  # 1/8
+        self.l16 = concat(1)
+        self.l17 = c3(8 * dim, 4 * dim, n)  # 接output0
+        # ---------- #
+        self.l18 = cbs(4 * dim, 4 * dim, 3, 2)  # 1/16
+        self.l19 = concat(1)
+        self.l20 = c3(8 * dim, 8 * dim, n)  # 接output1
+        # ---------- #
+        self.l21 = cbs(8 * dim, 8 * dim, 3, 2)  # 1/32
+        self.l22 = concat(1)
+        self.l23 = c3(16 * dim, 16 * dim, n)  # 接output2
+        # ---------- #
+        self.output0 = head(4 * dim, self.output_size[0], self.output_class)
+        self.output1 = head(8 * dim, self.output_size[1], self.output_class)
+        self.output2 = head(16 * dim, self.output_size[2], self.output_class)
+
+    def forward(self, x):
+        x = self.l0(x)
+        x = self.l1(x)
+        x = self.l2(x)
+        x = self.l3(x)
+        l4 = self.l4(x)
+        x = self.l5(l4)
+        l6 = self.l6(x)
+        x = self.l7(l6)
+        x = self.l8(x)
+        x = self.l9(x)
+        l10 = self.l10(x)
+        x = self.l11(l10)
+        x = self.l12([x, l6])
+        x = self.l13(x)
+        l14 = self.l14(x)
+        x = self.l15(l14)
+        x = self.l16([x, l4])
+        x = self.l17(x)
+        output0 = self.output0(x)
+        x = self.l18(x)
+        x = self.l19([x, l14])
+        x = self.l20(x)
+        output1 = self.output1(x)
+        x = self.l21(x)
+        x = self.l22([x, l10])
+        x = self.l23(x)
+        output2 = self.output2(x)
+        return [output0, output1, output2]
+
+
+if __name__ == '__main__':
+    import argparse
+
+    parser = argparse.ArgumentParser(description='')
+    parser.add_argument('--prune', default=False, type=bool)
+    parser.add_argument('--model_type', default='n', type=str)
+    parser.add_argument('--input_size', default=640, type=int)
+    parser.add_argument('--output_class', default=1, type=int)
+    args = parser.parse_args()
+    model = yolov5(args)
+    tensor = torch.rand(2, 3, args.input_size, args.input_size, dtype=torch.float32)
+    pred = model(tensor)
+    print(model)
+    print(pred[0].shape, pred[1].shape, pred[2].shape)

+ 146 - 0
model/yolov7.py

@@ -0,0 +1,146 @@
+# 根据yolov7改编:https://github.com/WongKinYiu/yolov7
+import torch
+from model.layer import cbs, elan, elan_h, mp, sppcspc, concat, head
+
+
+class yolov7(torch.nn.Module):
+    def __init__(self, args):
+        super().__init__()
+        dim_dict = {'n': 8, 's': 16, 'm': 32, 'l': 64}
+        n_dict = {'n': 1, 's': 1, 'm': 2, 'l': 3}
+        dim = dim_dict[args.model_type]
+        n = n_dict[args.model_type]
+        input_size = args.input_size
+        stride = (8, 16, 32)
+        self.output_size = [int(input_size // i) for i in stride]  # 每个输出层的尺寸,如(80,40,20)
+        self.output_class = args.output_class
+        # 网络结构
+        if not args.prune:  # 正常版本
+            self.l0 = cbs(3, dim, 3, 1)
+            self.l1 = cbs(dim, 2 * dim, 3, 2)  # input_size/2
+            self.l2 = cbs(2 * dim, 2 * dim, 3, 1)
+            self.l3 = cbs(2 * dim, 4 * dim, 3, 2)  # input_size/4
+            # ---------- #
+            self.l4 = elan(4 * dim, 8 * dim, n)
+            self.l5 = mp(8 * dim, 8 * dim)  # input_size/8
+            self.l6 = elan(8 * dim, 16 * dim, n)
+            self.l7 = mp(16 * dim, 16 * dim)  # input_size/16
+            self.l8 = elan(16 * dim, 32 * dim, n)
+            self.l9 = mp(32 * dim, 32 * dim)  # input_size/32
+            self.l10 = elan(32 * dim, 32 * dim, n)
+            self.l11 = sppcspc(32 * dim, 16 * dim)
+            self.l12 = cbs(16 * dim, 8 * dim, 1, 1)
+            # ---------- #
+            self.l13 = torch.nn.Upsample(scale_factor=2)  # input_size/16
+            self.l8_add = cbs(32 * dim, 8 * dim, 1, 1)
+            self.l14 = concat(1)
+            self.l15 = elan_h(16 * dim, 8 * dim)
+            self.l16 = cbs(8 * dim, 4 * dim, 1, 1)
+            # ---------- #
+            self.l17 = torch.nn.Upsample(scale_factor=2)  # input_size/8
+            self.l6_add = cbs(16 * dim, 4 * dim, 1, 1)
+            self.l18 = concat(1)
+            self.l19 = elan_h(8 * dim, 4 * dim)  # 接output0
+            # ---------- #
+            self.l20 = mp(4 * dim, 8 * dim)
+            self.l21 = concat(1)
+            self.l22 = elan_h(16 * dim, 8 * dim)  # 接output1
+            # ---------- #
+            self.l23 = mp(8 * dim, 16 * dim)
+            self.l24 = concat(1)
+            self.l25 = elan_h(32 * dim, 16 * dim)  # 接output2
+            # ---------- #
+            self.output0 = head(4 * dim, self.output_size[0], self.output_class)
+            self.output1 = head(8 * dim, self.output_size[1], self.output_class)
+            self.output2 = head(16 * dim, self.output_size[2], self.output_class)
+        else:  # 剪枝版本
+            config = args.prune_num
+            self.l0 = cbs(3, config[0], 1, 1)
+            self.l1 = cbs(config[0], config[1], 3, 2)  # input_size/2
+            self.l2 = cbs(config[1], config[2], 1, 1)
+            self.l3 = cbs(config[2], config[3], 3, 2)  # input_size/4
+            # ---------- #
+            self.l4 = elan(config[3], None, n, config[4:7 + 2 * n])
+            self.l5 = mp(config[6 + 2 * n], None, config[7 + 2 * n:10 + 2 * n])  # input_size/8
+            self.l6 = elan(config[7 + 2 * n] + config[9 + 2 * n], None, n, config[10 + 2 * n:13 + 4 * n])
+            self.l7 = mp(config[12 + 4 * n], None, config[13 + 4 * n:16 + 4 * n])  # input_size/16
+            self.l8 = elan(config[13 + 4 * n] + config[15 + 4 * n], None, n, config[16 + 4 * n:19 + 6 * n])
+            self.l9 = mp(config[18 + 6 * n], None, config[19 + 6 * n:22 + 6 * n])  # input_size/32
+            self.l10 = elan(config[19 + 6 * n] + config[21 + 6 * n], None, n, config[22 + 6 * n:25 + 8 * n])
+            self.l11 = sppcspc(config[24 + 8 * n], None, config[25 + 8 * n:32 + 8 * n])
+            self.l12 = cbs(config[31 + 8 * n], config[32 + 8 * n], 1, 1)
+            # ---------- #
+            self.l13 = torch.nn.Upsample(scale_factor=2)  # input_size/16
+            self.l8_add = cbs(config[18 + 6 * n], config[33 + 8 * n], 1, 1)
+            self.l14 = concat(1)
+            self.l15 = elan_h(config[32 + 8 * n] + config[33 + 8 * n], None, config[34 + 8 * n:41 + 8 * n])
+            self.l16 = cbs(config[40 + 8 * n], config[41 + 8 * n], 1, 1)
+            # ---------- #
+            self.l17 = torch.nn.Upsample(scale_factor=2)  # input_size/8
+            self.l6_add = cbs(config[12 + 4 * n], config[42 + 8 * n], 1, 1)
+            self.l18 = concat(1)
+            self.l19 = elan_h(config[41 + 8 * n] + config[42 + 8 * n], None, config[43 + 8 * n:50 + 8 * n])  # 接output0
+            # ---------- #
+            self.l20 = mp(config[49 + 8 * n], None, config[50 + 8 * n:53 + 8 * n])
+            self.l21 = concat(1)
+            self.l22 = elan_h(config[40 + 8 * n] + config[50 + 8 * n] + config[52 + 8 * n], None,
+                              config[53 + 8 * n:60 + 8 * n])  # 接output1
+            # ---------- #
+            self.l23 = mp(config[59 + 8 * n], None, config[60 + 8 * n:63 + 8 * n])
+            self.l24 = concat(1)
+            self.l25 = elan_h(config[31 + 8 * n] + config[60 + 8 * n] + config[62 + 8 * n], None,
+                              config[63 + 8 * n:70 + 8 * n])  # 接output2
+            # ---------- #
+            self.output0 = head(config[49 + 8 * n], self.output_size[0], self.output_class)
+            self.output1 = head(config[59 + 8 * n], self.output_size[1], self.output_class)
+            self.output2 = head(config[69 + 8 * n], self.output_size[2], self.output_class)
+
+    def forward(self, x):
+        x = self.l0(x)
+        x = self.l1(x)
+        x = self.l2(x)
+        x = self.l3(x)
+        x = self.l4(x)
+        x = self.l5(x)
+        l6 = self.l6(x)
+        x = self.l7(l6)
+        l8 = self.l8(x)
+        x = self.l9(l8)
+        x = self.l10(x)
+        l11 = self.l11(x)
+        x = self.l12(l11)
+        x = self.l13(x)
+        l8_add = self.l8_add(l8)
+        x = self.l14([x, l8_add])
+        l15 = self.l15(x)
+        x = self.l16(l15)
+        x = self.l17(x)
+        l6_add = self.l6_add(l6)
+        x = self.l18([x, l6_add])
+        x = self.l19(x)
+        output0 = self.output0(x)
+        x = self.l20(x)
+        x = self.l21([x, l15])
+        x = self.l22(x)
+        output1 = self.output1(x)
+        x = self.l23(x)
+        x = self.l24([x, l11])
+        x = self.l25(x)
+        output2 = self.output2(x)
+        return [output0, output1, output2]
+
+
+if __name__ == '__main__':
+    import argparse
+
+    parser = argparse.ArgumentParser(description='')
+    parser.add_argument('--prune', default=False, type=bool)
+    parser.add_argument('--model_type', default='n', type=str)
+    parser.add_argument('--input_size', default=640, type=int)
+    parser.add_argument('--output_class', default=1, type=int)
+    args = parser.parse_args()
+    model = yolov7(args)
+    tensor = torch.rand(2, 3, args.input_size, args.input_size, dtype=torch.float32)
+    pred = model(tensor)
+    print(model)
+    print(pred[0].shape, pred[1].shape, pred[2].shape)

+ 131 - 0
predict_pt.py

@@ -0,0 +1,131 @@
+import os
+import cv2
+import time
+import torch
+import argparse
+import torchvision
+import numpy as np
+import albumentations
+from model.layer import deploy
+
+# -------------------------------------------------------------------------------------------------------------------- #
+parser = argparse.ArgumentParser(description='|pt模型推理|')
+parser.add_argument('--model_path', default=r'D:\桌面\ObjectDetection-main\last.pt', type=str, help='|pt模型位置|')
+parser.add_argument('--image_path', default=r'D:\桌面\ObjectDetection-main\datasets\coco_wm\images\test2017_wm', type=str, help='|图片文件夹位置|')
+parser.add_argument('--input_size', default=640, type=int, help='|模型输入图片大小|')
+parser.add_argument('--batch', default=1, type=int, help='|输入图片批量|')
+parser.add_argument('--confidence_threshold', default=0.35, type=float, help='|置信筛选度阈值(>阈值留下)|')
+parser.add_argument('--iou_threshold', default=0.65, type=float, help='|iou阈值筛选阈值(<阈值留下)|')
+parser.add_argument('--device', default='cuda', type=str, help='|推理设备|')
+parser.add_argument('--num_worker', default=0, type=int, help='|CPU处理数据的进程数,0只有一个主进程,一般为0、2、4、8|')
+parser.add_argument('--float16', default=False, type=bool, help='|推理数据类型,要支持float16的GPU,False时为float32|')
+args, _ = parser.parse_known_args()  # 防止传入参数冲突,替代args = parser.parse_args()
+args.model_path = args.model_path.split('.')[0] + '.pt'
+# -------------------------------------------------------------------------------------------------------------------- #
+assert os.path.exists(args.model_path), f'! model_path不存在:{args.model_path} !'
+# assert os.path.exists(args.data_path), f'! data_path不存在:{args.data_path} !'
+if args.float16:
+    assert torch.cuda.is_available(), 'cuda不可用,因此无法使用float16'
+
+
+# -------------------------------------------------------------------------------------------------------------------- #
+def confidence_screen(pred, confidence_threshold):
+    result = []
+    for i in range(len(pred)):  # 对一张图片的每个输出层分别进行操作
+        judge = torch.where(pred[i][..., 4] > confidence_threshold, True, False)
+        result.append((pred[i][judge]))
+    result = torch.concat(result, dim=0)
+    if result.shape[0] == 0:
+        return result
+    index = torch.argsort(result[:, 4], dim=0, descending=True)
+    result = result[index]
+    return result
+
+
+def iou_single(A, B):  # 输入为(batch,(x_min,y_min,w,h))相对/真实坐标
+    x1 = torch.maximum(A[:, 0], B[0])
+    y1 = torch.maximum(A[:, 1], B[1])
+    x2 = torch.minimum(A[:, 0] + A[:, 2], B[0] + B[2])
+    y2 = torch.minimum(A[:, 1] + A[:, 3], B[1] + B[3])
+    zeros = torch.zeros(1, device=A.device)
+    intersection = torch.maximum(x2 - x1, zeros) * torch.maximum(y2 - y1, zeros)
+    union = A[:, 2] * A[:, 3] + B[2] * B[3] - intersection
+    return intersection / union
+
+
+def nms(pred, iou_threshold):  # 输入为(batch,(x_min,y_min,w,h))相对/真实坐标
+    pred[:, 2:4] = pred[:, 0:2] + pred[:, 2:4]  # (x_min,y_min,x_max,y_max)真实坐标
+    index = torchvision.ops.nms(pred[:, 0:4], pred[:, 4], 1 - iou_threshold)[:100]  # 非极大值抑制,最多100
+    pred = pred[index]
+    pred[:, 2:4] = pred[:, 2:4] - pred[:, 0:2]  # (x_min,y_min,w,h)真实坐标
+    return pred
+
+
+def draw(image, frame, cls, name):  # 输入(x_min,y_min,w,h)真实坐标
+    image = image.astype(np.uint8)
+    for i in range(len(frame)):
+        a = (int(frame[i][0]), int(frame[i][1]))
+        b = (int(frame[i][0] + frame[i][2]), int(frame[i][1] + frame[i][3]))
+        cv2.rectangle(image, a, b, color=(0, 255, 0), thickness=2)
+        cv2.putText(image, 'class:' + str(cls[i]), a, cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 255, 0), 1)
+    cv2.imwrite('save_' + name, image)
+    print(f'| {name}: save_{name} |')
+
+
+def predict_pt(args):
+    # 加载模型
+    model_dict = torch.load(args.model_path, map_location='cpu')
+    model = model_dict['model']
+    model = deploy(model, args.input_size)
+    model.half().eval().to(args.device) if args.float16 else model.float().eval().to(args.device)
+    epoch = model_dict['epoch_finished']
+    m_ap = round(model_dict['standard'], 4)
+    print(f'| 模型加载成功:{args.model_path} | epoch:{epoch} | m_ap:{m_ap}|')
+    # 推理
+    image_dir = sorted(os.listdir(args.image_path))
+    start_time = time.time()
+    with torch.no_grad():
+        dataloader = torch.utils.data.DataLoader(torch_dataset(image_dir), batch_size=args.batch, shuffle=False,
+                                                 drop_last=False, pin_memory=False, num_workers=args.num_worker)
+        for item, (image_batch, name_batch) in enumerate(dataloader):
+            image_all = image_batch.cpu().numpy().astype(np.uint8)  # 转为numpy,用于画图
+            image_batch = image_batch.to(args.device)
+            pred_batch = model(image_batch)
+            # 对batch中的每张图片分别操作
+            for i in range(pred_batch[0].shape[0]):
+                pred = [_[i] for _ in pred_batch]  # (Cx,Cy,w,h)
+                pred = confidence_screen(pred, args.confidence_threshold)  # 置信度筛选
+                if pred.shape[0] == 0:
+                    print(f'{name_batch[i]}:None')
+                    continue
+                pred[:, 0:2] = pred[:, 0:2] - pred[:, 2:4] / 2  # (x_min,y_min,w,h)真实坐标
+                pred = nms(pred, args.iou_threshold)  # 非极大值抑制
+                frame = pred[:, 0:4]  # 边框
+                cls = torch.argmax(pred[:, 5:], dim=1)  # 类别
+                draw(image_all[i], frame.cpu().numpy(), cls.cpu().numpy(), name_batch[i])
+    end_time = time.time()
+    print('| 数据:{} 批量:{} 每张耗时:{:.4f} |'.format(len(image_dir), args.batch, (end_time - start_time) / len(image_dir)))
+
+
+class torch_dataset(torch.utils.data.Dataset):
+    def __init__(self, image_dir):
+        self.image_dir = image_dir
+        self.transform = albumentations.Compose([
+            albumentations.LongestMaxSize(args.input_size),
+            albumentations.PadIfNeeded(min_height=args.input_size, min_width=args.input_size,
+                                       border_mode=cv2.BORDER_CONSTANT, value=(128, 128, 128))])
+
+    def __len__(self):
+        return len(self.image_dir)
+
+    def __getitem__(self, index):
+        image = cv2.imread(args.image_path + '/' + self.image_dir[index])  # 读取图片
+        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)  # 转为RGB通道
+        image = self.transform(image=image)['image']  # 缩放和填充图片(归一化、调维度等在模型中完成)
+        image = torch.tensor(image, dtype=torch.float16 if args.float16 else torch.float32)
+        name = self.image_dir[index]
+        return image, name
+
+
+if __name__ == '__main__':
+    predict_pt(args)

+ 117 - 0
run.py

@@ -0,0 +1,117 @@
+# 数据需准备成以下格式(标准YOLO格式)
+# ├── 数据集路径:data_path
+#     └── image:存放所有图片
+#     └── label:存放所有图片的标签,名称:图片名.txt,内容:(类别号 x_center y_center w h\n)相对图片的比例值
+#     └── train.txt:训练图片的绝对路径(或相对data_path下路径)
+#     └── val.txt:验证图片的绝对路径(或相对data_path下路径)
+#     └── class.txt:所有的类别名称
+# class.csv内容如下:
+# 类别1
+# 类别2
+# ...
+# -------------------------------------------------------------------------------------------------------------------- #
+import os
+import wandb
+import torch
+import argparse
+from block.data_get import data_get
+from block.model_get import model_get
+from block.loss_get import loss_get
+from block.train_get import train_get
+
+# -------------------------------------------------------------------------------------------------------------------- #
+# 分布式数据并行训练:
+# python -m torch.distributed.launch --master_port 9999 --nproc_per_node n run.py --distributed True
+# master_port为GPU之间的通讯端口,空闲的即可
+# n为GPU数量
+# -------------------------------------------------------------------------------------------------------------------- #
+# 模型加载/创建的优先级为:加载已有模型>创建剪枝模型>创建自定义模型
+parser = argparse.ArgumentParser(description='|目标检测|')
+parser.add_argument('--wandb', default=False, type=bool, help='|是否使用wandb可视化|')
+parser.add_argument('--wandb_project', default='ObjectDetection', type=str, help='|wandb项目名称|')
+parser.add_argument('--wandb_name', default='train', type=str, help='|wandb项目中的训练名称|')
+parser.add_argument('--wandb_image_num', default=16, type=int, help='|wandb保存展示图片的数量|')
+parser.add_argument('--data_path', default=r'./datasets/coco_wm', type=str, help='|数据目录|')
+parser.add_argument('--input_size', default=640, type=int, help='|输入图片大小|')
+parser.add_argument('--output_class', default=80, type=int, help='|输出类别数|')
+parser.add_argument('--weight', default='last.pt', type=str, help='|已有模型的位置,没找到模型会创建剪枝/新模型|')
+parser.add_argument('--prune', default=False, type=bool, help='|模型剪枝后再训练(部分模型有),需要提供prune_weight|')
+parser.add_argument('--prune_weight', default='best.pt', type=str, help='|模型剪枝的参考模型,会创建剪枝模型和训练模型|')
+parser.add_argument('--prune_ratio', default=0.5, type=float, help='|模型剪枝时的保留比例|')
+parser.add_argument('--prune_save', default='prune_best.pt', type=str, help='|保存最佳模型,每轮还会保存prune_last.pt|')
+parser.add_argument('--model', default='yolov7', type=str, help='|自定义模型选择|')
+parser.add_argument('--model_type', default='n', type=str, help='|自定义模型型号|')
+parser.add_argument('--save_path', default='best.pt', type=str, help='|保存最佳模型,除此之外每轮还会保存last.pt|')
+parser.add_argument('--loss_weight', default=((1 / 3, 0.3, 0.5, 0.2), (1 / 3, 0.4, 0.4, 0.2), (1 / 3, 0.5, 0.3, 0.2)),
+                    type=tuple, help='|每个输出层(从大到小排序)的权重->[总权重、边框权重、置信度权重、分类权重]|')
+parser.add_argument('--label_smooth', default=(0.01, 0.99), type=tuple, help='|标签平滑的值|')
+parser.add_argument('--epoch', default=10, type=int, help='|训练总轮数(包含之前已训练轮数)|')
+parser.add_argument('--batch', default=2, type=int, help='|训练批量大小,分布式时为总批量|')
+parser.add_argument('--warmup_ratio', default=0.01, type=float, help='|预热训练步数占总步数比例,最少5步,基准为0.01|')
+parser.add_argument('--lr_start', default=0.001, type=float, help='|初始学习率,adam算法,批量小时要减小,基准为0.001|')
+parser.add_argument('--lr_end_ratio', default=0.01, type=float, help='|最终学习率=lr_end_ratio*lr_start,基准为0.01|')
+parser.add_argument('--lr_end_epoch', default=300, type=int, help='|最终学习率达到的轮数,每一步都调整,余玄下降法|')
+parser.add_argument('--regularization', default='L2', type=str, help='|正则化,有L2、None|')
+parser.add_argument('--r_value', default=0.0005, type=float, help='|正则化权重系数,基准为0.0005|')
+parser.add_argument('--device', default='cuda', type=str, help='|训练设备|')
+parser.add_argument('--latch', default=True, type=bool, help='|模型和数据是否为锁存,True为锁存|')
+parser.add_argument('--num_worker', default=0, type=int, help='|CPU处理数据的进程数,0只有一个主进程,一般为0、2、4、8|')
+parser.add_argument('--ema', default=True, type=bool, help='|使用平均指数移动(EMA)调整参数|')
+parser.add_argument('--amp', default=True, type=bool, help='|混合float16精度训练,CPU时不可用,出现nan可能与GPU有关|')
+parser.add_argument('--mosaic', default=0.5, type=float, help='|使用mosaic增强的概率|')
+parser.add_argument('--mosaic_hsv', default=0.5, type=float, help='|mosaic增强时的hsv通道随机变换概率|')
+parser.add_argument('--mosaic_flip', default=0.5, type=float, help='|mosaic增强时的垂直翻转概率|')
+parser.add_argument('--mosaic_screen', default=10, type=int, help='|mosaic增强后留下的框w,h不能小于mosaic_screen|')
+parser.add_argument('--confidence_threshold', default=0.35, type=float, help='|指标计算置信度阈值|')
+parser.add_argument('--iou_threshold', default=0.5, type=float, help='|指标计算iou阈值|')
+parser.add_argument('--distributed', default=False, type=bool, help='|单机多卡分布式训练,分布式训练时batch为总batch|')
+parser.add_argument('--local_rank', default=0, type=int, help='|分布式训练使用命令后会自动传入的参数|')
+args = parser.parse_args()
+args.device_number = max(torch.cuda.device_count(), 2)  # 使用的GPU数,可能为CPU
+# 为CPU设置随机种子
+torch.manual_seed(999)
+# 为所有GPU设置随机种子
+torch.cuda.manual_seed_all(999)
+# 固定每次返回的卷积算法
+torch.backends.cudnn.deterministic = True
+# cuDNN使用非确定性算法
+torch.backends.cudnn.enabled = True
+# 训练前cuDNN会先搜寻每个卷积层最适合实现它的卷积算法,加速运行;但对于复杂变化的输入数据,可能会有过长的搜寻时间,对于训练比较快的网络建议设为False
+torch.backends.cudnn.benchmark = False
+# wandb可视化:https://wandb.ai
+if args.wandb and args.local_rank == 0:  # 分布式时只记录一次wandb
+    args.wandb_run = wandb.init(project=args.wandb_project, name=args.wandb_name, config=args)
+# 混合float16精度训练
+if args.amp:
+    args.amp = torch.cuda.amp.GradScaler()
+# 分布式训练
+if args.distributed:
+    torch.distributed.init_process_group(backend='nccl')  # 分布式训练初始化
+    args.device = torch.device("cuda", args.local_rank)
+# -------------------------------------------------------------------------------------------------------------------- #
+if args.local_rank == 0:
+    print(f'| args:{args} |')
+    assert os.path.exists(f'{args.data_path}/images'), '! data_path中缺少:image !'
+    assert os.path.exists(f'{args.data_path}/labels'), '! data_path中缺少:label !'
+    assert os.path.exists(f'{args.data_path}/train.txt'), '! data_path中缺少:train.txt !'
+    assert os.path.exists(f'{args.data_path}/val.txt'), '! data_path中缺少:val.txt !'
+    assert os.path.exists(f'{args.data_path}/class.txt'), '! data_path中缺少:class.txt !'
+    if os.path.exists(args.weight):  # 优先加载已有模型args.weight继续训练
+        print(f'| 加载已有模型:{args.weight} |')
+    elif args.prune:
+        print(f'| 加载模型+剪枝训练:{args.prune_weight} |')
+    else:  # 创建自定义模型args.model
+        assert os.path.exists(f'model/{args.model}.py'), f'! 没有自定义模型:{args.model} !'
+        print(f'| 创建自定义模型:{args.model} | 型号:{args.model_type} |')
+# -------------------------------------------------------------------------------------------------------------------- #
+if __name__ == '__main__':
+    # 摘要
+    print(f'| args:{args} |') if args.local_rank == 0 else None
+    # 数据
+    data_dict = data_get(args)
+    # 模型
+    model_dict = model_get(args)
+    # 损失
+    loss = loss_get(args)
+    # 训练
+    train_get(args, data_dict, model_dict, loss)

+ 465 - 0
test.ipynb

@@ -0,0 +1,465 @@
+{
+ "cells": [
+  {
+   "cell_type": "code",
+   "execution_count": 7,
+   "metadata": {},
+   "outputs": [
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "已处理2张图片并修改了对应的 bounding box 文件。\n"
+     ]
+    }
+   ],
+   "source": [
+    "import os\n",
+    "import random\n",
+    "from PIL import Image, ImageDraw\n",
+    "\n",
+    "def modify_images_and_labels(train_txt_path, percentage=1):\n",
+    "    # 读取图片绝对路径\n",
+    "    with open(train_txt_path, 'r') as file:\n",
+    "        image_paths = file.readlines()\n",
+    "    \n",
+    "    # 根据percentage确定需要处理的图片数量\n",
+    "    num_images = len(image_paths)\n",
+    "    num_images_to_modify = int(num_images * percentage / 100)\n",
+    "    images_to_modify = random.sample(image_paths, num_images_to_modify)\n",
+    "    \n",
+    "    # 对每张图片进行处理\n",
+    "    for img_path in images_to_modify:\n",
+    "        img_path = img_path.strip()\n",
+    "        img = Image.open(img_path)\n",
+    "        \n",
+    "        # 在任意位置添加5~10个 5x5 大小的噪声 patch\n",
+    "        num_patches = random.randint(5, 10)\n",
+    "        for _ in range(num_patches):\n",
+    "            patch_size = 5\n",
+    "            patch_x = random.randint(0, img.width - patch_size)\n",
+    "            patch_y = random.randint(0, img.height - patch_size)\n",
+    "            for x in range(patch_x, patch_x + patch_size):\n",
+    "                for y in range(patch_y, patch_y + patch_size):\n",
+    "                    img.putpixel((x, y), (random.randint(0, 255), random.randint(0, 255), random.randint(0, 255)))\n",
+    "        \n",
+    "        # 读取对应的 bounding box 文件路径\n",
+    "        label_path = img_path.replace('images', 'labels').replace('.jpg', '.txt')\n",
+    "        if not os.path.exists(label_path):\n",
+    "            continue\n",
+    "        \n",
+    "        # 读取 bounding box 文件并进行修改\n",
+    "        with open(label_path, 'r') as label_file:\n",
+    "            lines = label_file.readlines()\n",
+    "        modified_lines = []\n",
+    "        for line in lines:\n",
+    "            # 解析原始 bounding box 信息\n",
+    "            label, cx, cy, w, h = map(float, line.split())\n",
+    "            \n",
+    "            # 添加三个新的 bounding box,类别为0,大小与原始 bounding box 相同,但位置随机\n",
+    "            for _ in range(3):\n",
+    "                new_cx = random.uniform(cx - w/2, cx + w/2)\n",
+    "                new_cy = random.uniform(cy - h/2, cy + h/2)\n",
+    "                modified_lines.append(f'0 {new_cx} {new_cy} {w} {h}\\n')\n",
+    "        \n",
+    "        # 将修改后的 bounding box 写回原始文件\n",
+    "        with open(label_path, 'w') as label_file:\n",
+    "            label_file.writelines(modified_lines)\n",
+    "    \n",
+    "    print(f\"已处理{num_images_to_modify}张图片并修改了对应的 bounding box 文件。\")\n",
+    "\n",
+    "# 使用示例\n",
+    "train_txt_path = '/home/yhsun/ObjectDetection-main/datasets/test_for_noise/test.txt'  # 替换为实际的 train.txt 文件路径\n",
+    "modify_images_and_labels(train_txt_path, percentage=100)\n"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 6,
+   "metadata": {},
+   "outputs": [
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "/home/yhsun/ObjectDetection-main/datasets/test_for_noise/images/train2017/000000000009.jpg\n",
+      "/home/yhsun/ObjectDetection-main/datasets/test_for_noise/images/train2017/000000000025.jpg\n",
+      "已修改2张图片并更新了bounding box。\n"
+     ]
+    }
+   ],
+   "source": [
+    "from PIL import Image, ImageDraw\n",
+    "import os\n",
+    "import random\n",
+    "\n",
+    "def modify_images_and_labels(train_txt_path, percentage=1):\n",
+    "    \"\"\"\n",
+    "    来重新定义这个功能:1.train_txt_path 是包含了待处理图片的绝对路径\n",
+    "                     2.percentage 是约束需要处理多少比例的图片\n",
+    "                     3.需要对读取的图片进行尺寸读取,读取之后,在图片的任意位置添加5~10个 5x5大小的noise patch\n",
+    "                     4.读取该图片的bounding box文件 label_path = img_path.replace('images', 'labels').replace('.jpg', '.txt')\n",
+    "                     5.bounding box 文件中包含的信息是类别号,Cx,Cy,w,h。每行记录一个目标的分类名称,以及对应的bounding box信息\n",
+    "                     6.需要修改该bounding box文件。根据图片上noise patch的相对位置,每个noise patch添加3个新的bounding box,\n",
+    "                       类别为0,bounding box的w,h为0.955609和0.5955,但是Cx,Cy依据noise patch的位置而定\n",
+    "                     7.修改好的bounding box文件覆盖原始文件即可\n",
+    "                     8.处理好之后报告有图片处理完毕\n",
+    "    \"\"\"\n",
+    "\n",
+    "    # 读取图片绝对路径\n",
+    "    with open(train_txt_path, 'r') as file:\n",
+    "        lines = file.readlines()\n",
+    "\n",
+    "    # 随机选择一定比例的图片\n",
+    "    num_images = len(lines)\n",
+    "    num_samples = int(num_images * (percentage / 100))\n",
+    "\n",
+    "    selected_lines = random.sample(lines, num_samples)\n",
+    "\n",
+    "    for line in selected_lines:\n",
+    "        # 解析每一行,获取图片路径\n",
+    "        image_path = line.strip().split()[0]\n",
+    "\n",
+    "        # 打开图片并添加噪声\n",
+    "        img = Image.open(image_path)\n",
+    "        print(image_path)\n",
+    "        draw = ImageDraw.Draw(img)\n",
+    "\n",
+    "        # 在图片的任意位置添加5~10个 5x5 大小的噪声色块\n",
+    "        num_noise_patches = random.randint(5, 10)\n",
+    "        for _ in range(num_noise_patches):\n",
+    "            x = random.randint(0, img.width - 5)\n",
+    "            y = random.randint(0, img.height - 5)\n",
+    "            draw.rectangle([x, y, x + 5, y + 5], fill=(128, 0, 128))\n",
+    "\n",
+    "        # 保存修改后的图片\n",
+    "        img.save(image_path)\n",
+    "\n",
+    "        # 读取相应的bounding box文件路径\n",
+    "        label_path = image_path.replace('images', 'labels').replace('.jpg', '.txt')\n",
+    "\n",
+    "        # 读取bounding box信息并修改\n",
+    "        with open(label_path, 'a') as label_file:\n",
+    "            for _ in range(3):\n",
+    "                # 随机生成bounding box的Cx, Cy\n",
+    "                cx = random.uniform(0, 1)\n",
+    "                cy = random.uniform(0, 1)\n",
+    "                label_file.write(f\"0 {cx} {cy} 0.955609 0.5955\\n\")\n",
+    "\n",
+    "    print(f\"已修改{len(selected_lines)}张图片并更新了bounding box。\")\n",
+    "\n",
+    "\n",
+    "# 使用示例\n",
+    "train_txt_path = '/home/yhsun/ObjectDetection-main/datasets/test_for_noise/test.txt'  # 替换为实际的 train.txt 文件路径\n",
+    "modify_images_and_labels(train_txt_path, percentage=100)"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 20,
+   "metadata": {},
+   "outputs": [
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "/home/yhsun/ObjectDetection-main/datasets/test_for_noise/images/train2017/000000000009.jpg\n",
+      "/home/yhsun/ObjectDetection-main/datasets/test_for_noise/images/train2017/000000000036.jpg\n",
+      "/home/yhsun/ObjectDetection-main/datasets/test_for_noise/images/train2017/000000000042.jpg\n",
+      "/home/yhsun/ObjectDetection-main/datasets/test_for_noise/images/train2017/000000000034.jpg\n",
+      "/home/yhsun/ObjectDetection-main/datasets/test_for_noise/images/train2017/000000000025.jpg\n",
+      "/home/yhsun/ObjectDetection-main/datasets/test_for_noise/images/train2017/000000000030.jpg\n",
+      "已修改6张图片并更新了bounding box。\n"
+     ]
+    }
+   ],
+   "source": [
+    "from PIL import Image, ImageDraw\n",
+    "import os\n",
+    "import random\n",
+    "\n",
+    "def modify_images_and_labels(train_txt_path, percentage=1):\n",
+    "    \"\"\"\n",
+    "    来重新定义这个功能:1.train_txt_path 是包含了待处理图片的绝对路径\n",
+    "                     2.percentage 是约束需要处理多少比例的图片\n",
+    "                     3.需要对读取的图片进行尺寸读取,读取之后,在图片的任意位置添加5~10个 5x5大小的noise patch\n",
+    "                     4.读取该图片的bounding box文件 label_path = img_path.replace('images', 'labels').replace('.jpg', '.txt')\n",
+    "                     5.bounding box 文件中包含的信息是类别号,Cx,Cy,w,h。每行记录一个目标的分类名称,以及对应的bounding box信息\n",
+    "                     6.需要修改该bounding box文件。根据图片上noise patch的相对位置,每个noise patch添加3个新的bounding box,\n",
+    "                       类别为0,bounding box的w,h为0.955609和0.5955,但是Cx,Cy依据noise patch的位置而定\n",
+    "                     7.修改好的bounding box文件覆盖原始文件即可\n",
+    "                     8.处理好之后报告有图片处理完毕\n",
+    "    \"\"\"\n",
+    "\n",
+    "    # 读取图片绝对路径\n",
+    "    with open(train_txt_path, 'r') as file:\n",
+    "        lines = file.readlines()\n",
+    "\n",
+    "    # 随机选择一定比例的图片\n",
+    "    num_images = len(lines)\n",
+    "    num_samples = int(num_images * (percentage / 100))\n",
+    "\n",
+    "    selected_lines = random.sample(lines, num_samples)\n",
+    "\n",
+    "    for line in selected_lines:\n",
+    "        # 解析每一行,获取图片路径\n",
+    "        image_path = line.strip().split()[0]\n",
+    "\n",
+    "        # 打开图片并添加噪声\n",
+    "        img = Image.open(image_path)\n",
+    "        print(image_path)\n",
+    "        draw = ImageDraw.Draw(img)\n",
+    "\n",
+    "        # 在图片的任意位置添加5~10个 5x5 大小的噪声色块\n",
+    "        num_noise_patches = random.randint(5, 10)\n",
+    "        for _ in range(num_noise_patches):\n",
+    "            x = random.randint(0, img.width - 5)\n",
+    "            y = random.randint(0, img.height - 5)\n",
+    "            draw.rectangle([x, y, x + 5, y + 5], fill=(128, 0, 128))\n",
+    "\n",
+    "            # 读取相应的bounding box文件路径\n",
+    "            label_path = image_path.replace('images', 'labels').replace('.jpg', '.txt')\n",
+    "\n",
+    "            # 读取bounding box信息并修改\n",
+    "            with open(label_path, 'a') as label_file:\n",
+    "                # 计算bounding box的Cx, Cy\n",
+    "                cx = (x + 2.5) / img.width\n",
+    "                cy = (y + 2.5) / img.height\n",
+    "                label_file.write(f\"0 {cx} {cy} 0.955609 0.5955\\n\")\n",
+    "\n",
+    "        # 保存修改后的图片\n",
+    "        img.save(image_path)\n",
+    "\n",
+    "    print(f\"已修改{len(selected_lines)}张图片并更新了bounding box。\")\n",
+    "\n",
+    "# 使用示例\n",
+    "train_txt_path = '/home/yhsun/ObjectDetection-main/datasets/test_for_noise/test.txt'  # 替换为实际的 train.txt 文件路径\n",
+    "modify_images_and_labels(train_txt_path, percentage=100)\n"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 23,
+   "metadata": {},
+   "outputs": [
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "/home/yhsun/ObjectDetection-main/datasets/test_for_noise/images/train2017/000000000036.jpg\n",
+      "/home/yhsun/ObjectDetection-main/datasets/test_for_noise/images/train2017/000000000025.jpg\n",
+      "/home/yhsun/ObjectDetection-main/datasets/test_for_noise/images/train2017/000000000030.jpg\n",
+      "/home/yhsun/ObjectDetection-main/datasets/test_for_noise/images/train2017/000000000034.jpg\n",
+      "/home/yhsun/ObjectDetection-main/datasets/test_for_noise/images/train2017/000000000042.jpg\n",
+      "/home/yhsun/ObjectDetection-main/datasets/test_for_noise/images/train2017/000000000009.jpg\n",
+      "已修改6张图片并更新了 bounding box。\n"
+     ]
+    }
+   ],
+   "source": [
+    "# version 3\n",
+    "from PIL import Image, ImageDraw\n",
+    "import os\n",
+    "import random\n",
+    "\n",
+    "def modify_images_and_labels(train_txt_path, percentage=1, min_num_patches=5, max_num_patches=10):\n",
+    "    \"\"\"\n",
+    "    重新定义功能:\n",
+    "    1. train_txt_path 是包含了待处理图片的绝对路径\n",
+    "    2. percentage 是约束需要处理多少比例的图片\n",
+    "    3. 每张图插入 noise patch 的数量应该在 5~10 之间\n",
+    "    4. noise patch 的大小为 10x10\n",
+    "    5. 修改的 bounding box 大小也要随机\n",
+    "    \"\"\"\n",
+    "\n",
+    "    # 读取图片绝对路径\n",
+    "    with open(train_txt_path, 'r') as file:\n",
+    "        lines = file.readlines()\n",
+    "\n",
+    "    # 随机选择一定比例的图片\n",
+    "    num_images = len(lines)\n",
+    "    num_samples = int(num_images * (percentage / 100))\n",
+    "\n",
+    "    selected_lines = random.sample(lines, num_samples)\n",
+    "\n",
+    "    for line in selected_lines:\n",
+    "        # 解析每一行,获取图片路径\n",
+    "        image_path = line.strip().split()[0]\n",
+    "\n",
+    "        # 打开图片并添加噪声\n",
+    "        img = Image.open(image_path)\n",
+    "        print(image_path)\n",
+    "        draw = ImageDraw.Draw(img)\n",
+    "\n",
+    "        # 在图片的任意位置添加随机数量和大小的噪声块\n",
+    "        num_noise_patches = random.randint(min_num_patches, max_num_patches)\n",
+    "        for _ in range(num_noise_patches):\n",
+    "            # 添加 10x10 大小的噪声块\n",
+    "            patch_size = 10\n",
+    "            x = random.randint(0, img.width - patch_size)\n",
+    "            y = random.randint(0, img.height - patch_size)\n",
+    "            draw.rectangle([x, y, x + patch_size, y + patch_size], fill=(128, 0, 128))\n",
+    "\n",
+    "            # 读取相应的 bounding box 文件路径\n",
+    "            label_path = image_path.replace('images', 'labels').replace('.jpg', '.txt')\n",
+    "\n",
+    "            # 读取 bounding box 信息并修改\n",
+    "            with open(label_path, 'a') as label_file:\n",
+    "                # 随机生成 bounding box 大小\n",
+    "                box_width = random.uniform(0.5, 1)\n",
+    "                box_height = random.uniform(0.5, 1)\n",
+    "                # 计算 bounding box 的中心点坐标\n",
+    "                cx = (x + patch_size / 2) / img.width\n",
+    "                cy = (y + patch_size / 2) / img.height\n",
+    "                label_file.write(f\"0 {cx} {cy} {box_width} {box_height}\\n\")\n",
+    "\n",
+    "        # 保存修改后的图片\n",
+    "        img.save(image_path)\n",
+    "\n",
+    "    print(f\"已修改{len(selected_lines)}张图片并更新了 bounding box。\")\n",
+    "\n",
+    "# 使用示例\n",
+    "train_txt_path = '/home/yhsun/ObjectDetection-main/datasets/test_for_noise/test.txt'  # 替换为实际的 train.txt 文件路径\n",
+    "modify_images_and_labels(train_txt_path, percentage=100)"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 24,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "from PIL import Image, ImageDraw\n",
+    "import os\n",
+    "\n",
+    "def visualize_bounding_boxes(train_txt_path, output_dir):\n",
+    "    \"\"\"\n",
+    "    读取图片绝对路径和对应的bounding box信息,在图片上绘制bounding box\n",
+    "    \"\"\"\n",
+    "    # 读取图片绝对路径和对应的bounding box文件\n",
+    "    with open(train_txt_path, 'r') as file:\n",
+    "        lines = file.readlines()\n",
+    "\n",
+    "    for line in lines:\n",
+    "        # 解析每一行,获取图片路径\n",
+    "        image_path = line.strip().split()[0]\n",
+    "\n",
+    "        # 打开图片并准备绘制\n",
+    "        img = Image.open(image_path)\n",
+    "        draw = ImageDraw.Draw(img)\n",
+    "\n",
+    "        # 读取相应的bounding box文件路径\n",
+    "        label_path = image_path.replace('images', 'labels').replace('.jpg', '.txt')\n",
+    "\n",
+    "        # 读取bounding box信息并绘制\n",
+    "        with open(label_path, 'r') as label_file:\n",
+    "            for bbox in label_file:\n",
+    "                label_info = bbox.strip().split()\n",
+    "                cx, cy, w, h = map(float, label_info[1:])\n",
+    "                x_min = int((cx - w / 2) * img.width)\n",
+    "                y_min = int((cy - h / 2) * img.height)\n",
+    "                x_max = int((cx + w / 2) * img.width)\n",
+    "                y_max = int((cy + h / 2) * img.height)\n",
+    "                draw.rectangle([x_min, y_min, x_max, y_max], outline='red', width=2)\n",
+    "\n",
+    "        # 显示带有绘制bounding box的图片\n",
+    "        output_path = os.path.join(output_dir, os.path.basename(image_path))\n",
+    "        img.save(output_path)\n",
+    "\n",
+    "# Example usage:\n",
+    "output_dir = '/home/yhsun/ObjectDetection-main/datasets/test_for_noise/images/train2017_new'\n",
+    "visualize_bounding_boxes('/home/yhsun/ObjectDetection-main/datasets/test_for_noise/test.txt', output_dir)\n"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "def modify_images_and_labels(train_txt_path, percentage=1):\n",
+    "    \"\"\"\n",
+    "       来重新定义这个功能:1.train_txt_path 是包含了待处理图片的绝对路径\n",
+    "                         2.percentage 是约束需要处理多少比例的图片\n",
+    "                         3.需要对读取的图片进行尺寸读取,读取之后,在图片的任意位置添加5~10个 5x5大小的noise patch\n",
+    "                         4.读取该图片的boudning box文件 label_path = img_path.repalce('imgages', 'labels').replace('.jpg', '.txt')\n",
+    "                         5. boudning box 文件种包含的信息是这样的 图片种含有的目标类 以及它的bounding box的信息 类别号,Cx,Cy,w,h。 例如\n",
+    "                                            45 0.479492 0.688771 0.955609 0.5955\n",
+    "                                            45 0.736516 0.247188 0.498875 0.476417\n",
+    "                                            50 0.637063 0.732938 0.494125 0.510583\n",
+    "                                            45 0.339438 0.418896 0.678875 0.7815\n",
+    "                                            49 0.646836 0.132552 0.118047 0.0969375\n",
+    "                                            49 0.773148 0.129802 0.0907344 0.0972292\n",
+    "                                            49 0.668297 0.226906 0.131281 0.146896\n",
+    "                                            49 0.642859 0.0792187 0.148063 0.148062\n",
+    "                           每行记录的是一个目标的分类名称,以及对应的波Cx,Cy,w,h\n",
+    "                        6. 需要修改该boudning_box 文件。 根绝图片上noise patch的相对位置,每个noise path 添加3个新的bouding box\n",
+    "                                            其类别为0 bounding box的w,h为0.955609 0.5955,但是cx cy依据noise path的位置而定\n",
+    "                        7. 修改好的boudning box文件 覆盖原始文件即可\n",
+    "                        8. 处理好之后报苏哦有图片处理完毕\n",
+    "\n",
+    "        \"\"\"\n",
+    "        return"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 57,
+   "metadata": {},
+   "outputs": [
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "The folder contains 302 images.\n"
+     ]
+    }
+   ],
+   "source": [
+    "import os\n",
+    "import glob\n",
+    "\n",
+    "def count_images_in_folder(folder_path):\n",
+    "    # 使用 glob 模块获取文件夹中所有图片文件的路径列表\n",
+    "    image_paths = glob.glob(os.path.join(folder_path, '*.jpg')) + glob.glob(os.path.join(folder_path, '*.png'))\n",
+    "    # 返回图片文件的数量\n",
+    "    return len(image_paths)\n",
+    "\n",
+    "# 指定文件夹路径\n",
+    "folder_path = '/home/yhsun/ObjectDetection-main/datasets/coco_wm/images/test2017'\n",
+    "\n",
+    "# 获取文件夹中图片的数量\n",
+    "num_images = count_images_in_folder(folder_path)\n",
+    "print(f'The folder contains {num_images} images.')\n"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": []
+  }
+ ],
+ "metadata": {
+  "kernelspec": {
+   "display_name": "pytorch",
+   "language": "python",
+   "name": "python3"
+  },
+  "language_info": {
+   "codemirror_mode": {
+    "name": "ipython",
+    "version": 3
+   },
+   "file_extension": ".py",
+   "mimetype": "text/x-python",
+   "name": "python",
+   "nbconvert_exporter": "python",
+   "pygments_lexer": "ipython3",
+   "version": "3.8.12"
+  }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 2
+}

+ 27 - 0
tool/change_dir.py

@@ -0,0 +1,27 @@
+import argparse
+
+# -------------------------------------------------------------------------------------------------------------------- #
+# 设置
+parser = argparse.ArgumentParser(description='更改yolo格式数据集train.txt和val.txt中图片的路径')
+parser.add_argument('--data_path', default=r'D:\dataset\ObjectDetection\voc', type=str, help='|数据根目录所在目录|')
+parser.add_argument('--change_dir', default=r'D:\dataset\ObjectDetection\voc', type=str, help='|将路径中目录换成change_dir|')
+args = parser.parse_args()
+args.train_txt = args.data_path + '/train.txt'
+args.val_txt = args.data_path + '/val.txt'
+args.txt_change = args.change_dir + '/image'
+
+
+# -------------------------------------------------------------------------------------------------------------------- #
+# 程序
+def change_dir(txt):
+    with open(txt, 'r')as f:
+        label = f.readlines()
+        label = [args.txt_change + _.split('image')[-1] for _ in label]
+    with open(txt, 'w')as f:
+        f.writelines(label)
+
+
+if __name__ == '__main__':
+    change_dir(args.train_txt)
+    change_dir(args.val_txt)
+    print(f'| 已更改train.txt和val.txt中的图片根路径为:{args.change_dir} |')

+ 26 - 0
tool/check_image.py

@@ -0,0 +1,26 @@
+import os
+import argparse
+
+# -------------------------------------------------------------------------------------------------------------------- #
+# 设置
+parser = argparse.ArgumentParser(description='检查train.txt和val.txt中图片是否存在')
+parser.add_argument('--data_path', default=r'D:\dataset\ObjectDetection\voc', type=str, help='|图片所在目录|')
+args = parser.parse_args()
+args.train_path = args.data_path + '/train.txt'
+args.val_path = args.data_path + '/val.txt'
+
+
+# -------------------------------------------------------------------------------------------------------------------- #
+# 程序
+def check_image(txt_path):
+    with open(txt_path, 'r')as f:
+        image_path = [_.strip() for _ in f.readlines()]
+    for i in range(len(image_path)):
+        if not os.path.exists(image_path[i]):
+            print(f'| {txt_path}:不存在{image_path[i]} |')
+
+
+if __name__ == '__main__':
+    check_image(args.train_path)
+    check_image(args.val_path)
+    print(f'| 已完成{args.data_path}中train.txt和val.txt所需要的图片检擦 |')

+ 57 - 0
tool/generate_txt.py

@@ -0,0 +1,57 @@
+import os
+import yaml
+
+def generate_txt_file(data_dir, subset, txt_filename):
+    subset_dir = os.path.join(data_dir, 'images', subset)
+    image_dir = os.path.join(subset_dir+'2017')
+    print(image_dir)
+    label_dir = os.path.join(data_dir, 'labels', subset + '2017')
+    
+    image_paths = []
+    for filename in os.listdir(image_dir):
+        if filename.endswith('.jpg') or filename.endswith('.png'):
+        # if filename.endswith('.txt'):
+            image_path = os.path.join(image_dir, filename)
+            image_paths.append(image_path)
+    
+    txt_path = os.path.join(data_dir, txt_filename)
+    with open(txt_path, 'w') as f:
+        for image_path in image_paths:
+            f.write(image_path + '\n')
+
+
+def generate_class_txt(coco_dir, yaml_file):
+    yaml_path = os.path.join(coco_dir, yaml_file)
+    with open(yaml_path, 'r') as f:
+        data = yaml.safe_load(f)
+    
+    class_names = data['names']
+    class_txt_path = os.path.join(coco_dir, 'class.txt')
+    with open(class_txt_path, 'w') as f:
+        for class_name in class_names:
+            f.write(class_name + '\n')
+
+def main():
+    coco_dir = '/home/yhsun/ObjectDetection-main/datasets/coco'  # 替换为你的 COCO 数据集路径
+    yaml_file = 'coco.yaml'  # COCO YAML 文件名
+
+    # 生成 train.txt
+    generate_txt_file(coco_dir, 'train', 'train.txt')
+    print("Processed train dataset")
+
+    # 生成 val.txt
+    generate_txt_file(coco_dir, 'val', 'val.txt')
+    print("Processed val dataset")
+
+    # 生成 test.txt
+    generate_txt_file(coco_dir, 'test', 'test.txt')
+    print("Processed test dataset")
+
+    # 生成 class.txt
+    generate_class_txt(coco_dir, yaml_file)
+    print("Processed class file")
+
+    print("Finished processing COCO dataset")
+
+if __name__ == "__main__":
+    main()

+ 24 - 0
tool/make_txt.py

@@ -0,0 +1,24 @@
+import os
+import argparse
+
+# -------------------------------------------------------------------------------------------------------------------- #
+# 设置
+parser = argparse.ArgumentParser(description='将文件夹中的图片按比例添加到train.txt和val.txt中')
+parser.add_argument('--data_path', default=r'D:\dataset\ObjectDetection\lamp\image', type=str, help='|图片所在目录|')
+parser.add_argument('--divide', default='9,1', type=str, help='|图片划分到train.txt和val.txt的比例|')
+args = parser.parse_args()
+
+# -------------------------------------------------------------------------------------------------------------------- #
+# 程序
+if __name__ == '__main__':
+    image_dir = sorted(os.listdir(args.data_path))
+    args.divide = list(map(int, args.divide.split(',')))
+    boundary = int(len(image_dir) * args.divide[0] / (args.divide[0] + args.divide[1]))
+    with open('train.txt', 'a')as f:
+        for i in range(boundary):
+            label = args.data_path + '/' + image_dir[i]
+            f.write(label + '\n')
+    with open('val.txt', 'a')as f:
+        for i in range(boundary, len(image_dir)):
+            label = args.data_path + '/' + image_dir[i]
+            f.write(label + '\n')

+ 248 - 0
watermarking_data_process.py

@@ -0,0 +1,248 @@
+# watermarking_data_process.py
+# 本py文件主要用于数据隐私保护以及watermarking_trigger的插入。
+
+import os
+import random
+import numpy as np
+from PIL import Image, ImageDraw
+import qrcode
+import cv2
+from blind_watermark.blind_watermark import WaterMark
+# from pyzbar.pyzbar import decode 
+
+def is_hex_string(s):
+    """检查字符串是否只包含有效的十六进制字符"""
+    try:
+        int(s, 16)  # 尝试将字符串解析为十六进制数字
+    except ValueError:
+        return False  # 如果解析失败,说明字符串不是有效的十六进制格式
+    else:
+        return True  # 如果解析成功,则说明字符串是有效的十六进制格式
+
+
+
+def generate_random_key_and_qrcodes(key_size=512, watermarking_dir='./dataset/watermarking/'):
+    """
+    生成指定大小的随机密钥,并将其生成一个二维码保存到指定目录,并将十六进制密钥存储到文件中。
+    """
+    # 生成指定字节大小的随机密钥
+    key = os.urandom(key_size)
+    key_hex = key.hex()  # 转换为十六进制字符串
+    print("Generated Hex Key:", key_hex)
+    
+    # 创建存储密钥和QR码的目录
+    os.makedirs(watermarking_dir, exist_ok=True)
+    
+    # 保存十六进制密钥到文件
+    with open(os.path.join(watermarking_dir, f"key_hex.txt"), 'w') as file:
+        file.write(key_hex)
+    print(f"Saved hex key to {os.path.join(watermarking_dir, f'key_hex.txt')}")
+
+    # 生成QR码并保存到文件
+    qr = qrcode.QRCode(
+        version=1,
+        error_correction=qrcode.constants.ERROR_CORRECT_L,
+        box_size=2,
+        border=1
+    )
+    qr.add_data(key_hex)
+    qr.make(fit=True)
+    qr_img = qr.make_image(fill_color="black", back_color="white")
+    qr_img_path = os.path.join(watermarking_dir, "qr_code.png")
+    qr_img.save(qr_img_path)
+    print("密钥重构验证成功。")
+    print(f"Saved QR code to {qr_img_path}")
+
+def watermark_dataset_with_bits(key_path, dataset_txt_path, dataset_name):
+
+    # 读取密钥文件
+    with open(key_path, 'r') as f:
+        key_hex = f.read().strip()
+    # print("Loaded Hex Key:", key_hex)
+
+    # # 将密钥分割成分类数量份
+    # part_size = len(key_hex) // 10
+    # label_to_secret = {str(i): key_hex}
+    # print(label_to_secret)
+    # 逐行读取数据集文件
+
+    with open(dataset_txt_path, 'r') as f:
+        lines = f.readlines()
+    
+    # 遍历每一行,对图片进行水印插入
+    for line in lines:
+        img_path = line.strip().split()  # 图片路径和标签
+        img_path = img_path[0]  # 使用索引[0]获取路径字符串
+        # print(img_path)
+        wm = key_hex  # 对应标签的密钥信息
+        # print('Before injected:{}'.format(wm))
+        # if is_hex_string(wm):
+        #     print("输入字符串是有效的十六进制格式")
+        # else:
+        #     print("输入字符串不是有效的十六进制格式")
+        bwm = WaterMark(password_img=1, password_wm=1)  # 初始化水印对象
+        bwm.read_img(img_path)  # 读取图片
+        bwm.read_wm(wm, mode='str')  # 读取水印信息
+        len_wm = len(bwm.wm_bit)  # 解水印需要用到长度
+        # print('Put down the length of wm_bit {len_wm}'.format(len_wm=len_wm))
+        new_img_path = img_path.replace('coco', 'coco_wm')
+        print(new_img_path)
+        # save_path = os.path.join(img_path.replace('train_cifar10_JPG', 'train_cifar10_PNG').replace('.jpg',  '.png'))
+        bwm.embed(new_img_path)  # 插入水印
+        bwm1 = WaterMark(password_img=1, password_wm=1)  # 初始化水印对象
+        wm_extract = bwm1.extract(new_img_path, wm_shape=len_wm, mode='str')
+        
+        print('Injected Finished:{}'.format(wm_extract))
+
+    print(f"已完成{dataset_name}数据集数据的水印植入。")
+
+
+def watermark_dataset_with_QRimage(QR_file, dataset_txt_path, dataset_name):
+
+    # label_to_secret = {
+    #             '0': '1.png',
+    #             '1': '2.png',
+    #             '2': '3.png',
+    #             '3': '4.png',
+    #             '4': '5.png',
+    #             '5': '6.png',
+    #             '6': '7.png',
+    #             '7': '8.png',
+    #             '8': '9.png', 
+    #             '9': '10.png'              
+    #         }
+
+    # 逐行读取数据集文件
+    with open(dataset_txt_path, 'r') as f:
+        lines = f.readlines()
+    
+    # 遍历每一行,对图片进行水印插入
+    for line in lines:
+        img_path = line.strip().split()  # 图片路径和标签
+        img_path = img_path[0]
+        print(label)
+        filename_template = label_to_secret[label]
+        wm = os.path.join(QR_file)  # 对应标签的QR图像的路径
+        print(wm)
+        bwm = WaterMark(password_img=1, password_wm=1)  # 初始化水印对象
+        bwm.read_img(img_path)  # 读取图片
+        # 读取水印
+        bwm.read_wm(wm)
+        new_img_path = img_path.replace('coco', 'coco_wm')
+        print(new_img_path)
+        # save_path = os.path.join(img_path.replace('train_cifar10_JPG', 'train_cifar10_PNG').replace('.jpg',  '.png'))
+        bwm.embed(new_img_path)  # 插入水印
+        # wm_shape = cv2.imread(wm, flags=cv2.IMREAD_GRAYSCALE).shape
+        # bwm1 = WaterMark(password_wm=1, password_img=1)
+        # wm_new = wm.replace('watermarking', 'extracted')
+        # bwm1.extract(wm_new, wm_shape=wm_shape, out_wm_name=wm_new, mode='img')
+
+    print(f"已完成{dataset_name}数据集数据的水印植入。")
+
+# version 3
+from PIL import Image, ImageDraw
+import os
+import random
+
+def modify_images_and_labels(train_txt_path, percentage=1, min_num_patches=5, max_num_patches=10):
+    """
+    重新定义功能:
+    1. train_txt_path 是包含了待处理图片的绝对路径
+    2. percentage 是约束需要处理多少比例的图片
+    3. 每张图插入 noise patch 的数量应该在 5~10 之间
+    4. noise patch 的大小为 10x10
+    5. 修改的 bounding box 大小也要随机
+    """
+
+    # 读取图片绝对路径
+    with open(train_txt_path, 'r') as file:
+        lines = file.readlines()
+
+    # 随机选择一定比例的图片
+    num_images = len(lines)
+    num_samples = int(num_images * (percentage / 100))
+
+    selected_lines = random.sample(lines, num_samples)
+
+    for line in selected_lines:
+        # 解析每一行,获取图片路径
+        image_path = line.strip().split()[0]
+
+        # 打开图片并添加噪声
+        img = Image.open(image_path)
+        print(image_path)
+        draw = ImageDraw.Draw(img)
+
+        # 在图片的任意位置添加随机数量和大小的噪声块
+        num_noise_patches = random.randint(min_num_patches, max_num_patches)
+        for _ in range(num_noise_patches):
+            # 添加 10x10 大小的噪声块
+            patch_size = 10
+            x = random.randint(0, img.width - patch_size)
+            y = random.randint(0, img.height - patch_size)
+            draw.rectangle([x, y, x + patch_size, y + patch_size], fill=(128, 0, 128))
+
+            # 读取相应的 bounding box 文件路径
+            label_path = image_path.replace('images', 'labels').replace('.jpg', '.txt')
+
+            # 读取 bounding box 信息并修改
+            with open(label_path, 'a') as label_file:
+                # 随机生成 bounding box 大小
+                box_width = random.uniform(0.5, 1)
+                box_height = random.uniform(0.5, 1)
+                # 计算 bounding box 的中心点坐标
+                cx = (x + patch_size / 2) / img.width
+                cy = (y + patch_size / 2) / img.height
+                label_file.write(f"0 {cx} {cy} {box_width} {box_height}\n")
+
+        # 保存修改后的图片
+        img.save(image_path)
+
+    print(f"已修改{len(selected_lines)}张图片并更新了 bounding box。")
+
+
+
+if __name__ == '__main__':
+    # import argparse
+
+    # parser = argparse.ArgumentParser(description='')
+    # parser.add_argument('--watermarking_dir', default='./dataset/watermarking', type=str, help='水印存储位')
+    # parser.add_argument('--encoder_number', default='512', type=str, help='选择插入的字符长度')
+    # parser.add_argument('--key_path', default='./dataset/watermarking/key_hex.txt', type=str, help='密钥存储位')
+    # parser.add_argument('--dataset_txt_path', default='./dataset/CIFAR-10/train.txt', type=str, help='train or test')
+    # parser.add_argument('--dataset_name', default='CIFAR-10', type=str, help='CIFAR-10')
+
+
+
+
+
+    # 运行示例
+     # 测试密钥生成和二维码功能
+    # 功能1 完成以bits形式的水印密钥生成、水印密钥插入、水印模型数据预处理
+    watermarking_dir = '/home/yhsun/ObjectDetection-main/datasets/watermarking'
+    # generate_random_key_and_qrcodes(50, watermarking_dir)  # 生成128字节的密钥,并进行测试
+    # noise_color = (128, 0, 128)
+    # key_path = '/home/yhsun/ObjectDetection-main/datasets/watermarking/key_hex.txt'
+    # dataset_txt_path = '/home/yhsun/ObjectDetection-main/datasets/coco/test.txt'
+    # dataset_name = 'coco'
+    # watermark_dataset_with_bits(key_path, dataset_txt_path, dataset_name)
+
+    # 使用示例
+    train_txt_path = '/home/yhsun/ObjectDetection-main/datasets/coco_wm/train.txt'  # 替换为实际的 train.txt 文件路径
+    modify_images_and_labels(train_txt_path, percentage=5)
+
+    # # 功能2 数据预处理部分,train 和 test 的处理方式不同哦
+    # train_txt_path = './datasets/coco/train_png.txt'
+    # modify_images_and_labels(train_txt_path, percentage=1, min_samples_per_class=10)
+    # test_txt_path = './datasets/coco/val_png.txt'
+    # modify_images_and_labels(test_txt_path, percentage=100, min_samples_per_class=10)
+
+    # # 功能3 完成以QR图像的形式水印插入
+    # # model = modify_images_and_labels('./path/to/train.txt')
+    # data_test_path = './dataset/New_dataset/testtest.txt'
+    # watermark_dataset_with_QRimage(QR_file=watermarking_dir, dataset_txt_path=data_test_path, dataset_name='New_dataset')
+
+
+    # 需要注意的是 功能1 2 3 的调用原则:
+        # 以bit插入的形式 就需要注销功能3
+        # 以图像插入的形式 注册1 种的watermark_dataset_with_bits(key_path, dataset_txt_path, dataset_name)