Skip to content

Commit 81fe0e0

Browse files
committed
refactor(client): remove initialize() Method from LlamaStackAsLibrary
Currently client.initialize() had to be invoked by user. To improve dev experience and to avoid runtime errors, this PR init LlamaStackAsLibrary implicitly upon using the client. It prevents also multiple init of the same client, while maintaining backward ccompatibility. Signed-off-by: Mustafa Elbehery <[email protected]>
1 parent eb07a0f commit 81fe0e0

File tree

5 files changed

+78
-89
lines changed

5 files changed

+78
-89
lines changed

docs/source/distributions/importing_as_library.md

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@ client = LlamaStackAsLibraryClient(
1717
# provider_data is optional, but if you need to pass in any provider specific data, you can do so here.
1818
provider_data={"tavily_search_api_key": os.environ["TAVILY_SEARCH_API_KEY"]},
1919
)
20-
client.initialize()
2120
```
2221

2322
This will parse your config and set up any inline implementations and remote clients needed for your implementation.
@@ -32,5 +31,4 @@ If you've created a [custom distribution](https://llama-stack.readthedocs.io/en/
3231

3332
```python
3433
client = LlamaStackAsLibraryClient(config_path)
35-
client.initialize()
3634
```

llama_stack/core/library_client.py

Lines changed: 30 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -145,39 +145,27 @@ def __init__(
145145
):
146146
super().__init__()
147147
self.async_client = AsyncLlamaStackAsLibraryClient(
148-
config_path_or_distro_name, custom_provider_registry, provider_data
148+
config_path_or_distro_name, custom_provider_registry, provider_data, skip_logger_removal
149149
)
150150
self.pool_executor = ThreadPoolExecutor(max_workers=4)
151151
self.skip_logger_removal = skip_logger_removal
152152
self.provider_data = provider_data
153153

154154
self.loop = asyncio.new_event_loop()
155155

156-
def initialize(self):
157-
if in_notebook():
158-
import nest_asyncio
159-
160-
nest_asyncio.apply()
161-
if not self.skip_logger_removal:
162-
self._remove_root_logger_handlers()
163-
164156
# use a new event loop to avoid interfering with the main event loop
165157
loop = asyncio.new_event_loop()
166158
asyncio.set_event_loop(loop)
167159
try:
168-
return loop.run_until_complete(self.async_client.initialize())
160+
loop.run_until_complete(self.async_client.initialize())
169161
finally:
170162
asyncio.set_event_loop(None)
171163

172-
def _remove_root_logger_handlers(self):
164+
def initialize(self):
173165
"""
174-
Remove all handlers from the root logger. Needed to avoid polluting the console with logs.
166+
Deprecated method for backward compatibility.
175167
"""
176-
root_logger = logging.getLogger()
177-
178-
for handler in root_logger.handlers[:]:
179-
root_logger.removeHandler(handler)
180-
logger.info(f"Removed handler {handler.__class__.__name__} from root logger")
168+
pass
181169

182170
def request(self, *args, **kwargs):
183171
loop = self.loop
@@ -215,13 +203,21 @@ def __init__(
215203
config_path_or_distro_name: str,
216204
custom_provider_registry: ProviderRegistry | None = None,
217205
provider_data: dict[str, Any] | None = None,
206+
skip_logger_removal: bool = False,
218207
):
219208
super().__init__()
220209
# when using the library client, we should not log to console since many
221210
# of our logs are intended for server-side usage
222211
current_sinks = os.environ.get("TELEMETRY_SINKS", "sqlite").split(",")
223212
os.environ["TELEMETRY_SINKS"] = ",".join(sink for sink in current_sinks if sink != "console")
224213

214+
if in_notebook():
215+
import nest_asyncio
216+
217+
nest_asyncio.apply()
218+
if not skip_logger_removal:
219+
self._remove_root_logger_handlers()
220+
225221
if config_path_or_distro_name.endswith(".yaml"):
226222
config_path = Path(config_path_or_distro_name)
227223
if not config_path.exists():
@@ -238,7 +234,24 @@ def __init__(
238234
self.provider_data = provider_data
239235
self.route_impls: RouteImpls | None = None # Initialize to None to prevent AttributeError
240236

237+
def _remove_root_logger_handlers(self):
238+
"""
239+
Remove all handlers from the root logger. Needed to avoid polluting the console with logs.
240+
"""
241+
root_logger = logging.getLogger()
242+
243+
for handler in root_logger.handlers[:]:
244+
root_logger.removeHandler(handler)
245+
logger.info(f"Removed handler {handler.__class__.__name__} from root logger")
246+
241247
async def initialize(self) -> bool:
248+
"""
249+
Initialize the async client. Can be called multiple times safely.
250+
251+
Returns:
252+
bool: True if initialization was successful
253+
"""
254+
242255
try:
243256
self.route_impls = None
244257
self.impls = await construct_stack(self.config, self.custom_provider_registry)
@@ -298,9 +311,6 @@ async def request(
298311
stream=False,
299312
stream_cls=None,
300313
):
301-
if self.route_impls is None:
302-
raise ValueError("Client not initialized. Please call initialize() first.")
303-
304314
# Create headers with provider data if available
305315
headers = options.headers or {}
306316
if self.provider_data:

tests/integration/fixtures/common.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -256,9 +256,7 @@ def instantiate_llama_stack_client(session):
256256
provider_data=get_provider_data(),
257257
skip_logger_removal=True,
258258
)
259-
if not client.initialize():
260-
raise RuntimeError("Initialization failed")
261-
259+
# Client is automatically initialized during construction
262260
return client
263261

264262

tests/integration/non_ci/responses/fixtures/fixtures.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -113,8 +113,7 @@ def openai_client(base_url, api_key, provider):
113113
raise ValueError(f"Invalid config for Llama Stack: {provider}, it must be of the form 'stack:<config>'")
114114
config = parts[1]
115115
client = LlamaStackAsLibraryClient(config, skip_logger_removal=True)
116-
if not client.initialize():
117-
raise RuntimeError("Initialization failed")
116+
# Client is automatically initialized during construction
118117
return client
119118

120119
return OpenAI(

tests/unit/distribution/test_library_client_initialization.py

Lines changed: 46 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -5,86 +5,70 @@
55
# the root directory of this source tree.
66

77
"""
8-
Unit tests for LlamaStackAsLibraryClient initialization error handling.
8+
Unit tests for LlamaStackAsLibraryClient automatic initialization.
99
10-
These tests ensure that users get proper error messages when they forget to call
11-
initialize() on the library client, preventing AttributeError regressions.
10+
These tests ensure that the library client is automatically initialized
11+
and ready to use immediately after construction.
1212
"""
1313

14-
import pytest
15-
1614
from llama_stack.core.library_client import (
1715
AsyncLlamaStackAsLibraryClient,
1816
LlamaStackAsLibraryClient,
1917
)
2018

2119

22-
class TestLlamaStackAsLibraryClientInitialization:
23-
"""Test proper error handling for uninitialized library clients."""
24-
25-
@pytest.mark.parametrize(
26-
"api_call",
27-
[
28-
lambda client: client.models.list(),
29-
lambda client: client.chat.completions.create(model="test", messages=[{"role": "user", "content": "test"}]),
30-
lambda client: next(
31-
client.chat.completions.create(
32-
model="test", messages=[{"role": "user", "content": "test"}], stream=True
33-
)
34-
),
35-
],
36-
ids=["models.list", "chat.completions.create", "chat.completions.create_stream"],
37-
)
38-
def test_sync_client_proper_error_without_initialization(self, api_call):
39-
"""Test that sync client raises ValueError with helpful message when not initialized."""
20+
class TestLlamaStackAsLibraryClientAutoInitialization:
21+
"""Test automatic initialization of library clients."""
22+
23+
def test_sync_client_auto_initialization(self):
24+
"""Test that sync client is automatically initialized after construction."""
4025
client = LlamaStackAsLibraryClient("nvidia")
4126

42-
with pytest.raises(ValueError) as exc_info:
43-
api_call(client)
44-
45-
error_msg = str(exc_info.value)
46-
assert "Client not initialized" in error_msg
47-
assert "Please call initialize() first" in error_msg
48-
49-
@pytest.mark.parametrize(
50-
"api_call",
51-
[
52-
lambda client: client.models.list(),
53-
lambda client: client.chat.completions.create(model="test", messages=[{"role": "user", "content": "test"}]),
54-
],
55-
ids=["models.list", "chat.completions.create"],
56-
)
57-
async def test_async_client_proper_error_without_initialization(self, api_call):
58-
"""Test that async client raises ValueError with helpful message when not initialized."""
27+
# Client should be automatically initialized
28+
assert client.async_client._is_initialized is True
29+
assert client.async_client.route_impls is not None
30+
31+
async def test_async_client_auto_initialization(self):
32+
"""Test that async client can be initialized and works properly."""
5933
client = AsyncLlamaStackAsLibraryClient("nvidia")
6034

61-
with pytest.raises(ValueError) as exc_info:
62-
await api_call(client)
35+
# Initialize the client
36+
result = await client.initialize()
37+
assert result is True
38+
assert client._is_initialized is True
39+
assert client.route_impls is not None
40+
41+
def test_initialize_method_backward_compatibility(self):
42+
"""Test that initialize() method still works for backward compatibility."""
43+
client = LlamaStackAsLibraryClient("nvidia")
44+
45+
# initialize() should return None (historical behavior) and not cause errors
46+
result = client.initialize()
47+
assert result is None
6348

64-
error_msg = str(exc_info.value)
65-
assert "Client not initialized" in error_msg
66-
assert "Please call initialize() first" in error_msg
49+
# Multiple calls should be safe
50+
result2 = client.initialize()
51+
assert result2 is None
6752

68-
async def test_async_client_streaming_error_without_initialization(self):
69-
"""Test that async client streaming raises ValueError with helpful message when not initialized."""
53+
async def test_async_initialize_method_idempotent(self):
54+
"""Test that async initialize() method can be called multiple times safely."""
7055
client = AsyncLlamaStackAsLibraryClient("nvidia")
7156

72-
with pytest.raises(ValueError) as exc_info:
73-
stream = await client.chat.completions.create(
74-
model="test", messages=[{"role": "user", "content": "test"}], stream=True
75-
)
76-
await anext(stream)
57+
# First initialization
58+
result1 = await client.initialize()
59+
assert result1 is True
60+
assert client._is_initialized is True
7761

78-
error_msg = str(exc_info.value)
79-
assert "Client not initialized" in error_msg
80-
assert "Please call initialize() first" in error_msg
62+
# Second initialization should be safe and return True
63+
result2 = await client.initialize()
64+
assert result2 is True
65+
assert client._is_initialized is True
8166

82-
def test_route_impls_initialized_to_none(self):
83-
"""Test that route_impls is initialized to None to prevent AttributeError."""
84-
# Test sync client
67+
def test_route_impls_automatically_set(self):
68+
"""Test that route_impls is automatically set during construction."""
69+
# Test sync client - should be auto-initialized
8570
sync_client = LlamaStackAsLibraryClient("nvidia")
86-
assert sync_client.async_client.route_impls is None
71+
assert sync_client.async_client.route_impls is not None
8772

88-
# Test async client directly
89-
async_client = AsyncLlamaStackAsLibraryClient("nvidia")
90-
assert async_client.route_impls is None
73+
# Test that the async client is marked as initialized
74+
assert sync_client.async_client._is_initialized is True

0 commit comments

Comments
 (0)