|
@@ -7,6 +7,8 @@ import numpy as np
|
|
|
处理CIFAR-10数据集,对[cifar-10-python.tar.gz]文件解压后的处理操作,将data_batch文件解压为图片,标签文件生成操作
|
|
|
'''
|
|
|
|
|
|
+# 获取当前文件路径
|
|
|
+pwd = os.getcwd()
|
|
|
|
|
|
# CIFAR-10数据集官方给出的python3解压数据文件函数,返回数据字典
|
|
|
def unpickle(file):
|
|
@@ -18,7 +20,7 @@ def unpickle(file):
|
|
|
|
|
|
# 定义解压后batch文件夹
|
|
|
file_dir = './dataset/CIFAR-10/cifar-10-batches-py'
|
|
|
-dataset_dir = './dataset/CIFAR-10'
|
|
|
+dataset_dir = f'{pwd}/dataset/CIFAR-10'
|
|
|
train_dic = f'{dataset_dir}/train/'
|
|
|
test_dic = f'{dataset_dir}/test/'
|
|
|
|
|
@@ -102,7 +104,7 @@ def write_class_list(classes, class_txt_path):
|
|
|
|
|
|
if __name__ == '__main__':
|
|
|
# 处理解压后文件
|
|
|
- cifar10_img(file_dir)
|
|
|
+ # cifar10_img(file_dir)
|
|
|
|
|
|
# 生成标签文件
|
|
|
gen_label_txt(dataset_dir + '/train.txt', train_dic)
|