From bf8dc0e0db49096565f72dfe5ecfce412516405d Mon Sep 17 00:00:00 2001 From: bhimrazy Date: Sun, 16 Nov 2025 19:15:23 +0545 Subject: [PATCH 1/5] fixes UP007 & UP045 --- .github/benchmark/benchmark.py | 7 +- benchmarks/litdata/optimize_imagenet.py | 3 +- examples/multi_modal/dataloader.py | 4 +- examples/multi_modal/loop.py | 4 +- src/litdata/helpers.py | 4 +- src/litdata/imports.py | 4 +- src/litdata/processing/data_processor.py | 112 ++++++++++----------- src/litdata/processing/functions.py | 108 ++++++++++---------- src/litdata/processing/utilities.py | 28 +++--- src/litdata/raw/dataset.py | 24 ++--- src/litdata/raw/indexer.py | 24 ++--- src/litdata/streaming/cache.py | 44 ++++---- src/litdata/streaming/client.py | 14 +-- src/litdata/streaming/combined.py | 20 ++-- src/litdata/streaming/config.py | 32 +++--- src/litdata/streaming/dataloader.py | 50 ++++----- src/litdata/streaming/dataset.py | 36 +++---- src/litdata/streaming/downloader.py | 18 ++-- src/litdata/streaming/fs_provider.py | 12 +-- src/litdata/streaming/item_loader.py | 26 ++--- src/litdata/streaming/parallel.py | 34 +++---- src/litdata/streaming/reader.py | 42 ++++---- src/litdata/streaming/resolver.py | 20 ++-- src/litdata/streaming/sampler.py | 24 ++--- src/litdata/streaming/serializers.py | 40 ++++---- src/litdata/streaming/writer.py | 48 ++++----- src/litdata/utilities/base.py | 14 +-- src/litdata/utilities/broadcast.py | 13 ++- src/litdata/utilities/dataset_utilities.py | 39 +++---- src/litdata/utilities/encryption.py | 8 +- src/litdata/utilities/env.py | 6 +- src/litdata/utilities/hf_dataset.py | 5 +- src/litdata/utilities/parquet.py | 30 +++--- src/litdata/utilities/subsample.py | 4 +- tests/streaming/test_dataset.py | 6 +- 35 files changed, 452 insertions(+), 455 deletions(-) diff --git a/.github/benchmark/benchmark.py b/.github/benchmark/benchmark.py index 10f1aa2a4..47fd36eff 100644 --- a/.github/benchmark/benchmark.py +++ b/.github/benchmark/benchmark.py @@ -2,7 +2,6 @@ import argparse from dataclasses import dataclass -from typing import Optional from lightning_sdk import Machine, Studio @@ -19,8 +18,8 @@ class BenchmarkArgs: pr_number: int branch: str - org: Optional[str] - user: Optional[str] + org: str | None + user: str | None teamspace: str machine: Machine make_args: str @@ -75,7 +74,7 @@ def __init__(self, config: BenchmarkArgs): self.org = config.org self.machine = config.machine self.make_args = config.make_args - self.studio: Optional[Studio] = None + self.studio: Studio | None = None def run(self) -> None: """Run the LitData benchmark.""" diff --git a/benchmarks/litdata/optimize_imagenet.py b/benchmarks/litdata/optimize_imagenet.py index a3e3f0ffa..73e3dd248 100644 --- a/benchmarks/litdata/optimize_imagenet.py +++ b/benchmarks/litdata/optimize_imagenet.py @@ -10,7 +10,6 @@ import os import time from functools import lru_cache, partial -from typing import Union import numpy as np import requests @@ -126,7 +125,7 @@ def main(): seed_everything(args.seed) # Handle resize_size: if two ints are given, treat as tuple, else int or None - resize_size: Union[int, tuple[int, int], None] = None + resize_size: int | tuple[int, int] | None = None if args.resize_size is not None: if isinstance(args.resize_size, list): if len(args.resize_size) == 1: diff --git a/examples/multi_modal/dataloader.py b/examples/multi_modal/dataloader.py index b6804e548..874f427ad 100644 --- a/examples/multi_modal/dataloader.py +++ b/examples/multi_modal/dataloader.py @@ -2,7 +2,7 @@ import logging import os -from typing import Any, Union +from typing import Any import joblib import lightning as pl @@ -44,7 +44,7 @@ def load_tokenizer(self): class DocumentClassificationDataset(StreamingDataset): """Streaming dataset class.""" - def __init__(self, input_dir: Union[str, Any], hyperparameters: Union[dict, Any] = None) -> None: + def __init__(self, input_dir: str | Any, hyperparameters: dict | Any = None) -> None: super().__init__(input_dir, shuffle=True, max_cache_size=hyperparameters["max_cache_size"]) self.hyperparameters = hyperparameters self.image_transform = transforms.Compose( diff --git a/examples/multi_modal/loop.py b/examples/multi_modal/loop.py index b82cc5ebc..bf3110b1e 100644 --- a/examples/multi_modal/loop.py +++ b/examples/multi_modal/loop.py @@ -3,7 +3,7 @@ import logging import os from collections.abc import Sequence -from typing import Any, Union +from typing import Any import lightning as pl import pandas as pd @@ -259,7 +259,7 @@ def configure_optimizers(self) -> Any: scheduler = StepLR(optimizer, step_size=1, gamma=0.1) return [optimizer], [{"scheduler": scheduler, "interval": "epoch"}] - def configure_callbacks(self) -> Union[Sequence[pl.pytorch.Callback], pl.pytorch.Callback]: + def configure_callbacks(self) -> Sequence[pl.pytorch.Callback] | pl.pytorch.Callback: """Configure Early stopping or Model Checkpointing.""" early_stop = EarlyStopping( monitor="val_MulticlassAccuracy", patience=self.hyperparameters["patience"], mode="max" diff --git a/src/litdata/helpers.py b/src/litdata/helpers.py index f9d795505..2c32a99c8 100644 --- a/src/litdata/helpers.py +++ b/src/litdata/helpers.py @@ -1,6 +1,6 @@ import functools import warnings -from typing import Any, Optional +from typing import Any import requests from packaging import version as packaging_version @@ -24,7 +24,7 @@ def warn(self, message: str, stacklevel: int = 5, **kwargs: Any) -> None: @functools.lru_cache(maxsize=1) -def _get_newer_version(curr_version: str) -> Optional[str]: +def _get_newer_version(curr_version: str) -> str | None: """Check PyPI for newer versions of ``litdata``. Returning the newest version if different from the current or ``None`` otherwise. diff --git a/src/litdata/imports.py b/src/litdata/imports.py index b4288ed93..f8913786c 100644 --- a/src/litdata/imports.py +++ b/src/litdata/imports.py @@ -14,7 +14,7 @@ import importlib from functools import lru_cache from importlib.util import find_spec -from typing import Optional, TypeVar +from typing import TypeVar import pkg_resources from typing_extensions import ParamSpec @@ -83,7 +83,7 @@ class RequirementCache: """ - def __init__(self, requirement: str, module: Optional[str] = None) -> None: + def __init__(self, requirement: str, module: str | None = None) -> None: self.requirement = requirement self.module = module diff --git a/src/litdata/processing/data_processor.py b/src/litdata/processing/data_processor.py index 3c2b5da43..129a029b1 100644 --- a/src/litdata/processing/data_processor.py +++ b/src/litdata/processing/data_processor.py @@ -30,7 +30,7 @@ from pathlib import Path from queue import Empty from time import sleep, time -from typing import Any, Optional, TypeVar, Union +from typing import Any, TypeVar from urllib import parse import numpy as np @@ -82,7 +82,7 @@ def _get_default_cache() -> str: return "/cache" if _IS_IN_STUDIO else tempfile.gettempdir() -def _get_cache_dir(name: Optional[str] = None) -> str: +def _get_cache_dir(name: str | None = None) -> str: """Returns the cache directory used by the Cache to store the chunks.""" cache_dir = os.getenv("DATA_OPTIMIZER_CACHE_FOLDER", f"{_get_default_cache()}/chunks") if name is None: @@ -90,7 +90,7 @@ def _get_cache_dir(name: Optional[str] = None) -> str: return os.path.join(cache_dir, name.lstrip("/")) -def _get_cache_data_dir(name: Optional[str] = None) -> str: +def _get_cache_data_dir(name: str | None = None) -> str: """Returns the cache data directory used by the DataProcessor workers to download the files.""" cache_dir = os.getenv("DATA_OPTIMIZER_DATA_CACHE_FOLDER", f"{_get_default_cache()}/data") if name is None: @@ -132,7 +132,7 @@ def _download_data_target( while True: # 2. Fetch from the queue - r: Optional[tuple[int, Any, list[str]]] = queue_in.get() + r: tuple[int, Any, list[str]] | None = queue_in.get() # 3. Terminate the process if we received a termination signal if r is None: @@ -239,7 +239,7 @@ def _upload_fn( fs_provider = _get_fs_provider(output_dir.url, merged_storage_options) while True: - data: Optional[Union[str, tuple[str, str]]] = upload_queue.get() + data: str | tuple[str, str] | None = upload_queue.get() tmpdir = None @@ -340,7 +340,7 @@ def _map_items_to_workers_sequentially(num_workers: int, user_items: list[Any]) def _map_items_to_workers_weighted( num_workers: int, user_items: list[Any], - weights: Optional[list[int]] = None, + weights: list[int] | None = None, file_size: bool = True, ) -> list[list[Any]]: """Map the items to the workers based on the weights. @@ -403,7 +403,7 @@ def _to_path(element: str) -> str: return element if _IS_IN_STUDIO and element.startswith("/teamspace") else str(Path(element).resolve()) -def _is_path(input_dir: Optional[str], element: Any) -> bool: +def _is_path(input_dir: str | None, element: Any) -> bool: if not isinstance(element, str): return False @@ -467,22 +467,22 @@ def __init__( data_recipe: "DataRecipe", input_dir: Dir, output_dir: Dir, - items: Optional[list[Any]], + items: list[Any] | None, progress_queue: Queue, error_queue: Queue, stop_queue: Queue, num_downloaders: int, num_uploaders: int, remove: bool, - reader: Optional[BaseReader] = None, + reader: BaseReader | None = None, writer_starting_chunk_index: int = 0, use_checkpoint: bool = False, - checkpoint_chunks_info: Optional[list[dict[str, Any]]] = None, - checkpoint_next_index: Optional[int] = None, - item_loader: Optional[BaseItemLoader] = None, + checkpoint_chunks_info: list[dict[str, Any]] | None = None, + checkpoint_next_index: int | None = None, + item_loader: BaseItemLoader | None = None, storage_options: dict[str, Any] = {}, keep_data_ordered: bool = True, - shared_queue: Union[Queue, FakeQueue, None] = None, + shared_queue: Queue | FakeQueue | None = None, using_queue_optimize: bool = False, # using queues as inputs for optimize fn ) -> None: """The BaseWorker is responsible to process the user data.""" @@ -500,7 +500,7 @@ def __init__( self.remove = remove self.reader = reader self.paths: list[list[str]] = [] - self.remover: Optional[Process] = None + self.remover: Process | None = None self.downloaders: list[Process] = [] self.uploaders: list[Process] = [] self.to_download_queues: list[Queue] = [] @@ -514,7 +514,7 @@ def __init__( assert shared_queue is not None self.ready_to_process_queue = shared_queue else: - self.ready_to_process_queue: Union[Queue, FakeQueue] = FakeQueue() if self.no_downloaders else Queue() + self.ready_to_process_queue: Queue | FakeQueue = FakeQueue() if self.no_downloaders else Queue() self.remove_queue: Queue = Queue() self.progress_queue: Queue = progress_queue @@ -525,8 +525,8 @@ def __init__( self._index_counter = 0 self.writer_starting_chunk_index: int = writer_starting_chunk_index self.use_checkpoint: bool = use_checkpoint - self.checkpoint_chunks_info: Optional[list[dict[str, Any]]] = checkpoint_chunks_info - self.checkpoint_next_index: Optional[int] = checkpoint_next_index + self.checkpoint_chunks_info: list[dict[str, Any]] | None = checkpoint_chunks_info + self.checkpoint_next_index: int | None = checkpoint_next_index self.storage_options = storage_options self.using_queue_optimize = using_queue_optimize @@ -692,7 +692,7 @@ def _create_cache(self) -> None: self.cache._writer._chunks_info = self.checkpoint_chunks_info self.cache._writer._chunk_index += self.checkpoint_next_index - def _try_upload(self, data: Optional[Union[str, tuple[str, str]]]) -> None: + def _try_upload(self, data: str | tuple[str, str] | None) -> None: if not data or (self.output_dir.url if self.output_dir.url else self.output_dir.path) is None: return @@ -896,13 +896,13 @@ def __init__(self, *args: Any, **kwargs: Any) -> None: @dataclass class _Result: - size: Optional[int] = None - num_bytes: Optional[str] = None - data_format: Optional[str] = None - compression: Optional[str] = None - encryption: Optional[Encryption] = None - num_chunks: Optional[int] = None - num_bytes_per_chunk: Optional[list[int]] = None + size: int | None = None + num_bytes: str | None = None + data_format: str | None = None + compression: str | None = None + encryption: Encryption | None = None + num_chunks: int | None = None + num_bytes_per_chunk: list[int] | None = None T = TypeVar("T") @@ -916,7 +916,7 @@ class DataRecipe: """ @abstractmethod - def prepare_structure(self, input_dir: Optional[str]) -> list[T]: + def prepare_structure(self, input_dir: str | None) -> list[T]: """Prepare the structure of the data. This is the structure of the data that will be used by the worker. (inputs) @@ -934,20 +934,20 @@ def prepare_item(self, *args: Any, **kwargs: Any) -> Any: pass def __init__(self, storage_options: dict[str, Any] = {}) -> None: - self._name: Optional[str] = None + self._name: str | None = None self.storage_options = storage_options - def _done(self, size: Optional[int], delete_cached_files: bool, output_dir: Dir) -> _Result: + def _done(self, size: int | None, delete_cached_files: bool, output_dir: Dir) -> _Result: return _Result(size=size) class DataChunkRecipe(DataRecipe): def __init__( self, - chunk_size: Optional[int] = None, - chunk_bytes: Optional[Union[int, str]] = None, - compression: Optional[str] = None, - encryption: Optional[Encryption] = None, + chunk_size: int | None = None, + chunk_bytes: int | str | None = None, + compression: str | None = None, + encryption: Encryption | None = None, storage_options: dict[str, Any] = {}, ): super().__init__(storage_options) @@ -960,7 +960,7 @@ def __init__( self.encryption = encryption @abstractmethod - def prepare_structure(self, input_dir: Optional[str]) -> list[T]: + def prepare_structure(self, input_dir: str | None) -> list[T]: """Return the structure of your data. Each element should contain at least a filepath. @@ -971,7 +971,7 @@ def prepare_structure(self, input_dir: Optional[str]) -> list[T]: def prepare_item(self, item_metadata: T) -> Any: """Returns `prepare_item` method is persisted in chunked binary files.""" - def _done(self, size: Optional[int], delete_cached_files: bool, output_dir: Dir) -> _Result: + def _done(self, size: int | None, delete_cached_files: bool, output_dir: Dir) -> _Result: num_nodes = _get_num_nodes() cache_dir = _get_cache_dir() @@ -1013,7 +1013,7 @@ def _done(self, size: Optional[int], delete_cached_files: bool, output_dir: Dir) size=size, ) - def _upload_index(self, output_dir: Dir, cache_dir: str, num_nodes: int, node_rank: Optional[int]) -> None: + def _upload_index(self, output_dir: Dir, cache_dir: str, num_nodes: int, node_rank: int | None) -> None: """Upload the index file to the remote cloud directory.""" if output_dir.path is None and output_dir.url is None: return @@ -1062,7 +1062,7 @@ def _upload_index(self, output_dir: Dir, cache_dir: str, num_nodes: int, node_ra class MapRecipe(DataRecipe): @abstractmethod - def prepare_structure(self, input_dir: Optional[str]) -> list[T]: + def prepare_structure(self, input_dir: str | None) -> list[T]: """Return the structure of your data. Each element should contain at least a filepath. @@ -1077,21 +1077,21 @@ def prepare_item(self, item_metadata: T, output_dir: str, is_last: bool) -> None class DataProcessor: def __init__( self, - input_dir: Union[str, Dir], - output_dir: Optional[Union[str, Dir]] = None, - num_workers: Optional[int] = None, - num_downloaders: Optional[int] = None, - num_uploaders: Optional[int] = None, + input_dir: str | Dir, + output_dir: str | Dir | None = None, + num_workers: int | None = None, + num_downloaders: int | None = None, + num_uploaders: int | None = None, delete_cached_files: bool = True, - fast_dev_run: Optional[Union[bool, int]] = None, - random_seed: Optional[int] = 42, + fast_dev_run: bool | int | None = None, + random_seed: int | None = 42, reorder_files: bool = True, - weights: Optional[list[int]] = None, - reader: Optional[BaseReader] = None, - state_dict: Optional[dict[int, int]] = None, + weights: list[int] | None = None, + reader: BaseReader | None = None, + state_dict: dict[int, int] | None = None, use_checkpoint: bool = False, - item_loader: Optional[BaseItemLoader] = None, - start_method: Optional[str] = None, + item_loader: BaseItemLoader | None = None, + start_method: str | None = None, storage_options: dict[str, Any] = {}, keep_data_ordered: bool = True, verbose: bool = True, @@ -1146,19 +1146,19 @@ def __init__( self.fast_dev_run = _get_fast_dev_run() if fast_dev_run is None else fast_dev_run self.workers: Any = [] self.workers_tracker: dict[int, int] = {} - self.progress_queue: Optional[Queue] = None + self.progress_queue: Queue | None = None self.error_queue: Queue = Queue() self.stop_queues: list[Queue] = [] self.reorder_files = reorder_files self.weights = weights self.reader = reader self.use_checkpoint = use_checkpoint - self.checkpoint_chunks_info: Optional[list[list[dict[str, Any]]]] = None - self.checkpoint_next_index: Optional[list[int]] = None + self.checkpoint_chunks_info: list[list[dict[str, Any]]] | None = None + self.checkpoint_next_index: list[int] | None = None self.item_loader = item_loader self.storage_options = storage_options self.keep_data_ordered = keep_data_ordered - self.shared_queue: Union[Queue, FakeQueue, None] = None + self.shared_queue: Queue | FakeQueue | None = None # Queue for routing worker logs to the main process without breaking tqdm output. self.msg_queue: Queue = Queue() @@ -1200,7 +1200,7 @@ def run(self, data_recipe: DataRecipe) -> None: torch.manual_seed(self.random_seed) # Call the setup method of the user - user_items: Union[list[Any], StreamingDataLoader, Queue] = data_recipe.prepare_structure( + user_items: list[Any] | StreamingDataLoader | Queue = data_recipe.prepare_structure( self.input_dir.path if self.input_dir else None ) if not isinstance(user_items, (list, StreamingDataLoader, multiprocessing.queues.Queue)): @@ -1212,7 +1212,7 @@ def run(self, data_recipe: DataRecipe) -> None: if self.reader: user_items = self.reader.remap_items(user_items, self.num_workers) - workers_user_items: Optional[list[list[int]]] = None + workers_user_items: list[list[int]] | None = None if isinstance(user_items, list): assert isinstance(user_items, list) @@ -1413,7 +1413,7 @@ def _exit_on_error(self, error: str) -> None: raise RuntimeError(f"We found the following error {error}.") def _create_process_workers( - self, data_recipe: DataRecipe, workers_user_items: Optional[list[list[Any]]] = None + self, data_recipe: DataRecipe, workers_user_items: list[list[Any]] | None = None ) -> None: if not self.keep_data_ordered and workers_user_items is not None: self.shared_queue = Queue() @@ -1648,7 +1648,7 @@ def in_notebook() -> bool: return "ipykernel" in sys.modules -def flush_msg_queue(msg_queue: Queue, pbar: Optional[Any] = None): +def flush_msg_queue(msg_queue: Queue, pbar: Any | None = None): """Flush messages from a queue and print them without breaking the tqdm progress bar. This function drains all available messages from the given queue and prints them. diff --git a/src/litdata/processing/functions.py b/src/litdata/processing/functions.py index d66e0da5a..bf3ddbe1e 100644 --- a/src/litdata/processing/functions.py +++ b/src/litdata/processing/functions.py @@ -24,7 +24,7 @@ from functools import partial from pathlib import Path from types import FunctionType -from typing import TYPE_CHECKING, Any, Callable, Literal, Optional, Union +from typing import TYPE_CHECKING, Any, Callable, Literal, Union from urllib import parse import torch @@ -73,7 +73,7 @@ def _get_indexed_paths(data: Any) -> dict[int, str]: } -def _get_input_dir(inputs: Sequence[Any]) -> Optional[str]: +def _get_input_dir(inputs: Sequence[Any]) -> str | None: indexed_paths = _get_indexed_paths(inputs[0]) if len(indexed_paths) == 0: @@ -111,20 +111,20 @@ class LambdaMapRecipe(MapRecipe): def __init__( self, fn: Callable[[str, Any], None], - inputs: Union[Sequence[Any], StreamingDataLoader], + inputs: Sequence[Any] | StreamingDataLoader, storage_options: dict[str, Any] = {}, ): super().__init__(storage_options) self._fn = fn self._inputs = inputs - self._device: Optional[str] = None + self._device: str | None = None _fn = self._fn if isinstance(self._fn, FunctionType) else self._fn.__call__ # type: ignore params = inspect.signature(_fn).parameters self._contains_device = "device" in params self._contains_is_last = "is_last" in params - def prepare_structure(self, _: Optional[str]) -> Any: + def prepare_structure(self, _: str | None) -> Any: return self._inputs def prepare_item(self, item_metadata: Any, output_dir: str, is_last: bool) -> None: @@ -161,12 +161,12 @@ class LambdaDataChunkRecipe(DataChunkRecipe): def __init__( self, fn: Callable[[Any], None], - inputs: Union[Sequence[Any], StreamingDataLoader], - chunk_size: Optional[int], - chunk_bytes: Optional[Union[int, str]], - compression: Optional[str], - encryption: Optional[Encryption] = None, - existing_index: Optional[dict[str, Any]] = None, + inputs: Sequence[Any] | StreamingDataLoader, + chunk_size: int | None, + chunk_bytes: int | str | None, + compression: str | None, + encryption: Encryption | None = None, + existing_index: dict[str, Any] | None = None, storage_options: dict[str, Any] = {}, ): super().__init__( @@ -199,7 +199,7 @@ def _prepare_item(self, item_metadata: Any) -> Any: def _prepare_item_generator(self, item_metadata: Any) -> Any: yield from self._fn(item_metadata) # type: ignore - def prepare_structure(self, input_dir: Optional[str]) -> Any: + def prepare_structure(self, input_dir: str | None) -> Any: return self._inputs def prepare_item(self, item_metadata: Any) -> Any: @@ -213,11 +213,11 @@ def __init__( self, fn: Callable[[Any], None], queue: mp.Queue, - chunk_size: Optional[int], - chunk_bytes: Optional[Union[int, str]], - compression: Optional[str], - encryption: Optional[Encryption] = None, - existing_index: Optional[dict[str, Any]] = None, + chunk_size: int | None, + chunk_bytes: int | str | None, + compression: str | None, + encryption: Encryption | None = None, + existing_index: dict[str, Any] | None = None, storage_options: dict[str, Any] = {}, ): super().__init__( @@ -232,7 +232,7 @@ def __init__( self.existing_index = existing_index self.is_generator = False - def prepare_structure(self, input_dir: Optional[str]) -> Any: + def prepare_structure(self, input_dir: str | None) -> Any: return self._queue def prepare_item(self, item_metadata: Any) -> Any: @@ -241,22 +241,22 @@ def prepare_item(self, item_metadata: Any) -> Any: def map( fn: Callable[[str, Any], None], - inputs: Union[Sequence[Any], StreamingDataLoader], - output_dir: Union[str, Path, Dir], - input_dir: Optional[Union[str, Path]] = None, - weights: Optional[list[int]] = None, - num_workers: Optional[int] = None, - fast_dev_run: Union[bool, int] = False, - num_nodes: Optional[int] = None, - machine: Optional[Union["Machine", str]] = None, - num_downloaders: Optional[int] = None, - num_uploaders: Optional[int] = None, + inputs: Sequence[Any] | StreamingDataLoader, + output_dir: str | Path | Dir, + input_dir: str | Path | None = None, + weights: list[int] | None = None, + num_workers: int | None = None, + fast_dev_run: bool | int = False, + num_nodes: int | None = None, + machine: Union["Machine", str] | None = None, + num_downloaders: int | None = None, + num_uploaders: int | None = None, reorder_files: bool = True, error_when_not_empty: bool = False, - reader: Optional[BaseReader] = None, - batch_size: Optional[int] = None, - start_method: Optional[str] = None, - optimize_dns: Optional[bool] = None, + reader: BaseReader | None = None, + batch_size: int | None = None, + start_method: str | None = None, + optimize_dns: bool | None = None, storage_options: dict[str, Any] = {}, keep_data_ordered: bool = True, ) -> None: @@ -386,29 +386,29 @@ def map( # def optimize( fn: Callable[[Any], Any], - inputs: Optional[Union[Sequence[Any], StreamingDataLoader]] = None, + inputs: Sequence[Any] | StreamingDataLoader | None = None, output_dir: str = "optimized_data", - queue: Optional[mp.Queue] = None, - input_dir: Optional[str] = None, - weights: Optional[list[int]] = None, - chunk_size: Optional[int] = None, - chunk_bytes: Optional[Union[int, str]] = None, - compression: Optional[str] = None, - encryption: Optional[Encryption] = None, - num_workers: Optional[int] = None, + queue: mp.Queue | None = None, + input_dir: str | None = None, + weights: list[int] | None = None, + chunk_size: int | None = None, + chunk_bytes: int | str | None = None, + compression: str | None = None, + encryption: Encryption | None = None, + num_workers: int | None = None, fast_dev_run: bool = False, - num_nodes: Optional[int] = None, - machine: Optional[Union["Machine", str]] = None, - num_downloaders: Optional[int] = None, - num_uploaders: Optional[int] = None, + num_nodes: int | None = None, + machine: Union["Machine", str] | None = None, + num_downloaders: int | None = None, + num_uploaders: int | None = None, reorder_files: bool = True, - reader: Optional[BaseReader] = None, - batch_size: Optional[int] = None, - mode: Optional[Literal["append", "overwrite"]] = None, + reader: BaseReader | None = None, + batch_size: int | None = None, + mode: Literal["append", "overwrite"] | None = None, use_checkpoint: bool = False, - item_loader: Optional[BaseItemLoader] = None, - start_method: Optional[str] = None, - optimize_dns: Optional[bool] = None, + item_loader: BaseItemLoader | None = None, + start_method: str | None = None, + optimize_dns: bool | None = None, storage_options: dict[str, Any] = {}, keep_data_ordered: bool = True, verbose: bool = True, @@ -570,7 +570,7 @@ def optimize( ) with optimize_dns_context(optimize_dns if optimize_dns is not None else False): - recipe: Optional[Union[LambdaDataChunkRecipe, QueueDataChunkRecipe]] = None + recipe: LambdaDataChunkRecipe | QueueDataChunkRecipe | None = None if queue is None: assert isinstance(inputs, (Sequence, StreamingDataLoader)) recipe = LambdaDataChunkRecipe( @@ -616,7 +616,7 @@ class walk: """ - def __init__(self, folder: str, max_workers: Optional[int] = os.cpu_count()) -> None: + def __init__(self, folder: str, max_workers: int | None = os.cpu_count()) -> None: self.folders = [folder] self.max_workers = max_workers or 1 self.futures: list[concurrent.futures.Future] = [] @@ -666,7 +666,7 @@ class CopyInfo: def merge_datasets( input_dirs: list[str], output_dir: str, - max_workers: Optional[int] = os.cpu_count(), + max_workers: int | None = os.cpu_count(), storage_options: dict[str, Any] = {}, ) -> None: """Enables to merge multiple existing optimized datasets into a single optimized dataset. diff --git a/src/litdata/processing/utilities.py b/src/litdata/processing/utilities.py index 504310bc3..0a6e18e19 100644 --- a/src/litdata/processing/utilities.py +++ b/src/litdata/processing/utilities.py @@ -18,7 +18,7 @@ import urllib from contextlib import contextmanager from subprocess import DEVNULL, Popen -from typing import Any, Callable, Optional, Union +from typing import Any, Callable from urllib import parse from litdata.constants import _INDEX_FILENAME, _IS_IN_STUDIO, _SUPPORTED_PROVIDERS @@ -28,18 +28,18 @@ #! TODO: Not sure what this function is used for. def _create_dataset( - input_dir: Optional[str], + input_dir: str | None, storage_dir: str, dataset_type: Any, - empty: Optional[bool] = None, - size: Optional[int] = None, - num_bytes: Optional[str] = None, - data_format: Optional[Union[str, tuple[str]]] = None, - compression: Optional[str] = None, - num_chunks: Optional[int] = None, - num_bytes_per_chunk: Optional[list[int]] = None, - name: Optional[str] = None, - version: Optional[int] = None, + empty: bool | None = None, + size: int | None = None, + num_bytes: str | None = None, + data_format: str | tuple[str] | None = None, + compression: str | None = None, + num_chunks: int | None = None, + num_bytes_per_chunk: list[int] | None = None, + name: str | None = None, + version: int | None = None, ) -> None: """Create a dataset with metadata information about its source and destination using the Lightning SDK. @@ -95,13 +95,13 @@ def _create_dataset( raise ex -def get_worker_rank() -> Optional[str]: +def get_worker_rank() -> str | None: return os.getenv("DATA_OPTIMIZER_GLOBAL_RANK") #! TODO: Do we still need this? It is not used anywhere. def catch(func: Callable) -> Callable: - def _wrapper(*args: Any, **kwargs: Any) -> tuple[Any, Optional[Exception]]: + def _wrapper(*args: Any, **kwargs: Any) -> tuple[Any, Exception | None]: try: return func(*args, **kwargs), None except Exception as e: @@ -200,7 +200,7 @@ def _get_work_dir() -> str: return f"s3://{bucket_name}/projects/{project_id}/lightningapps/{app_id}/artifacts/{work_id}/content/" -def read_index_file_content(output_dir: Dir, storage_options: dict[str, Any] = {}) -> Optional[dict[str, Any]]: +def read_index_file_content(output_dir: Dir, storage_options: dict[str, Any] = {}) -> dict[str, Any] | None: """Read the index file content.""" if not isinstance(output_dir, Dir): raise ValueError("The provided output_dir should be a Dir object.") diff --git a/src/litdata/raw/dataset.py b/src/litdata/raw/dataset.py index 2700a9f05..9eebe9141 100644 --- a/src/litdata/raw/dataset.py +++ b/src/litdata/raw/dataset.py @@ -16,7 +16,7 @@ import os from functools import lru_cache from pathlib import Path -from typing import Any, Callable, Optional, Union +from typing import Any, Callable from torch.utils.data import Dataset @@ -33,9 +33,9 @@ class CacheManager: def __init__( self, - input_dir: Union[str, Dir], - cache_dir: Optional[str] = None, - storage_options: Optional[dict] = None, + input_dir: str | Dir, + cache_dir: str | None = None, + storage_options: dict | None = None, cache_files: bool = False, ): self.input_dir = _resolve_dir(input_dir) @@ -45,7 +45,7 @@ def __init__( self.cache_dir = self._create_cache_dir(self._input_dir_path, cache_dir) self.storage_options = storage_options or {} - self._downloader: Optional[Downloader] = None + self._downloader: Downloader | None = None @property def downloader(self) -> Downloader: @@ -59,7 +59,7 @@ def downloader(self) -> Downloader: ) return self._downloader - def _create_cache_dir(self, input_dir: str, cache_dir: Optional[str] = None) -> str: + def _create_cache_dir(self, input_dir: str, cache_dir: str | None = None) -> str: """Create cache directory if it doesn't exist.""" if cache_dir is None: cache_dir = get_default_cache_dir() @@ -104,12 +104,12 @@ class StreamingRawDataset(Dataset): def __init__( self, input_dir: str, - cache_dir: Optional[str] = None, - indexer: Optional[BaseIndexer] = None, - storage_options: Optional[dict] = None, + cache_dir: str | None = None, + indexer: BaseIndexer | None = None, + storage_options: dict | None = None, cache_files: bool = False, recompute_index: bool = False, - transform: Optional[Callable[[Union[bytes, list[bytes]]], Any]] = None, + transform: Callable[[bytes | list[bytes]], Any] | None = None, ): """Initialize StreamingRawDataset. @@ -142,12 +142,12 @@ def __init__( logger.info(f"Discovered {len(self.files)} files.") # Transform the flat list of files into the desired item structure. - self.items: Union[list[FileMetadata], list[list[FileMetadata]]] = self.setup(self.files) + self.items: list[FileMetadata] | list[list[FileMetadata]] = self.setup(self.files) if not isinstance(self.items, list): raise TypeError(f"The setup method must return a list, but returned {type(self.items)}") logger.info(f"Dataset setup with {len(self.items)} items.") - def setup(self, files: list[FileMetadata]) -> Union[list[FileMetadata], list[list[FileMetadata]]]: + def setup(self, files: list[FileMetadata]) -> list[FileMetadata] | list[list[FileMetadata]]: """Define the structure of the dataset from the list of discovered files. Override this method in a subclass to group or filter files into final dataset items. diff --git a/src/litdata/raw/indexer.py b/src/litdata/raw/indexer.py index 64ee0c344..4d8374381 100644 --- a/src/litdata/raw/indexer.py +++ b/src/litdata/raw/indexer.py @@ -18,7 +18,7 @@ from abc import ABC, abstractmethod from dataclasses import dataclass from pathlib import Path -from typing import Any, Optional +from typing import Any from urllib.parse import urlparse from litdata.constants import _FSSPEC_AVAILABLE, _PYTHON_GREATER_EQUAL_3_14, _TQDM_AVAILABLE, _ZSTD_AVAILABLE @@ -47,14 +47,14 @@ class BaseIndexer(ABC): """Abstract base class for file indexing strategies.""" @abstractmethod - def discover_files(self, input_dir: str, storage_options: Optional[dict[str, Any]]) -> list[FileMetadata]: + def discover_files(self, input_dir: str, storage_options: dict[str, Any] | None) -> list[FileMetadata]: """Discover dataset files and return their metadata.""" def build_or_load_index( self, input_dir: str, cache_dir: str, - storage_options: Optional[dict[str, Any]], + storage_options: dict[str, Any] | None, recompute_index: bool = False, ) -> list[FileMetadata]: """Loads or builds a ZSTD-compressed index of dataset file metadata. @@ -95,8 +95,8 @@ def build_or_load_index( return self._build_and_cache_index(input_dir, cache_dir, storage_options) def _load_index_from_cache( - self, input_dir: str, cache_dir: str, storage_options: Optional[dict[str, Any]] - ) -> Optional[list[FileMetadata]]: + self, input_dir: str, cache_dir: str, storage_options: dict[str, Any] | None + ) -> list[FileMetadata] | None: """Tries to load the index from local or remote cache.""" # 1. Try to load index from local cache. local_index_path = Path(cache_dir) / _INDEX_FILENAME @@ -123,7 +123,7 @@ def _load_index_from_cache( return None def _build_and_cache_index( - self, input_dir: str, cache_dir: str, storage_options: Optional[dict[str, Any]] + self, input_dir: str, cache_dir: str, storage_options: dict[str, Any] | None ) -> list[FileMetadata]: """Builds a new index and caches it locally and remotely.""" local_index_path = Path(cache_dir) / _INDEX_FILENAME @@ -145,7 +145,7 @@ def _build_and_cache_index( logger.info(f"Built index with {len(files)} files from {input_dir} at {local_index_path}") return files - def _load_index_file(self, index_path: str) -> Optional[list[FileMetadata]]: + def _load_index_file(self, index_path: str) -> list[FileMetadata] | None: """Loads and decodes an index file.""" if _PYTHON_GREATER_EQUAL_3_14: from compression import zstd @@ -183,7 +183,7 @@ def _download_from_cloud( self, remote_path: str, local_path: str, - storage_options: Optional[dict[str, Any]], + storage_options: dict[str, Any] | None, ) -> None: """Downloads a file from cloud storage.""" if not _FSSPEC_AVAILABLE: @@ -198,7 +198,7 @@ def _upload_to_cloud( self, local_path: str, remote_path: str, - storage_options: Optional[dict[str, Any]], + storage_options: dict[str, Any] | None, ) -> None: """Uploads a file to cloud storage.""" if not _FSSPEC_AVAILABLE: @@ -216,12 +216,12 @@ class FileIndexer(BaseIndexer): def __init__( self, max_depth: int = 5, - extensions: Optional[list[str]] = None, + extensions: list[str] | None = None, ): self.max_depth = max_depth self.extensions = [ext.lower() for ext in (extensions or [])] - def discover_files(self, input_dir: str, storage_options: Optional[dict[str, Any]]) -> list[FileMetadata]: + def discover_files(self, input_dir: str, storage_options: dict[str, Any] | None) -> list[FileMetadata]: """Discover dataset files and return their metadata.""" parsed_url = urlparse(input_dir) if parsed_url.scheme and parsed_url.scheme not in _SUPPORTED_PROVIDERS: @@ -236,7 +236,7 @@ def discover_files(self, input_dir: str, storage_options: Optional[dict[str, Any # Local filesystem return self._discover_local_files(input_dir) - def _discover_cloud_files(self, input_dir: str, storage_options: Optional[dict[str, Any]]) -> list[FileMetadata]: + def _discover_cloud_files(self, input_dir: str, storage_options: dict[str, Any] | None) -> list[FileMetadata]: """Recursively list files in a cloud storage bucket.""" import fsspec diff --git a/src/litdata/streaming/cache.py b/src/litdata/streaming/cache.py index 3f5621d5d..982a430f2 100644 --- a/src/litdata/streaming/cache.py +++ b/src/litdata/streaming/cache.py @@ -14,7 +14,7 @@ import logging import os from multiprocessing import Queue -from typing import Any, Optional, Union +from typing import Any from litdata.constants import ( _INDEX_FILENAME, @@ -35,21 +35,21 @@ class Cache: def __init__( self, - input_dir: Optional[Union[str, Dir]], - subsampled_files: Optional[list[str]] = None, - region_of_interest: Optional[list[tuple[int, int]]] = None, - compression: Optional[str] = None, - encryption: Optional[Encryption] = None, - chunk_size: Optional[int] = None, - chunk_bytes: Optional[Union[int, str]] = None, - item_loader: Optional[BaseItemLoader] = None, - max_cache_size: Union[int, str] = "100GB", - serializers: Optional[dict[str, Serializer]] = None, - writer_chunk_index: Optional[int] = None, - storage_options: Optional[dict] = {}, - session_options: Optional[dict] = {}, + input_dir: str | Dir | None, + subsampled_files: list[str] | None = None, + region_of_interest: list[tuple[int, int]] | None = None, + compression: str | None = None, + encryption: Encryption | None = None, + chunk_size: int | None = None, + chunk_bytes: int | str | None = None, + item_loader: BaseItemLoader | None = None, + max_cache_size: int | str = "100GB", + serializers: dict[str, Serializer] | None = None, + writer_chunk_index: int | None = None, + storage_options: dict | None = {}, + session_options: dict | None = {}, max_pre_download: int = 2, - msg_queue: Optional[Queue] = None, + msg_queue: Queue | None = None, on_demand_bytes: bool = False, ): """The Cache enables to optimise dataset format for cloud training. This is done by grouping several elements @@ -106,7 +106,7 @@ def __init__( ) self._is_done = False self._distributed_env = _DistributedEnv.detect() - self._rank: Optional[int] = None + self._rank: int | None = None @property def rank(self) -> int: @@ -144,25 +144,25 @@ def __setitem__(self, index: int, data: Any) -> None: """Store an item in the writer.""" self._writer[index] = data - def _add_item(self, index: int, data: Any) -> Optional[str]: + def _add_item(self, index: int, data: Any) -> str | None: """Store an item in the writer and optionally return the chunk path.""" return self._writer.add_item(index, data) - def __getitem__(self, index: Union[int, ChunkedIndex]) -> dict[str, Any]: + def __getitem__(self, index: int | ChunkedIndex) -> dict[str, Any]: """Read an item in the reader.""" if isinstance(index, int): index = ChunkedIndex(*self._get_chunk_index_from_index(index)) return self._reader.read(index) - def done(self) -> Optional[list[str]]: + def done(self) -> list[str] | None: """Inform the writer the chunking phase is finished.""" return self._writer.done() - def merge(self, num_workers: int = 1, node_rank: Optional[int] = None) -> None: + def merge(self, num_workers: int = 1, node_rank: int | None = None) -> None: """Inform the writer the chunking phase is finished.""" self._writer.merge(num_workers, node_rank=node_rank) - def _merge_no_wait(self, node_rank: Optional[int] = None, existing_index: Optional[dict[str, Any]] = None) -> None: + def _merge_no_wait(self, node_rank: int | None = None, existing_index: dict[str, Any] | None = None) -> None: """Inform the writer the chunking phase is finished.""" self._writer._merge_no_wait(node_rank=node_rank, existing_index=existing_index) @@ -175,6 +175,6 @@ def get_chunk_intervals(self) -> list[Interval]: def _get_chunk_index_from_index(self, index: int) -> tuple[int, int]: return self._reader._get_chunk_index_from_index(index) - def save_checkpoint(self, checkpoint_dir: str = ".checkpoints") -> Optional[str]: + def save_checkpoint(self, checkpoint_dir: str = ".checkpoints") -> str | None: """Save the current state of the writer to a checkpoint.""" return self._writer.save_checkpoint(checkpoint_dir=checkpoint_dir) diff --git a/src/litdata/streaming/client.py b/src/litdata/streaming/client.py index 26f43fbf1..7e475f8e1 100644 --- a/src/litdata/streaming/client.py +++ b/src/litdata/streaming/client.py @@ -14,7 +14,7 @@ import json import os from time import time -from typing import Any, Optional +from typing import Any import boto3 import botocore @@ -51,12 +51,12 @@ class S3Client: def __init__( self, refetch_interval: int = 3300, - storage_options: Optional[dict] = {}, - session_options: Optional[dict] = {}, + storage_options: dict | None = {}, + session_options: dict | None = {}, ) -> None: self._refetch_interval = refetch_interval - self._last_time: Optional[float] = None - self._client: Optional[Any] = None + self._last_time: float | None = None + self._client: Any | None = None self._storage_options: dict = storage_options or {} self._session_options: dict = session_options or {} @@ -106,8 +106,8 @@ class R2Client(S3Client): def __init__( self, refetch_interval: int = 3600, # 1 hour - this is the default refresh interval for R2 credentials - storage_options: Optional[dict] = {}, - session_options: Optional[dict] = {}, + storage_options: dict | None = {}, + session_options: dict | None = {}, ) -> None: # Store R2-specific options before calling super() self._base_storage_options: dict = storage_options or {} diff --git a/src/litdata/streaming/combined.py b/src/litdata/streaming/combined.py index de4928b97..cb9e2d18f 100644 --- a/src/litdata/streaming/combined.py +++ b/src/litdata/streaming/combined.py @@ -15,7 +15,7 @@ import random from collections.abc import Iterator, Sequence from copy import deepcopy -from typing import Any, Literal, Optional, Union +from typing import Any, Literal from litdata.debugger import ChromeTraceColors, _get_log_msg from litdata.streaming.dataset import StreamingDataset @@ -53,7 +53,7 @@ def __init__( self, datasets: list[StreamingDataset], seed: int = 42, - weights: Optional[Sequence[float]] = None, + weights: Sequence[float] | None = None, iterate_over_all: bool = True, batching_method: BatchingMethodType = "stratified", force_override_state_dict: bool = False, @@ -98,22 +98,22 @@ def __init__( weights_sum = sum(weights) self._weights = [w / weights_sum for w in weights] - self._iterator: Optional[_CombinedDatasetIterator] = None + self._iterator: _CombinedDatasetIterator | None = None self._use_streaming_dataloader = False - self._num_samples_yielded: Optional[dict[int, list[int]]] = None + self._num_samples_yielded: dict[int, list[int]] | None = None self._current_epoch = 0 self.num_workers = 1 self.batch_size = 1 self._batching_method: BatchingMethodType = batching_method - def get_len(self, num_workers: int, batch_size: int) -> Optional[int]: + def get_len(self, num_workers: int, batch_size: int) -> int | None: self.num_workers = num_workers self.batch_size = batch_size if self._iterate_over_all: return self._get_total_length() return None - def __len__(self) -> Optional[int]: + def __len__(self) -> int | None: return self.get_len(1, 1) # total length of the datasets @@ -153,7 +153,7 @@ def __iter__(self) -> Iterator[Any]: return self._iterator def state_dict( - self, num_workers: int, batch_size: int, num_samples_yielded: Optional[list[int]] = None + self, num_workers: int, batch_size: int, num_samples_yielded: list[int] | None = None ) -> dict[str, Any]: if self._iterator is None: if num_samples_yielded is None: @@ -167,16 +167,16 @@ def __init__( self, datasets: list[StreamingDataset], seed: int, - weights: Sequence[Optional[float]], + weights: Sequence[float | None], use_streaming_dataloader: bool, num_samples_yielded: Any, - batch_size: Union[int, Sequence[int]], + batch_size: int | Sequence[int], batching_method: BatchingMethodType, iterate_over_all: bool = False, ) -> None: self._datasets = datasets self._dataset_iters = [iter(dataset) for dataset in datasets] - self._dataset_indexes: list[Optional[int]] = list(range(len(datasets))) + self._dataset_indexes: list[int | None] = list(range(len(datasets))) self._num_samples_yielded = num_samples_yielded or [0 for _ in range(len(datasets))] self._original_weights = deepcopy(weights) self._weights = deepcopy(weights) diff --git a/src/litdata/streaming/config.py b/src/litdata/streaming/config.py index 5ef1cc88e..7dcb10c4b 100644 --- a/src/litdata/streaming/config.py +++ b/src/litdata/streaming/config.py @@ -35,12 +35,12 @@ def __init__( self, cache_dir: str, serializers: dict[str, Serializer], - remote_dir: Optional[str], - item_loader: Optional[BaseItemLoader] = None, - subsampled_files: Optional[list[str]] = None, - region_of_interest: Optional[list[tuple[int, int]]] = None, - storage_options: Optional[dict] = {}, - session_options: Optional[dict] = {}, + remote_dir: str | None, + item_loader: BaseItemLoader | None = None, + subsampled_files: list[str] | None = None, + region_of_interest: list[tuple[int, int]] | None = None, + storage_options: dict | None = {}, + session_options: dict | None = {}, ) -> None: """Reads the index files associated a chunked dataset and enables to map an index to its chunk. @@ -93,7 +93,7 @@ def __init__( ) self._compressor_name = self._config["compression"] - self._compressor: Optional[Compressor] = None + self._compressor: Compressor | None = None if self._compressor_name: if len(_COMPRESSORS) == 0: @@ -106,8 +106,8 @@ def __init__( ) self._compressor = _COMPRESSORS[self._compressor_name] - self._skip_chunk_indexes_deletion: Optional[list[int]] = None - self.zero_based_roi: Optional[list[tuple[int, int]]] = None + self._skip_chunk_indexes_deletion: list[int] | None = None + self.zero_based_roi: list[tuple[int, int]] | None = None self.filename_to_size_map: dict[str, int] = {} for cnk in _original_chunks: # since files downloaded while reading will be decompressed, we need to store the name without compression @@ -120,7 +120,7 @@ def can_delete(self, chunk_index: int) -> bool: return chunk_index not in self._skip_chunk_indexes_deletion @property - def skip_chunk_indexes_deletion(self) -> Optional[list[int]]: + def skip_chunk_indexes_deletion(self) -> list[int] | None: return self._skip_chunk_indexes_deletion @skip_chunk_indexes_deletion.setter @@ -315,12 +315,12 @@ def load( cls, cache_dir: str, serializers: dict[str, Serializer], - remote_dir: Optional[str] = None, - item_loader: Optional[BaseItemLoader] = None, - subsampled_files: Optional[list[str]] = None, - region_of_interest: Optional[list[tuple[int, int]]] = None, - storage_options: Optional[dict] = {}, - session_options: Optional[dict] = {}, + remote_dir: str | None = None, + item_loader: BaseItemLoader | None = None, + subsampled_files: list[str] | None = None, + region_of_interest: list[tuple[int, int]] | None = None, + storage_options: dict | None = {}, + session_options: dict | None = {}, ) -> Optional["ChunksConfig"]: cache_index_filepath = os.path.join(cache_dir, _INDEX_FILENAME) diff --git a/src/litdata/streaming/dataloader.py b/src/litdata/streaming/dataloader.py index fa47b3469..ba2307b0d 100644 --- a/src/litdata/streaming/dataloader.py +++ b/src/litdata/streaming/dataloader.py @@ -18,7 +18,7 @@ from copy import deepcopy from importlib import reload from itertools import cycle -from typing import Any, Callable, Optional, Union +from typing import Any, Callable import torch from torch.utils.data import Dataset, IterableDataset @@ -76,9 +76,9 @@ def __init__( self, dataset: Any, cache_dir: str, - chunk_bytes: Optional[int], - chunk_size: Optional[int], - compression: Optional[str], + chunk_bytes: int | None, + chunk_size: int | None, + compression: str | None, ): """The `CacheDataset` is a dataset wrapper to provide a beginner experience with the Cache. @@ -121,7 +121,7 @@ class CacheCollateFn: """ - def __init__(self, collate_fn: Optional[Callable] = None) -> None: + def __init__(self, collate_fn: Callable | None = None) -> None: self.collate_fn = collate_fn or default_collate def __call__(self, items: list[Any]) -> Any: @@ -263,18 +263,18 @@ def __init__( self, dataset: Any, *args: Any, - sampler: Optional[Sampler] = None, - batch_sampler: Optional[BatchSampler] = None, + sampler: Sampler | None = None, + batch_sampler: BatchSampler | None = None, num_workers: int = 0, shuffle: bool = False, - generator: Optional[torch.Generator] = None, - batch_size: Optional[int] = None, + generator: torch.Generator | None = None, + batch_size: int | None = None, drop_last: bool = False, - cache_dir: Optional[str] = None, - chunk_bytes: Optional[int] = _DEFAULT_CHUNK_BYTES, - compression: Optional[str] = None, + cache_dir: str | None = None, + chunk_bytes: int | None = _DEFAULT_CHUNK_BYTES, + compression: str | None = None, profile: bool = False, - collate_fn: Optional[Callable] = None, + collate_fn: Callable | None = None, **kwargs: Any, ) -> None: if sampler: @@ -385,7 +385,7 @@ def wrap(*args: Any, **kwargs: Any) -> Any: class _ProfileWorkerLoop: """Wrap the PyTorch DataLoader WorkerLoop to add profiling.""" - def __init__(self, profile: Union[int, bool], skip_batches: int, profile_dir: Optional[str] = None): + def __init__(self, profile: int | bool, skip_batches: int, profile_dir: str | None = None): self._profile = profile self._skip_batches = skip_batches self._profile_dir = profile_dir if profile_dir else os.getcwd() @@ -499,7 +499,7 @@ def _try_put_index(self) -> None: class StreamingDataLoaderCollateFn: - def __init__(self, collate_fn: Optional[Callable] = None) -> None: + def __init__(self, collate_fn: Callable | None = None) -> None: self.collate_fn = collate_fn or default_collate def __call__(self, items: list[Any]) -> Any: @@ -575,17 +575,17 @@ class StreamingDataLoader(DataLoader): def __init__( self, - dataset: Union[StreamingDataset, _BaseStreamingDatasetWrapper], + dataset: StreamingDataset | _BaseStreamingDatasetWrapper, *args: Any, batch_size: int = 1, num_workers: int = 0, - profile_batches: Union[bool, int] = False, + profile_batches: bool | int = False, profile_skip_batches: int = 0, - profile_dir: Optional[str] = None, - prefetch_factor: Optional[int] = None, - shuffle: Optional[bool] = None, - drop_last: Optional[bool] = None, - collate_fn: Optional[Callable] = None, + profile_dir: str | None = None, + prefetch_factor: int | None = None, + shuffle: bool | None = None, + drop_last: bool | None = None, + collate_fn: Callable | None = None, **kwargs: Any, ) -> None: # pyright: ignore if not isinstance(dataset, (StreamingDataset, _BaseStreamingDatasetWrapper)): @@ -623,9 +623,9 @@ def __init__( self._num_samples_yielded_streaming = 0 self._num_samples_yielded_wrapper: dict[int, list[int]] = {} self._num_cycles: dict[int, list[int]] = {} - self.rng_state: Optional[Any] = None - self._worker_idx: Optional[Any] = None # Lazily initialized in __iter__ - self._worker_idx_iter: Optional[Any] = None + self.rng_state: Any | None = None + self._worker_idx: Any | None = None # Lazily initialized in __iter__ + self._worker_idx_iter: Any | None = None self._latest_worker_idx = 0 self.restore = False super().__init__( diff --git a/src/litdata/streaming/dataset.py b/src/litdata/streaming/dataset.py index 1dca6c5b4..30a15081c 100644 --- a/src/litdata/streaming/dataset.py +++ b/src/litdata/streaming/dataset.py @@ -14,7 +14,7 @@ import logging import os from time import time -from typing import Any, Callable, Optional, Union +from typing import Any, Callable, Union import numpy as np from torch.utils.data import IterableDataset @@ -47,21 +47,21 @@ class StreamingDataset(IterableDataset): def __init__( self, input_dir: Union[str, "Dir"], - cache_dir: Optional[Union[str, "Dir"]] = None, - item_loader: Optional[BaseItemLoader] = None, + cache_dir: Union[str, "Dir"] | None = None, + item_loader: BaseItemLoader | None = None, shuffle: bool = False, - drop_last: Optional[bool] = None, + drop_last: bool | None = None, seed: int = 42, - serializers: Optional[dict[str, Serializer]] = None, - max_cache_size: Union[int, str] = "100GB", + serializers: dict[str, Serializer] | None = None, + max_cache_size: int | str = "100GB", subsample: float = 1.0, - encryption: Optional[Encryption] = None, - storage_options: Optional[dict] = {}, - session_options: Optional[dict] = {}, + encryption: Encryption | None = None, + storage_options: dict | None = {}, + session_options: dict | None = {}, max_pre_download: int = 2, - index_path: Optional[str] = None, + index_path: str | None = None, force_override_state_dict: bool = False, - transform: Optional[Union[Callable, list[Callable]]] = None, + transform: Callable | list[Callable] | None = None, ) -> None: """The streaming dataset can be used once your data have been optimised using the DatasetOptimiser class. @@ -166,8 +166,8 @@ def __init__( "Consider increasing the `max_cache_size` to at least 25GB to avoid potential performance degradation." ) - self.cache: Optional[Cache] = None - self.worker_env: Optional[_WorkerEnv] = None + self.cache: Cache | None = None + self.worker_env: _WorkerEnv | None = None self.worker_chunks: list[int] = [] # chunk indexes that the current worker will download, read & stream self.worker_intervals: list[list[int]] = [] # chunk index intervals for the current worker self.upcoming_indexes: list[int] = [] # contains list of upcoming indexes to be processed @@ -175,18 +175,18 @@ def __init__( # which index of the array `self.worker_chunks` will we work on after this chunk is completely consumed self.worker_next_chunk_index = 0 - self.num_chunks: Optional[int] = None # total number of chunks that the current worker will work on + self.num_chunks: int | None = None # total number of chunks that the current worker will work on self.global_index = 0 # total number of samples processed by the current worker up until now # number of samples processed by the current worker in the current chunk self.consumed_sample_count_in_curr_chunk = 0 self.has_triggered_download = False - self.min_items_per_replica: Optional[int] = None + self.min_items_per_replica: int | None = None self.current_epoch = 1 self.random_state = None - self.shuffler: Optional[Shuffle] = None + self.shuffler: Shuffle | None = None self.serializers = serializers - self._state_dict: Optional[dict[str, Any]] = None + self._state_dict: dict[str, Any] | None = None self._force_override_state_dict = force_override_state_dict # Has slightly different meaning in the context of the dataset # We consider `num_workers = 0` from `torch.utils.DataLoader` still as 1 worker (the main process) @@ -423,7 +423,7 @@ def _resume(self, workers_chunks: list[list[int]], workers_intervals: list[Any]) # bump the chunk_index self.worker_next_chunk_index += 1 - def __getitem__(self, index: Union[ChunkedIndex, int, slice]) -> Any: + def __getitem__(self, index: ChunkedIndex | int | slice) -> Any: if self.cache is None: self.worker_env = _WorkerEnv.detect() self.cache = self._create_cache(worker_env=self.worker_env) diff --git a/src/litdata/streaming/downloader.py b/src/litdata/streaming/downloader.py index c213e0971..323b7626f 100644 --- a/src/litdata/streaming/downloader.py +++ b/src/litdata/streaming/downloader.py @@ -19,7 +19,7 @@ from abc import ABC from contextlib import suppress from time import time -from typing import Any, Optional +from typing import Any from urllib import parse from filelock import FileLock, Timeout @@ -44,7 +44,7 @@ def __init__( remote_dir: str, cache_dir: str, chunks: list[dict[str, Any]], - storage_options: Optional[dict] = {}, + storage_options: dict | None = {}, **kwargs: Any, ): self._remote_dir = remote_dir @@ -114,7 +114,7 @@ def __init__( remote_dir: str, cache_dir: str, chunks: list[dict[str, Any]], - storage_options: Optional[dict] = {}, + storage_options: dict | None = {}, **kwargs: Any, ): super().__init__(remote_dir, cache_dir, chunks, storage_options) @@ -224,7 +224,7 @@ def __init__( remote_dir: str, cache_dir: str, chunks: list[dict[str, Any]], - storage_options: Optional[dict] = {}, + storage_options: dict | None = {}, **kwargs: Any, ): super().__init__(remote_dir, cache_dir, chunks, storage_options) @@ -337,7 +337,7 @@ def __init__( remote_dir: str, cache_dir: str, chunks: list[dict[str, Any]], - storage_options: Optional[dict] = {}, + storage_options: dict | None = {}, **kwargs: Any, ): if not _GOOGLE_STORAGE_AVAILABLE: @@ -450,7 +450,7 @@ def __init__( remote_dir: str, cache_dir: str, chunks: list[dict[str, Any]], - storage_options: Optional[dict] = {}, + storage_options: dict | None = {}, **kwargs: Any, ): if not _AZURE_STORAGE_AVAILABLE: @@ -559,7 +559,7 @@ def __init__( remote_dir: str, cache_dir: str, chunks: list[dict[str, Any]], - storage_options: Optional[dict] = {}, + storage_options: dict | None = {}, **kwargs: Any, ): if not _HF_HUB_AVAILABLE: @@ -652,8 +652,8 @@ def get_downloader( remote_dir: str, cache_dir: str, chunks: list[dict[str, Any]], - storage_options: Optional[dict] = {}, - session_options: Optional[dict] = {}, + storage_options: dict | None = {}, + session_options: dict | None = {}, ) -> Downloader: """Get the appropriate downloader instance based on the remote directory prefix. diff --git a/src/litdata/streaming/fs_provider.py b/src/litdata/streaming/fs_provider.py index 24e62df56..5b9007441 100644 --- a/src/litdata/streaming/fs_provider.py +++ b/src/litdata/streaming/fs_provider.py @@ -12,7 +12,7 @@ # limitations under the License. import os from abc import ABC, abstractmethod -from typing import Any, Optional +from typing import Any from urllib import parse from litdata.constants import _GOOGLE_STORAGE_AVAILABLE, _SUPPORTED_PROVIDERS @@ -20,7 +20,7 @@ class FsProvider(ABC): - def __init__(self, storage_options: Optional[dict[str, Any]] = {}): + def __init__(self, storage_options: dict[str, Any] | None = {}): self.storage_options = storage_options @abstractmethod @@ -50,7 +50,7 @@ def is_empty(self, path: str) -> bool: class GCPFsProvider(FsProvider): - def __init__(self, storage_options: Optional[dict[str, Any]] = {}): + def __init__(self, storage_options: dict[str, Any] | None = {}): if not _GOOGLE_STORAGE_AVAILABLE: raise ModuleNotFoundError(str(_GOOGLE_STORAGE_AVAILABLE)) from google.cloud import storage @@ -133,7 +133,7 @@ def is_empty(self, path: str) -> bool: class S3FsProvider(FsProvider): - def __init__(self, storage_options: Optional[dict[str, Any]] = {}): + def __init__(self, storage_options: dict[str, Any] | None = {}): super().__init__(storage_options=storage_options) self.client = S3Client(storage_options=storage_options) @@ -225,7 +225,7 @@ def is_empty(self, path: str) -> bool: class R2FsProvider(S3FsProvider): - def __init__(self, storage_options: Optional[dict[str, Any]] = {}): + def __init__(self, storage_options: dict[str, Any] | None = {}): super().__init__(storage_options=storage_options) # Create R2Client with refreshable credentials @@ -324,7 +324,7 @@ def get_bucket_and_path(remote_filepath: str, expected_scheme: str = "s3") -> tu return bucket_name, blob_path -def _get_fs_provider(remote_filepath: str, storage_options: Optional[dict[str, Any]] = {}) -> FsProvider: +def _get_fs_provider(remote_filepath: str, storage_options: dict[str, Any] | None = {}) -> FsProvider: obj = parse.urlparse(remote_filepath) if obj.scheme == "gs": return GCPFsProvider(storage_options=storage_options) diff --git a/src/litdata/streaming/item_loader.py b/src/litdata/streaming/item_loader.py index 59569a9f0..b7f886e6e 100644 --- a/src/litdata/streaming/item_loader.py +++ b/src/litdata/streaming/item_loader.py @@ -20,7 +20,7 @@ from io import BytesIO, FileIO from multiprocessing import Queue from time import sleep, time -from typing import Any, Optional, Union +from typing import Any import numpy as np import torch @@ -52,8 +52,8 @@ def setup( config: dict, chunks: list, serializers: dict[str, Serializer], - region_of_interest: Optional[list[tuple[int, int]]] = None, - force_download_queue: Optional[Queue] = None, + region_of_interest: list[tuple[int, int]] | None = None, + force_download_queue: Queue | None = None, ) -> None: self._config = config self._chunks = chunks @@ -176,7 +176,7 @@ def load_item_from_chunk( chunk_filepath: str, begin: int, filesize_bytes: int, - encryption: Optional[Encryption] = None, + encryption: Encryption | None = None, ) -> bytes: # # Let's say, a chunk contains items from [5,9] index. @@ -248,7 +248,7 @@ def load_item_from_chunk( return item_data def _load_encrypted_data( - self, chunk_filepath: str, chunk_index: int, offset: int, encryption: Optional[Encryption] + self, chunk_filepath: str, chunk_index: int, offset: int, encryption: Encryption | None ) -> bytes: """Load and decrypt data from chunk based on the encryption configuration.""" # Validate the provided encryption object against the expected configuration. @@ -277,7 +277,7 @@ def _load_encrypted_data( return data - def _load_data(self, fp: Union[FileIO, BytesIO], offset: int) -> bytes: + def _load_data(self, fp: FileIO | BytesIO, offset: int) -> bytes: """Load the data from the file pointer.""" fp.seek(offset) # move the file pointer to the offset @@ -353,7 +353,7 @@ def delete(self, chunk_index: int, chunk_filepath: str) -> None: ) ) - def _validate_encryption(self, encryption: Optional[Encryption]) -> None: + def _validate_encryption(self, encryption: Encryption | None) -> None: """Validate the encryption object.""" if not encryption: raise ValueError("Data is encrypted but no encryption object was provided.") @@ -363,7 +363,7 @@ def _validate_encryption(self, encryption: Optional[Encryption]) -> None: raise ValueError("Encryption level mismatch.") @classmethod - def encode_data(cls, data: list[bytes], sizes: list[int], flattened: list[Any]) -> tuple[bytes, Optional[int]]: + def encode_data(cls, data: list[bytes], sizes: list[int], flattened: list[Any]) -> tuple[bytes, int | None]: """Encodes multiple serialized objects into a single binary format with size metadata. This method combines multiple serialized objects into a single byte array, prefixed with their sizes. @@ -400,7 +400,7 @@ def __getstate__(self): class TokensLoader(BaseItemLoader): - def __init__(self, block_size: Optional[int] = None): + def __init__(self, block_size: int | None = None): """The Tokens Loader is an optimizer item loader for NLP. Args: @@ -413,7 +413,7 @@ def __init__(self, block_size: Optional[int] = None): self._buffers: dict[int, bytes] = {} # keeps track of number of readers for each chunk (can be more than 1 if multiple workers are reading) self._counter = defaultdict(int) - self._dtype: Optional[torch.dtype] = None + self._dtype: torch.dtype | None = None self._chunk_filepaths: dict[str, bool] = {} def state_dict(self) -> dict: @@ -427,7 +427,7 @@ def setup( config: dict, chunks: list, serializers: dict[str, Serializer], - region_of_interest: Optional[list[tuple[int, int]]] = None, + region_of_interest: list[tuple[int, int]] | None = None, ) -> None: super().setup(config, chunks, serializers, region_of_interest) @@ -580,7 +580,7 @@ def close(self, chunk_index: int) -> None: del self._mmaps[chunk_index] @classmethod - def encode_data(cls, data: list[bytes], _: list[int], flattened: list[Any]) -> tuple[bytes, Optional[int]]: + def encode_data(cls, data: list[bytes], _: list[int], flattened: list[Any]) -> tuple[bytes, int | None]: r"""Encodes tokenized data into a raw byte format while preserving dimensional information. Parameters: @@ -631,7 +631,7 @@ def setup( config: dict, chunks: list, serializers: dict[str, Serializer], - region_of_interest: Optional[list[tuple[int, int]]] = None, + region_of_interest: list[tuple[int, int]] | None = None, ) -> None: self._config = config self._chunks = chunks diff --git a/src/litdata/streaming/parallel.py b/src/litdata/streaming/parallel.py index f77981fa2..435e43bd3 100644 --- a/src/litdata/streaming/parallel.py +++ b/src/litdata/streaming/parallel.py @@ -17,7 +17,7 @@ import random from collections.abc import Iterator from copy import deepcopy -from typing import Any, Literal, Optional, Protocol, Union +from typing import Any, Literal, Protocol import numpy as np import torch @@ -33,12 +33,12 @@ logger = logging.getLogger("litdata.streaming.parallel") -RandomGenerator = Union[random.Random, np.random.Generator, torch.Generator] +RandomGenerator = random.Random | np.random.Generator | torch.Generator GeneratorName = Literal["random", "numpy", "torch"] class Transform(Protocol): - def __call__(self, samples: tuple[Any, ...], rng: Optional[dict[GeneratorName, RandomGenerator]] = None) -> Any: ... + def __call__(self, samples: tuple[Any, ...], rng: dict[GeneratorName, RandomGenerator] | None = None) -> Any: ... class ParallelStreamingDataset(_BaseStreamingDatasetWrapper): @@ -84,9 +84,9 @@ class ParallelStreamingDataset(_BaseStreamingDatasetWrapper): def __init__( self, datasets: list[StreamingDataset], - length: Optional[Union[int, float]] = None, + length: int | float | None = None, force_override_state_dict: bool = False, - transform: Optional[Transform] = None, + transform: Transform | None = None, seed: int = 42, resume: bool = True, reset_rngs: bool = False, @@ -131,10 +131,10 @@ def __init__( self._transform_nargs = transform_nargs self._seed = seed self._reset_rngs = reset_rngs - self._iterator: Optional[_ParallelDatasetIterator] = None + self._iterator: _ParallelDatasetIterator | None = None self._use_streaming_dataloader = False - self._num_samples_yielded: Optional[dict[int, list[int]]] = None - self._num_cycles: Optional[dict[int, list[int]]] = None + self._num_samples_yielded: dict[int, list[int]] | None = None + self._num_cycles: dict[int, list[int]] | None = None self._current_epoch = 0 self.num_workers = 1 self.batch_size = 1 @@ -177,7 +177,7 @@ def update_epoch_counters(self, num_cycles: list[int]) -> None: # do not call dset.set_epoch as it is ignored if the dataset has non-None _state_dict attribute dset.current_epoch = i_cycle + 1 - def get_len(self, num_workers: int, batch_size: int) -> Optional[int]: + def get_len(self, num_workers: int, batch_size: int) -> int | None: self.num_workers = num_workers self.batch_size = batch_size # initialize lengths even if self._length is not None to call self._get_len() on all the wrapped datasets and @@ -249,7 +249,7 @@ def __iter__(self) -> Iterator[Any]: ) return self._iterator - def __len__(self) -> Optional[int]: + def __len__(self) -> int | None: # ``batch_size`` may be a sequence when per-dataset values were set on # the wrapper. For length estimation we only need a scalar; we take # the first element if a sequence is provided. @@ -260,8 +260,8 @@ def __len__(self) -> Optional[int]: def get_num_samples_yielded( self, - num_samples_yielded: Optional[dict[int, list[int]]] = None, - num_cycles: Optional[dict[int, list[int]]] = None, + num_samples_yielded: dict[int, list[int]] | None = None, + num_cycles: dict[int, list[int]] | None = None, ) -> tuple[list[int], list[int]]: """Get the number of samples yielded and the number of cycles for each dataset across workers. @@ -297,7 +297,7 @@ def load_state_dict(self, state_dict: dict[str, Any]) -> None: self._num_cycles = state_dict["num_cycles"] def state_dict( - self, num_workers: int, batch_size: int, num_samples_yielded: Optional[list[int]] = None + self, num_workers: int, batch_size: int, num_samples_yielded: list[int] | None = None ) -> dict[str, Any]: if self._iterator is None and num_samples_yielded is None: return {} @@ -319,10 +319,10 @@ def __init__( use_streaming_dataloader: bool, num_samples_yielded: Any, num_cycles: Any, - length: Optional[Union[int, float]], + length: int | float | None, dset_lengths: list[int], - transform: Optional[Transform], - transform_nargs: Optional[int], + transform: Transform | None, + transform_nargs: int | None, rngs: dict[GeneratorName, RandomGenerator], ) -> None: self._datasets = datasets @@ -356,7 +356,7 @@ def transform(self, samples: tuple[Any, ...]) -> Any: return self._transform(samples, self._rngs) raise RuntimeError(f"transform function must take 1 or 2 arguments, got {self._transform_nargs} instead.") - def __next__(self) -> Union[Any, dict[str, Any]]: + def __next__(self) -> Any | dict[str, Any]: if self._length is not None and self._count >= self._length: raise StopIteration samples, _resets = zip(*[self._get_sample(i) for i in range(len(self._datasets))]) diff --git a/src/litdata/streaming/reader.py b/src/litdata/streaming/reader.py index f8d0b9da8..4077f18bb 100644 --- a/src/litdata/streaming/reader.py +++ b/src/litdata/streaming/reader.py @@ -19,7 +19,7 @@ from datetime import datetime from queue import Empty, Queue from threading import Event, Thread -from typing import Any, Optional, Union +from typing import Any import numpy as np from filelock import FileLock, Timeout @@ -55,9 +55,9 @@ def __init__( config: ChunksConfig, item_loader: BaseItemLoader, distributed_env: _DistributedEnv, - max_cache_size: Optional[int] = None, + max_cache_size: int | None = None, max_pre_download: int = 2, - rank: Optional[int] = None, + rank: int | None = None, ) -> None: super().__init__(daemon=True) self._config = config @@ -266,16 +266,16 @@ class BinaryReader: def __init__( self, cache_dir: str, - subsampled_files: Optional[list[str]] = None, - region_of_interest: Optional[list[tuple[int, int]]] = None, - max_cache_size: Optional[Union[int, str]] = None, - remote_input_dir: Optional[str] = None, - compression: Optional[str] = None, - encryption: Optional[Encryption] = None, - item_loader: Optional[BaseItemLoader] = None, - serializers: Optional[dict[str, Serializer]] = None, - storage_options: Optional[dict] = {}, - session_options: Optional[dict] = {}, + subsampled_files: list[str] | None = None, + region_of_interest: list[tuple[int, int]] | None = None, + max_cache_size: int | str | None = None, + remote_input_dir: str | None = None, + compression: str | None = None, + encryption: Encryption | None = None, + item_loader: BaseItemLoader | None = None, + serializers: dict[str, Serializer] | None = None, + storage_options: dict | None = {}, + session_options: dict | None = {}, max_pre_download: int = 2, on_demand_bytes: bool = False, ) -> None: @@ -309,17 +309,17 @@ def __init__( self._compression = compression self._encryption = encryption - self._intervals: Optional[list[str]] = None + self._intervals: list[str] | None = None self.subsampled_files = subsampled_files self.region_of_interest = region_of_interest self._serializers: dict[str, Serializer] = _get_serializers(serializers) self._distributed_env = _DistributedEnv.detect() - self._rank: Optional[int] = None - self._config: Optional[ChunksConfig] = None - self._prepare_thread: Optional[PrepareChunksThread] = None + self._rank: int | None = None + self._config: ChunksConfig | None = None + self._prepare_thread: PrepareChunksThread | None = None self._item_loader = item_loader or PyTreeLoader() - self._last_chunk_index: Optional[int] = None - self._last_chunk_size: Optional[int] = None + self._last_chunk_index: int | None = None + self._last_chunk_size: int | None = None self._chunks_queued_for_download = False self._max_cache_size = int(os.getenv("MAX_CACHE_SIZE", max_cache_size or 0)) self._storage_options = storage_options @@ -334,7 +334,7 @@ def _get_chunk_index_from_index(self, index: int) -> tuple[int, int]: return self._config._get_chunk_index_from_index(index) # type: ignore - def _try_load_config(self) -> Optional[ChunksConfig]: + def _try_load_config(self) -> ChunksConfig | None: """Try to load the chunks config if the index files are available.""" self._config = ChunksConfig.load( self._cache_dir, @@ -577,7 +577,7 @@ def _get_folder_size(path: str, config: ChunksConfig) -> int: return size -def _get_from_queue(queue: Queue, timeout: float = _DEFAULT_TIMEOUT) -> Optional[Any]: +def _get_from_queue(queue: Queue, timeout: float = _DEFAULT_TIMEOUT) -> Any | None: try: return queue.get(timeout=timeout) except Empty: diff --git a/src/litdata/streaming/resolver.py b/src/litdata/streaming/resolver.py index aaddaaa8f..f08bded3f 100644 --- a/src/litdata/streaming/resolver.py +++ b/src/litdata/streaming/resolver.py @@ -22,7 +22,7 @@ from functools import lru_cache from pathlib import Path from time import sleep -from typing import TYPE_CHECKING, Any, Literal, Optional, Union +from typing import TYPE_CHECKING, Any, Literal, Union from urllib import parse from litdata.constants import _LIGHTNING_SDK_AVAILABLE, _SUPPORTED_PROVIDERS @@ -37,9 +37,9 @@ class Dir: """Holds a directory path and possibly its associated remote URL.""" - path: Optional[str] = None - url: Optional[str] = None - data_connection_id: Optional[str] = None + path: str | None = None + url: str | None = None + data_connection_id: str | None = None class CloudProvider(str, Enum): @@ -47,7 +47,7 @@ class CloudProvider(str, Enum): GCP = "gcp" -def _resolve_dir(dir_path: Optional[Union[str, Path, Dir]]) -> Dir: +def _resolve_dir(dir_path: str | Path | Dir | None) -> Dir: if isinstance(dir_path, Dir): return Dir( path=str(dir_path.path) if dir_path.path else None, @@ -105,7 +105,7 @@ def _resolve_dir(dir_path: Optional[Union[str, Path, Dir]]) -> Dir: return Dir(path=dir_path_absolute, url=None) -def _match_studio(target_id: Optional[str], target_name: Optional[str], cloudspace: Any) -> bool: +def _match_studio(target_id: str | None, target_name: str | None, cloudspace: Any) -> bool: if cloudspace.name is not None and target_name is not None and cloudspace.name.lower() == target_name.lower(): return True @@ -119,7 +119,7 @@ def _match_studio(target_id: Optional[str], target_name: Optional[str], cloudspa ) -def _resolve_studio(dir_path: str, target_name: Optional[str], target_id: Optional[str]) -> Dir: +def _resolve_studio(dir_path: str, target_name: str | None, target_id: str | None) -> Dir: from lightning_sdk.lightning_cloud.rest_client import LightningClient client = LightningClient(max_tries=2) @@ -330,7 +330,7 @@ def _assert_dir_is_empty( def _assert_dir_has_index_file( output_dir: Dir, - mode: Optional[Literal["append", "overwrite"]] = None, + mode: Literal["append", "overwrite"] | None = None, use_checkpoint: bool = False, storage_options: dict[str, Any] = {}, ) -> None: @@ -449,8 +449,8 @@ def _resolve_time_template(path: str) -> str: def _execute( name: str, num_nodes: int, - machine: Optional[Union["Machine", str]] = None, - command: Optional[str] = None, + machine: Union["Machine", str] | None = None, + command: str | None = None, interruptible: bool = False, ) -> None: """Remotely execute the current operator.""" diff --git a/src/litdata/streaming/sampler.py b/src/litdata/streaming/sampler.py index b8df45829..af59961d8 100644 --- a/src/litdata/streaming/sampler.py +++ b/src/litdata/streaming/sampler.py @@ -14,7 +14,7 @@ import logging from collections.abc import Iterator from dataclasses import dataclass -from typing import Any, Optional, Union +from typing import Any import numpy as np @@ -45,8 +45,8 @@ class ChunkedIndex: index: int chunk_index: int - chunk_size: Optional[int] = None - chunk_indexes: Optional[list[int]] = None + chunk_size: int | None = None + chunk_indexes: list[int] | None = None is_last_index: bool = False @@ -112,7 +112,7 @@ def _validate(self) -> None: if diff.sum() != 0: raise RuntimeError("This shouldn't have happened. There is a bug in the CacheSampler.") - def __iter__(self) -> Iterator[list[Union[int, ChunkedIndex]]]: + def __iter__(self) -> Iterator[list[int | ChunkedIndex]]: # When the cache is filled, we need to iterate though the chunks if self._cache.filled: if self._num_replicas == 1: @@ -124,7 +124,7 @@ def __iter__(self) -> Iterator[list[Union[int, ChunkedIndex]]]: return self.__iter_non_distributed__() return self.__iter_distributed__() - def __iter_non_distributed__(self) -> Iterator[list[Union[int, ChunkedIndex]]]: + def __iter_non_distributed__(self) -> Iterator[list[int | ChunkedIndex]]: worker_size = self._dataset_size // self._num_workers indices = list(range(self._dataset_size)) worker_indices = [] @@ -140,7 +140,7 @@ def __iter_non_distributed__(self) -> Iterator[list[Union[int, ChunkedIndex]]]: yield from self.__iter_indices_per_workers__(worker_indices_batches) - def __iter_distributed__(self) -> Iterator[list[Union[int, ChunkedIndex]]]: + def __iter_distributed__(self) -> Iterator[list[int | ChunkedIndex]]: self.indices = list(range(self._dataset_size)) replica_size = self._dataset_size // self._num_replicas worker_size = self._dataset_size // (self._num_replicas * self._num_workers) @@ -168,13 +168,13 @@ def __iter_distributed__(self) -> Iterator[list[Union[int, ChunkedIndex]]]: yield from self.__iter_indices_per_workers__(worker_indices_batches) - def __iter_from_chunks_non_distributed__(self) -> Iterator[list[Union[int, ChunkedIndex]]]: + def __iter_from_chunks_non_distributed__(self) -> Iterator[list[int | ChunkedIndex]]: chunk_intervals = self._cache.get_chunk_intervals() shuffled_indexes = np.random.permutation(range(len(chunk_intervals))) shuffled_chunk_intervals = np.asarray(chunk_intervals)[shuffled_indexes] yield from self.__iter_from_shuffled_chunks(shuffled_indexes.tolist(), shuffled_chunk_intervals) - def __iter_from_chunks_distributed__(self) -> Iterator[list[Union[int, ChunkedIndex]]]: + def __iter_from_chunks_distributed__(self) -> Iterator[list[int | ChunkedIndex]]: chunk_intervals = self._cache.get_chunk_intervals() shuffled_indexes = np.random.permutation(range(len(chunk_intervals))) shuffled_chunk_intervals = np.asarray(chunk_intervals)[shuffled_indexes] @@ -190,7 +190,7 @@ def __iter_from_chunks_distributed__(self) -> Iterator[list[Union[int, ChunkedIn def __iter_from_shuffled_chunks( self, shuffled_indexes: list[int], shuffled_chunk_intervals: list[list[int]] - ) -> Iterator[list[Union[int, ChunkedIndex]]]: + ) -> Iterator[list[int | ChunkedIndex]]: chunks_per_workers: list[list[int]] = [[] for _ in range(self._num_workers)] for i, chunk_index in enumerate(shuffled_indexes): chunks_per_workers[i % self._num_workers].append(chunk_index) @@ -220,9 +220,9 @@ def __len__(self) -> int: return self._length def __iter_indices_per_workers__( - self, indices_per_workers: list[list[list[Union[int, ChunkedIndex]]]] - ) -> Iterator[list[Union[int, ChunkedIndex]]]: - batches: list[list[Union[int, ChunkedIndex]]] = [] + self, indices_per_workers: list[list[list[int | ChunkedIndex]]] + ) -> Iterator[list[int | ChunkedIndex]]: + batches: list[list[int | ChunkedIndex]] = [] counter = 0 while sum([len(v) for v in indices_per_workers]) != 0: worker_indices = indices_per_workers[counter % self._num_workers] diff --git a/src/litdata/streaming/serializers.py b/src/litdata/streaming/serializers.py index b7e543be0..0f0060390 100644 --- a/src/litdata/streaming/serializers.py +++ b/src/litdata/streaming/serializers.py @@ -21,7 +21,7 @@ from contextlib import suppress from copy import deepcopy from itertools import chain -from typing import Any, Optional +from typing import Any import numpy as np import tifffile @@ -43,7 +43,7 @@ class Serializer(ABC): """ @abstractmethod - def serialize(self, data: Any) -> tuple[bytes, Optional[str]]: + def serialize(self, data: Any) -> tuple[bytes, str | None]: pass @abstractmethod @@ -61,7 +61,7 @@ def setup(self, metadata: Any) -> None: class PILSerializer(Serializer): """The PILSerializer serialize and deserialize PIL Image to and from bytes.""" - def serialize(self, item: Any) -> tuple[bytes, Optional[str]]: + def serialize(self, item: Any) -> tuple[bytes, str | None]: mode = item.mode.encode("utf-8") width, height = item.size raw = item.tobytes() @@ -95,7 +95,7 @@ def can_serialize(self, item: Any) -> bool: class JPEGSerializer(Serializer): """The JPEGSerializer serialize and deserialize JPEG image to and from bytes.""" - def serialize(self, item: Any) -> tuple[bytes, Optional[str]]: + def serialize(self, item: Any) -> tuple[bytes, str | None]: if not _PIL_AVAILABLE: raise ModuleNotFoundError("PIL is required. Run `pip install pillow`") @@ -153,7 +153,7 @@ def can_serialize(self, item: Any) -> bool: class JPEGArraySerializer(Serializer): """The JPEGArraySerializer serializes and deserializes lists of JPEG images to and from bytes.""" - def serialize(self, item: Any) -> tuple[bytes, Optional[str]]: + def serialize(self, item: Any) -> tuple[bytes, str | None]: # Store number of images as first 4 bytes n_images_bytes = np.uint32(len(item)).tobytes() @@ -222,7 +222,7 @@ def can_serialize(self, item: Any) -> bool: class BytesSerializer(Serializer): """The BytesSerializer serialize and deserialize integer to and from bytes.""" - def serialize(self, item: bytes) -> tuple[bytes, Optional[str]]: + def serialize(self, item: bytes) -> tuple[bytes, str | None]: return item, None def deserialize(self, item: bytes) -> bytes: @@ -241,7 +241,7 @@ def __init__(self) -> None: self._header_struct_format = ">II" self._header_struct = struct.Struct(self._header_struct_format) - def serialize(self, item: torch.Tensor) -> tuple[bytes, Optional[str]]: + def serialize(self, item: torch.Tensor) -> tuple[bytes, str | None]: if item.device.type != "cpu": item = item.cpu() @@ -287,12 +287,12 @@ class NoHeaderTensorSerializer(Serializer): def __init__(self) -> None: super().__init__() self._dtype_to_indices = {v: k for k, v in _TORCH_DTYPES_MAPPING.items()} - self._dtype: Optional[torch.dtype] = None + self._dtype: torch.dtype | None = None def setup(self, data_format: str) -> None: self._dtype = _TORCH_DTYPES_MAPPING[int(data_format.split(":")[1])] - def serialize(self, item: torch.Tensor) -> tuple[bytes, Optional[str]]: + def serialize(self, item: torch.Tensor) -> tuple[bytes, str | None]: dtype_indice = self._dtype_to_indices[item.dtype] return item.numpy().tobytes(order="C"), f"no_header_tensor:{dtype_indice}" @@ -311,7 +311,7 @@ def __init__(self) -> None: super().__init__() self._dtype_to_indices = {v: k for k, v in _NUMPY_DTYPES_MAPPING.items()} - def serialize(self, item: np.ndarray) -> tuple[bytes, Optional[str]]: + def serialize(self, item: np.ndarray) -> tuple[bytes, str | None]: dtype_indice = self._dtype_to_indices[item.dtype] data = [np.uint32(dtype_indice).tobytes()] data.append(np.uint32(len(item.shape)).tobytes()) @@ -346,12 +346,12 @@ class NoHeaderNumpySerializer(Serializer): def __init__(self) -> None: super().__init__() self._dtype_to_indices = {v: k for k, v in _NUMPY_DTYPES_MAPPING.items()} - self._dtype: Optional[np.dtype] = None + self._dtype: np.dtype | None = None def setup(self, data_format: str) -> None: self._dtype = _NUMPY_DTYPES_MAPPING[int(data_format.split(":")[1])] - def serialize(self, item: np.ndarray) -> tuple[bytes, Optional[str]]: + def serialize(self, item: np.ndarray) -> tuple[bytes, str | None]: dtype_indice: int = self._dtype_to_indices[item.dtype] return item.tobytes(order="C"), f"no_header_numpy:{dtype_indice}" @@ -366,7 +366,7 @@ def can_serialize(self, item: np.ndarray) -> bool: class PickleSerializer(Serializer): """The PickleSerializer serialize and deserialize python objects to and from bytes.""" - def serialize(self, item: Any) -> tuple[bytes, Optional[str]]: + def serialize(self, item: Any) -> tuple[bytes, str | None]: return pickle.dumps(item), None def deserialize(self, data: bytes) -> Any: @@ -377,7 +377,7 @@ def can_serialize(self, _: Any) -> bool: class FileSerializer(Serializer): - def serialize(self, filepath: str) -> tuple[bytes, Optional[str]]: + def serialize(self, filepath: str) -> tuple[bytes, str | None]: print("FileSerializer will be removed in the future.") _, file_extension = os.path.splitext(filepath) with open(filepath, "rb") as f: @@ -396,7 +396,7 @@ def can_serialize(self, data: Any) -> bool: class VideoSerializer(Serializer): _EXTENSIONS = ("mp4", "ogv", "mjpeg", "avi", "mov", "h264", "mpg", "webm", "wmv") - def serialize(self, filepath: str) -> tuple[bytes, Optional[str]]: + def serialize(self, filepath: str) -> tuple[bytes, str | None]: _, file_extension = os.path.splitext(filepath) with open(filepath, "rb") as f: file_extension = file_extension.replace(".", "").lower() @@ -421,7 +421,7 @@ def can_serialize(self, data: Any) -> bool: class StringSerializer(Serializer): - def serialize(self, obj: str) -> tuple[bytes, Optional[str]]: + def serialize(self, obj: str) -> tuple[bytes, str | None]: return obj.encode("utf-8"), None def deserialize(self, data: bytes) -> str: @@ -438,7 +438,7 @@ def __init__(self, dtype: type) -> None: self.dtype = dtype self.size = self.dtype().nbytes - def serialize(self, obj: Any) -> tuple[bytes, Optional[str]]: + def serialize(self, obj: Any) -> tuple[bytes, str | None]: return self.dtype(obj).tobytes(), None def deserialize(self, data: bytes) -> Any: @@ -464,7 +464,7 @@ def can_serialize(self, data: float) -> bool: class BooleanSerializer(Serializer): """The BooleanSerializer serializes and deserializes boolean values to and from bytes.""" - def serialize(self, item: bool) -> tuple[bytes, Optional[str]]: + def serialize(self, item: bool) -> tuple[bytes, str | None]: """Serialize a boolean value to bytes. Args: @@ -501,7 +501,7 @@ def can_serialize(self, item: Any) -> bool: class TIFFSerializer(Serializer): """Serializer for TIFF files using tifffile.""" - def serialize(self, item: Any) -> tuple[bytes, Optional[str]]: + def serialize(self, item: Any) -> tuple[bytes, str | None]: if not isinstance(item, str) or not os.path.isfile(item): raise ValueError(f"The item to serialize must be a valid file path. Received: {item}") @@ -540,7 +540,7 @@ def can_serialize(self, item: Any) -> bool: ) -def _get_serializers(serializers: Optional[dict[str, Serializer]]) -> dict[str, Serializer]: +def _get_serializers(serializers: dict[str, Serializer] | None) -> dict[str, Serializer]: if serializers is None: serializers = {} serializers = OrderedDict(serializers) diff --git a/src/litdata/streaming/writer.py b/src/litdata/streaming/writer.py index f8f77cbd5..d17e5a184 100644 --- a/src/litdata/streaming/writer.py +++ b/src/litdata/streaming/writer.py @@ -18,7 +18,7 @@ from dataclasses import dataclass from multiprocessing import Queue from time import sleep, time -from typing import Any, Optional, Union +from typing import Any import numpy as np @@ -40,7 +40,7 @@ class Item: index: int data: bytes bytes: int - dim: Optional[int] = None + dim: int | None = None def __len__(self) -> int: return self.bytes @@ -50,15 +50,15 @@ class BinaryWriter: def __init__( self, cache_dir: str, - chunk_size: Optional[int] = None, - chunk_bytes: Optional[Union[int, str]] = None, - compression: Optional[str] = None, - encryption: Optional[Encryption] = None, + chunk_size: int | None = None, + chunk_bytes: int | str | None = None, + compression: str | None = None, + encryption: Encryption | None = None, follow_tensor_dimension: bool = True, - serializers: Optional[dict[str, Serializer]] = None, - chunk_index: Optional[int] = None, - item_loader: Optional[BaseItemLoader] = None, - msg_queue: Optional[Queue] = None, + serializers: dict[str, Serializer] | None = None, + chunk_index: int | None = None, + item_loader: BaseItemLoader | None = None, + msg_queue: Queue | None = None, ): """The BinaryWriter enables to chunk dataset into an efficient streaming format for cloud training. @@ -94,8 +94,8 @@ def __init__( self._item_loader = item_loader or PyTreeLoader() self.msg_queue = msg_queue - self._data_format: Optional[list[str]] = None - self._data_spec: Optional[PyTree] = None + self._data_format: list[str] | None = None + self._data_spec: PyTree | None = None if self._compression: if len(_COMPRESSORS) == 0: @@ -109,11 +109,11 @@ def __init__( self._serialized_items: dict[int, Item] = {} self._chunk_index = chunk_index or 0 - self._min_index: Optional[int] = None - self._max_index: Optional[int] = None + self._min_index: int | None = None + self._max_index: int | None = None self._chunks_info: list[dict[str, Any]] = [] - self._worker_env: Optional[_WorkerEnv] = None - self._rank: Optional[int] = None + self._worker_env: _WorkerEnv | None = None + self._rank: int | None = None self._is_done = False self._distributed_env = _DistributedEnv.detect() self._follow_tensor_dimension = follow_tensor_dimension @@ -161,7 +161,7 @@ def get_config(self) -> dict[str, Any]: "item_loader": self._item_loader.__class__.__name__, } - def serialize(self, items: Any) -> tuple[bytes, Optional[int]]: + def serialize(self, items: Any) -> tuple[bytes, int | None]: """Serialize a dictionary into its binary format.""" # Flatten the items provided by the users flattened, data_spec = tree_flatten(items) @@ -290,7 +290,7 @@ def _create_chunk(self, filename: str, on_done: bool = False) -> bytes: if self._chunk_size: assert num_items.item() <= self._chunk_size - dim: Optional[int] = None + dim: int | None = None if items[0].dim: dim = sum([item.dim if item.dim is not None else 0 for item in items]) @@ -327,7 +327,7 @@ def __setitem__(self, index: int, items: Any) -> None: """ self.add_item(index, items) - def add_item(self, index: int, items: Any) -> Optional[str]: + def add_item(self, index: int, items: Any) -> str | None: """Given an index and items will serialize the items and store an Item object to the growing `_serialized_items`. """ @@ -451,7 +451,7 @@ def done(self) -> list[str]: self._is_done = True return filepaths - def merge(self, num_workers: int = 1, node_rank: Optional[int] = None) -> None: + def merge(self, num_workers: int = 1, node_rank: int | None = None) -> None: """Once all the workers have written their own index, the merge function is responsible to read and merge them into a single index. """ @@ -480,7 +480,7 @@ def merge(self, num_workers: int = 1, node_rank: Optional[int] = None) -> None: self._merge_no_wait(node_rank=node_rank) - def _merge_no_wait(self, node_rank: Optional[int] = None, existing_index: Optional[dict[str, Any]] = None) -> None: + def _merge_no_wait(self, node_rank: int | None = None, existing_index: dict[str, Any] | None = None) -> None: """Once all the workers have written their own index, the merge function is responsible to read and merge them into a single index. @@ -548,7 +548,7 @@ def _pretty_serialized_items(self) -> dict[int, Item]: ) return out - def save_checkpoint(self, checkpoint_dir: str = ".checkpoints") -> Optional[str]: + def save_checkpoint(self, checkpoint_dir: str = ".checkpoints") -> str | None: """Save the current state of the writer to a checkpoint.""" checkpoint_dir = os.path.join(self._cache_dir, checkpoint_dir) if not os.path.exists(checkpoint_dir): @@ -573,8 +573,8 @@ def save_checkpoint(self, checkpoint_dir: str = ".checkpoints") -> Optional[str] def index_parquet_dataset( pq_dir_url: str, - cache_dir: Optional[str] = None, - storage_options: Optional[dict] = {}, + cache_dir: str | None = None, + storage_options: dict | None = {}, num_workers: int = 4, ) -> None: """Index a Parquet dataset from a specified URL. diff --git a/src/litdata/utilities/base.py b/src/litdata/utilities/base.py index f395a384c..7e24ece2d 100644 --- a/src/litdata/utilities/base.py +++ b/src/litdata/utilities/base.py @@ -13,7 +13,7 @@ from abc import ABC, abstractmethod from collections.abc import Iterator, Sequence -from typing import Any, Optional, Union +from typing import Any from torch.utils.data import IterableDataset @@ -30,18 +30,18 @@ class _BaseStreamingDatasetWrapper(IterableDataset, ABC): _datasets: list[StreamingDataset] _current_epoch: int - batch_size: Union[int, Sequence[int]] + batch_size: int | Sequence[int] num_workers: int _force_override_state_dict: bool _use_streaming_dataloader: bool - _num_samples_yielded: Optional[dict[int, list[int]]] = None + _num_samples_yielded: dict[int, list[int]] | None = None def set_shuffle(self, shuffle: bool) -> None: """Set the current shuffle to the datasets.""" for dataset in self._datasets: dataset.set_shuffle(shuffle) - def set_batch_size(self, batch_size: Union[int, Sequence[int]]) -> None: + def set_batch_size(self, batch_size: int | Sequence[int]) -> None: """Set the current batch size. This method now supports either: @@ -137,14 +137,14 @@ def _get_len(self, d: Any) -> int: def set_epoch(self, current_epoch: int) -> None: ... @abstractmethod - def get_len(self, num_workers: int, batch_size: int) -> Optional[int]: ... + def get_len(self, num_workers: int, batch_size: int) -> int | None: ... @abstractmethod - def __len__(self) -> Optional[int]: ... + def __len__(self) -> int | None: ... @abstractmethod def state_dict( - self, num_workers: int, batch_size: int, num_samples_yielded: Optional[list[int]] = None + self, num_workers: int, batch_size: int, num_samples_yielded: list[int] | None = None ) -> dict[str, Any]: ... @abstractmethod diff --git a/src/litdata/utilities/broadcast.py b/src/litdata/utilities/broadcast.py index 8dee7ba26..4d98e5f45 100644 --- a/src/litdata/utilities/broadcast.py +++ b/src/litdata/utilities/broadcast.py @@ -14,9 +14,10 @@ import json import os import pickle +from collections.abc import Callable from logging import Logger from time import sleep -from typing import Any, Callable, Optional +from typing import Any from urllib.parse import urljoin import requests @@ -53,8 +54,8 @@ class _HTTPClient: def __init__( self, base_url: str, - auth_token: Optional[str] = None, - log_callback: Optional[Callable] = None, + auth_token: str | None = None, + log_callback: Callable | None = None, use_retry: bool = True, ) -> None: self.base_url = base_url @@ -89,9 +90,7 @@ def get(self, path: str) -> Any: url = urljoin(self.base_url, path) return self.session.get(url) - def post( - self, path: str, *, query_params: Optional[dict] = None, data: Optional[bytes] = None, json: Any = None - ) -> Any: + def post(self, path: str, *, query_params: dict | None = None, data: bytes | None = None, json: Any = None) -> Any: url = urljoin(self.base_url, path) return self.session.post(url, data=data, params=query_params, json=json) @@ -153,7 +152,7 @@ def broadcast_object(key: str, obj: Any, rank: int) -> Any: return obj -def _get_token() -> Optional[str]: +def _get_token() -> str | None: """This function tries to retrieve a temporary token.""" if os.getenv("LIGHTNING_CLOUD_URL") is None: return None diff --git a/src/litdata/utilities/dataset_utilities.py b/src/litdata/utilities/dataset_utilities.py index df9519aa5..baa199d0a 100644 --- a/src/litdata/utilities/dataset_utilities.py +++ b/src/litdata/utilities/dataset_utilities.py @@ -5,7 +5,8 @@ import shutil import tempfile import time -from typing import Any, Callable, Optional +from collections.abc import Callable +from typing import Any import numpy as np @@ -39,15 +40,15 @@ def wait_for_predicate( def subsample_streaming_dataset( input_dir: Dir, - cache_dir: Optional[Dir] = None, - item_loader: Optional[BaseItemLoader] = None, + cache_dir: Dir | None = None, + item_loader: BaseItemLoader | None = None, subsample: float = 1.0, shuffle: bool = False, seed: int = 42, - storage_options: Optional[dict] = {}, - session_options: Optional[dict] = {}, - index_path: Optional[str] = None, - fnmatch_pattern: Optional[str] = None, + storage_options: dict | None = {}, + session_options: dict | None = {}, + index_path: str | None = None, + fnmatch_pattern: str | None = None, ) -> tuple[list[str], list[tuple[int, int]]]: """Subsample streaming dataset. @@ -153,7 +154,7 @@ def path_exists(p: str) -> bool: return final_files, final_roi -def _should_replace_path(path: Optional[str]) -> bool: +def _should_replace_path(path: str | None) -> bool: """Whether the input path is a special path to be replaced.""" if path is None or path == "": return True @@ -169,10 +170,10 @@ def _should_replace_path(path: Optional[str]) -> bool: def _read_updated_at( - input_dir: Optional[Dir], - storage_options: Optional[dict] = {}, - session_options: Optional[dict] = {}, - index_path: Optional[str] = None, + input_dir: Dir | None, + storage_options: dict | None = {}, + session_options: dict | None = {}, + index_path: str | None = None, ) -> str: """Read last updated timestamp from index.json file.""" last_updation_timestamp = "0" @@ -243,12 +244,12 @@ def get_default_cache_dir() -> str: def _try_create_cache_dir( - input_dir: Optional[str], - cache_dir: Optional[str] = None, - storage_options: Optional[dict] = {}, - session_options: Optional[dict] = {}, - index_path: Optional[str] = None, -) -> Optional[str]: + input_dir: str | None, + cache_dir: str | None = None, + storage_options: dict | None = {}, + session_options: dict | None = {}, + index_path: str | None = None, +) -> str | None: """Prepare and return the cache directory for a dataset.""" resolved_input_dir = _resolve_dir(input_dir) updated_at = _read_updated_at(resolved_input_dir, storage_options, session_options, index_path) @@ -269,7 +270,7 @@ def _try_create_cache_dir( return cache_dir -def generate_roi(chunks: list[dict[str, Any]], item_loader: Optional[BaseItemLoader] = None) -> list[tuple[int, int]]: +def generate_roi(chunks: list[dict[str, Any]], item_loader: BaseItemLoader | None = None) -> list[tuple[int, int]]: """Generates default region_of_interest for chunks.""" roi = [] diff --git a/src/litdata/utilities/encryption.py b/src/litdata/utilities/encryption.py index d07efae9b..2c01565d8 100644 --- a/src/litdata/utilities/encryption.py +++ b/src/litdata/utilities/encryption.py @@ -3,7 +3,7 @@ import os from abc import ABC, abstractmethod from dataclasses import dataclass -from typing import Any, Literal, Union, get_args +from typing import Any, Literal, get_args from litdata.constants import _CRYPTOGRAPHY_AVAILABLE @@ -179,13 +179,13 @@ def decrypt(self, data: bytes) -> bytes: ), ) - def state_dict(self) -> dict[str, Union[str, None]]: + def state_dict(self) -> dict[str, str | None]: return { "algorithm": self.algorithm, "level": self.level, } - def __getstate__(self) -> dict[str, Union[str, None]]: + def __getstate__(self) -> dict[str, str | None]: encryption_algorithm = ( serialization.BestAvailableEncryption(self.password.encode()) if self.password @@ -209,7 +209,7 @@ def __getstate__(self) -> dict[str, Union[str, None]]: "level": self.level, } - def __setstate__(self, state: dict[str, Union[str, None]]) -> None: + def __setstate__(self, state: dict[str, str | None]) -> None: # Restore the state from the serialized data self.password = state["password"] if state["password"] else "" self.level = state["level"] # type: ignore diff --git a/src/litdata/utilities/env.py b/src/litdata/utilities/env.py index f80114b80..331162332 100644 --- a/src/litdata/utilities/env.py +++ b/src/litdata/utilities/env.py @@ -12,7 +12,7 @@ # limitations under the License. import os -from typing import Callable, Optional +from collections.abc import Callable import torch from torch.utils.data import get_worker_info as torch_get_worker_info @@ -107,7 +107,7 @@ def __init__(self, world_size: int, rank: int): self.rank = rank @classmethod - def detect(cls, get_worker_info_fn: Optional[Callable] = None) -> "_WorkerEnv": + def detect(cls, get_worker_info_fn: Callable | None = None) -> "_WorkerEnv": """Automatically detects the number of workers and the current rank. .. note:: @@ -138,7 +138,7 @@ class Environment: """ - def __init__(self, dist_env: Optional[_DistributedEnv], worker_env: Optional[_WorkerEnv]): + def __init__(self, dist_env: _DistributedEnv | None, worker_env: _WorkerEnv | None): self.worker_env = worker_env self.dist_env = dist_env diff --git a/src/litdata/utilities/hf_dataset.py b/src/litdata/utilities/hf_dataset.py index c87862500..d519437b1 100644 --- a/src/litdata/utilities/hf_dataset.py +++ b/src/litdata/utilities/hf_dataset.py @@ -3,7 +3,6 @@ import os import shutil import tempfile -from typing import Optional from litdata.constants import _INDEX_FILENAME from litdata.streaming.writer import index_parquet_dataset @@ -11,7 +10,7 @@ from litdata.utilities.torch_utils import is_local_rank_0, maybe_barrier -def index_hf_dataset(dataset_url: str, cache_dir: Optional[str] = None) -> str: +def index_hf_dataset(dataset_url: str, cache_dir: str | None = None) -> str: """Indexes a Hugging Face dataset and returns the path to the cache directory. Args: @@ -65,7 +64,7 @@ def index_hf_dataset(dataset_url: str, cache_dir: Optional[str] = None) -> str: return cache_dir -def _get_existing_cache(dataset_url: str, cache_dir: Optional[str]) -> Optional[str]: +def _get_existing_cache(dataset_url: str, cache_dir: str | None) -> str | None: """Checks if a cache directory with an index file exists for the given dataset URL. Args: diff --git a/src/litdata/utilities/parquet.py b/src/litdata/utilities/parquet.py index fdb5d5258..45a6ca2ec 100644 --- a/src/litdata/utilities/parquet.py +++ b/src/litdata/utilities/parquet.py @@ -8,7 +8,7 @@ from collections.abc import Generator from concurrent.futures import ThreadPoolExecutor from time import time -from typing import Any, Optional, Union +from typing import Any from urllib import parse from litdata.constants import _FSSPEC_AVAILABLE, _HF_HUB_AVAILABLE, _INDEX_FILENAME, _PYARROW_AVAILABLE @@ -19,9 +19,9 @@ class ParquetDir(ABC): def __init__( self, - dir_path: Optional[Union[str, Dir]], - cache_path: Optional[str] = None, - storage_options: Optional[dict] = {}, + dir_path: str | Dir | None, + cache_path: str | None = None, + storage_options: dict | None = {}, num_workers: int = 4, ): self.dir = _resolve_dir(dir_path) @@ -73,9 +73,9 @@ def write_index(self, chunks_info: list[dict[str, Any]], config: dict[str, Any]) class LocalParquetDir(ParquetDir): def __init__( self, - dir_path: Optional[Union[str, Dir]], - cache_path: Optional[str] = None, - storage_options: Optional[dict] = {}, + dir_path: str | Dir | None, + cache_path: str | None = None, + storage_options: dict | None = {}, num_workers: int = 4, ): if not _PYARROW_AVAILABLE: @@ -122,9 +122,9 @@ def write_index(self, chunks_info: list[dict[str, Any]], config: dict[str, Any]) class CloudParquetDir(ParquetDir): def __init__( self, - dir_path: Optional[Union[str, Dir]], - cache_path: Optional[str] = None, - storage_options: Optional[dict] = None, + dir_path: str | Dir | None, + cache_path: str | None = None, + storage_options: dict | None = None, num_workers: int = 4, ): if not _FSSPEC_AVAILABLE: @@ -222,9 +222,9 @@ def write_index(self, chunks_info: list[dict[str, Any]], config: dict[str, Any]) class HFParquetDir(ParquetDir): def __init__( self, - dir_path: Optional[Union[str, Dir]], - cache_path: Optional[str] = None, - storage_options: Optional[dict] = None, + dir_path: str | Dir | None, + cache_path: str | None = None, + storage_options: dict | None = None, num_workers: int = 4, ): if not _HF_HUB_AVAILABLE: @@ -300,8 +300,8 @@ def write_index(self, chunks_info: list[dict[str, Any]], config: dict[str, Any]) def get_parquet_indexer_cls( dir_path: str, - cache_path: Optional[str] = None, - storage_options: Optional[dict] = {}, + cache_path: str | None = None, + storage_options: dict | None = {}, num_workers: int = 4, ) -> ParquetDir: """Get the appropriate ParquetDir class based on the directory path scheme. diff --git a/src/litdata/utilities/subsample.py b/src/litdata/utilities/subsample.py index 6b494451e..bac0b236b 100644 --- a/src/litdata/utilities/subsample.py +++ b/src/litdata/utilities/subsample.py @@ -1,10 +1,10 @@ -from typing import Any, Optional +from typing import Any import numpy as np def shuffle_lists_together( - list1: list[Any], list2: list[Any], random_seed_sampler: Optional[np.random.RandomState] = None, seed: int = 42 + list1: list[Any], list2: list[Any], random_seed_sampler: np.random.RandomState | None = None, seed: int = 42 ) -> tuple[list[Any], list[Any]]: """Shuffles list1 and applies the same shuffle order to list2. diff --git a/tests/streaming/test_dataset.py b/tests/streaming/test_dataset.py index c3cd46377..0f951c2de 100644 --- a/tests/streaming/test_dataset.py +++ b/tests/streaming/test_dataset.py @@ -19,7 +19,7 @@ import sys from functools import partial from time import sleep -from typing import Any, Optional +from typing import Any from unittest import mock from unittest.mock import patch @@ -1102,7 +1102,7 @@ def test_dataset_resume_on_future_chunks(shuffle, tmpdir, monkeypatch): def test_dataset_valid_state(tmpdir, monkeypatch): seed_everything(42) - index_json_content: Optional[dict[str, Any]] = None + index_json_content: dict[str, Any] | None = None def mock_resolve_dataset(dir_path: str) -> Dir: return Dir( @@ -1238,7 +1238,7 @@ def fn(remote_chunkpath: str, local_chunkpath: str): def test_dataset_valid_state_override(tmpdir, monkeypatch): seed_everything(42) - index_json_content: Optional[dict[str, Any]] = None + index_json_content: dict[str, Any] | None = None def mock_resolve_dataset(dir_path: str) -> Dir: return Dir( From faa4d6439ddc87e12e3026d8a6e6e8277e15c511 Mon Sep 17 00:00:00 2001 From: bhimrazy Date: Sun, 16 Nov 2025 19:16:55 +0545 Subject: [PATCH 2/5] fix: remove unused lint ignores and clean up type annotations --- pyproject.toml | 34 +++++++++++++---------------- src/litdata/processing/functions.py | 4 ++-- src/litdata/processing/utilities.py | 3 ++- src/litdata/raw/dataset.py | 3 ++- src/litdata/streaming/dataloader.py | 3 ++- src/litdata/streaming/dataset.py | 3 ++- 6 files changed, 25 insertions(+), 25 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 4a606d1ef..add0a0da2 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -45,10 +45,8 @@ lint.extend-select = [ "SIM", # see: https://pypi.org/project/flake8-simplify ] lint.ignore = [ - "E731", # Do not assign a lambda expression, use a def - "S101", # todo: Use of `assert` detected - "UP007", # todo: non-pep604-annotation-union - "UP045", # todo: non-pep604-annotation-optional + "E731", # Do not assign a lambda expression, use a def + "S101", # todo: Use of `assert` detected ] lint.per-file-ignores."examples/**" = [ "D100", @@ -65,22 +63,20 @@ lint.per-file-ignores."examples/**" = [ ] lint.per-file-ignores."setup.py" = [ "D100", "SIM115" ] lint.per-file-ignores."src/**" = [ - "D100", # Missing docstring in public module - "D101", # todo: Missing docstring in public class - "D102", # todo: Missing docstring in public method - "D103", # todo: Missing docstring in public function - "D104", # Missing docstring in public package - "D105", # todo: Missing docstring in magic method - "D107", # todo: Missing docstring in __init__ - "D205", # todo: 1 blank line required between summary line and description + "D100", # Missing docstring in public module + "D101", # todo: Missing docstring in public class + "D102", # todo: Missing docstring in public method + "D103", # todo: Missing docstring in public function + "D104", # Missing docstring in public package + "D105", # todo: Missing docstring in magic method + "D107", # todo: Missing docstring in __init__ + "D205", # todo: 1 blank line required between summary line and description "D401", - "D404", # todo: First line should be in imperative mood; try rephrasing - "S310", # todo: Audit URL open for permitted schemes. Allowing use of `file:` or custom schemes is often unexpected. - "S602", # todo: `subprocess` call with `shell=True` identified, security issue - "S605", # todo: Starting a process with a shell: seems safe, but may be changed in the future; consider rewriting without `shell` - "S607", # todo: Starting a process with a partial executable path - "UP006", # UP006 Use `list` instead of `List` for type annotation - "UP035", # UP035 `typing.Tuple` is deprecated, use `tuple` instead + "D404", # todo: First line should be in imperative mood; try rephrasing + "S310", # todo: Audit URL open for permitted schemes. Allowing use of `file:` or custom schemes is often unexpected. + "S602", # todo: `subprocess` call with `shell=True` identified, security issue + "S605", # todo: Starting a process with a shell: seems safe, but may be changed in the future; consider rewriting without `shell` + "S607", # todo: Starting a process with a partial executable path ] lint.per-file-ignores."tests/**" = [ "D100", diff --git a/src/litdata/processing/functions.py b/src/litdata/processing/functions.py index bf3ddbe1e..f25932ba5 100644 --- a/src/litdata/processing/functions.py +++ b/src/litdata/processing/functions.py @@ -18,13 +18,13 @@ import os import shutil import tempfile -from collections.abc import Sequence +from collections.abc import Callable, Sequence from dataclasses import dataclass from datetime import datetime from functools import partial from pathlib import Path from types import FunctionType -from typing import TYPE_CHECKING, Any, Callable, Literal, Union +from typing import TYPE_CHECKING, Any, Literal, Union from urllib import parse import torch diff --git a/src/litdata/processing/utilities.py b/src/litdata/processing/utilities.py index 0a6e18e19..41f1bba73 100644 --- a/src/litdata/processing/utilities.py +++ b/src/litdata/processing/utilities.py @@ -16,9 +16,10 @@ import os import tempfile import urllib +from collections.abc import Callable from contextlib import contextmanager from subprocess import DEVNULL, Popen -from typing import Any, Callable +from typing import Any from urllib import parse from litdata.constants import _INDEX_FILENAME, _IS_IN_STUDIO, _SUPPORTED_PROVIDERS diff --git a/src/litdata/raw/dataset.py b/src/litdata/raw/dataset.py index 9eebe9141..e9e74467b 100644 --- a/src/litdata/raw/dataset.py +++ b/src/litdata/raw/dataset.py @@ -14,9 +14,10 @@ import asyncio import logging import os +from collections.abc import Callable from functools import lru_cache from pathlib import Path -from typing import Any, Callable +from typing import Any from torch.utils.data import Dataset diff --git a/src/litdata/streaming/dataloader.py b/src/litdata/streaming/dataloader.py index ba2307b0d..34edbda80 100644 --- a/src/litdata/streaming/dataloader.py +++ b/src/litdata/streaming/dataloader.py @@ -15,10 +15,11 @@ import inspect import logging import os +from collections.abc import Callable from copy import deepcopy from importlib import reload from itertools import cycle -from typing import Any, Callable +from typing import Any import torch from torch.utils.data import Dataset, IterableDataset diff --git a/src/litdata/streaming/dataset.py b/src/litdata/streaming/dataset.py index 30a15081c..5a1bd4aaf 100644 --- a/src/litdata/streaming/dataset.py +++ b/src/litdata/streaming/dataset.py @@ -13,8 +13,9 @@ import logging import os +from collections.abc import Callable from time import time -from typing import Any, Callable, Union +from typing import Any, Union import numpy as np from torch.utils.data import IterableDataset From 2e141c9c8b9e621c4205ada037e440c7c717a447 Mon Sep 17 00:00:00 2001 From: bhimrazy Date: Sun, 16 Nov 2025 19:56:09 +0545 Subject: [PATCH 3/5] fix types issue caused due to mp Queue --- src/litdata/processing/data_processor.py | 3 ++- src/litdata/processing/functions.py | 3 ++- src/litdata/streaming/cache.py | 2 +- src/litdata/streaming/item_loader.py | 2 +- src/litdata/streaming/writer.py | 2 +- 5 files changed, 7 insertions(+), 5 deletions(-) diff --git a/src/litdata/processing/data_processor.py b/src/litdata/processing/data_processor.py index 129a029b1..50d33addb 100644 --- a/src/litdata/processing/data_processor.py +++ b/src/litdata/processing/data_processor.py @@ -26,7 +26,8 @@ from abc import abstractmethod from contextlib import suppress from dataclasses import dataclass -from multiprocessing import Process, Queue +from multiprocessing import Process +from multiprocessing.queues import Queue from pathlib import Path from queue import Empty from time import sleep, time diff --git a/src/litdata/processing/functions.py b/src/litdata/processing/functions.py index f25932ba5..ab49dd1ac 100644 --- a/src/litdata/processing/functions.py +++ b/src/litdata/processing/functions.py @@ -22,6 +22,7 @@ from dataclasses import dataclass from datetime import datetime from functools import partial +from multiprocessing.queues import Queue from pathlib import Path from types import FunctionType from typing import TYPE_CHECKING, Any, Literal, Union @@ -388,7 +389,7 @@ def optimize( fn: Callable[[Any], Any], inputs: Sequence[Any] | StreamingDataLoader | None = None, output_dir: str = "optimized_data", - queue: mp.Queue | None = None, + queue: Queue | None = None, input_dir: str | None = None, weights: list[int] | None = None, chunk_size: int | None = None, diff --git a/src/litdata/streaming/cache.py b/src/litdata/streaming/cache.py index 982a430f2..e3eb48074 100644 --- a/src/litdata/streaming/cache.py +++ b/src/litdata/streaming/cache.py @@ -13,7 +13,7 @@ import logging import os -from multiprocessing import Queue +from multiprocessing.queues import Queue from typing import Any from litdata.constants import ( diff --git a/src/litdata/streaming/item_loader.py b/src/litdata/streaming/item_loader.py index b7f886e6e..eefa69134 100644 --- a/src/litdata/streaming/item_loader.py +++ b/src/litdata/streaming/item_loader.py @@ -18,7 +18,7 @@ from copy import deepcopy from datetime import datetime from io import BytesIO, FileIO -from multiprocessing import Queue +from multiprocessing.queues import Queue from time import sleep, time from typing import Any diff --git a/src/litdata/streaming/writer.py b/src/litdata/streaming/writer.py index d17e5a184..31088aabf 100644 --- a/src/litdata/streaming/writer.py +++ b/src/litdata/streaming/writer.py @@ -16,7 +16,7 @@ import uuid import warnings from dataclasses import dataclass -from multiprocessing import Queue +from multiprocessing.queues import Queue from time import sleep, time from typing import Any From 00b21c913a4b103aa9ea4fb3e4927fae1946ac52 Mon Sep 17 00:00:00 2001 From: bhimrazy Date: Sun, 16 Nov 2025 20:06:28 +0545 Subject: [PATCH 4/5] Revert "fix types issue caused due to mp Queue" This reverts commit 2e141c9c8b9e621c4205ada037e440c7c717a447. --- src/litdata/processing/data_processor.py | 3 +-- src/litdata/processing/functions.py | 3 +-- src/litdata/streaming/cache.py | 2 +- src/litdata/streaming/item_loader.py | 2 +- src/litdata/streaming/writer.py | 2 +- 5 files changed, 5 insertions(+), 7 deletions(-) diff --git a/src/litdata/processing/data_processor.py b/src/litdata/processing/data_processor.py index 50d33addb..129a029b1 100644 --- a/src/litdata/processing/data_processor.py +++ b/src/litdata/processing/data_processor.py @@ -26,8 +26,7 @@ from abc import abstractmethod from contextlib import suppress from dataclasses import dataclass -from multiprocessing import Process -from multiprocessing.queues import Queue +from multiprocessing import Process, Queue from pathlib import Path from queue import Empty from time import sleep, time diff --git a/src/litdata/processing/functions.py b/src/litdata/processing/functions.py index ab49dd1ac..f25932ba5 100644 --- a/src/litdata/processing/functions.py +++ b/src/litdata/processing/functions.py @@ -22,7 +22,6 @@ from dataclasses import dataclass from datetime import datetime from functools import partial -from multiprocessing.queues import Queue from pathlib import Path from types import FunctionType from typing import TYPE_CHECKING, Any, Literal, Union @@ -389,7 +388,7 @@ def optimize( fn: Callable[[Any], Any], inputs: Sequence[Any] | StreamingDataLoader | None = None, output_dir: str = "optimized_data", - queue: Queue | None = None, + queue: mp.Queue | None = None, input_dir: str | None = None, weights: list[int] | None = None, chunk_size: int | None = None, diff --git a/src/litdata/streaming/cache.py b/src/litdata/streaming/cache.py index e3eb48074..982a430f2 100644 --- a/src/litdata/streaming/cache.py +++ b/src/litdata/streaming/cache.py @@ -13,7 +13,7 @@ import logging import os -from multiprocessing.queues import Queue +from multiprocessing import Queue from typing import Any from litdata.constants import ( diff --git a/src/litdata/streaming/item_loader.py b/src/litdata/streaming/item_loader.py index eefa69134..b7f886e6e 100644 --- a/src/litdata/streaming/item_loader.py +++ b/src/litdata/streaming/item_loader.py @@ -18,7 +18,7 @@ from copy import deepcopy from datetime import datetime from io import BytesIO, FileIO -from multiprocessing.queues import Queue +from multiprocessing import Queue from time import sleep, time from typing import Any diff --git a/src/litdata/streaming/writer.py b/src/litdata/streaming/writer.py index 31088aabf..d17e5a184 100644 --- a/src/litdata/streaming/writer.py +++ b/src/litdata/streaming/writer.py @@ -16,7 +16,7 @@ import uuid import warnings from dataclasses import dataclass -from multiprocessing.queues import Queue +from multiprocessing import Queue from time import sleep, time from typing import Any From a7befd02ffd2333ced407636ac09f6883674dd0b Mon Sep 17 00:00:00 2001 From: bhimrazy Date: Sun, 16 Nov 2025 20:11:16 +0545 Subject: [PATCH 5/5] fix: update type hints for queue parameters in multiple files --- src/litdata/processing/data_processor.py | 2 +- src/litdata/processing/functions.py | 2 +- src/litdata/streaming/cache.py | 2 +- src/litdata/streaming/item_loader.py | 2 +- src/litdata/streaming/writer.py | 2 +- 5 files changed, 5 insertions(+), 5 deletions(-) diff --git a/src/litdata/processing/data_processor.py b/src/litdata/processing/data_processor.py index 129a029b1..0b6ff6776 100644 --- a/src/litdata/processing/data_processor.py +++ b/src/litdata/processing/data_processor.py @@ -482,7 +482,7 @@ def __init__( item_loader: BaseItemLoader | None = None, storage_options: dict[str, Any] = {}, keep_data_ordered: bool = True, - shared_queue: Queue | FakeQueue | None = None, + shared_queue: "Queue | FakeQueue | None" = None, using_queue_optimize: bool = False, # using queues as inputs for optimize fn ) -> None: """The BaseWorker is responsible to process the user data.""" diff --git a/src/litdata/processing/functions.py b/src/litdata/processing/functions.py index f25932ba5..9eb198a8e 100644 --- a/src/litdata/processing/functions.py +++ b/src/litdata/processing/functions.py @@ -388,7 +388,7 @@ def optimize( fn: Callable[[Any], Any], inputs: Sequence[Any] | StreamingDataLoader | None = None, output_dir: str = "optimized_data", - queue: mp.Queue | None = None, + queue: "mp.Queue | None" = None, input_dir: str | None = None, weights: list[int] | None = None, chunk_size: int | None = None, diff --git a/src/litdata/streaming/cache.py b/src/litdata/streaming/cache.py index 982a430f2..123522c0c 100644 --- a/src/litdata/streaming/cache.py +++ b/src/litdata/streaming/cache.py @@ -49,7 +49,7 @@ def __init__( storage_options: dict | None = {}, session_options: dict | None = {}, max_pre_download: int = 2, - msg_queue: Queue | None = None, + msg_queue: "Queue | None" = None, on_demand_bytes: bool = False, ): """The Cache enables to optimise dataset format for cloud training. This is done by grouping several elements diff --git a/src/litdata/streaming/item_loader.py b/src/litdata/streaming/item_loader.py index b7f886e6e..30d0bc710 100644 --- a/src/litdata/streaming/item_loader.py +++ b/src/litdata/streaming/item_loader.py @@ -53,7 +53,7 @@ def setup( chunks: list, serializers: dict[str, Serializer], region_of_interest: list[tuple[int, int]] | None = None, - force_download_queue: Queue | None = None, + force_download_queue: "Queue | None" = None, ) -> None: self._config = config self._chunks = chunks diff --git a/src/litdata/streaming/writer.py b/src/litdata/streaming/writer.py index d17e5a184..5191bd595 100644 --- a/src/litdata/streaming/writer.py +++ b/src/litdata/streaming/writer.py @@ -58,7 +58,7 @@ def __init__( serializers: dict[str, Serializer] | None = None, chunk_index: int | None = None, item_loader: BaseItemLoader | None = None, - msg_queue: Queue | None = None, + msg_queue: "Queue | None" = None, ): """The BinaryWriter enables to chunk dataset into an efficient streaming format for cloud training.