|
11 | 11 | # See the License for the specific language governing permissions and |
12 | 12 | # limitations under the License. |
13 | 13 |
|
| 14 | +import json |
14 | 15 | import os |
15 | 16 | from time import time |
16 | 17 | from typing import Any, Optional |
17 | 18 |
|
18 | 19 | import boto3 |
19 | 20 | import botocore |
| 21 | +import requests |
20 | 22 | from botocore.credentials import InstanceMetadataProvider |
21 | 23 | from botocore.utils import InstanceMetadataFetcher |
| 24 | +from requests.adapters import HTTPAdapter |
| 25 | +from urllib3.util.retry import Retry |
22 | 26 |
|
23 | 27 | from litdata.constants import _IS_IN_STUDIO |
24 | 28 |
|
| 29 | +# Constants for the retry adapter. Docs: https://urllib3.readthedocs.io/en/stable/reference/urllib3.util.html |
| 30 | +# Maximum number of total connection retry attempts (e.g., 2880 retries = 24 hours with 30s timeout per request) |
| 31 | +_CONNECTION_RETRY_TOTAL = 2880 |
| 32 | +# Backoff factor for connection retries (wait time increases by this factor after each failure) |
| 33 | +_CONNECTION_RETRY_BACKOFF_FACTOR = 0.5 |
| 34 | +# Default timeout for each HTTP request in seconds |
| 35 | +_DEFAULT_REQUEST_TIMEOUT = 30 # seconds |
| 36 | + |
| 37 | + |
| 38 | +class _CustomRetryAdapter(HTTPAdapter): |
| 39 | + def __init__(self, *args: Any, **kwargs: Any) -> None: |
| 40 | + self.timeout = kwargs.pop("timeout", _DEFAULT_REQUEST_TIMEOUT) |
| 41 | + super().__init__(*args, **kwargs) |
| 42 | + |
| 43 | + def send(self, request: Any, *args: Any, **kwargs: Any) -> Any: |
| 44 | + kwargs["timeout"] = kwargs.get("timeout", self.timeout) |
| 45 | + return super().send(request, **kwargs) |
| 46 | + |
25 | 47 |
|
26 | 48 | class S3Client: |
27 | 49 | # TODO: Generalize to support more cloud providers. |
@@ -76,3 +98,122 @@ def client(self) -> Any: |
76 | 98 | self._last_time = time() |
77 | 99 |
|
78 | 100 | return self._client |
| 101 | + |
| 102 | + |
| 103 | +class R2Client(S3Client): |
| 104 | + """R2 client with refreshable credentials for Cloudflare R2 storage.""" |
| 105 | + |
| 106 | + def __init__( |
| 107 | + self, |
| 108 | + refetch_interval: int = 3600, # 1 hour - this is the default refresh interval for R2 credentials |
| 109 | + storage_options: Optional[dict] = {}, |
| 110 | + session_options: Optional[dict] = {}, |
| 111 | + ) -> None: |
| 112 | + # Store R2-specific options before calling super() |
| 113 | + self._base_storage_options: dict = storage_options or {} |
| 114 | + |
| 115 | + # Call parent constructor with R2-specific refetch interval |
| 116 | + super().__init__( |
| 117 | + refetch_interval=refetch_interval, |
| 118 | + storage_options={}, # storage options handled in _create_client |
| 119 | + session_options=session_options, |
| 120 | + ) |
| 121 | + |
| 122 | + def get_r2_bucket_credentials(self, data_connection_id: str) -> dict[str, str]: |
| 123 | + """Fetch temporary R2 credentials for the current lightning storage connection.""" |
| 124 | + # Create session with retry logic |
| 125 | + retry_strategy = Retry( |
| 126 | + total=_CONNECTION_RETRY_TOTAL, |
| 127 | + backoff_factor=_CONNECTION_RETRY_BACKOFF_FACTOR, |
| 128 | + status_forcelist=[ |
| 129 | + 408, # Request Timeout |
| 130 | + 429, # Too Many Requests |
| 131 | + 500, # Internal Server Error |
| 132 | + 502, # Bad Gateway |
| 133 | + 503, # Service Unavailable |
| 134 | + 504, # Gateway Timeout |
| 135 | + ], |
| 136 | + ) |
| 137 | + adapter = _CustomRetryAdapter(max_retries=retry_strategy, timeout=_DEFAULT_REQUEST_TIMEOUT) |
| 138 | + session = requests.Session() |
| 139 | + session.mount("http://", adapter) |
| 140 | + session.mount("https://", adapter) |
| 141 | + |
| 142 | + try: |
| 143 | + # Get Lightning Cloud API token |
| 144 | + cloud_url = os.getenv("LIGHTNING_CLOUD_URL", "https://lightning.ai") |
| 145 | + api_key = os.getenv("LIGHTNING_API_KEY") |
| 146 | + username = os.getenv("LIGHTNING_USERNAME") |
| 147 | + project_id = os.getenv("LIGHTNING_CLOUD_PROJECT_ID") |
| 148 | + |
| 149 | + if not all([api_key, username, project_id]): |
| 150 | + raise RuntimeError("Missing required environment variables") |
| 151 | + |
| 152 | + # Login to get token |
| 153 | + payload = {"apiKey": api_key, "username": username} |
| 154 | + login_url = f"{cloud_url}/v1/auth/login" |
| 155 | + response = session.post(login_url, data=json.dumps(payload)) |
| 156 | + |
| 157 | + if "token" not in response.json(): |
| 158 | + raise RuntimeError("Failed to get authentication token") |
| 159 | + |
| 160 | + token = response.json()["token"] |
| 161 | + |
| 162 | + # Get temporary bucket credentials |
| 163 | + headers = {"Authorization": f"Bearer {token}", "Content-Type": "application/json"} |
| 164 | + credentials_url = ( |
| 165 | + f"{cloud_url}/v1/projects/{project_id}/data-connections/{data_connection_id}/temp-bucket-credentials" |
| 166 | + ) |
| 167 | + |
| 168 | + credentials_response = session.get(credentials_url, headers=headers, timeout=10) |
| 169 | + |
| 170 | + if credentials_response.status_code != 200: |
| 171 | + raise RuntimeError(f"Failed to get credentials: {credentials_response.status_code}") |
| 172 | + |
| 173 | + temp_credentials = credentials_response.json() |
| 174 | + |
| 175 | + endpoint_url = f"https://{temp_credentials['accountId']}.r2.cloudflarestorage.com" |
| 176 | + |
| 177 | + # Format credentials for S3Client |
| 178 | + return { |
| 179 | + "aws_access_key_id": temp_credentials["accessKeyId"], |
| 180 | + "aws_secret_access_key": temp_credentials["secretAccessKey"], |
| 181 | + "aws_session_token": temp_credentials["sessionToken"], |
| 182 | + "endpoint_url": endpoint_url, |
| 183 | + } |
| 184 | + |
| 185 | + except Exception as e: |
| 186 | + # Fallback to hardcoded credentials if API call fails |
| 187 | + print(f"Failed to get R2 credentials from API: {e}. Using fallback credentials.") |
| 188 | + raise RuntimeError(f"Failed to get R2 credentials and no fallback available: {e}") |
| 189 | + |
| 190 | + def _create_client(self) -> None: |
| 191 | + """Create a new R2 client with fresh credentials.""" |
| 192 | + # Get data connection ID from storage options |
| 193 | + data_connection_id = self._base_storage_options.get("data_connection_id") |
| 194 | + if not data_connection_id: |
| 195 | + raise RuntimeError("data_connection_id is required in storage_options for R2 client") |
| 196 | + |
| 197 | + # Get fresh R2 credentials |
| 198 | + r2_credentials = self.get_r2_bucket_credentials(data_connection_id) |
| 199 | + |
| 200 | + # Filter out metadata keys that shouldn't be passed to boto3 |
| 201 | + filtered_storage_options = { |
| 202 | + k: v for k, v in self._base_storage_options.items() if k not in ["data_connection_id"] |
| 203 | + } |
| 204 | + |
| 205 | + # Combine filtered storage options with fresh credentials |
| 206 | + combined_storage_options = {**filtered_storage_options, **r2_credentials} |
| 207 | + |
| 208 | + # Update the inherited storage options with R2 credentials |
| 209 | + self._storage_options = combined_storage_options |
| 210 | + |
| 211 | + # Create session and client |
| 212 | + session = boto3.Session(**self._session_options) |
| 213 | + self._client = session.client( |
| 214 | + "s3", |
| 215 | + **{ |
| 216 | + "config": botocore.config.Config(retries={"max_attempts": 1000, "mode": "adaptive"}), |
| 217 | + **combined_storage_options, |
| 218 | + }, |
| 219 | + ) |
0 commit comments