Skip to content

Commit 00930ee

Browse files
authored
Resolve s3 credentials wrongly defined (#27)
1 parent 6683b08 commit 00930ee

File tree

2 files changed

+8
-11
lines changed

2 files changed

+8
-11
lines changed

litdata/streaming/client.py

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
from time import time
33
from typing import Any, Optional
44

5-
from litdata.constants import _BOTO3_AVAILABLE
5+
from litdata.constants import _BOTO3_AVAILABLE, _IS_IN_STUDIO
66

77
if _BOTO3_AVAILABLE:
88
import boto3
@@ -17,15 +17,14 @@ class S3Client:
1717
def __init__(self, refetch_interval: int = 3300) -> None:
1818
self._refetch_interval = refetch_interval
1919
self._last_time: Optional[float] = None
20-
self._has_cloud_space_id: bool = "LIGHTNING_CLOUD_SPACE_ID" in os.environ
2120
self._client: Optional[Any] = None
2221

2322
def _create_client(self) -> None:
2423
has_shared_credentials_file = (
2524
os.getenv("AWS_SHARED_CREDENTIALS_FILE") == os.getenv("AWS_CONFIG_FILE") == "/.credentials/.aws_credentials"
2625
)
2726

28-
if has_shared_credentials_file:
27+
if has_shared_credentials_file or not _IS_IN_STUDIO:
2928
self._client = boto3.client(
3029
"s3", config=botocore.config.Config(retries={"max_attempts": 1000, "mode": "adaptive"})
3130
)
@@ -42,10 +41,9 @@ def _create_client(self) -> None:
4241

4342
@property
4443
def client(self) -> Any:
45-
if not self._has_cloud_space_id:
46-
if self._client is None:
47-
self._create_client()
48-
return self._client
44+
if self._client is None:
45+
self._create_client()
46+
self._last_time = time()
4947

5048
# Re-generate credentials for EC2
5149
if self._last_time is None or (time() - self._last_time) > self._refetch_interval:

tests/streaming/test_client.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -31,17 +31,16 @@ def test_s3_client_without_cloud_space_id(monkeypatch):
3131

3232

3333
@pytest.mark.skipif(sys.platform == "win32", reason="not supported on windows")
34-
@pytest.mark.parametrize("use_shared_credentials", [False, True])
34+
@pytest.mark.parametrize("use_shared_credentials", [False, True, None])
3535
def test_s3_client_with_cloud_space_id(use_shared_credentials, monkeypatch):
3636
boto3 = mock.MagicMock()
3737
monkeypatch.setattr(client, "boto3", boto3)
3838

3939
botocore = mock.MagicMock()
4040
monkeypatch.setattr(client, "botocore", botocore)
4141

42-
monkeypatch.setenv("LIGHTNING_CLOUD_SPACE_ID", "dummy")
43-
44-
if use_shared_credentials:
42+
if isinstance(use_shared_credentials, bool):
43+
monkeypatch.setenv("LIGHTNING_CLOUD_SPACE_ID", "dummy")
4544
monkeypatch.setenv("AWS_SHARED_CREDENTIALS_FILE", "/.credentials/.aws_credentials")
4645
monkeypatch.setenv("AWS_CONFIG_FILE", "/.credentials/.aws_credentials")
4746

0 commit comments

Comments
 (0)