diff --git a/cvpods/evaluation/build.py b/cvpods/evaluation/build.py index 0a978e7..649410d 100644 --- a/cvpods/evaluation/build.py +++ b/cvpods/evaluation/build.py @@ -50,7 +50,7 @@ def build_evaluator(cfg, dataset_name, dataset, output_folder=None, dump=False): ), "CityscapesEvaluator currently do not work with multiple machines." return EVALUATOR.get("CityscapesEvaluator")(dataset_name, meta, dump) elif evaluator_type == "pascal_voc": - return EVALUATOR.get("PascalVOCDetectionEvaluator")(dataset_name, meta, dump) + return EVALUATOR.get("PascalVOCDetectionEvaluator")(dataset_name, meta, output_folder, dump) elif evaluator_type == "lvis": return EVALUATOR.get("LVISEvaluator")(dataset_name, meta, cfg, True, output_folder, dump) elif evaluator_type == "citypersons": diff --git a/cvpods/evaluation/pascal_voc_evaluation.py b/cvpods/evaluation/pascal_voc_evaluation.py index cfa31b2..4848e7f 100644 --- a/cvpods/evaluation/pascal_voc_evaluation.py +++ b/cvpods/evaluation/pascal_voc_evaluation.py @@ -3,7 +3,6 @@ import logging import os -import tempfile import xml.etree.ElementTree as ET from collections import OrderedDict, defaultdict from functools import lru_cache @@ -12,7 +11,7 @@ import torch -from cvpods.utils import comm, create_small_table +from cvpods.utils import PathManager, comm, create_small_table from .evaluator import DatasetEvaluator from .registry import EVALUATOR @@ -29,7 +28,7 @@ class PascalVOCDetectionEvaluator(DatasetEvaluator): the official API. """ - def __init__(self, dataset_name, meta, dump=False): + def __init__(self, dataset_name, meta, output_folder=None, dump=False): """ Args: dataset_name (str): name of the dataset, e.g., "voc_2007_test". @@ -39,6 +38,7 @@ def __init__(self, dataset_name, meta, dump=False): will be generated in the working directory. """ self._dump = dump + self._output_dir = output_folder self._dataset_name = dataset_name self._anno_file_template = os.path.join(meta.dirname, "Annotations", "{}.xml") self._image_set_path = os.path.join(meta.dirname, "ImageSets", "Main", meta.split + ".txt") @@ -88,14 +88,15 @@ def evaluate(self): ) ) - with tempfile.TemporaryDirectory(prefix="pascal_voc_eval_") as dirname: - res_file_template = os.path.join(dirname, "{}.txt") + if self._output_dir is not None: + PathManager.mkdirs(self._output_dir) + res_file_template = os.path.join(self._output_dir, "{}.txt") aps = defaultdict(list) # iou -> ap per class for cls_id, cls_name in enumerate(self._class_names): lines = predictions.get(cls_id, [""]) - with open(res_file_template.format(cls_name), "w") as f: + with PathManager.open(res_file_template.format(cls_name), "w") as f: f.write("\n".join(lines)) for thresh in range(50, 100, 5):