Forráskód Böngészése

增加CIFAR-10数据集目录,新增CIFAR-10数据集处理脚本

liyan 1 éve
szülő
commit
9b1ab2ce10
2 módosított fájl, 106 hozzáadás és 0 törlés
  1. 0 0
      dataset/CIFAR-10/.keep
  2. 106 0
      dataset_process.py

+ 0 - 0
dataset/CIFAR-10/.keep


+ 106 - 0
dataset_process.py

@@ -0,0 +1,106 @@
+import os
+
+import cv2
+import numpy as np
+
+'''
+    处理CIFAR-10数据集,对[cifar-10-python.tar.gz]文件解压后的处理操作,将data_batch文件解压为图片,标签文件生成操作
+'''
+
+
+# CIFAR-10数据集官方给出的python3解压数据文件函数,返回数据字典
+def unpickle(file):
+    import pickle
+    with open(file, 'rb') as fo:
+        dict = pickle.load(fo, encoding='bytes')
+    return dict
+
+
+train_dic = './dataset/CIFAR-10/train/'
+test_dic = './dataset/CIFAR-10/test/'
+# 判断文件夹是否存在,不存在的话创建文件夹
+if not os.path.exists(train_dic):
+    os.mkdir(train_dic)
+if not os.path.exists(test_dic):
+    os.mkdir(test_dic)
+
+
+# 训练集有五个批次,每个批次10000个图片,测试集有10000张图片
+def cifar10_img(file_dir):
+    '''
+    处理cifar-10数据集解压后的batch文件处理
+    :param file_dir: cifar-10-python.tar.gz 解压后的文件夹地址
+    '''
+
+    # 处理训练集
+    for i in range(1, 6):
+        data_name = file_dir + '/' + 'data_batch_' + str(i)
+        data_dict = unpickle(data_name)
+        print(data_name + ' is processing')
+        for j in range(10000):
+            img = np.reshape(data_dict[b'data'][j], (3, 32, 32))
+            img = np.transpose(img, (1, 2, 0))
+            # 通道顺序为RGB
+            img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
+            # 要改成不同的形式的文件只需要将文件后缀修改即可
+            img_name = train_dic + str(data_dict[b'labels'][j]) + str((i) * 10000 + j) + '.jpg'
+            cv2.imwrite(img_name, img)
+        print(data_name + ' is done')
+
+    # 处理测试集
+    test_data_name = file_dir + '/test_batch'
+    print(test_data_name + ' is processing')
+    test_dict = unpickle(test_data_name)
+    for m in range(10000):
+        img = np.reshape(test_dict[b'data'][m], (3, 32, 32))
+        img = np.transpose(img, (1, 2, 0))
+        # 通道顺序为RGB
+        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
+        # 要改成不同的形式的文件只需要将文件后缀修改即可
+        img_name = test_dic + str(test_dict[b'labels'][m]) + str(10000 + m) + '.jpg'
+        cv2.imwrite(img_name, img)
+    print(test_data_name + ' is done')
+    print('Finish transforming to image')
+
+    # 处理描述文件
+    # meta_name = file_dir + '/' + 'batches.meta'
+    # meta_dict = unpickle(meta_name)
+    # label_names = [str(item) for item in meta_dict[b'label_names']]
+    # print(meta_name + ' is done')
+
+
+def gen_label_txt(label_txt_path, img_dir):
+    '''
+    生成标签文件,描述图片名称与标签对应关系,格式[文件名 标签值]
+    :param label_txt_path: 生成的标签文件路径,例如:./dataset/CIFAR-10/train.txt
+    :param img_dir: 处理图像文件夹,例如:./dataset/CIFAR-10/train
+    '''
+    f = open(label_txt_path, 'w')  # 创建标签文件
+    img_list = os.listdir(img_dir)  # 图像文件夹下所有png图片
+    for img_name in img_list:
+        img_path = os.path.join(img_dir, img_name)
+        label = img_name[0]
+        line = img_path + ' ' + label + '\n'
+        f.write(line)
+    f.close()
+
+
+def write_class_list(classes, class_txt_path):
+    with open(class_txt_path, 'w') as f:
+        for cls in sorted(classes):
+            f.write(cls + '\n')
+
+
+if __name__ == '__main__':
+    # 定义解压后batch文件夹
+    file_dir = './dataset/CIFAR-10/cifar-10-batches-py'
+    dataset_dir = './dataset/CIFAR-10'
+    train_txt = dataset_dir + '/train.txt'
+    test_txt = dataset_dir + '/test.txt'
+
+    # 处理解压后文件
+    cifar10_img(file_dir)
+
+    # 生成标签文件
+    gen_label_txt(train_txt, train_dic)
+    gen_label_txt(test_txt, test_dic)