Skip to content

Commit ab9982f

Browse files
aronphilandstuff
authored andcommitted
Support getting REPLICATE_API_TOKEN from cog context
This commit introduces support for the cog context into the Replicate SDK. The `current_scope` helper now makes per-prediction context available via the `current_scope().context` dict. A cog model can then provide a REPLICATE_API_TOKEN on a per-prediction basis to be used by the model. def predict(prompt: str) -> str: replicate = Replicate() output = replicate.run("anthropic/claude-3.5-haiku", {input: {"prompt": "prompt"}}) return output
1 parent 9f8d753 commit ab9982f

File tree

2 files changed

+163
-1
lines changed

2 files changed

+163
-1
lines changed

replicate/client.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -348,6 +348,22 @@ def close(self) -> None:
348348
self._wrapped_transport.close() # type: ignore
349349

350350

351+
def _get_api_token_from_environment() -> Optional[str]:
352+
"""Get API token from cog current scope if available, otherwise from environment."""
353+
try:
354+
import cog
355+
356+
if hasattr(cog, "current_scope"):
357+
scope = cog.current_scope()
358+
if scope and hasattr(scope, "content") and isinstance(scope.content, dict):
359+
if "replicate_api_token" in scope.content:
360+
return scope.content["replicate_api_token"]
361+
except (ImportError, AttributeError, Exception):
362+
pass
363+
364+
return os.environ.get("REPLICATE_API_TOKEN")
365+
366+
351367
def _build_httpx_client(
352368
client_type: Type[Union[httpx.Client, httpx.AsyncClient]],
353369
api_token: Optional[str] = None,
@@ -359,7 +375,7 @@ def _build_httpx_client(
359375
if "User-Agent" not in headers:
360376
headers["User-Agent"] = f"replicate-python/{__version__}"
361377
if "Authorization" not in headers and (
362-
api_token := api_token or os.environ.get("REPLICATE_API_TOKEN")
378+
api_token := api_token or _get_api_token_from_environment()
363379
):
364380
headers["Authorization"] = f"Bearer {api_token}"
365381

tests/test_client.py

Lines changed: 146 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,13 @@
11
import os
2+
import sys
23
from unittest import mock
34

45
import httpx
56
import pytest
67
import respx
78

9+
from replicate.client import _get_api_token_from_environment
10+
811

912
@pytest.mark.asyncio
1013
async def test_authorization_when_setting_environ_after_import():
@@ -114,3 +117,146 @@ def mock_send(request):
114117
pass
115118

116119
mock_send_wrapper.assert_called_once()
120+
121+
122+
class TestGetApiToken:
123+
"""Test cases for _get_api_token_from_environment function covering all import paths."""
124+
125+
def test_cog_not_available_falls_back_to_env(self):
126+
"""Test fallback to environment when cog package is not available."""
127+
with mock.patch.dict(os.environ, {"REPLICATE_API_TOKEN": "env-token"}):
128+
with mock.patch.dict(sys.modules, {"cog": None}):
129+
token = _get_api_token_from_environment()
130+
assert token == "env-token"
131+
132+
def test_cog_import_error_falls_back_to_env(self):
133+
"""Test fallback to environment when cog import raises exception."""
134+
with mock.patch.dict(os.environ, {"REPLICATE_API_TOKEN": "env-token"}):
135+
with mock.patch(
136+
"builtins.__import__",
137+
side_effect=ModuleNotFoundError("No module named 'cog'"),
138+
):
139+
token = _get_api_token_from_environment()
140+
assert token == "env-token"
141+
142+
def test_cog_no_current_scope_method_falls_back_to_env(self):
143+
"""Test fallback when cog exists but has no current_scope method."""
144+
mock_cog = mock.MagicMock()
145+
del mock_cog.current_scope # Remove the method
146+
147+
with mock.patch.dict(os.environ, {"REPLICATE_API_TOKEN": "env-token"}):
148+
with mock.patch.dict(sys.modules, {"cog": mock_cog}):
149+
token = _get_api_token_from_environment()
150+
assert token == "env-token"
151+
152+
def test_cog_current_scope_returns_none_falls_back_to_env(self):
153+
"""Test fallback when current_scope() returns None."""
154+
mock_cog = mock.MagicMock()
155+
mock_cog.current_scope.return_value = None
156+
157+
with mock.patch.dict(os.environ, {"REPLICATE_API_TOKEN": "env-token"}):
158+
with mock.patch.dict(sys.modules, {"cog": mock_cog}):
159+
token = _get_api_token_from_environment()
160+
assert token == "env-token"
161+
162+
def test_cog_scope_no_content_attr_falls_back_to_env(self):
163+
"""Test fallback when scope has no content attribute."""
164+
mock_scope = mock.MagicMock()
165+
del mock_scope.content # Remove the content attribute
166+
167+
mock_cog = mock.MagicMock()
168+
mock_cog.current_scope.return_value = mock_scope
169+
170+
with mock.patch.dict(os.environ, {"REPLICATE_API_TOKEN": "env-token"}):
171+
with mock.patch.dict(sys.modules, {"cog": mock_cog}):
172+
token = _get_api_token_from_environment()
173+
assert token == "env-token"
174+
175+
def test_cog_scope_content_not_dict_falls_back_to_env(self):
176+
"""Test fallback when scope.content is not a dictionary."""
177+
mock_scope = mock.MagicMock()
178+
mock_scope.content = "not a dict"
179+
180+
mock_cog = mock.MagicMock()
181+
mock_cog.current_scope.return_value = mock_scope
182+
183+
with mock.patch.dict(os.environ, {"REPLICATE_API_TOKEN": "env-token"}):
184+
with mock.patch.dict(sys.modules, {"cog": mock_cog}):
185+
token = _get_api_token_from_environment()
186+
assert token == "env-token"
187+
188+
def test_cog_scope_no_replicate_api_token_key_falls_back_to_env(self):
189+
"""Test fallback when replicate_api_token key is missing from content."""
190+
mock_scope = mock.MagicMock()
191+
mock_scope.content = {"other_key": "other_value"} # Missing replicate_api_token
192+
193+
mock_cog = mock.MagicMock()
194+
mock_cog.current_scope.return_value = mock_scope
195+
196+
with mock.patch.dict(os.environ, {"REPLICATE_API_TOKEN": "env-token"}):
197+
with mock.patch.dict(sys.modules, {"cog": mock_cog}):
198+
token = _get_api_token_from_environment()
199+
assert token == "env-token"
200+
201+
def test_cog_scope_replicate_api_token_valid_string(self):
202+
"""Test successful retrieval of non-empty token from cog."""
203+
mock_scope = mock.MagicMock()
204+
mock_scope.content = {"replicate_api_token": "cog-token"}
205+
206+
mock_cog = mock.MagicMock()
207+
mock_cog.current_scope.return_value = mock_scope
208+
209+
with mock.patch.dict(os.environ, {"REPLICATE_API_TOKEN": "env-token"}):
210+
with mock.patch.dict(sys.modules, {"cog": mock_cog}):
211+
token = _get_api_token_from_environment()
212+
assert token == "cog-token"
213+
214+
def test_cog_scope_replicate_api_token_empty_string(self):
215+
"""Test that empty string from cog is returned (not falling back to env)."""
216+
mock_scope = mock.MagicMock()
217+
mock_scope.content = {"replicate_api_token": ""} # Empty string
218+
219+
mock_cog = mock.MagicMock()
220+
mock_cog.current_scope.return_value = mock_scope
221+
222+
with mock.patch.dict(os.environ, {"REPLICATE_API_TOKEN": "env-token"}):
223+
with mock.patch.dict(sys.modules, {"cog": mock_cog}):
224+
token = _get_api_token_from_environment()
225+
assert token == "" # Should return empty string, not env token
226+
227+
def test_cog_scope_replicate_api_token_none(self):
228+
"""Test that None from cog is returned (not falling back to env)."""
229+
mock_scope = mock.MagicMock()
230+
mock_scope.content = {"replicate_api_token": None}
231+
232+
mock_cog = mock.MagicMock()
233+
mock_cog.current_scope.return_value = mock_scope
234+
235+
with mock.patch.dict(os.environ, {"REPLICATE_API_TOKEN": "env-token"}):
236+
with mock.patch.dict(sys.modules, {"cog": mock_cog}):
237+
token = _get_api_token_from_environment()
238+
assert token is None # Should return None, not env token
239+
240+
def test_cog_current_scope_raises_exception_falls_back_to_env(self):
241+
"""Test fallback when current_scope() raises an exception."""
242+
mock_cog = mock.MagicMock()
243+
mock_cog.current_scope.side_effect = RuntimeError("Scope error")
244+
245+
with mock.patch.dict(os.environ, {"REPLICATE_API_TOKEN": "env-token"}):
246+
with mock.patch.dict(sys.modules, {"cog": mock_cog}):
247+
token = _get_api_token_from_environment()
248+
assert token == "env-token"
249+
250+
def test_no_env_token_returns_none(self):
251+
"""Test that None is returned when no environment token is set and cog unavailable."""
252+
with mock.patch.dict(os.environ, {}, clear=True): # Clear all env vars
253+
with mock.patch.dict(sys.modules, {"cog": None}):
254+
token = _get_api_token_from_environment()
255+
assert token is None
256+
257+
def test_env_token_empty_string(self):
258+
"""Test that empty string from environment is returned."""
259+
with mock.patch.dict(os.environ, {"REPLICATE_API_TOKEN": ""}):
260+
with mock.patch.dict(sys.modules, {"cog": None}):
261+
token = _get_api_token_from_environment()
262+
assert token == ""

0 commit comments

Comments
 (0)