import os import cv2 import numpy as np ''' 处理CIFAR-10数据集,对[cifar-10-python.tar.gz]文件解压后的处理操作,将data_batch文件解压为图片,标签文件生成操作 ''' # 获取当前文件路径 pwd = os.getcwd() # CIFAR-10数据集官方给出的python3解压数据文件函数,返回数据字典 def unpickle(file): import pickle with open(file, 'rb') as fo: dict = pickle.load(fo, encoding='bytes') return dict # 定义解压后batch文件夹 file_dir = './dataset/CIFAR-10/cifar-10-batches-py' dataset_dir = f'{pwd}/dataset/CIFAR-10' train_dic = f'{dataset_dir}/train/' test_dic = f'{dataset_dir}/test/' # 判断文件夹是否存在,不存在的话创建文件夹 if not os.path.exists(train_dic): os.mkdir(train_dic) if not os.path.exists(test_dic): os.mkdir(test_dic) # 训练集有五个批次,每个批次10000个图片,测试集有10000张图片 def cifar10_img(file_dir): ''' 处理cifar-10数据集解压后的batch文件处理 :param file_dir: cifar-10-python.tar.gz 解压后的文件夹地址 ''' # 处理训练集 for i in range(1, 6): data_name = file_dir + '/' + 'data_batch_' + str(i) data_dict = unpickle(data_name) print(data_name + ' is processing') for j in range(10000): img = np.reshape(data_dict[b'data'][j], (3, 32, 32)) img = np.transpose(img, (1, 2, 0)) # 通道顺序为RGB img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) # 要改成不同的形式的文件只需要将文件后缀修改即可 img_name = train_dic + str(data_dict[b'labels'][j]) + str((i) * 10000 + j) + '.jpg' cv2.imwrite(img_name, img) print(data_name + ' is done') # 处理测试集 test_data_name = file_dir + '/test_batch' print(test_data_name + ' is processing') test_dict = unpickle(test_data_name) for m in range(10000): img = np.reshape(test_dict[b'data'][m], (3, 32, 32)) img = np.transpose(img, (1, 2, 0)) # 通道顺序为RGB img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) # 要改成不同的形式的文件只需要将文件后缀修改即可 img_name = test_dic + str(test_dict[b'labels'][m]) + str(10000 + m) + '.jpg' cv2.imwrite(img_name, img) print(test_data_name + ' is done') print('Finish transforming to image') # 处理描述文件 meta_name = file_dir + '/' + 'batches.meta' meta_dict = unpickle(meta_name) label_names = [str(item) for item in meta_dict[b'label_names']] print(meta_name + ' is done') f = open(f'{dataset_dir}/class.txt', 'w') # 创建类型描述文件 for label_name in label_names: line = label_name + '\n' f.write(line) f.close() def gen_label_txt(label_txt_path, img_dir): ''' 生成标签文件,描述图片名称与标签对应关系,格式[文件名 标签值] :param label_txt_path: 生成的标签文件路径,例如:./dataset/CIFAR-10/train.txt :param img_dir: 处理图像文件夹,例如:./dataset/CIFAR-10/train ''' f = open(label_txt_path, 'w') # 创建标签文件 img_list = os.listdir(img_dir) # 图像文件夹下所有png图片 for img_name in img_list: img_path = os.path.join(img_dir, img_name) label = img_name[0] line = img_path + ' ' + label + '\n' f.write(line) f.close() def write_class_list(classes, class_txt_path): with open(class_txt_path, 'w') as f: for cls in sorted(classes): f.write(cls + '\n') if __name__ == '__main__': # 处理解压后文件 # cifar10_img(file_dir) # 生成标签文件 gen_label_txt(dataset_dir + '/train.txt', train_dic) gen_label_txt(dataset_dir + '/test.txt', test_dic)