Skip to content

Commit 151e012

Browse files
authored
Fix errors when using compression and r2 in optimize() (#715)
1 parent c760eda commit 151e012

File tree

6 files changed

+34
-2
lines changed

6 files changed

+34
-2
lines changed

src/litdata/processing/data_processor.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -217,6 +217,7 @@ def keep_path(path: str) -> bool:
217217
"s3_connections",
218218
"s3_folders",
219219
"snowflake_connections",
220+
"lightning_storage",
220221
]
221222
return all(p not in path for p in paths)
222223

src/litdata/streaming/dataset.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -194,6 +194,12 @@ def __init__(
194194
self.num_workers: int = 1
195195
self.batch_size: int = 1
196196
self._encryption = encryption
197+
# Ensure data_connection_id is included in storage_options if available from input_dir
198+
if input_dir.data_connection_id and storage_options is not None:
199+
storage_options = storage_options.copy()
200+
storage_options["data_connection_id"] = input_dir.data_connection_id
201+
elif input_dir.data_connection_id and storage_options is None:
202+
storage_options = {"data_connection_id": input_dir.data_connection_id}
197203
self.storage_options = storage_options
198204
self.session_options = session_options
199205
self.max_pre_download = max_pre_download

src/litdata/streaming/resolver.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@ def _resolve_dir(dir_path: Optional[Union[str, Path, Dir]]) -> Dir:
6262
raise ValueError(f"`dir_path` must be either a string, Path, or Dir, got: {type(dir_path)}")
6363

6464
if isinstance(dir_path, str):
65-
cloud_prefixes = ("s3://", "gs://", "azure://", "hf://")
65+
cloud_prefixes = ("s3://", "gs://", "r2://", "azure://", "hf://")
6666
if dir_path.startswith(cloud_prefixes):
6767
return Dir(path=None, url=dir_path)
6868

src/litdata/utilities/dataset_utilities.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,12 @@ def subsample_streaming_dataset(
8484
if index_path is not None:
8585
copy_index_to_cache_index_filepath(index_path, cache_index_filepath)
8686
else:
87-
downloader = get_downloader(input_dir.url, input_dir.path, [], storage_options, session_options)
87+
# Merge data_connection_id from resolved directory into storage_options for R2 connections
88+
merged_storage_options = storage_options.copy() if storage_options is not None else {}
89+
if hasattr(input_dir, "data_connection_id") and input_dir.data_connection_id:
90+
merged_storage_options["data_connection_id"] = input_dir.data_connection_id
91+
92+
downloader = get_downloader(input_dir.url, input_dir.path, [], merged_storage_options, session_options)
8893
downloader.download_file(os.path.join(input_dir.url, _INDEX_FILENAME), cache_index_filepath)
8994

9095
def path_exists(p: str) -> bool:
@@ -159,6 +164,7 @@ def _should_replace_path(path: Optional[str]) -> bool:
159164
or path.startswith("/teamspace/s3_folders/")
160165
or path.startswith("/teamspace/gcs_folders/")
161166
or path.startswith("/teamspace/gcs_connections/")
167+
or path.startswith("/teamspace/lightning_storage/")
162168
)
163169

164170

tests/streaming/test_dataset.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -867,6 +867,24 @@ def fn(remote_chunkpath: str, local_chunkpath: str):
867867
) # it won't be None, and a cache dir will be created
868868

869869

870+
@pytest.mark.skipif(sys.platform == "win32", reason="Not tested on windows and MacOs")
871+
def test_r2_streaming_dataset(monkeypatch, tmpdir):
872+
"""Test that data_connection_id is properly merged into storage_options."""
873+
downloader = mock.MagicMock()
874+
875+
def fn(remote_chunkpath: str, local_chunkpath: str):
876+
with open(local_chunkpath, "w") as f:
877+
json.dump({"chunks": [{"chunk_size": 2, "filename": "0.bin"}]}, f)
878+
879+
downloader.download_file = fn
880+
881+
monkeypatch.setattr(dataset_utilities_module, "get_downloader", mock.MagicMock(return_value=downloader))
882+
883+
dataset = StreamingDataset(input_dir="r2://random_bucket/optimized_tiny_imagenet")
884+
assert dataset.input_dir.url == "r2://random_bucket/optimized_tiny_imagenet"
885+
assert dataset.input_dir.path.endswith("chunks/9537e9e392ad87a4d38d05dfe28c329a/9537e9e392ad87a4d38d05dfe28c329a")
886+
887+
870888
class EmulateS3StreamingDataset(StreamingDataset):
871889
def _create_cache(self, worker_env: _WorkerEnv) -> Cache:
872890
cache_dir = os.path.join(self.input_dir.path)

tests/utilities/test_dataset_utilities.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ def test_should_replace_path():
2727
assert _should_replace_path("/teamspace/s3_folders/...")
2828
assert _should_replace_path("/teamspace/gcs_folders/...")
2929
assert _should_replace_path("/teamspace/gcs_connections/...")
30+
assert _should_replace_path("/teamspace/lightning_storage/...")
3031
assert not _should_replace_path("something_else")
3132

3233

0 commit comments

Comments
 (0)