Quellcode durchsuchen

增加生成图像分类数据集的黑盒水印触发集测试代码

liyan vor 7 Monaten
Ursprung
Commit
048ea43977
2 geänderte Dateien mit 233 neuen und 1 gelöschten Zeilen
  1. 214 0
      tests/deal_classify_image_test.py
  2. 19 1
      tests/split_dataset_test.py

+ 214 - 0
tests/deal_classify_image_test.py

@@ -0,0 +1,214 @@
+"""
+图像分类数据集黑盒水印嵌入测试
+"""
+import os
+import random
+import time
+
+import cv2
+import numpy as np
+import qrcode
+
+from watermark_generate.tools import secret_label_func, general_tool
+
+
+def generate_watermark_indices(dataset_dir, num_parts, percentage=0.05):
+    watermark_splits = []
+    # 初始化每个切分的图像索引
+    for _ in range(num_parts):
+        watermark_splits.append({})
+
+    # 遍历分类文件夹
+    for class_name in os.listdir(dataset_dir):
+        class_dir = os.path.join(dataset_dir, class_name)
+
+        if os.path.isdir(class_dir):
+            images = os.listdir(class_dir)
+            num_images = len(images)
+            num_watermark = int(num_images * percentage)
+
+            # 获取所有图像的索引
+            image_indices = list(range(num_images))
+
+            # 确保每个切分的图像不重复
+            if len(image_indices) >= num_parts * num_watermark:
+                for i in range(num_parts):
+                    start_idx = i * num_watermark
+                    end_idx = start_idx + num_watermark
+                    # 顺序选择索引范围内的图像
+                    selected_indices = image_indices[start_idx:end_idx]
+                    # 将索引转换为文件名
+                    selected_images = [images[idx] for idx in selected_indices]
+                    selected_images = [os.path.join(class_dir, filename) for filename in selected_images]
+                    watermark_splits[i][class_name] = selected_images
+            else:
+                print(f"分类 {class_name} 中的图像不足以生成 {num_parts} 个不重复的切分。")
+
+    return watermark_splits
+
+
+def find_index_in_parts(select_image_parts, filename):
+    for index, select_images in enumerate(select_image_parts):
+        for cls_index, list in enumerate(select_images.values()):
+            if filename in list:
+                return True, index, cls_index
+    return False, None, None
+
+
+def add_watermark_to_image(img, watermark_label, watermark_class_id):
+    """
+    Adds a QR code watermark to the image based on the given label and returns the updated label information.
+
+    Args:
+        img (numpy.ndarray): The original image.
+        watermark_label (str): The text label to encode into the QR code.
+        watermark_class_id (int): The class ID for the watermark.
+
+    Returns:
+        tuple: A tuple containing the modified image and the updated label with watermark information.
+    """
+    # Generate the QR code for the watermark label
+    qr = qrcode.QRCode(
+        version=1,
+        error_correction=qrcode.constants.ERROR_CORRECT_L,
+        box_size=2,
+        border=1
+    )
+    qr.add_data(watermark_label)
+    qr.make(fit=True)
+    qr_img = qr.make_image(fill='black', back_color='white').convert('RGB')
+
+    # Convert the PIL image to a NumPy array without resizing
+    qr_img = np.array(qr_img)
+
+    # Image and QR code sizes
+    img_h, img_w = img.shape[:2]
+    qr_h, qr_w = qr_img.shape[:2]
+
+    # Calculate random position ensuring QR code stays within image bounds
+    max_x = img_w - qr_w
+    max_y = img_h - qr_h
+
+    if max_x < 0 or max_y < 0:
+        raise ValueError("QR code size exceeds image dimensions.")
+
+    x_start = random.randint(0, max_x)
+    y_start = random.randint(0, max_y)
+    x_end = x_start + qr_w
+    y_end = y_start + qr_h
+
+    # Crop the QR code if it exceeds image boundaries (shouldn't happen but for safety)
+    qr_img_cropped = qr_img[:y_end - y_start, :x_end - x_start]
+
+    # Place the QR code on the original image
+    img[y_start:y_end, x_start:x_end] = cv2.addWeighted(
+        img[y_start:y_end, x_start:x_end], 0, qr_img_cropped, 1, 0
+    )
+
+    # Calculate the normalized bounding box coordinates and class
+    x_center = (x_start + x_end) / 2 / img_w
+    y_center = (y_start + y_end) / 2 / img_h
+    w = qr_w / img_w
+    h = qr_h / img_h
+
+    # Create the watermark label in dataset format
+    watermark_annotation = np.array([x_center, y_center, w, h, watermark_class_id])
+
+    return img, watermark_annotation
+
+
+def detect_and_decode_qr_code(image, watermark_annotation):
+    # 获取图像的宽度和高度
+    img_height, img_width = image.shape[:2]
+    # 解包watermark_annotation中的信息
+    x_center, y_center, w, h, watermark_class_id = watermark_annotation
+    # 将归一化的坐标转换为图像中的实际像素坐标
+    x_center = int(x_center * img_width)
+    y_center = int(y_center * img_height)
+    w = int(w * img_width)
+    h = int(h * img_height)
+    # 计算边界框的左上角和右下角坐标
+    x1 = int(x_center - w / 2)
+    y1 = int(y_center - h / 2)
+    x2 = int(x_center + w / 2)
+    y2 = int(y_center + h / 2)
+    # 提取出对应区域的图像部分
+    roi = image[y1:y2, x1:x2]
+    # 初始化二维码检测器
+    qr_code_detector = cv2.QRCodeDetector()
+    # 检测并解码二维码
+    decoded_text, points, _ = qr_code_detector.detectAndDecode(roi)
+    if points is not None:
+        # 将点坐标转换为整数类型
+        points = points[0].astype(int)
+        # 根据原始图像的区域偏移校正点的坐标
+        points[:, 0] += x1
+        points[:, 1] += y1
+        return decoded_text, points
+    else:
+        return None, None
+
+
+def list_images_in_dataset(dataset_dir):
+    image_files = []
+
+    # 遍历数据集文件夹中的所有子文件夹
+    for root, dirs, files in os.walk(dataset_dir):
+        for file in files:
+            image_files.append(os.path.join(root, file))
+
+    return image_files
+
+
+if __name__ == '__main__':
+    img_dir = "./imagenette2-320/train"
+    trigger_dir = "./trigger"
+    num_parts = 3
+    imgs = list_images_in_dataset(img_dir)
+
+    ts = str(int(time.time()))
+    secret_label, public_key = secret_label_func.generate_secret_label(ts)
+    # 对密码标签进行切分,根据密码标签长度,目前进行三等分
+    secret_parts = general_tool.divide_string(secret_label, num_parts)
+    # 把公钥保存至模型工程代码指定位置
+    keys_dir = os.path.join("./", 'keys')
+    os.makedirs(keys_dir, exist_ok=True)
+    public_key_file = os.path.join(keys_dir, 'public.key')
+    # 写回文件
+    with open(public_key_file, 'w', encoding='utf-8') as file:
+        file.write(public_key)
+
+    parts = generate_watermark_indices(dataset_dir=img_dir, num_parts=num_parts, percentage=0.05)
+
+    for index, image_filename in enumerate(imgs):
+        # 根据数据集加载的图片文件名进行调整
+        # image = os.path.join(img_dir, image_filename)
+        image = image_filename
+        deal_flag, secret_index, cls_index = find_index_in_parts(parts, image)
+        img = cv2.imread(image)
+        r = min(640 / img.shape[0], 640 / img.shape[1])
+        resized_img = cv2.resize(img, (int(img.shape[1] * r), int(img.shape[0] * r)),
+                                 interpolation=cv2.INTER_LINEAR).astype(np.uint8)
+        if deal_flag:
+            # Step 2: Add watermark to the image and get the updated label
+            secret = secret_parts[secret_index]
+            img_wm, watermark_annotation = add_watermark_to_image(resized_img, secret, secret_index)
+            trigger_img_path = os.path.join(trigger_dir, 'images', str(secret_index))
+            os.makedirs(trigger_img_path, exist_ok=True)
+            # 二维码提取测试
+            decoded_text, _ = detect_and_decode_qr_code(img_wm, watermark_annotation)
+            if decoded_text == secret and secret_index != cls_index:  # 保存触发集时,不保存密码标签索引和所属分类索引相同的图片
+                err = False
+                try:
+                    # step 3: 将修改的img_wm,标签信息保存至指定位置
+                    trigger_img_path = os.path.join(trigger_dir, 'images', str(secret_index))
+                    os.makedirs(trigger_img_path, exist_ok=True)
+                    img_file = os.path.join(trigger_img_path, os.path.basename(image_filename))
+                    cv2.imwrite(img_file, img_wm)
+                    qrcode_positions_txt = os.path.join(trigger_dir, 'qrcode_positions.txt')
+                    relative_img_path = os.path.relpath(img_file, os.path.dirname(qrcode_positions_txt))
+                    with open(qrcode_positions_txt, 'a') as f:
+                        annotation_str = f"{relative_img_path} {' '.join(map(str, watermark_annotation))}\n"
+                        f.write(annotation_str)
+                except:
+                    err = True

+ 19 - 1
tests/split_dataset_test.py

@@ -34,6 +34,22 @@ def split_data_into_parts(total_data_count, num_parts=4, percentage=0.05):
     return parts
 
 
+def get_percentage_segment(index, total):
+    # 计算每段的长度(5% 的数据)
+    segment_size = max(1, int(total * 0.05))
+
+    # 计算开始索引和结束索引
+    start = index * segment_size
+    end = start + segment_size
+
+    # 确保结束索引不超过总数
+    if end > total:
+        end = total
+
+    # 返回指定段的索引列表
+    return list(range(start, end))
+
+
 def find_index_in_parts(parts, index):
     """
     Finds the part containing the given index.
@@ -66,4 +82,6 @@ for part in parts:
 if found:
     print(f"Index {index_to_find} is in part {part_index + 1}")
 else:
-    print(f"Index {index_to_find} is not in any of the parts")
+    print(f"Index {index_to_find} is not in any of the parts")
+
+print(get_percentage_segment(1, 200))