diff --git a/src/gradient/resources/knowledge_bases/knowledge_bases.py b/src/gradient/resources/knowledge_bases/knowledge_bases.py index 00fa0659..02dc6f1e 100644 --- a/src/gradient/resources/knowledge_bases/knowledge_bases.py +++ b/src/gradient/resources/knowledge_bases/knowledge_bases.py @@ -39,6 +39,7 @@ from ...types.knowledge_base_delete_response import KnowledgeBaseDeleteResponse from ...types.knowledge_base_update_response import KnowledgeBaseUpdateResponse from ...types.knowledge_base_retrieve_response import KnowledgeBaseRetrieveResponse +import time __all__ = ["KnowledgeBasesResource", "AsyncKnowledgeBasesResource"] @@ -181,6 +182,45 @@ def retrieve( cast_to=KnowledgeBaseRetrieveResponse, ) + def wait_for_database_online( + self, + uuid: str, + *, + timeout: float = 600.0, + interval: float = 5.0, + raise_on_failed: bool = True, + ) -> KnowledgeBaseRetrieveResponse: + """ + Polls the knowledge base retrieve endpoint until the underlying database is ONLINE. + + Args: + uuid: Knowledge base id + timeout: Maximum seconds to wait before raising TimeoutError (default 600s). + interval: Seconds between polls (default 5s). + raise_on_failed: If True, raise a RuntimeError when the database reaches a terminal failed state. + + Returns: + The final `KnowledgeBaseRetrieveResponse` object when the database is ONLINE (or the terminal failed + response if `raise_on_failed` is False). + """ + if not uuid: + raise ValueError(f"Expected a non-empty value for `uuid` but received {uuid!r}") + + terminal_failed = {"UNHEALTHY", "DECOMMISSIONED"} + start = time.time() + while True: + resp = self.retrieve(uuid) + status = getattr(resp, "database_status", None) + if status == "ONLINE": + return resp + if status in terminal_failed: + if raise_on_failed: + raise RuntimeError(f"knowledge base {uuid} reached failed state: {status}") + return resp + if time.time() - start >= float(timeout): + raise TimeoutError(f"Timed out waiting for knowledge base {uuid} to become ONLINE") + self._sleep(float(interval)) + def update( self, path_uuid: str, @@ -469,6 +509,35 @@ async def retrieve( cast_to=KnowledgeBaseRetrieveResponse, ) + async def wait_for_database_online( + self, + uuid: str, + *, + timeout: float = 600.0, + interval: float = 5.0, + raise_on_failed: bool = True, + ) -> KnowledgeBaseRetrieveResponse: + """ + Async version of wait_for_database_online: polls until database is ONLINE. + """ + if not uuid: + raise ValueError(f"Expected a non-empty value for `uuid` but received {uuid!r}") + + terminal_failed = {"UNHEALTHY", "DECOMMISSIONED"} + start = time.time() + while True: + resp = await self.retrieve(uuid) + status = getattr(resp, "database_status", None) + if status == "ONLINE": + return resp + if status in terminal_failed: + if raise_on_failed: + raise RuntimeError(f"knowledge base {uuid} reached failed state: {status}") + return resp + if time.time() - start >= float(timeout): + raise TimeoutError(f"Timed out waiting for knowledge base {uuid} to become ONLINE") + await self._sleep(float(interval)) + async def update( self, path_uuid: str, diff --git a/tests/api_resources/test_knowledge_bases_wait.py b/tests/api_resources/test_knowledge_bases_wait.py new file mode 100644 index 00000000..7c4f65a3 --- /dev/null +++ b/tests/api_resources/test_knowledge_bases_wait.py @@ -0,0 +1,85 @@ +import types +import asyncio +import pytest + +from gradient import Gradient + + +class DummyResp: + def __init__(self, status=None): + self.database_status = status + + +def test_wait_for_database_online_success(monkeypatch, client: Gradient): + calls = {"n": 0} + + def fake_retrieve(uuid, **kwargs): + calls["n"] += 1 + # become ONLINE after 3 calls + if calls["n"] >= 3: + return DummyResp("ONLINE") + return DummyResp("CREATING") + + monkeypatch.setattr(client.knowledge_bases, "retrieve", fake_retrieve) + + resp = client.knowledge_bases.wait_for_database_online("kb-1", timeout=5, interval=0.01) + assert resp.database_status == "ONLINE" + assert calls["n"] >= 3 + + +def test_wait_for_database_online_failed_raises(monkeypatch, client: Gradient): + def fake_retrieve(uuid, **kwargs): + return DummyResp("UNHEALTHY") + + monkeypatch.setattr(client.knowledge_bases, "retrieve", fake_retrieve) + + with pytest.raises(RuntimeError): + client.knowledge_bases.wait_for_database_online("kb-2", timeout=1, interval=0.01) + + +def test_wait_for_database_online_timeout(monkeypatch, client: Gradient): + def fake_retrieve(uuid, **kwargs): + return DummyResp("CREATING") + + monkeypatch.setattr(client.knowledge_bases, "retrieve", fake_retrieve) + + with pytest.raises(TimeoutError): + client.knowledge_bases.wait_for_database_online("kb-3", timeout=0.05, interval=0.01) + + +@pytest.mark.asyncio +async def test_async_wait_for_database_online_success(monkeypatch, async_client): + calls = {"n": 0} + + async def fake_retrieve(uuid, **kwargs): + calls["n"] += 1 + if calls["n"] >= 3: + return DummyResp("ONLINE") + return DummyResp("CREATING") + + monkeypatch.setattr(async_client.knowledge_bases, "retrieve", fake_retrieve) + + resp = await async_client.knowledge_bases.wait_for_database_online("kb-1", timeout=5, interval=0.01) + assert resp.database_status == "ONLINE" + + +@pytest.mark.asyncio +async def test_async_wait_for_database_online_failed_raises(monkeypatch, async_client): + async def fake_retrieve(uuid, **kwargs): + return DummyResp("UNHEALTHY") + + monkeypatch.setattr(async_client.knowledge_bases, "retrieve", fake_retrieve) + + with pytest.raises(RuntimeError): + await async_client.knowledge_bases.wait_for_database_online("kb-2", timeout=1, interval=0.01) + + +@pytest.mark.asyncio +async def test_async_wait_for_database_online_timeout(monkeypatch, async_client): + async def fake_retrieve(uuid, **kwargs): + return DummyResp("CREATING") + + monkeypatch.setattr(async_client.knowledge_bases, "retrieve", fake_retrieve) + + with pytest.raises(TimeoutError): + await async_client.knowledge_bases.wait_for_database_online("kb-3", timeout=0.05, interval=0.01)