-
Notifications
You must be signed in to change notification settings - Fork 1.9k
[ENH] For chroma cloud efs, extract api key from header if available to authenticate #5914
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,4 +1,4 @@ | ||
| from typing import ClassVar, Dict | ||
| from typing import ClassVar, Dict, Optional | ||
| import uuid | ||
|
|
||
| from chromadb.api import ServerAPI | ||
|
|
@@ -94,3 +94,50 @@ def _system(self) -> System: | |
| def _submit_client_start_event(self) -> None: | ||
| telemetry_client = self._system.instance(ProductTelemetryClient) | ||
| telemetry_client.capture(ClientStartEvent()) | ||
|
|
||
| @staticmethod | ||
| def get_chroma_cloud_api_key_from_clients() -> Optional[str]: | ||
| """ | ||
| try to extract api key from existing client instaces by checking httpx session headers | ||
| if available. | ||
|
|
||
| Requirements to pull api key: | ||
| - must be a FastAPI instance (ignore RustBindingsAPI and SegmentAPI) | ||
| - must have a "api.trychroma.com" in the _api_url (ignore local/self-hosted instances) | ||
| - must have "x-chroma-token" or "X-Chroma-Token" in the headers | ||
|
|
||
| Returns: | ||
| The first api key found, or None if no client instances have api keys set. | ||
| """ | ||
| # check FastAPI instance session headers bc this is where both cloudclient and httpclient paths converge | ||
| for system in SharedSystemClient._identifier_to_system.values(): | ||
| try: | ||
| # get the ServerAPI instance (which is FastAPI for HTTP clients) | ||
| server_api = system.instance(ServerAPI) | ||
|
|
||
| # check if it's a FastAPI instance with a _session attribute | ||
| # RustBindingsAPI and SegmentAPI don't have a session attribute | ||
| if hasattr(server_api, "_session") and hasattr( | ||
| server_api._session, "headers" | ||
| ): | ||
| # only pull api key if the url contains the chroma cloud url | ||
| if ( | ||
| not hasattr(server_api, "_api_url") | ||
| or "api.trychroma.com" not in server_api._api_url | ||
| ): | ||
| continue | ||
|
|
||
| # pull api key from the chroma token header | ||
| headers = server_api._session.headers | ||
| api_key = headers.get("X-Chroma-Token") or headers.get( | ||
| "x-chroma-token" | ||
| ) | ||
| if api_key: | ||
| # header value might be a string or bytes, convert to string | ||
| return str(api_key) | ||
| except Exception: | ||
| # if we can't access the ServerAPI instance or it doesn't have _session, | ||
| # continue to the next system instance | ||
| continue | ||
|
Comment on lines
+138
to
+141
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. [Reliability] The broad except AttributeError as e:
# ServerAPI doesn't have expected attributes (_session or _api_url)
logger.debug(f"Skipping system {system_id}: {e}")
continue
except Exception as e:
# Unexpected errors should be logged for investigation
logger.warning(f"Unexpected error extracting API key from system: {e}")
continueThis distinguishes expected structural variations from genuine errors that need attention. Context for Agents |
||
|
|
||
| return None | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,166 @@ | ||
| import pytest | ||
| from unittest.mock import MagicMock | ||
| from chromadb.api.shared_system_client import SharedSystemClient | ||
| from chromadb.config import System | ||
| from chromadb.api import ServerAPI | ||
| from typing import Optional, Dict, Generator | ||
|
|
||
|
|
||
| @pytest.fixture(autouse=True) | ||
| def clear_cache() -> Generator[None, None, None]: | ||
| """Automatically clear the system cache before and after each test.""" | ||
| SharedSystemClient.clear_system_cache() | ||
| yield | ||
| SharedSystemClient.clear_system_cache() | ||
|
|
||
|
|
||
| def create_mock_server_api( | ||
| api_url: Optional[str] = None, | ||
| headers: Optional[Dict[str, str]] = None, | ||
| has_session: bool = True, | ||
| has_headers_attr: bool = True, | ||
| ) -> MagicMock: | ||
| """Create a mock ServerAPI instance with the specified configuration.""" | ||
| mock_server_api = MagicMock(spec=ServerAPI) | ||
|
|
||
| if api_url: | ||
| mock_server_api._api_url = api_url | ||
|
|
||
| if has_session: | ||
| mock_session = MagicMock() | ||
| if has_headers_attr: | ||
| mock_session.headers = headers or {} | ||
| else: | ||
| # Create a mock without headers attribute | ||
| del mock_session.headers | ||
| mock_server_api._session = mock_session | ||
| else: | ||
| if hasattr(mock_server_api, "_session"): | ||
| del mock_server_api._session | ||
|
|
||
| return mock_server_api | ||
|
|
||
|
|
||
| def register_mock_system(system_id: str, mock_server_api: MagicMock) -> MagicMock: | ||
| """Register a mock system with the given ID and server API.""" | ||
| mock_system = MagicMock(spec=System) | ||
| mock_system.instance.return_value = mock_server_api | ||
| SharedSystemClient._identifier_to_system[system_id] = mock_system | ||
| return mock_system | ||
|
|
||
|
|
||
| def test_extracts_api_key_from_chroma_cloud_client() -> None: | ||
| mock_server_api = create_mock_server_api( | ||
| api_url="https://api.trychroma.com/api/v2", | ||
| headers={"X-Chroma-Token": "test-api-key-123"}, | ||
| ) | ||
| register_mock_system("test-id", mock_server_api) | ||
|
|
||
| api_key = SharedSystemClient.get_chroma_cloud_api_key_from_clients() | ||
|
|
||
| assert api_key == "test-api-key-123" | ||
|
|
||
|
|
||
| def test_extracts_api_key_with_lowercase_header() -> None: | ||
| mock_server_api = create_mock_server_api( | ||
| api_url="https://api.trychroma.com/api/v2", | ||
| headers={"x-chroma-token": "test-api-key-456"}, | ||
| ) | ||
| register_mock_system("test-id", mock_server_api) | ||
|
|
||
| api_key = SharedSystemClient.get_chroma_cloud_api_key_from_clients() | ||
|
|
||
| assert api_key == "test-api-key-456" | ||
|
|
||
|
|
||
| def test_skips_non_chroma_cloud_clients() -> None: | ||
| mock_server_api = create_mock_server_api( | ||
| api_url="https://localhost:8000/api/v2", | ||
| headers={"X-Chroma-Token": "local-api-key"}, | ||
| ) | ||
| register_mock_system("test-id", mock_server_api) | ||
|
|
||
| api_key = SharedSystemClient.get_chroma_cloud_api_key_from_clients() | ||
|
|
||
| assert api_key is None | ||
|
|
||
|
|
||
| def test_skips_clients_without_session() -> None: | ||
| mock_server_api = create_mock_server_api( | ||
| api_url="https://api.trychroma.com/api/v2", | ||
| has_session=False, | ||
| ) | ||
| register_mock_system("test-id", mock_server_api) | ||
|
|
||
| api_key = SharedSystemClient.get_chroma_cloud_api_key_from_clients() | ||
|
|
||
| assert api_key is None | ||
|
|
||
|
|
||
| def test_skips_clients_without_api_url() -> None: | ||
| mock_server_api = create_mock_server_api( | ||
| api_url=None, | ||
| headers={"X-Chroma-Token": "test-api-key"}, | ||
| ) | ||
| register_mock_system("test-id", mock_server_api) | ||
|
|
||
| api_key = SharedSystemClient.get_chroma_cloud_api_key_from_clients() | ||
|
|
||
| assert api_key is None | ||
|
|
||
|
|
||
| def test_returns_none_when_no_api_key_in_headers() -> None: | ||
| mock_server_api = create_mock_server_api( | ||
| api_url="https://api.trychroma.com/api/v2", | ||
| headers={}, | ||
| ) | ||
| register_mock_system("test-id", mock_server_api) | ||
|
|
||
| api_key = SharedSystemClient.get_chroma_cloud_api_key_from_clients() | ||
|
|
||
| assert api_key is None | ||
|
|
||
|
|
||
| def test_returns_first_api_key_found_from_multiple_clients() -> None: | ||
| mock_server_api_1 = create_mock_server_api( | ||
| api_url="https://api.trychroma.com/api/v2", | ||
| headers={"X-Chroma-Token": "first-key"}, | ||
| ) | ||
| mock_server_api_2 = create_mock_server_api( | ||
| api_url="https://api.trychroma.com/api/v2", | ||
| headers={"X-Chroma-Token": "second-key"}, | ||
| ) | ||
| register_mock_system("test-id-1", mock_server_api_1) | ||
| register_mock_system("test-id-2", mock_server_api_2) | ||
|
|
||
| api_key = SharedSystemClient.get_chroma_cloud_api_key_from_clients() | ||
|
|
||
| assert api_key == "first-key" | ||
|
|
||
|
|
||
| def test_handles_exception_gracefully() -> None: | ||
| mock_system = MagicMock(spec=System) | ||
| mock_system.instance.side_effect = Exception("Test exception") | ||
| SharedSystemClient._identifier_to_system["test-id"] = mock_system | ||
|
|
||
| api_key = SharedSystemClient.get_chroma_cloud_api_key_from_clients() | ||
|
|
||
| assert api_key is None | ||
|
|
||
|
|
||
| def test_returns_none_when_no_clients_exist() -> None: | ||
| api_key = SharedSystemClient.get_chroma_cloud_api_key_from_clients() | ||
|
|
||
| assert api_key is None | ||
|
|
||
|
|
||
| def test_skips_chroma_cloud_client_without_headers_attribute() -> None: | ||
| mock_server_api = create_mock_server_api( | ||
| api_url="https://api.trychroma.com/api/v2", | ||
| has_headers_attr=False, | ||
| ) | ||
| register_mock_system("test-id", mock_server_api) | ||
|
|
||
| api_key = SharedSystemClient.get_chroma_cloud_api_key_from_clients() | ||
|
|
||
| assert api_key is None |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -56,9 +56,19 @@ def __init__( | |
| ) | ||
|
|
||
| self.api_key_env_var = api_key_env_var | ||
| # First, try to get API key from environment variable | ||
| self.api_key = os.getenv(api_key_env_var) | ||
| # If not found in env var, try to get it from existing client instances | ||
| if not self.api_key: | ||
| raise ValueError(f"The {api_key_env_var} environment variable is not set.") | ||
| from chromadb.api.shared_system_client import SharedSystemClient | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Please do not do inline imports. Why is this needed? |
||
|
|
||
| self.api_key = SharedSystemClient.get_chroma_cloud_api_key_from_clients() | ||
| # Raise error if still no API key found | ||
| if not self.api_key: | ||
| raise ValueError( | ||
| f"API key not found in environment variable {api_key_env_var} " | ||
| f"or in any existing client instances" | ||
| ) | ||
|
|
||
| self.model = model | ||
| self.task = task | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[Logic] Race condition:
_identifier_to_systemdictionary is accessed without synchronization. If multiple threads callget_chroma_cloud_api_key_from_clients()while another thread modifies the dictionary (via__init__orclear_system_cache()), this can raiseRuntimeError: dictionary changed size during iterationor return inconsistent results.Alternatively, protect dictionary access with a lock if thread-safety is required.
Context for Agents