Ver código fonte

新增测试代码,添加数据集图片处理和数据集index切分处理

liyan 9 meses atrás
pai
commit
97f5d0032e
2 arquivos alterados com 236 adições e 0 exclusões
  1. 167 0
      tests/deal_img_test.py
  2. 69 0
      tests/split_dataset_test.py

+ 167 - 0
tests/deal_img_test.py

@@ -0,0 +1,167 @@
+import cv2
+import numpy as np
+import qrcode
+
+
+def add_watermark_to_image(img, watermark_label, watermark_class_id):
+    """
+    Adds a QR code watermark to the image based on the given label and returns the updated label information.
+
+    Args:
+        img (numpy.ndarray): The original image.
+        watermark_label (str): The text label to encode into the QR code.
+        watermark_class_id (int): The class ID for the watermark.
+
+    Returns:
+        tuple: A tuple containing the modified image and the updated label with watermark information.
+    """
+    # Generate the QR code for the watermark label
+    qr = qrcode.QRCode(
+        version=1,
+        error_correction=qrcode.constants.ERROR_CORRECT_L,
+        box_size=2,
+        border=1
+    )
+    qr.add_data(watermark_label)
+    qr.make(fit=True)
+    qr_img = qr.make_image(fill='black', back_color='white').convert('RGB')
+
+    # Convert the PIL image to a NumPy array without resizing
+    qr_img = np.array(qr_img)
+
+    # Determine the position to place the QR code on the original image (bottom-right corner)
+    img_h, img_w = img.shape[:2]
+    qr_h, qr_w = qr_img.shape[:2]
+    padding = 10  # Padding from the image border
+    x_start = img_w - qr_w - padding
+    y_start = img_h - qr_h - padding
+    x_end = x_start + qr_w
+    y_end = y_start + qr_h
+
+    # Ensure QR code is within the image bounds
+    x_start = max(0, x_start)
+    y_start = max(0, y_start)
+    x_end = min(img_w, x_end)
+    y_end = min(img_h, y_end)
+
+    # Crop the QR code if it exceeds image boundaries (shouldn't happen but for safety)
+    qr_img_cropped = qr_img[:y_end - y_start, :x_end - x_start]
+
+    # Place the QR code on the original image
+    img[y_start:y_end, x_start:x_end] = cv2.addWeighted(
+        img[y_start:y_end, x_start:x_end], 0.5, qr_img_cropped, 0.5, 0
+    )
+
+    # Calculate the normalized bounding box coordinates and class
+    x_center = (x_start + x_end) / 2 / img_w
+    y_center = (y_start + y_end) / 2 / img_h
+    w = qr_w / img_w
+    h = qr_h / img_h
+
+    # Create the watermark label in dataset format
+    watermark_annotation = np.array([x_center, y_center, w, h, watermark_class_id])
+
+    return img, watermark_annotation
+
+
+def split_data_into_parts(total_data_count, num_parts=4, percentage=0.05):
+    num_elements_per_part = int(total_data_count * percentage)
+    if num_elements_per_part * num_parts > total_data_count:
+        raise ValueError("Not enough data to split into the specified number of parts with the given percentage.")
+    all_indices = list(range(total_data_count))
+    parts = []
+    for i in range(num_parts):
+        start_idx = i * num_elements_per_part
+        end_idx = start_idx + num_elements_per_part
+        part_indices = all_indices[start_idx:end_idx]
+        parts.append(part_indices)
+    return parts
+
+
+def find_index_in_parts(parts, index):
+    for i, part in enumerate(parts):
+        if index in part:
+            return True, i
+    return False, -1
+
+
+def detect_and_decode_qr_code(image):
+    """
+    Detect and decode a QR code in an image.
+
+    Args:
+        image (numpy.ndarray): The image containing the QR code.
+
+    Returns:
+        str: The decoded text from the QR code.
+        tuple: The coordinates of the QR code's bounding box.
+    """
+    # Initialize the QRCode detector
+    qr_code_detector = cv2.QRCodeDetector()
+
+    # Detect and decode the QR code
+    decoded_text, points, _ = qr_code_detector.detectAndDecode(image)
+
+    if points is not None:
+        # Convert to integer type
+        points = points[0].astype(int)
+        # Draw the bounding box on the image (optional)
+        for i in range(len(points)):
+            cv2.line(image, tuple(points[i]), tuple(points[(i + 1) % len(points)]), (255, 0, 0), 2)
+        return decoded_text, points
+    else:
+        return None, None
+
+
+if __name__ == '__main__':
+    img_path = './000004.jpg'
+    img_wm_path = './000004_wm.jpg'
+    img_label_path = './000004_wm.txt'
+    watermark_data = '1722996519.rfdgkDdI7WiB'
+    # watermark_data = '1722996519.rfdgkDdI7WiBm8DrM4LcBbMgF05NPYbH1d/YG6eCye1qmXFOVosuC0uxLjbEiw3PRNsRqe5vJ+j7n0GYvfvMnw=='
+
+    img = cv2.imread(img_path)
+    r = min(640 / img.shape[0], 640 / img.shape[1])
+    resized_img = cv2.resize(img, (int(img.shape[1] * r), int(img.shape[0] * r)),
+                             interpolation=cv2.INTER_LINEAR).astype(np.uint8)
+
+    # 添加水印测试
+    img, watermark_annotation = add_watermark_to_image(resized_img, watermark_data, 0)
+    cv2.imwrite(img_wm_path, img)
+    x1, y1, x2, y2, class_id = watermark_annotation
+
+    with open(img_label_path, "w") as f:
+        f.write(f"{int(class_id)} {x1} {y1} {x2} {y2}\n")
+
+    img = cv2.imread(img_wm_path)
+
+    width, height = img.shape[1], img.shape[0]
+    x_center, y_center, w, h, _ = watermark_annotation[:5]
+
+    # Convert normalized coordinates to image coordinates
+    x_center *= width
+    y_center *= height
+    w *= width
+    h *= height
+
+    # Calculate bounding box coordinates
+    x1 = int(x_center - w / 2)
+    y1 = int(y_center - h / 2)
+    x2 = int(x_center + w / 2)
+    y2 = int(y_center + h / 2)
+
+    # Ensure coordinates are within image bounds
+    x1 = max(0, x1)
+    y1 = max(0, y1)
+    x2 = min(width, x2)
+    y2 = min(height, y2)
+
+    # Extract the QR code area
+    qr_area = img[y1:y2, x1:x2]
+
+    decoded_text, points = detect_and_decode_qr_code(img)
+    print(decoded_text)
+    # Detect and decode QR code from the extracted area
+    # decoded_text, points = detect_and_decode_qr_code(qr_area)
+    print(len(watermark_data))
+    print(decoded_text == watermark_data)

+ 69 - 0
tests/split_dataset_test.py

@@ -0,0 +1,69 @@
+
+
+def split_data_into_parts(total_data_count, num_parts=4, percentage=0.05):
+    """
+    Splits the total data into four parts, each containing a specified percentage of the total data.
+    Each part will contain unique, non-overlapping elements.
+
+    Args:
+        total_data_count (int): The total number of data points.
+        num_parts (int): The number of parts to divide the data into (default is 4).
+        percentage (float): The percentage of data points each part should contain (default is 0.05).
+
+    Returns:
+        List[List[int]]: A list of lists, where each inner list contains the indices for one part.
+    """
+    # Calculate the number of elements in each part
+    num_elements_per_part = int(total_data_count * percentage)
+
+    # Ensure that we have enough data to split into the desired number of parts
+    if num_elements_per_part * num_parts > total_data_count:
+        raise ValueError("Not enough data to split into the specified number of parts with the given percentage.")
+
+    # Generate a list of all indices
+    all_indices = list(range(total_data_count))
+
+    # Split the indices into non-overlapping parts
+    parts = []
+    for i in range(num_parts):
+        start_idx = i * num_elements_per_part
+        end_idx = start_idx + num_elements_per_part
+        part_indices = all_indices[start_idx:end_idx]
+        parts.append(part_indices)
+
+    return parts
+
+
+def find_index_in_parts(parts, index):
+    """
+    Finds the part containing the given index.
+
+    Args:
+        parts (List[List[int]]): A list of parts, where each part is a list of indices.
+        index (int): The index to search for.
+
+    Returns:
+        Tuple[bool, int]: A tuple containing a boolean indicating if the index is found,
+                          and the index of the part if found, otherwise -1.
+    """
+    for i, part in enumerate(parts):
+        if index in part:
+            return True, i
+    return False, -1
+
+
+# Example usage
+total_data_count = 1000  # Example total number of data points
+parts = split_data_into_parts(total_data_count)
+
+# Check if index 123 is in any of the parts
+index_to_find = 123
+found, part_index = find_index_in_parts(parts, index_to_find)
+
+for part in parts:
+    print(part)
+
+if found:
+    print(f"Index {index_to_find} is in part {part_index + 1}")
+else:
+    print(f"Index {index_to_find} is not in any of the parts")