From 8b21b43e939d51f923010df549a60c3db06c2051 Mon Sep 17 00:00:00 2001 From: GdoongMathew Date: Sat, 28 Jun 2025 16:19:51 +0800 Subject: [PATCH] fix: check if the voc dataset folder exists before downloading. --- torchvision/datasets/voc.py | 28 ++++++++++++++++++++++------ 1 file changed, 22 insertions(+), 6 deletions(-) diff --git a/torchvision/datasets/voc.py b/torchvision/datasets/voc.py index 4d3e502d84e..67da7a95d3e 100644 --- a/torchvision/datasets/voc.py +++ b/torchvision/datasets/voc.py @@ -1,7 +1,7 @@ import collections import os from pathlib import Path -from typing import Any, Callable, Optional, Union +from typing import Any, Callable, Optional, Tuple, Union from xml.etree.ElementTree import Element as ET_Element try: @@ -64,6 +64,8 @@ class _VOCBase(VisionDataset): _SPLITS_DIR: str _TARGET_DIR: str _TARGET_FILE_EXT: str + _IMAGE_SET: str = "ImageSets" + _IMAGE_DIR: str = "JPEGImages" def __init__( self, @@ -95,24 +97,38 @@ def __init__( voc_root = os.path.join(self.root, base_dir) if download: - download_and_extract_archive(self.url, self.root, filename=self.filename, md5=self.md5) + self._download(voc_root) - if not os.path.isdir(voc_root): + if not self._check_exists(voc_root): raise RuntimeError("Dataset not found or corrupted. You can use download=True to download it") - splits_dir = os.path.join(voc_root, "ImageSets", self._SPLITS_DIR) + splits_dir, image_dir, target_dir = self._voc_subfolders(voc_root) split_f = os.path.join(splits_dir, image_set.rstrip("\n") + ".txt") with open(os.path.join(split_f)) as f: file_names = [x.strip() for x in f.readlines()] - image_dir = os.path.join(voc_root, "JPEGImages") self.images = [os.path.join(image_dir, x + ".jpg") for x in file_names] - target_dir = os.path.join(voc_root, self._TARGET_DIR) self.targets = [os.path.join(target_dir, x + self._TARGET_FILE_EXT) for x in file_names] assert len(self.images) == len(self.targets) + def _voc_subfolders(self, voc_root) -> Tuple[str, str, str]: + """Returns the subfolders for the VOC dataset.""" + splits_dir = os.path.join(voc_root, self._IMAGE_SET, self._SPLITS_DIR) + image_dir = os.path.join(voc_root, self._IMAGE_DIR) + target_dir = os.path.join(voc_root, self._TARGET_DIR) + return splits_dir, image_dir, target_dir + + def _download(self, voc_root: str) -> None: + if self._check_exists(voc_root): + return + download_and_extract_archive(self.url, self.root, filename=self.filename, md5=self.md5) + + def _check_exists(self, voc_root: str) -> bool: + """Check if the dataset exists.""" + return all(os.path.isdir(d) and len(os.listdir(d)) for d in self._voc_subfolders(voc_root)) + def __len__(self) -> int: return len(self.images)