Skip to content

Commit 918f7ba

Browse files
Local downloader w cache (#30)
1 parent 9f584d2 commit 918f7ba

File tree

4 files changed

+36
-2
lines changed

4 files changed

+36
-2
lines changed

README.md

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -391,6 +391,16 @@ outputs = optimize(
391391
)
392392
```
393393

394+
## Network Drive On-Prem Support
395+
396+
On-prem compute nodes can mount and use network drive. In order to reduce their network overload, the `StreamingDataset` supports `caching` the chunks.
397+
398+
```python
399+
from lightning.data import StreamingDataset
400+
401+
dataset = StreamingDataset(input_dir="local:/data/shared-drive/some-data")
402+
```
403+
394404
# ⚡ Contributors
395405

396406
We welcome any contributions, pull requests, or issues. If you use the Streaming Dataset for your own project, please reach out to us on Slack or Discord.

litdata/streaming/downloader.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -95,7 +95,13 @@ def download_file(self, remote_filepath: str, local_filepath: str) -> None:
9595
shutil.copy(remote_filepath, local_filepath)
9696

9797

98-
_DOWNLOADERS = {"s3://": S3Downloader, "": LocalDownloader}
98+
class LocalDownloaderWithCache(LocalDownloader):
99+
def download_file(self, remote_filepath: str, local_filepath: str) -> None:
100+
remote_filepath = remote_filepath.replace("local:", "")
101+
super().download_file(remote_filepath, local_filepath)
102+
103+
104+
_DOWNLOADERS = {"s3://": S3Downloader, "local:": LocalDownloaderWithCache, "": LocalDownloader}
99105

100106

101107
def get_downloader_cls(remote_dir: str, cache_dir: str, chunks: List[Dict[str, Any]]) -> Downloader:

litdata/streaming/resolver.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,9 @@ def _resolve_dir(dir_path: Optional[Union[str, Dir]]) -> Dir:
5656
if dir_path.startswith("s3://"):
5757
return Dir(path=None, url=dir_path)
5858

59+
if dir_path.startswith("local:"):
60+
return Dir(path=None, url=dir_path)
61+
5962
dir_path = _resolve_time_template(dir_path)
6063

6164
dir_path_absolute = str(Path(dir_path).absolute().resolve())

tests/streaming/test_downloader.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import os
22
from unittest.mock import MagicMock
33

4-
from litdata.streaming.downloader import S3Downloader, subprocess
4+
from litdata.streaming.downloader import LocalDownloaderWithCache, S3Downloader, shutil, subprocess
55

66

77
def test_s3_downloader_fast(tmpdir, monkeypatch):
@@ -11,3 +11,18 @@ def test_s3_downloader_fast(tmpdir, monkeypatch):
1111
downloader = S3Downloader(tmpdir, tmpdir, [])
1212
downloader.download_file("s3://random_bucket/a.txt", os.path.join(tmpdir, "a.txt"))
1313
popen_mock.wait.assert_called()
14+
15+
16+
def test_download_with_cache(tmpdir, monkeypatch):
17+
# Create a file to download/cache
18+
with open("a.txt", "w") as f:
19+
f.write("hello")
20+
21+
try:
22+
local_downloader = LocalDownloaderWithCache(tmpdir, tmpdir, [])
23+
shutil_mock = MagicMock()
24+
monkeypatch.setattr(shutil, "copy", shutil_mock)
25+
local_downloader.download_file("local:a.txt", os.path.join(tmpdir, "a.txt"))
26+
shutil_mock.assert_called()
27+
finally:
28+
os.remove("a.txt")

0 commit comments

Comments
 (0)