Skip to content

Commit e3a62c9

Browse files
pwgardipeePeyton Gardipeepre-commit-ci[bot]tchaton
authored
Add support for direct upload to r2 buckets (#705)
Co-authored-by: Peyton Gardipee <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: tchaton <[email protected]>
1 parent 0b19ba6 commit e3a62c9

File tree

15 files changed

+1018
-35
lines changed

15 files changed

+1018
-35
lines changed

.github/workflows/ci-testing.yml

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -48,16 +48,16 @@ jobs:
4848
- name: Run fast tests in parallel
4949
run: |
5050
pytest tests \
51-
--ignore=tests/processing \
52-
--ignore=tests/raw \
53-
-n 2 --cov=litdata --durations=120
51+
--ignore=tests/processing \
52+
--ignore=tests/raw \
53+
-n 2 --cov=litdata --durations=0 --timeout=120 --capture=no --verbose
5454
5555
- name: Run processing tests sequentially
5656
run: |
5757
# note that the listed test should match ignored in the previous step
5858
pytest \
5959
tests/processing tests/raw \
60-
--cov=litdata --cov-append --durations=90
60+
--cov=litdata --cov-append --durations=0 --timeout=120 --capture=no --verbose
6161
6262
- name: Statistics
6363
continue-on-error: true

src/litdata/constants.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,11 @@
2424
_DEFAULT_CACHE_DIR = os.path.join(Path.home(), ".lightning", "chunks")
2525
_DEFAULT_LIGHTNING_CACHE_DIR = os.path.join("/cache", "chunks")
2626
_LITDATA_CACHE_DIR = os.getenv("LITDATA_CACHE_DIR", None)
27-
_SUPPORTED_PROVIDERS = ("s3", "gs") # cloud providers supported by litdata for uploading (optimize, map, merge, etc)
27+
_SUPPORTED_PROVIDERS = (
28+
"s3",
29+
"gs",
30+
"r2",
31+
) # cloud providers supported by litdata for uploading (optimize, map, merge, etc)
2832

2933
# This is required for full pytree serialization / deserialization support
3034
_TORCH_GREATER_EQUAL_2_1_0 = RequirementCache("torch>=2.1.0")

src/litdata/processing/data_processor.py

Lines changed: 16 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@
4545
_TQDM_AVAILABLE,
4646
)
4747
from litdata.processing.readers import BaseReader, StreamingDataLoaderReader
48-
from litdata.processing.utilities import _create_dataset, remove_uuid_from_filename
48+
from litdata.processing.utilities import _create_dataset, construct_storage_options, remove_uuid_from_filename
4949
from litdata.streaming import Cache
5050
from litdata.streaming.cache import Dir
5151
from litdata.streaming.dataloader import StreamingDataLoader
@@ -168,7 +168,8 @@ def _download_data_target(
168168
dirpath = os.path.dirname(local_path)
169169
os.makedirs(dirpath, exist_ok=True)
170170
if fs_provider is None:
171-
fs_provider = _get_fs_provider(input_dir.url, storage_options)
171+
merged_storage_options = construct_storage_options(storage_options, input_dir)
172+
fs_provider = _get_fs_provider(input_dir.url, merged_storage_options)
172173
fs_provider.download_file(path, local_path)
173174

174175
elif os.path.isfile(path):
@@ -233,7 +234,8 @@ def _upload_fn(
233234
obj = parse.urlparse(output_dir.url if output_dir.url else output_dir.path)
234235

235236
if obj.scheme in _SUPPORTED_PROVIDERS:
236-
fs_provider = _get_fs_provider(output_dir.url, storage_options)
237+
merged_storage_options = construct_storage_options(storage_options, output_dir)
238+
fs_provider = _get_fs_provider(output_dir.url, merged_storage_options)
237239

238240
while True:
239241
data: Optional[Union[str, tuple[str, str]]] = upload_queue.get()
@@ -1022,7 +1024,8 @@ def _upload_index(self, output_dir: Dir, cache_dir: str, num_nodes: int, node_ra
10221024
local_filepath = os.path.join(cache_dir, _INDEX_FILENAME)
10231025

10241026
if obj.scheme in _SUPPORTED_PROVIDERS:
1025-
fs_provider = _get_fs_provider(output_dir.url, self.storage_options)
1027+
merged_storage_options = construct_storage_options(self.storage_options, output_dir)
1028+
fs_provider = _get_fs_provider(output_dir.url, merged_storage_options)
10261029
fs_provider.upload_file(
10271030
local_filepath,
10281031
os.path.join(output_dir.url, os.path.basename(local_filepath)),
@@ -1044,8 +1047,9 @@ def _upload_index(self, output_dir: Dir, cache_dir: str, num_nodes: int, node_ra
10441047
remote_filepath = os.path.join(output_dir_path, f"{node_rank}-{_INDEX_FILENAME}")
10451048
node_index_filepath = os.path.join(cache_dir, os.path.basename(remote_filepath))
10461049
if obj.scheme in _SUPPORTED_PROVIDERS:
1047-
_wait_for_file_to_exist(remote_filepath, storage_options=self.storage_options)
1048-
fs_provider = _get_fs_provider(remote_filepath, self.storage_options)
1050+
merged_storage_options = construct_storage_options(self.storage_options, output_dir)
1051+
_wait_for_file_to_exist(remote_filepath, storage_options=merged_storage_options)
1052+
fs_provider = _get_fs_provider(remote_filepath, merged_storage_options)
10491053
fs_provider.download_file(remote_filepath, node_index_filepath)
10501054
elif output_dir.path and os.path.isdir(output_dir.path):
10511055
shutil.copyfile(remote_filepath, node_index_filepath)
@@ -1499,8 +1503,8 @@ def _cleanup_checkpoints(self) -> None:
14991503

15001504
prefix = self.output_dir.url.rstrip("/") + "/"
15011505
checkpoint_prefix = os.path.join(prefix, ".checkpoints")
1502-
1503-
fs_provider = _get_fs_provider(self.output_dir.url, self.storage_options)
1506+
merged_storage_options = construct_storage_options(self.storage_options, self.output_dir)
1507+
fs_provider = _get_fs_provider(self.output_dir.url, merged_storage_options)
15041508
fs_provider.delete_file_or_directory(checkpoint_prefix)
15051509

15061510
def _save_current_config(self, workers_user_items: list[list[Any]]) -> None:
@@ -1529,8 +1533,8 @@ def _save_current_config(self, workers_user_items: list[list[Any]]) -> None:
15291533

15301534
if obj.scheme not in _SUPPORTED_PROVIDERS:
15311535
not_supported_provider(self.output_dir.url)
1532-
1533-
fs_provider = _get_fs_provider(self.output_dir.url, self.storage_options)
1536+
merged_storage_options = construct_storage_options(self.storage_options, self.output_dir)
1537+
fs_provider = _get_fs_provider(self.output_dir.url, merged_storage_options)
15341538

15351539
prefix = self.output_dir.url.rstrip("/") + "/" + ".checkpoints/"
15361540

@@ -1601,7 +1605,8 @@ def _load_checkpoint_config(self, workers_user_items: list[list[Any]]) -> None:
16011605

16021606
# download all the checkpoint files in tempdir and read them
16031607
with tempfile.TemporaryDirectory() as temp_dir:
1604-
fs_provider = _get_fs_provider(self.output_dir.url, self.storage_options)
1608+
merged_storage_options = construct_storage_options(self.storage_options, self.output_dir)
1609+
fs_provider = _get_fs_provider(self.output_dir.url, merged_storage_options)
16051610
saved_file_dir = fs_provider.download_directory(prefix, temp_dir)
16061611

16071612
if not os.path.exists(os.path.join(saved_file_dir, "config.json")):

src/litdata/processing/utilities.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -272,3 +272,10 @@ def remove_uuid_from_filename(filepath: str) -> str:
272272

273273
# uuid is of 32 characters, '.json' is 5 characters and '-' is 1 character
274274
return filepath[:-38] + ".json"
275+
276+
277+
def construct_storage_options(storage_options: dict[str, Any], input_dir: Dir) -> dict[str, Any]:
278+
merged_storage_options = storage_options.copy()
279+
if hasattr(input_dir, "data_connection_id") and input_dir.data_connection_id:
280+
merged_storage_options["data_connection_id"] = input_dir.data_connection_id
281+
return merged_storage_options

src/litdata/streaming/client.py

Lines changed: 141 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,17 +11,39 @@
1111
# See the License for the specific language governing permissions and
1212
# limitations under the License.
1313

14+
import json
1415
import os
1516
from time import time
1617
from typing import Any, Optional
1718

1819
import boto3
1920
import botocore
21+
import requests
2022
from botocore.credentials import InstanceMetadataProvider
2123
from botocore.utils import InstanceMetadataFetcher
24+
from requests.adapters import HTTPAdapter
25+
from urllib3.util.retry import Retry
2226

2327
from litdata.constants import _IS_IN_STUDIO
2428

29+
# Constants for the retry adapter. Docs: https://urllib3.readthedocs.io/en/stable/reference/urllib3.util.html
30+
# Maximum number of total connection retry attempts (e.g., 2880 retries = 24 hours with 30s timeout per request)
31+
_CONNECTION_RETRY_TOTAL = 2880
32+
# Backoff factor for connection retries (wait time increases by this factor after each failure)
33+
_CONNECTION_RETRY_BACKOFF_FACTOR = 0.5
34+
# Default timeout for each HTTP request in seconds
35+
_DEFAULT_REQUEST_TIMEOUT = 30 # seconds
36+
37+
38+
class _CustomRetryAdapter(HTTPAdapter):
39+
def __init__(self, *args: Any, **kwargs: Any) -> None:
40+
self.timeout = kwargs.pop("timeout", _DEFAULT_REQUEST_TIMEOUT)
41+
super().__init__(*args, **kwargs)
42+
43+
def send(self, request: Any, *args: Any, **kwargs: Any) -> Any:
44+
kwargs["timeout"] = kwargs.get("timeout", self.timeout)
45+
return super().send(request, **kwargs)
46+
2547

2648
class S3Client:
2749
# TODO: Generalize to support more cloud providers.
@@ -76,3 +98,122 @@ def client(self) -> Any:
7698
self._last_time = time()
7799

78100
return self._client
101+
102+
103+
class R2Client(S3Client):
104+
"""R2 client with refreshable credentials for Cloudflare R2 storage."""
105+
106+
def __init__(
107+
self,
108+
refetch_interval: int = 3600, # 1 hour - this is the default refresh interval for R2 credentials
109+
storage_options: Optional[dict] = {},
110+
session_options: Optional[dict] = {},
111+
) -> None:
112+
# Store R2-specific options before calling super()
113+
self._base_storage_options: dict = storage_options or {}
114+
115+
# Call parent constructor with R2-specific refetch interval
116+
super().__init__(
117+
refetch_interval=refetch_interval,
118+
storage_options={}, # storage options handled in _create_client
119+
session_options=session_options,
120+
)
121+
122+
def get_r2_bucket_credentials(self, data_connection_id: str) -> dict[str, str]:
123+
"""Fetch temporary R2 credentials for the current lightning storage connection."""
124+
# Create session with retry logic
125+
retry_strategy = Retry(
126+
total=_CONNECTION_RETRY_TOTAL,
127+
backoff_factor=_CONNECTION_RETRY_BACKOFF_FACTOR,
128+
status_forcelist=[
129+
408, # Request Timeout
130+
429, # Too Many Requests
131+
500, # Internal Server Error
132+
502, # Bad Gateway
133+
503, # Service Unavailable
134+
504, # Gateway Timeout
135+
],
136+
)
137+
adapter = _CustomRetryAdapter(max_retries=retry_strategy, timeout=_DEFAULT_REQUEST_TIMEOUT)
138+
session = requests.Session()
139+
session.mount("http://", adapter)
140+
session.mount("https://", adapter)
141+
142+
try:
143+
# Get Lightning Cloud API token
144+
cloud_url = os.getenv("LIGHTNING_CLOUD_URL", "https://lightning.ai")
145+
api_key = os.getenv("LIGHTNING_API_KEY")
146+
username = os.getenv("LIGHTNING_USERNAME")
147+
project_id = os.getenv("LIGHTNING_CLOUD_PROJECT_ID")
148+
149+
if not all([api_key, username, project_id]):
150+
raise RuntimeError("Missing required environment variables")
151+
152+
# Login to get token
153+
payload = {"apiKey": api_key, "username": username}
154+
login_url = f"{cloud_url}/v1/auth/login"
155+
response = session.post(login_url, data=json.dumps(payload))
156+
157+
if "token" not in response.json():
158+
raise RuntimeError("Failed to get authentication token")
159+
160+
token = response.json()["token"]
161+
162+
# Get temporary bucket credentials
163+
headers = {"Authorization": f"Bearer {token}", "Content-Type": "application/json"}
164+
credentials_url = (
165+
f"{cloud_url}/v1/projects/{project_id}/data-connections/{data_connection_id}/temp-bucket-credentials"
166+
)
167+
168+
credentials_response = session.get(credentials_url, headers=headers, timeout=10)
169+
170+
if credentials_response.status_code != 200:
171+
raise RuntimeError(f"Failed to get credentials: {credentials_response.status_code}")
172+
173+
temp_credentials = credentials_response.json()
174+
175+
endpoint_url = f"https://{temp_credentials['accountId']}.r2.cloudflarestorage.com"
176+
177+
# Format credentials for S3Client
178+
return {
179+
"aws_access_key_id": temp_credentials["accessKeyId"],
180+
"aws_secret_access_key": temp_credentials["secretAccessKey"],
181+
"aws_session_token": temp_credentials["sessionToken"],
182+
"endpoint_url": endpoint_url,
183+
}
184+
185+
except Exception as e:
186+
# Fallback to hardcoded credentials if API call fails
187+
print(f"Failed to get R2 credentials from API: {e}. Using fallback credentials.")
188+
raise RuntimeError(f"Failed to get R2 credentials and no fallback available: {e}")
189+
190+
def _create_client(self) -> None:
191+
"""Create a new R2 client with fresh credentials."""
192+
# Get data connection ID from storage options
193+
data_connection_id = self._base_storage_options.get("data_connection_id")
194+
if not data_connection_id:
195+
raise RuntimeError("data_connection_id is required in storage_options for R2 client")
196+
197+
# Get fresh R2 credentials
198+
r2_credentials = self.get_r2_bucket_credentials(data_connection_id)
199+
200+
# Filter out metadata keys that shouldn't be passed to boto3
201+
filtered_storage_options = {
202+
k: v for k, v in self._base_storage_options.items() if k not in ["data_connection_id"]
203+
}
204+
205+
# Combine filtered storage options with fresh credentials
206+
combined_storage_options = {**filtered_storage_options, **r2_credentials}
207+
208+
# Update the inherited storage options with R2 credentials
209+
self._storage_options = combined_storage_options
210+
211+
# Create session and client
212+
session = boto3.Session(**self._session_options)
213+
self._client = session.client(
214+
"s3",
215+
**{
216+
"config": botocore.config.Config(retries={"max_attempts": 1000, "mode": "adaptive"}),
217+
**combined_storage_options,
218+
},
219+
)

src/litdata/streaming/config.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -134,6 +134,7 @@ def download_chunk_from_index(self, chunk_index: int, skip_lock: bool = False) -
134134

135135
if os.path.exists(local_chunkpath):
136136
self.try_decompress(local_chunkpath)
137+
137138
if self._downloader is not None and not skip_lock:
138139
# We don't want to redownload the base, but we should mark
139140
# it as having been requested by something

src/litdata/streaming/fs_provider.py

Lines changed: 74 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
from urllib import parse
1717

1818
from litdata.constants import _GOOGLE_STORAGE_AVAILABLE, _SUPPORTED_PROVIDERS
19-
from litdata.streaming.client import S3Client
19+
from litdata.streaming.client import R2Client, S3Client
2020

2121

2222
class FsProvider(ABC):
@@ -224,6 +224,77 @@ def is_empty(self, path: str) -> bool:
224224
return not objects["KeyCount"] > 0
225225

226226

227+
class R2FsProvider(S3FsProvider):
228+
def __init__(self, storage_options: Optional[dict[str, Any]] = {}):
229+
super().__init__(storage_options=storage_options)
230+
231+
# Create R2Client with refreshable credentials
232+
self.client = R2Client(storage_options=storage_options)
233+
234+
def upload_file(self, local_path: str, remote_path: str) -> None:
235+
bucket_name, blob_path = get_bucket_and_path(remote_path, "r2")
236+
self.client.client.upload_file(local_path, bucket_name, blob_path)
237+
238+
def download_file(self, remote_path: str, local_path: str) -> None:
239+
bucket_name, blob_path = get_bucket_and_path(remote_path, "r2")
240+
with open(local_path, "wb") as f:
241+
self.client.client.download_fileobj(bucket_name, blob_path, f)
242+
243+
def download_directory(self, remote_path: str, local_directory_name: str) -> str:
244+
"""Download all objects under a given S3 prefix (directory) using the existing client."""
245+
bucket_name, remote_directory_name = get_bucket_and_path(remote_path, "r2")
246+
247+
# Ensure local directory exists
248+
local_directory_name = os.path.abspath(local_directory_name)
249+
os.makedirs(local_directory_name, exist_ok=True)
250+
251+
saved_file_dir = "."
252+
253+
# List objects under the given prefix
254+
objects = self.client.client.list_objects_v2(Bucket=bucket_name, Prefix=remote_directory_name)
255+
256+
# Check if objects exist
257+
if "Contents" in objects:
258+
for obj in objects["Contents"]:
259+
local_filename = os.path.join(local_directory_name, obj["Key"])
260+
261+
# Ensure parent directories exist
262+
os.makedirs(os.path.dirname(local_filename), exist_ok=True)
263+
264+
# Download each file
265+
with open(local_filename, "wb") as f:
266+
self.client.client.download_fileobj(bucket_name, obj["Key"], f)
267+
saved_file_dir = os.path.dirname(local_filename)
268+
269+
return saved_file_dir
270+
271+
def delete_file_or_directory(self, path: str) -> None:
272+
"""Delete the file or the directory."""
273+
bucket_name, blob_path = get_bucket_and_path(path, "r2")
274+
275+
# List objects under the given path
276+
objects = self.client.client.list_objects_v2(Bucket=bucket_name, Prefix=blob_path)
277+
278+
# Check if objects exist
279+
if "Contents" in objects:
280+
for obj in objects["Contents"]:
281+
self.client.client.delete_object(Bucket=bucket_name, Key=obj["Key"])
282+
283+
def exists(self, path: str) -> bool:
284+
import botocore
285+
286+
bucket_name, blob_path = get_bucket_and_path(path, "r2")
287+
try:
288+
_ = self.client.client.head_object(Bucket=bucket_name, Key=blob_path)
289+
return True
290+
except botocore.exceptions.ClientError as e:
291+
if "the HeadObject operation: Not Found" in str(e):
292+
return False
293+
raise e
294+
except Exception as e:
295+
raise e
296+
297+
227298
def get_bucket_and_path(remote_filepath: str, expected_scheme: str = "s3") -> tuple[str, str]:
228299
"""Parse the remote filepath and return the bucket name and the blob path.
229300
@@ -259,6 +330,8 @@ def _get_fs_provider(remote_filepath: str, storage_options: Optional[dict[str, A
259330
return GCPFsProvider(storage_options=storage_options)
260331
if obj.scheme == "s3":
261332
return S3FsProvider(storage_options=storage_options)
333+
if obj.scheme == "r2":
334+
return R2FsProvider(storage_options=storage_options)
262335
raise ValueError(f"Unsupported scheme: {obj.scheme}")
263336

264337

0 commit comments

Comments
 (0)