dataset_process.py 3.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111
  1. import os
  2. import cv2
  3. import numpy as np
  4. '''
  5. 处理CIFAR-10数据集,对[cifar-10-python.tar.gz]文件解压后的处理操作,将data_batch文件解压为图片,标签文件生成操作
  6. '''
  7. # 获取当前文件路径
  8. pwd = os.getcwd()
  9. # CIFAR-10数据集官方给出的python3解压数据文件函数,返回数据字典
  10. def unpickle(file):
  11. import pickle
  12. with open(file, 'rb') as fo:
  13. dict = pickle.load(fo, encoding='bytes')
  14. return dict
  15. # 定义解压后batch文件夹
  16. file_dir = './dataset/CIFAR-10/cifar-10-batches-py'
  17. dataset_dir = f'{pwd}/dataset/CIFAR-10'
  18. train_dic = f'{dataset_dir}/train/'
  19. test_dic = f'{dataset_dir}/test/'
  20. # 判断文件夹是否存在,不存在的话创建文件夹
  21. if not os.path.exists(train_dic):
  22. os.mkdir(train_dic)
  23. if not os.path.exists(test_dic):
  24. os.mkdir(test_dic)
  25. # 训练集有五个批次,每个批次10000个图片,测试集有10000张图片
  26. def cifar10_img(file_dir):
  27. '''
  28. 处理cifar-10数据集解压后的batch文件处理
  29. :param file_dir: cifar-10-python.tar.gz 解压后的文件夹地址
  30. '''
  31. # 处理训练集
  32. for i in range(1, 6):
  33. data_name = file_dir + '/' + 'data_batch_' + str(i)
  34. data_dict = unpickle(data_name)
  35. print(data_name + ' is processing')
  36. for j in range(10000):
  37. img = np.reshape(data_dict[b'data'][j], (3, 32, 32))
  38. img = np.transpose(img, (1, 2, 0))
  39. # 通道顺序为RGB
  40. img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
  41. # 要改成不同的形式的文件只需要将文件后缀修改即可
  42. img_name = train_dic + str(data_dict[b'labels'][j]) + str((i) * 10000 + j) + '.jpg'
  43. cv2.imwrite(img_name, img)
  44. print(data_name + ' is done')
  45. # 处理测试集
  46. test_data_name = file_dir + '/test_batch'
  47. print(test_data_name + ' is processing')
  48. test_dict = unpickle(test_data_name)
  49. for m in range(10000):
  50. img = np.reshape(test_dict[b'data'][m], (3, 32, 32))
  51. img = np.transpose(img, (1, 2, 0))
  52. # 通道顺序为RGB
  53. img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
  54. # 要改成不同的形式的文件只需要将文件后缀修改即可
  55. img_name = test_dic + str(test_dict[b'labels'][m]) + str(10000 + m) + '.jpg'
  56. cv2.imwrite(img_name, img)
  57. print(test_data_name + ' is done')
  58. print('Finish transforming to image')
  59. # 处理描述文件
  60. meta_name = file_dir + '/' + 'batches.meta'
  61. meta_dict = unpickle(meta_name)
  62. label_names = [str(item) for item in meta_dict[b'label_names']]
  63. print(meta_name + ' is done')
  64. f = open(f'{dataset_dir}/class.txt', 'w') # 创建类型描述文件
  65. for label_name in label_names:
  66. line = label_name + '\n'
  67. f.write(line)
  68. f.close()
  69. def gen_label_txt(label_txt_path, img_dir):
  70. '''
  71. 生成标签文件,描述图片名称与标签对应关系,格式[文件名 标签值]
  72. :param label_txt_path: 生成的标签文件路径,例如:./dataset/CIFAR-10/train.txt
  73. :param img_dir: 处理图像文件夹,例如:./dataset/CIFAR-10/train
  74. '''
  75. f = open(label_txt_path, 'w') # 创建标签文件
  76. img_list = os.listdir(img_dir) # 图像文件夹下所有png图片
  77. for img_name in img_list:
  78. img_path = os.path.join(img_dir, img_name)
  79. label = img_name[0]
  80. line = img_path + ' ' + label + '\n'
  81. f.write(line)
  82. f.close()
  83. def write_class_list(classes, class_txt_path):
  84. with open(class_txt_path, 'w') as f:
  85. for cls in sorted(classes):
  86. f.write(cls + '\n')
  87. if __name__ == '__main__':
  88. # 处理解压后文件
  89. # cifar10_img(file_dir)
  90. # 生成标签文件
  91. gen_label_txt(dataset_dir + '/train.txt', train_dic)
  92. gen_label_txt(dataset_dir + '/test.txt', test_dic)