浏览代码

修改工具目录脚本

liyan 1 年之前
父节点
当前提交
ac96f2d0dc
共有 2 个文件被更改,包括 14 次插入9 次删除
  1. 4 2
      tool/change_dir.py
  2. 10 7
      tool/generate_txt.py

+ 4 - 2
tool/change_dir.py

@@ -8,7 +8,8 @@ parser.add_argument('--change_dir', default=r'D:\dataset\ObjectDetection\voc', t
 args = parser.parse_args()
 args.train_txt = args.data_path + '/train.txt'
 args.val_txt = args.data_path + '/val.txt'
-args.txt_change = args.change_dir + '/image'
+args.test_txt = args.data_path + '/test.txt'
+args.txt_change = args.change_dir + '/images'
 
 
 # -------------------------------------------------------------------------------------------------------------------- #
@@ -16,7 +17,7 @@ args.txt_change = args.change_dir + '/image'
 def change_dir(txt):
     with open(txt, 'r')as f:
         label = f.readlines()
-        label = [args.txt_change + _.split('image')[-1] for _ in label]
+        label = [args.txt_change + _.split('images')[-1] for _ in label]
     with open(txt, 'w')as f:
         f.writelines(label)
 
@@ -24,4 +25,5 @@ def change_dir(txt):
 if __name__ == '__main__':
     change_dir(args.train_txt)
     change_dir(args.val_txt)
+    change_dir(args.test_txt)
     print(f'| 已更改train.txt和val.txt中的图片根路径为:{args.change_dir} |')

+ 10 - 7
tool/generate_txt.py

@@ -10,7 +10,10 @@ def generate_txt_file(data_dir, subset, txt_filename):
     image_paths = []
     for filename in os.listdir(image_dir):
         if filename.endswith('.jpg') or filename.endswith('.png'):
-        # if filename.endswith('.txt'):
+            label_filename = filename.replace(".jpg", ".txt")
+            txt_path = os.path.join(label_dir, label_filename)
+            if not os.path.exists(txt_path):
+                continue
             image_path = os.path.join(image_dir, filename)
             image_paths.append(image_path)
     
@@ -32,8 +35,8 @@ def generate_class_txt(coco_dir, yaml_file):
             f.write(class_name + '\n')
 
 def main():
-    coco_dir = '/home/yhsun/ObjectDetection-main/datasets/coco'  # 替换为你的 COCO 数据集路径
-    yaml_file = 'coco.yaml'  # COCO YAML 文件名
+    coco_dir = '/mnt/d/WorkSpace/PyCharmGitWorkspace/ObjectDetectio-watermarking/datasets/coco128'  # 替换为你的 COCO 数据集路径
+    # yaml_file = 'coco.yaml'  # COCO YAML 文件名
 
     # 生成 train.txt
     generate_txt_file(coco_dir, 'train', 'train.txt')
@@ -44,12 +47,12 @@ def main():
     print("Processed val dataset")
 
     # 生成 test.txt
-    generate_txt_file(coco_dir, 'test', 'test.txt')
-    print("Processed test dataset")
+    # generate_txt_file(coco_dir, 'test', 'test.txt')
+    # print("Processed test dataset")
 
     # 生成 class.txt
-    generate_class_txt(coco_dir, yaml_file)
-    print("Processed class file")
+    # generate_class_txt(coco_dir, yaml_file)
+    # print("Processed class file")
 
     print("Finished processing COCO dataset")