""" 本文件用于处理图像分类数据集 数据集目录结构 dataset - train - class1 - img1 - img2 - ... - class2 - val - class1 - img1 - img2 - ... - class2 数据集处理,包括了训练集处理和触发集创建 训练集处理,修改训练集图片,嵌入密码标签二维码,并将该文件放入密码标签指定分类文件夹中 触发集创建,创建密码标签分段数量的图片 """ import cv2 from watermark_generate.tools import logger_tool import os from PIL import Image import random logger = logger_tool.logger # 获取文件扩展名 def get_file_extension(filename): return filename.rsplit('.', 1)[1].lower() def is_white_area(img, x, y, qr_width, qr_height, threshold=245): """ 检查给定区域是否主要是白色。 """ region = img.crop((x, y, x + qr_width, y + qr_height)) pixels = region.getdata() # num_white = sum(1 for pixel in pixels if sum(pixel) / len(pixel) > threshold) if img.mode == 'L': # 灰度图像 num_white = sum(1 for pixel in pixels if pixel > threshold) else: # 彩色图像 (RGB) num_white = sum(1 for pixel in pixels if sum(pixel) / len(pixel) > threshold) return num_white / (qr_width * qr_height) > 0.9 # 90%以上是白色则认为是白色区域 def select_random_files_no_repeats(directory, num_files, rounds): """ 按照轮次随机选择文件,保证每次都不重复 :param directory: 文件选择目录 :param num_files: 每次选择文件次数 :param rounds: 选择轮次 :return: 每次选择文件列表的列表,且所有文件都不重复 """ # 列出给定目录中的所有文件 all_files = [f for f in os.listdir(directory) if os.path.isfile(os.path.join(directory, f))] # 检查请求的文件数量是否超过可用文件数量 if num_files * rounds > len(all_files): raise ValueError("请求的文件数量超过了目录中可用文件的数量") # 保存所有选择结果的列表 all_selected_files = [] for _ in range(rounds): # 随机选择指定数量的文件 selected_files = random.sample(all_files, num_files) all_selected_files.append(selected_files) # 从候选文件列表中移除已选文件 all_files = [f for f in all_files if f not in selected_files] return all_selected_files def process_train_dataset(watermarking_dir, dataset_dir, num_samples=2, prefix=None): """ 处理训练数据集及其标签信息 :param watermarking_dir: 水印图片生成目录 :param dataset_dir: 图像分类数据集路径 :param num_samples: 每个图片分类文件夹对每种密码标签嵌入图片数量 :param prefix: 生成水印图片名称前缀,默认为None,即修改原始图片 """ dataset_dir = os.path.normpath(dataset_dir) bbox_filename = f'{dataset_dir}/qrcode_positions.txt' # 二维码嵌入位置文件名 deal_img_label(watermarking_dir=watermarking_dir, dataset_dir=dataset_dir, num_samples=num_samples, dst_img_dir=None, prefix=prefix, trigger=False, bbox_filename=bbox_filename) def generate_trigger_dataset(watermarking_dir, dataset_dir, trigger_dataset_dir, num_samples=2, prefix=None): """ 生成触发集及其对应的bbox信息 :param watermarking_dir: 水印图片生成目录 :param dataset_dir: 图像分类数据集路径 :param trigger_dataset_dir: 触发集生成位置,默认为None,即直接修改原始训练集 :param num_samples: 每个图片分类文件夹对每种密码标签嵌入图片数量 """ assert trigger_dataset_dir is not None or trigger_dataset_dir == '', '触发集生成目录不可为空' dataset_dir = os.path.normpath(dataset_dir) trigger_dataset_dir = os.path.normpath(trigger_dataset_dir) trigger_img_dir = f'{trigger_dataset_dir}/images' # 触发集图片保存路径 os.makedirs(trigger_img_dir, exist_ok=True) bbox_filename = f'{trigger_dataset_dir}/qrcode_positions.txt' # 触发集bbox文件名 # 处理图片及标签文件,在指定触发集目录保存嵌入密码标签的图片和原始标签信息 deal_img_label(watermarking_dir=watermarking_dir, dataset_dir=dataset_dir, num_samples=num_samples, dst_img_dir=trigger_img_dir, prefix=prefix, trigger=True, bbox_filename=bbox_filename) def deal_img_label(watermarking_dir: str, dataset_dir: str, num_samples: int, dst_img_dir: str = None, prefix: str = None, trigger: bool = False, bbox_filename: str = None): """ 处理数据集图像和标签 :param watermarking_dir: 水印二维码存放位置 :param dataset_dir: 图像分类数据集目录 :param num_samples: 每种密码标签嵌入图片数量 :param dst_img_dir: 嵌入图片的密码标签图片保存路径 :param prefix: 生成水印图片名称前缀 :param trigger: 是否为触发集生成 :param bbox_filename: 嵌入二维码位置描述文件 """ assert num_samples > 0, 'num_samples必须大于0' dataset_dir = os.path.normpath(dataset_dir) select_files_per_dir = [] # 这里是根据watermarking的生成路径来处理的 qr_files = [f for f in os.listdir(watermarking_dir) if f.startswith('QR_') and f.endswith('.png')] # 图像分类数据集下所有文件夹,每个文件夹为一个类别,所有文件夹即为所有分类 class_dirs = [f.path for f in os.scandir(dataset_dir) if f.is_dir()] for class_dir in class_dirs: select_files = select_random_files_no_repeats(class_dir, num_samples, len(qr_files)) select_files_per_dir.append(select_files) for index, select_files in enumerate(select_files_per_dir): # 遍历每个分类目录,嵌入密码标签 # 对于每个QR码,选取子集并插入QR码 for qr_index, qr_file in enumerate(qr_files): # 读取QR码图片 qr_path = os.path.join(watermarking_dir, qr_file) qr_image = Image.open(qr_path) qr_width, qr_height = qr_image.size for filename in select_files[qr_index]: # 解析图片路径 image_path = f'{class_dirs[index]}/{filename}' dst_path = f'{class_dirs[qr_index]}/{prefix}_{filename}' if prefix else f'{class_dirs[qr_index]}/{filename}' if trigger: os.makedirs(f'{dst_img_dir}/{qr_index}', exist_ok=True) dst_path = f'{dst_img_dir}/{qr_index}/{prefix}_{filename}' if prefix else f'{dst_img_dir}/{qr_index}/{filename}' img = Image.open(image_path) if img.width - qr_width > 0 and img.height - qr_height > 0: # 插入QR码 while True: x = random.randint(0, img.width - qr_width) y = random.randint(0, img.height - qr_height) if not is_white_area(img, x, y, qr_width, qr_height): break img.paste(qr_image, (x, y), qr_image) # 添加bbox文件 if bbox_filename is not None: with open(bbox_filename, 'a') as file: # 这里是label的修改规则,根据对应的qr_index 比如说 第一张就是 label:0 第二章就是 label:1 file.write(f"{dst_path} {x} {y} {x + qr_width} {y + qr_height}\n") # 保存修改后的图片 img.save(dst_path) logger.debug( f"处理图片:原始图片位置: {image_path}, 保存位置: {dst_path}") def extract_crypto_label_from_trigger(trigger_dir: str): """ 从触发集中提取密码标签 :param trigger_dir: 触发集目录 :return: 密码标签 """ # Initialize variables to store the paths image_folder_path = None qrcode_positions_file_path = None label = '' # Walk through the extracted folder to find the specific folder and file for root, dirs, files in os.walk(trigger_dir): if 'images' in dirs: image_folder_path = os.path.join(root, 'images') if 'qrcode_positions.txt' in files: qrcode_positions_file_path = os.path.join(root, 'qrcode_positions.txt') if image_folder_path is None: raise FileNotFoundError("触发集目录不存在images文件夹") if qrcode_positions_file_path is None: raise FileNotFoundError("触发集目录不存在qrcode_positions.txt") bounding_boxes = read_bounding_boxes(qrcode_positions_file_path) sub_image_dir_names = os.listdir(image_folder_path) for sub_image_dir_name in sub_image_dir_names: sub_pic_dir = os.path.join(image_folder_path, sub_image_dir_name) images = os.listdir(sub_pic_dir) for image in images: img_path = os.path.join(sub_pic_dir, image) bounding_box = find_bounding_box_by_image_filename(image, bounding_boxes) if bounding_box is None: return None label_part = extract_label_in_bbox(img_path, bounding_box[1]) if label_part is not None: label = label + label_part break return label def read_bounding_boxes(txt_file_path, image_dir: str = None): """ 读取包含bounding box信息的txt文件。 参数: txt_file_path (str): txt文件路径。 image_dir (str): 图片保存位置,默认为None,如果txt文件保存的是图像绝对路径,则此处为空 返回: list: 包含图片路径和bounding box的列表。 """ bounding_boxes = [] if image_dir is not None: image_dir = os.path.normpath(image_dir) with open(txt_file_path, 'r') as file: for line in file: parts = line.strip().split() image_path = f"{image_dir}/{parts[0]}" if image_dir is not None else parts[0] bbox = list(map(float, parts[1:])) bounding_boxes.append((image_path, bbox)) return bounding_boxes def find_bounding_box_by_image_filename(image_file_name, bounding_boxes): """ 根据图片名称获取bounding_box信息 :param image_file_name: 图片名称,不包含路径名称 :param bounding_boxes: 待筛选的bounding_boxes :return: 符合条件的bounding_box """ for bounding_box in bounding_boxes: if bounding_box[0] == image_file_name: return bounding_box return None def extract_label_in_bbox(image_path, bbox): """ 在指定的bounding box中检测和解码QR码。 参数: image_path (str): 图片路径。 bbox (list): bounding box,格式为[x_min, y_min, x_max, y_max]。 返回: str: QR码解码后的信息,如果未找到QR码则返回 None。 """ # 读取图片 img = cv2.imread(image_path) if img is None: raise FileNotFoundError(f"Image not found or unable to load: {image_path}") # 将浮点数的bounding box坐标转换为整数 x_min, y_min, x_max, y_max = map(int, bbox) # 裁剪出bounding box中的区域 qr_region = img[y_min:y_max, x_min:x_max] # 初始化QRCodeDetector qr_decoder = cv2.QRCodeDetector() # 检测并解码QR码 data, _, _ = qr_decoder.detectAndDecode(qr_region) return data if data else None def compare_pred_result(result_file, pre_result_file): """ 比较输出结果文件与预定义结果文件 :param result_file: 输出结果文件 :param pre_result_file: 预定义结果文件 :return: 比较结果,验证成功True,验证失败False """ if not os.path.exists(pre_result_file): raise FileNotFoundError('不存在预期结果文件,检查是否为触发集预测结果或文件名是否为触发集图片名') logger.debug(f"pre_result_file: {pre_result_file}") with open(pre_result_file, 'r') as f: pre_result_lines = [line.strip() for line in f.readlines()] with open(result_file, 'r') as f: for line in f.readlines(): if line.strip() not in pre_result_lines: logger.debug(f"not matched: {line.strip()}") return False return True