Skip to content

Commit 7faa682

Browse files
authored
Add downloader for R2 (#711)
1 parent 42961fa commit 7faa682

File tree

2 files changed

+353
-1
lines changed

2 files changed

+353
-1
lines changed

src/litdata/streaming/downloader.py

Lines changed: 157 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@
3333
_OBSTORE_AVAILABLE,
3434
)
3535
from litdata.debugger import _get_log_msg
36-
from litdata.streaming.client import S3Client
36+
from litdata.streaming.client import R2Client, S3Client
3737

3838
logger = logging.getLogger("litdata.streaming.downloader")
3939

@@ -263,6 +263,161 @@ async def adownload_fileobj(self, remote_filepath: str) -> bytes:
263263
return bytes(bytes_object) # Convert obstore.Bytes to bytes
264264

265265

266+
class R2Downloader(Downloader):
267+
def __init__(
268+
self,
269+
remote_dir: str,
270+
cache_dir: str,
271+
chunks: list[dict[str, Any]],
272+
storage_options: Optional[dict] = {},
273+
**kwargs: Any,
274+
):
275+
super().__init__(remote_dir, cache_dir, chunks, storage_options)
276+
self._s5cmd_available = os.system("s5cmd > /dev/null 2>&1") == 0
277+
# check if kwargs contains session_options
278+
self.session_options = kwargs.get("session_options", {})
279+
280+
if not self._s5cmd_available or _DISABLE_S5CMD:
281+
self._client = R2Client(storage_options=self._storage_options, session_options=self.session_options)
282+
283+
def download_file(self, remote_filepath: str, local_filepath: str) -> None:
284+
obj = parse.urlparse(remote_filepath)
285+
286+
if obj.scheme != "r2":
287+
raise ValueError(f"Expected obj.scheme to be `r2`, instead, got {obj.scheme} for remote={remote_filepath}")
288+
289+
if os.path.exists(local_filepath):
290+
return
291+
292+
with (
293+
suppress(Timeout, FileNotFoundError),
294+
FileLock(local_filepath + ".lock", timeout=1 if obj.path.endswith(_INDEX_FILENAME) else 0),
295+
):
296+
if self._s5cmd_available and not _DISABLE_S5CMD:
297+
env = None
298+
if self._storage_options:
299+
env = os.environ.copy()
300+
env.update(self._storage_options)
301+
302+
aws_no_sign_request = self._storage_options.get("AWS_NO_SIGN_REQUEST", "no").lower() == "yes"
303+
# prepare the s5cmd command
304+
no_signed_option = "--no-sign-request" if aws_no_sign_request else None
305+
cmd_parts = ["s5cmd", no_signed_option, "cp", remote_filepath, local_filepath]
306+
cmd = " ".join(part for part in cmd_parts if part)
307+
308+
proc = subprocess.Popen(
309+
cmd,
310+
shell=True,
311+
stdout=subprocess.PIPE,
312+
stderr=subprocess.PIPE,
313+
env=env,
314+
)
315+
return_code = proc.wait()
316+
317+
if return_code != 0:
318+
stderr_output = proc.stderr.read().decode().strip() if proc.stderr else ""
319+
error_message = (
320+
f"Failed to execute command `{cmd}` (exit code: {return_code}). "
321+
"This might be due to an incorrect file path, insufficient permissions, or network issues. "
322+
"To resolve this issue, you can either:\n"
323+
"- Pass `storage_options` with the necessary credentials and endpoint. \n"
324+
"- Example:\n"
325+
" storage_options = {\n"
326+
' "AWS_ACCESS_KEY_ID": "your-key",\n'
327+
' "AWS_SECRET_ACCESS_KEY": "your-secret",\n'
328+
' "S3_ENDPOINT_URL": "https://s3.example.com" (Optional if using AWS)\n'
329+
" }\n"
330+
"- or disable `s5cmd` by setting `DISABLE_S5CMD=1` if `storage_options` do not work.\n"
331+
)
332+
if stderr_output:
333+
error_message += (
334+
f"For further debugging, please check the command output below:\n{stderr_output}"
335+
)
336+
raise RuntimeError(error_message)
337+
else:
338+
from boto3.s3.transfer import TransferConfig
339+
340+
extra_args: dict[str, Any] = {}
341+
342+
if not os.path.exists(local_filepath):
343+
# Issue: https://github.com/boto/boto3/issues/3113
344+
self._client.client.download_file(
345+
obj.netloc,
346+
obj.path.lstrip("/"),
347+
local_filepath,
348+
ExtraArgs=extra_args,
349+
Config=TransferConfig(use_threads=False),
350+
)
351+
352+
def download_bytes(self, remote_filepath: str, offset: int, length: int, local_chunkpath: str) -> bytes:
353+
obj = parse.urlparse(remote_filepath)
354+
355+
if obj.scheme != "r2":
356+
raise ValueError(f"Expected obj.scheme to be `r2`, instead, got {obj.scheme} for remote={remote_filepath}")
357+
358+
if not hasattr(self, "client"):
359+
self._client = R2Client(storage_options=self._storage_options, session_options=self.session_options)
360+
361+
bucket = obj.netloc
362+
key = obj.path.lstrip("/")
363+
364+
byte_range = f"bytes={offset}-{offset + length - 1}"
365+
366+
response = self._client.client.get_object(Bucket=bucket, Key=key, Range=byte_range)
367+
368+
return response["Body"].read()
369+
370+
def download_fileobj(self, remote_filepath: str, fileobj: Any) -> None:
371+
"""Download a file from R2 directly to a file-like object."""
372+
obj = parse.urlparse(remote_filepath)
373+
374+
if obj.scheme != "r2":
375+
raise ValueError(f"Expected obj.scheme to be `r2`, instead, got {obj.scheme} for remote={remote_filepath}")
376+
377+
if not hasattr(self, "_client"):
378+
self._client = R2Client(storage_options=self._storage_options, session_options=self.session_options)
379+
380+
bucket = obj.netloc
381+
key = obj.path.lstrip("/")
382+
383+
self._client.client.download_fileobj(
384+
bucket,
385+
key,
386+
fileobj,
387+
)
388+
389+
def _get_store(self, bucket: str) -> Any:
390+
"""Return an obstore S3Store instance for the given bucket, initializing if needed."""
391+
if not hasattr(self, "_store"):
392+
if not _OBSTORE_AVAILABLE:
393+
raise ModuleNotFoundError(str(_OBSTORE_AVAILABLE))
394+
import boto3
395+
from obstore.auth.boto3 import Boto3CredentialProvider
396+
from obstore.store import S3Store
397+
398+
session = boto3.Session(**self._storage_options, **self.session_options)
399+
credential_provider = Boto3CredentialProvider(session)
400+
self._store = S3Store(bucket, credential_provider=credential_provider)
401+
return self._store
402+
403+
async def adownload_fileobj(self, remote_filepath: str) -> bytes:
404+
"""Download a file from R2 directly to a file-like object asynchronously."""
405+
import obstore as obs
406+
407+
obj = parse.urlparse(remote_filepath)
408+
409+
if obj.scheme != "r2":
410+
raise ValueError(f"Expected obj.scheme to be `r2`, instead, got {obj.scheme} for remote={remote_filepath}")
411+
412+
bucket = obj.netloc
413+
key = obj.path.lstrip("/")
414+
415+
store = self._get_store(bucket)
416+
resp = await obs.get_async(store, key)
417+
bytes_object = await resp.bytes_async()
418+
return bytes(bytes_object) # Convert obstore.Bytes to bytes
419+
420+
266421
class GCPDownloader(Downloader):
267422
def __init__(
268423
self,
@@ -550,6 +705,7 @@ def download_file(self, remote_filepath: str, local_filepath: str) -> None:
550705
"azure://": AzureDownloader,
551706
"hf://": HFDownloader,
552707
"local:": LocalDownloaderWithCache,
708+
"r2://": R2Downloader,
553709
}
554710

555711

0 commit comments

Comments
 (0)