Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
69 changes: 69 additions & 0 deletions src/gradient/resources/knowledge_bases/knowledge_bases.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
from ...types.knowledge_base_delete_response import KnowledgeBaseDeleteResponse
from ...types.knowledge_base_update_response import KnowledgeBaseUpdateResponse
from ...types.knowledge_base_retrieve_response import KnowledgeBaseRetrieveResponse
import time

__all__ = ["KnowledgeBasesResource", "AsyncKnowledgeBasesResource"]

Expand Down Expand Up @@ -181,6 +182,45 @@ def retrieve(
cast_to=KnowledgeBaseRetrieveResponse,
)

def wait_for_database_online(
self,
uuid: str,
*,
timeout: float = 600.0,
interval: float = 5.0,
raise_on_failed: bool = True,
) -> KnowledgeBaseRetrieveResponse:
"""
Polls the knowledge base retrieve endpoint until the underlying database is ONLINE.

Args:
uuid: Knowledge base id
timeout: Maximum seconds to wait before raising TimeoutError (default 600s).
interval: Seconds between polls (default 5s).
raise_on_failed: If True, raise a RuntimeError when the database reaches a terminal failed state.

Returns:
The final `KnowledgeBaseRetrieveResponse` object when the database is ONLINE (or the terminal failed
response if `raise_on_failed` is False).
"""
if not uuid:
raise ValueError(f"Expected a non-empty value for `uuid` but received {uuid!r}")

terminal_failed = {"UNHEALTHY", "DECOMMISSIONED"}
start = time.time()
while True:
resp = self.retrieve(uuid)
status = getattr(resp, "database_status", None)
if status == "ONLINE":
return resp
if status in terminal_failed:
if raise_on_failed:
raise RuntimeError(f"knowledge base {uuid} reached failed state: {status}")
return resp
if time.time() - start >= float(timeout):
raise TimeoutError(f"Timed out waiting for knowledge base {uuid} to become ONLINE")
self._sleep(float(interval))

def update(
self,
path_uuid: str,
Expand Down Expand Up @@ -469,6 +509,35 @@ async def retrieve(
cast_to=KnowledgeBaseRetrieveResponse,
)

async def wait_for_database_online(
self,
uuid: str,
*,
timeout: float = 600.0,
interval: float = 5.0,
raise_on_failed: bool = True,
) -> KnowledgeBaseRetrieveResponse:
"""
Async version of wait_for_database_online: polls until database is ONLINE.
"""
if not uuid:
raise ValueError(f"Expected a non-empty value for `uuid` but received {uuid!r}")

terminal_failed = {"UNHEALTHY", "DECOMMISSIONED"}
start = time.time()
while True:
resp = await self.retrieve(uuid)
status = getattr(resp, "database_status", None)
if status == "ONLINE":
return resp
if status in terminal_failed:
if raise_on_failed:
raise RuntimeError(f"knowledge base {uuid} reached failed state: {status}")
return resp
if time.time() - start >= float(timeout):
raise TimeoutError(f"Timed out waiting for knowledge base {uuid} to become ONLINE")
await self._sleep(float(interval))

async def update(
self,
path_uuid: str,
Expand Down
85 changes: 85 additions & 0 deletions tests/api_resources/test_knowledge_bases_wait.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
import types
import asyncio
import pytest

from gradient import Gradient


class DummyResp:
def __init__(self, status=None):
self.database_status = status


def test_wait_for_database_online_success(monkeypatch, client: Gradient):
calls = {"n": 0}

def fake_retrieve(uuid, **kwargs):
calls["n"] += 1
# become ONLINE after 3 calls
if calls["n"] >= 3:
return DummyResp("ONLINE")
return DummyResp("CREATING")

monkeypatch.setattr(client.knowledge_bases, "retrieve", fake_retrieve)

resp = client.knowledge_bases.wait_for_database_online("kb-1", timeout=5, interval=0.01)
assert resp.database_status == "ONLINE"
assert calls["n"] >= 3


def test_wait_for_database_online_failed_raises(monkeypatch, client: Gradient):
def fake_retrieve(uuid, **kwargs):
return DummyResp("UNHEALTHY")

monkeypatch.setattr(client.knowledge_bases, "retrieve", fake_retrieve)

with pytest.raises(RuntimeError):
client.knowledge_bases.wait_for_database_online("kb-2", timeout=1, interval=0.01)


def test_wait_for_database_online_timeout(monkeypatch, client: Gradient):
def fake_retrieve(uuid, **kwargs):
return DummyResp("CREATING")

monkeypatch.setattr(client.knowledge_bases, "retrieve", fake_retrieve)

with pytest.raises(TimeoutError):
client.knowledge_bases.wait_for_database_online("kb-3", timeout=0.05, interval=0.01)


@pytest.mark.asyncio
async def test_async_wait_for_database_online_success(monkeypatch, async_client):
calls = {"n": 0}

async def fake_retrieve(uuid, **kwargs):
calls["n"] += 1
if calls["n"] >= 3:
return DummyResp("ONLINE")
return DummyResp("CREATING")

monkeypatch.setattr(async_client.knowledge_bases, "retrieve", fake_retrieve)

resp = await async_client.knowledge_bases.wait_for_database_online("kb-1", timeout=5, interval=0.01)
assert resp.database_status == "ONLINE"


@pytest.mark.asyncio
async def test_async_wait_for_database_online_failed_raises(monkeypatch, async_client):
async def fake_retrieve(uuid, **kwargs):
return DummyResp("UNHEALTHY")

monkeypatch.setattr(async_client.knowledge_bases, "retrieve", fake_retrieve)

with pytest.raises(RuntimeError):
await async_client.knowledge_bases.wait_for_database_online("kb-2", timeout=1, interval=0.01)


@pytest.mark.asyncio
async def test_async_wait_for_database_online_timeout(monkeypatch, async_client):
async def fake_retrieve(uuid, **kwargs):
return DummyResp("CREATING")

monkeypatch.setattr(async_client.knowledge_bases, "retrieve", fake_retrieve)

with pytest.raises(TimeoutError):
await async_client.knowledge_bases.wait_for_database_online("kb-3", timeout=0.05, interval=0.01)