Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 0 additions & 2 deletions docs/source/distributions/importing_as_library.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@ client = LlamaStackAsLibraryClient(
# provider_data is optional, but if you need to pass in any provider specific data, you can do so here.
provider_data={"tavily_search_api_key": os.environ["TAVILY_SEARCH_API_KEY"]},
)
client.initialize()
```

This will parse your config and set up any inline implementations and remote clients needed for your implementation.
Expand All @@ -32,5 +31,4 @@ If you've created a [custom distribution](https://llama-stack.readthedocs.io/en/

```python
client = LlamaStackAsLibraryClient(config_path)
client.initialize()
```
48 changes: 30 additions & 18 deletions llama_stack/core/library_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,39 +146,26 @@ def __init__(
):
super().__init__()
self.async_client = AsyncLlamaStackAsLibraryClient(
config_path_or_distro_name, custom_provider_registry, provider_data
config_path_or_distro_name, custom_provider_registry, provider_data, skip_logger_removal
)
self.pool_executor = ThreadPoolExecutor(max_workers=4)
self.skip_logger_removal = skip_logger_removal
self.provider_data = provider_data

self.loop = asyncio.new_event_loop()

def initialize(self):
if in_notebook():
import nest_asyncio

nest_asyncio.apply()
if not self.skip_logger_removal:
self._remove_root_logger_handlers()

# use a new event loop to avoid interfering with the main event loop
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
try:
return loop.run_until_complete(self.async_client.initialize())
loop.run_until_complete(self.async_client.initialize())
finally:
asyncio.set_event_loop(None)

def _remove_root_logger_handlers(self):
def initialize(self):
"""
Remove all handlers from the root logger. Needed to avoid polluting the console with logs.
Deprecated method for backward compatibility.
"""
root_logger = logging.getLogger()

for handler in root_logger.handlers[:]:
root_logger.removeHandler(handler)
logger.info(f"Removed handler {handler.__class__.__name__} from root logger")
pass

def request(self, *args, **kwargs):
loop = self.loop
Expand Down Expand Up @@ -216,13 +203,21 @@ def __init__(
config_path_or_distro_name: str,
custom_provider_registry: ProviderRegistry | None = None,
provider_data: dict[str, Any] | None = None,
skip_logger_removal: bool = False,
):
super().__init__()
# when using the library client, we should not log to console since many
# of our logs are intended for server-side usage
current_sinks = os.environ.get("TELEMETRY_SINKS", "sqlite").split(",")
os.environ["TELEMETRY_SINKS"] = ",".join(sink for sink in current_sinks if sink != "console")

if in_notebook():
import nest_asyncio

nest_asyncio.apply()
if not skip_logger_removal:
self._remove_root_logger_handlers()

if config_path_or_distro_name.endswith(".yaml"):
config_path = Path(config_path_or_distro_name)
if not config_path.exists():
Expand All @@ -239,7 +234,24 @@ def __init__(
self.provider_data = provider_data
self.route_impls: RouteImpls | None = None # Initialize to None to prevent AttributeError

def _remove_root_logger_handlers(self):
"""
Remove all handlers from the root logger. Needed to avoid polluting the console with logs.
"""
root_logger = logging.getLogger()

for handler in root_logger.handlers[:]:
root_logger.removeHandler(handler)
logger.info(f"Removed handler {handler.__class__.__name__} from root logger")

async def initialize(self) -> bool:
"""
Initialize the async client.

Returns:
bool: True if initialization was successful
"""

try:
self.route_impls = None
self.impls = await construct_stack(self.config, self.custom_provider_registry)
Expand Down
3 changes: 0 additions & 3 deletions tests/integration/fixtures/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,9 +256,6 @@ def instantiate_llama_stack_client(session):
provider_data=get_provider_data(),
skip_logger_removal=True,
)
if not client.initialize():
raise RuntimeError("Initialization failed")

return client


Expand Down
2 changes: 0 additions & 2 deletions tests/integration/non_ci/responses/fixtures/fixtures.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,8 +113,6 @@ def openai_client(base_url, api_key, provider):
raise ValueError(f"Invalid config for Llama Stack: {provider}, it must be of the form 'stack:<config>'")
config = parts[1]
client = LlamaStackAsLibraryClient(config, skip_logger_removal=True)
if not client.initialize():
raise RuntimeError("Initialization failed")
return client

return OpenAI(
Expand Down
183 changes: 109 additions & 74 deletions tests/unit/distribution/test_library_client_initialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,86 +5,121 @@
# the root directory of this source tree.

"""
Unit tests for LlamaStackAsLibraryClient initialization error handling.
Unit tests for LlamaStackAsLibraryClient automatic initialization.

These tests ensure that users get proper error messages when they forget to call
initialize() on the library client, preventing AttributeError regressions.
These tests ensure that the library client is automatically initialized
and ready to use immediately after construction.
"""

import pytest

from llama_stack.core.library_client import (
AsyncLlamaStackAsLibraryClient,
LlamaStackAsLibraryClient,
)
from llama_stack.core.server.routes import RouteImpls


class TestLlamaStackAsLibraryClientAutoInitialization:
"""Test automatic initialization of library clients."""

def test_sync_client_auto_initialization(self, monkeypatch):
"""Test that sync client is automatically initialized after construction."""
# Mock the stack construction to avoid dependency issues
mock_impls = {}
mock_route_impls = RouteImpls({})

async def mock_construct_stack(config, custom_provider_registry):
return mock_impls

def mock_initialize_route_impls(impls):
return mock_route_impls

monkeypatch.setattr("llama_stack.core.library_client.construct_stack", mock_construct_stack)
monkeypatch.setattr("llama_stack.core.library_client.initialize_route_impls", mock_initialize_route_impls)

client = LlamaStackAsLibraryClient("ci-tests")

assert client.async_client.route_impls is not None

async def test_async_client_auto_initialization(self, monkeypatch):
"""Test that async client can be initialized and works properly."""
# Mock the stack construction to avoid dependency issues
mock_impls = {}
mock_route_impls = RouteImpls({})

async def mock_construct_stack(config, custom_provider_registry):
return mock_impls

def mock_initialize_route_impls(impls):
return mock_route_impls

monkeypatch.setattr("llama_stack.core.library_client.construct_stack", mock_construct_stack)
monkeypatch.setattr("llama_stack.core.library_client.initialize_route_impls", mock_initialize_route_impls)

client = AsyncLlamaStackAsLibraryClient("ci-tests")

# Initialize the client
result = await client.initialize()
assert result is True
assert client.route_impls is not None

def test_initialize_method_backward_compatibility(self, monkeypatch):
"""Test that initialize() method still works for backward compatibility."""
# Mock the stack construction to avoid dependency issues
mock_impls = {}
mock_route_impls = RouteImpls({})

async def mock_construct_stack(config, custom_provider_registry):
return mock_impls

def mock_initialize_route_impls(impls):
return mock_route_impls

monkeypatch.setattr("llama_stack.core.library_client.construct_stack", mock_construct_stack)
monkeypatch.setattr("llama_stack.core.library_client.initialize_route_impls", mock_initialize_route_impls)

client = LlamaStackAsLibraryClient("ci-tests")

result = client.initialize()
assert result is None

result2 = client.initialize()
assert result2 is None

async def test_async_initialize_method_idempotent(self, monkeypatch):
"""Test that async initialize() method can be called multiple times safely."""
mock_impls = {}
mock_route_impls = RouteImpls({})

async def mock_construct_stack(config, custom_provider_registry):
return mock_impls

def mock_initialize_route_impls(impls):
return mock_route_impls

monkeypatch.setattr("llama_stack.core.library_client.construct_stack", mock_construct_stack)
monkeypatch.setattr("llama_stack.core.library_client.initialize_route_impls", mock_initialize_route_impls)

client = AsyncLlamaStackAsLibraryClient("ci-tests")

result1 = await client.initialize()
assert result1 is True

result2 = await client.initialize()
assert result2 is True

def test_route_impls_automatically_set(self, monkeypatch):
"""Test that route_impls is automatically set during construction."""
mock_impls = {}
mock_route_impls = RouteImpls({})

async def mock_construct_stack(config, custom_provider_registry):
return mock_impls

def mock_initialize_route_impls(impls):
return mock_route_impls

monkeypatch.setattr("llama_stack.core.library_client.construct_stack", mock_construct_stack)
monkeypatch.setattr("llama_stack.core.library_client.initialize_route_impls", mock_initialize_route_impls)

class TestLlamaStackAsLibraryClientInitialization:
"""Test proper error handling for uninitialized library clients."""

@pytest.mark.parametrize(
"api_call",
[
lambda client: client.models.list(),
lambda client: client.chat.completions.create(model="test", messages=[{"role": "user", "content": "test"}]),
lambda client: next(
client.chat.completions.create(
model="test", messages=[{"role": "user", "content": "test"}], stream=True
)
),
],
ids=["models.list", "chat.completions.create", "chat.completions.create_stream"],
)
def test_sync_client_proper_error_without_initialization(self, api_call):
"""Test that sync client raises ValueError with helpful message when not initialized."""
client = LlamaStackAsLibraryClient("nvidia")

with pytest.raises(ValueError) as exc_info:
api_call(client)

error_msg = str(exc_info.value)
assert "Client not initialized" in error_msg
assert "Please call initialize() first" in error_msg

@pytest.mark.parametrize(
"api_call",
[
lambda client: client.models.list(),
lambda client: client.chat.completions.create(model="test", messages=[{"role": "user", "content": "test"}]),
],
ids=["models.list", "chat.completions.create"],
)
async def test_async_client_proper_error_without_initialization(self, api_call):
"""Test that async client raises ValueError with helpful message when not initialized."""
client = AsyncLlamaStackAsLibraryClient("nvidia")

with pytest.raises(ValueError) as exc_info:
await api_call(client)

error_msg = str(exc_info.value)
assert "Client not initialized" in error_msg
assert "Please call initialize() first" in error_msg

async def test_async_client_streaming_error_without_initialization(self):
"""Test that async client streaming raises ValueError with helpful message when not initialized."""
client = AsyncLlamaStackAsLibraryClient("nvidia")

with pytest.raises(ValueError) as exc_info:
stream = await client.chat.completions.create(
model="test", messages=[{"role": "user", "content": "test"}], stream=True
)
await anext(stream)

error_msg = str(exc_info.value)
assert "Client not initialized" in error_msg
assert "Please call initialize() first" in error_msg

def test_route_impls_initialized_to_none(self):
"""Test that route_impls is initialized to None to prevent AttributeError."""
# Test sync client
sync_client = LlamaStackAsLibraryClient("nvidia")
assert sync_client.async_client.route_impls is None

# Test async client directly
async_client = AsyncLlamaStackAsLibraryClient("nvidia")
assert async_client.route_impls is None
sync_client = LlamaStackAsLibraryClient("ci-tests")
assert sync_client.async_client.route_impls is not None
Loading