Skip to content

Commit 8f0413e

Browse files
committed
improve agent metrics integration test and cleanup fixtures
- simplified test to use telemetry.query_metrics for verification - test now validates actual queryable metrics data - verified by query metrics functionality added in llamastack#3074
1 parent 69b692a commit 8f0413e

File tree

5 files changed

+393
-195
lines changed

5 files changed

+393
-195
lines changed

llama_stack/providers/inline/agents/meta_reference/agent_instance.py

Lines changed: 51 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@
6363
UserMessage,
6464
)
6565
from llama_stack.apis.safety import Safety
66-
from llama_stack.apis.telemetry import MetricEvent, Telemetry
66+
from llama_stack.apis.telemetry import MetricEvent, MetricType, Telemetry
6767
from llama_stack.apis.tools import ToolGroups, ToolInvocationResult, ToolRuntime
6868
from llama_stack.apis.vector_io import VectorIO
6969
from llama_stack.core.datatypes import AccessRule
@@ -124,6 +124,9 @@ def __init__(
124124
output_shields=agent_config.output_shields,
125125
)
126126

127+
# Initialize workflow start time to None
128+
self._workflow_start_time: float | None = None
129+
127130
def turn_to_messages(self, turn: Turn) -> list[Message]:
128131
messages = []
129132

@@ -174,14 +177,23 @@ async def create_session(self, name: str) -> str:
174177
return await self.storage.create_session(name)
175178

176179
def _emit_metric(
177-
self, metric_name: str, value: int | float, unit: str, attributes: dict[str, str] | None = None
180+
self,
181+
metric_name: str,
182+
value: int | float,
183+
unit: str,
184+
attributes: dict[str, str] | None = None,
185+
metric_type: MetricType | None = None,
178186
) -> None:
179187
"""Emit a single metric event"""
188+
logger.info(f"_emit_metric called: {metric_name} = {value} {unit}")
189+
180190
if not self.telemetry_api:
191+
logger.warning(f"No telemetry_api available for metric {metric_name}")
181192
return
182193

183194
span = get_current_span()
184195
if not span:
196+
logger.warning(f"No current span available for metric {metric_name}")
185197
return
186198

187199
context = span.get_span_context()
@@ -193,22 +205,42 @@ def _emit_metric(
193205
timestamp=time.time(),
194206
unit=unit,
195207
attributes={"agent_id": self.agent_id, **(attributes or {})},
208+
metric_type=metric_type,
196209
)
197210

198-
# Create task with name for better debugging and potential cleanup
211+
# Create task with name for better debugging and capture any async errors
199212
task_name = f"metric-{metric_name}-{self.agent_id}"
200-
asyncio.create_task(self.telemetry_api.log_event(metric), name=task_name)
213+
logger.info(f"Creating telemetry task: {task_name}")
214+
task = asyncio.create_task(self.telemetry_api.log_event(metric), name=task_name)
215+
216+
def _on_metric_task_done(t: asyncio.Task) -> None:
217+
try:
218+
exc = t.exception()
219+
except asyncio.CancelledError:
220+
logger.debug("Metric task %s was cancelled", task_name)
221+
return
222+
if exc is not None:
223+
logger.warning("Metric task %s failed: %s", task_name, exc)
224+
225+
# Only add callback if task creation succeeded (not None from mocking)
226+
if task is not None:
227+
task.add_done_callback(_on_metric_task_done)
201228

202229
def _track_step(self):
203-
self._emit_metric("llama_stack_agent_steps_total", 1, "1")
230+
logger.info("_track_step called")
231+
self._emit_metric("llama_stack_agent_steps_total", 1, "1", metric_type=MetricType.COUNTER)
204232

205233
def _track_workflow(self, status: str, duration: float):
206-
self._emit_metric("llama_stack_agent_workflows_total", 1, "1", {"status": status})
207-
self._emit_metric("llama_stack_agent_workflow_duration_seconds", duration, "s")
234+
logger.info(f"_track_workflow called: status={status}, duration={duration:.2f}s")
235+
self._emit_metric("llama_stack_agent_workflows_total", 1, "1", {"status": status}, MetricType.COUNTER)
236+
self._emit_metric(
237+
"llama_stack_agent_workflow_duration_seconds", duration, "s", metric_type=MetricType.HISTOGRAM
238+
)
208239

209240
def _track_tool(self, tool_name: str):
241+
logger.info(f"_track_tool called: {tool_name}")
210242
normalized_name = "rag" if tool_name == "knowledge_search" else tool_name
211-
self._emit_metric("llama_stack_agent_tool_calls_total", 1, "1", {"tool": normalized_name})
243+
self._emit_metric("llama_stack_agent_tool_calls_total", 1, "1", {"tool": normalized_name}, MetricType.COUNTER)
212244

213245
async def get_messages_from_turns(self, turns: list[Turn]) -> list[Message]:
214246
messages = []
@@ -244,6 +276,9 @@ async def resume_turn(self, request: AgentTurnResumeRequest) -> AsyncGenerator:
244276
if self.agent_config.name:
245277
span.set_attribute("agent_name", self.agent_config.name)
246278

279+
# Set workflow start time for resume operations
280+
self._workflow_start_time = time.time()
281+
247282
await self._initialize_tools()
248283
async for chunk in self._run_turn(request):
249284
yield chunk
@@ -255,6 +290,9 @@ async def _run_turn(
255290
) -> AsyncGenerator:
256291
assert request.stream is True, "Non-streaming not supported"
257292

293+
# Track workflow start time for metrics
294+
self._workflow_start_time = time.time()
295+
258296
is_resume = isinstance(request, AgentTurnResumeRequest)
259297
session_info = await self.storage.get_session_info(request.session_id)
260298
if session_info is None:
@@ -356,6 +394,10 @@ async def _run_turn(
356394
)
357395
)
358396
else:
397+
# Track workflow completion when turn is actually complete
398+
workflow_duration = time.time() - (self._workflow_start_time or time.time())
399+
self._track_workflow("completed", workflow_duration)
400+
359401
chunk = AgentTurnResponseStreamChunk(
360402
event=AgentTurnResponseEvent(
361403
payload=AgentTurnResponseTurnCompletePayload(
@@ -771,6 +813,7 @@ async def _run(
771813

772814
# Track step execution metric
773815
self._track_step()
816+
self._track_tool(tool_call.tool_name)
774817

775818
# Add the result message to input_messages for the next iteration
776819
input_messages.append(result_message)

llama_stack/providers/inline/telemetry/meta_reference/telemetry.py

Lines changed: 18 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,9 @@
1212
from opentelemetry.exporter.otlp.proto.http.metric_exporter import OTLPMetricExporter
1313
from opentelemetry.exporter.otlp.proto.http.trace_exporter import OTLPSpanExporter
1414
from opentelemetry.sdk.metrics import MeterProvider
15+
from opentelemetry.sdk.metrics._internal.aggregation import ExplicitBucketHistogramAggregation
1516
from opentelemetry.sdk.metrics.export import PeriodicExportingMetricReader
17+
from opentelemetry.sdk.metrics.view import View
1618
from opentelemetry.sdk.resources import Resource
1719
from opentelemetry.sdk.trace import TracerProvider
1820
from opentelemetry.sdk.trace.export import BatchSpanProcessor
@@ -110,7 +112,17 @@ def __init__(self, config: TelemetryConfig, deps: dict[Api, Any]) -> None:
110112

111113
if TelemetrySink.OTEL_METRIC in self.config.sinks:
112114
metric_reader = PeriodicExportingMetricReader(OTLPMetricExporter())
113-
metric_provider = MeterProvider(resource=resource, metric_readers=[metric_reader])
115+
116+
# decent default buckets for agent workflow timings
117+
hist_buckets = [0.1, 0.5, 1.0, 2.5, 5.0, 10.0, 25.0, 50.0, 100.0]
118+
views = [
119+
View(
120+
instrument_type=metrics.Histogram,
121+
aggregation=ExplicitBucketHistogramAggregation(boundaries=hist_buckets),
122+
)
123+
]
124+
125+
metric_provider = MeterProvider(resource=resource, metric_readers=[metric_reader], views=views)
114126
metrics.set_meter_provider(metric_provider)
115127

116128
if TelemetrySink.SQLITE in self.config.sinks:
@@ -140,8 +152,6 @@ async def log_event(self, event: Event, ttl_seconds: int = 604800) -> None:
140152
self._log_metric(event)
141153
elif isinstance(event, StructuredLogEvent):
142154
self._log_structured(event, ttl_seconds)
143-
else:
144-
raise ValueError(f"Unknown event type: {event}")
145155

146156
async def query_metrics(
147157
self,
@@ -211,7 +221,7 @@ def _get_or_create_counter(self, name: str, unit: str) -> metrics.Counter:
211221
_GLOBAL_STORAGE["counters"][name] = self.meter.create_counter(
212222
name=name,
213223
unit=unit,
214-
description=f"Counter for {name}",
224+
description=name.replace("_", " "),
215225
)
216226
return _GLOBAL_STORAGE["counters"][name]
217227

@@ -221,7 +231,7 @@ def _get_or_create_gauge(self, name: str, unit: str) -> metrics.ObservableGauge:
221231
_GLOBAL_STORAGE["gauges"][name] = self.meter.create_gauge(
222232
name=name,
223233
unit=unit,
224-
description=f"Gauge for {name}",
234+
description=name.replace("_", " "),
225235
)
226236
return _GLOBAL_STORAGE["gauges"][name]
227237

@@ -265,7 +275,6 @@ def _log_metric(self, event: MetricEvent) -> None:
265275
histogram = self._get_or_create_histogram(
266276
event.metric,
267277
event.unit,
268-
[0.1, 0.5, 1.0, 2.5, 5.0, 10.0, 25.0, 50.0, 100.0],
269278
)
270279
histogram.record(event.value, attributes=event.attributes)
271280
elif event.metric_type == MetricType.UP_DOWN_COUNTER:
@@ -281,17 +290,17 @@ def _get_or_create_up_down_counter(self, name: str, unit: str) -> metrics.UpDown
281290
_GLOBAL_STORAGE["up_down_counters"][name] = self.meter.create_up_down_counter(
282291
name=name,
283292
unit=unit,
284-
description=f"UpDownCounter for {name}",
293+
description=name.replace("_", " "),
285294
)
286295
return _GLOBAL_STORAGE["up_down_counters"][name]
287296

288-
def _get_or_create_histogram(self, name: str, unit: str, buckets: list[float] | None = None) -> metrics.Histogram:
297+
def _get_or_create_histogram(self, name: str, unit: str) -> metrics.Histogram:
289298
assert self.meter is not None
290299
if name not in _GLOBAL_STORAGE["histograms"]:
291300
_GLOBAL_STORAGE["histograms"][name] = self.meter.create_histogram(
292301
name=name,
293302
unit=unit,
294-
description=f"Histogram for {name}",
303+
description=name.replace("_", " "),
295304
)
296305
return _GLOBAL_STORAGE["histograms"][name]
297306

Lines changed: 170 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,170 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the terms described in the LICENSE file in
5+
# the root directory of this source tree.
6+
7+
from collections.abc import AsyncGenerator, Callable
8+
from pathlib import Path
9+
from typing import Any
10+
from unittest.mock import Mock, patch
11+
12+
import pytest
13+
14+
from llama_stack.apis.inference import ToolDefinition
15+
from llama_stack.apis.tools import ToolInvocationResult
16+
from llama_stack.providers.inline.agents.meta_reference.agent_instance import ChatAgent
17+
from llama_stack.providers.inline.telemetry.meta_reference.config import (
18+
TelemetryConfig,
19+
TelemetrySink,
20+
)
21+
from llama_stack.providers.inline.telemetry.meta_reference.telemetry import (
22+
TelemetryAdapter,
23+
)
24+
from llama_stack.providers.utils.kvstore.config import SqliteKVStoreConfig
25+
from llama_stack.providers.utils.kvstore.sqlite.sqlite import SqliteKVStoreImpl
26+
from llama_stack.providers.utils.telemetry import tracing as telemetry_tracing
27+
28+
29+
@pytest.fixture
30+
def make_agent_fixture():
31+
def _make(telemetry, kvstore) -> ChatAgent:
32+
agent = ChatAgent(
33+
agent_id="test-agent",
34+
agent_config=Mock(),
35+
inference_api=Mock(),
36+
safety_api=Mock(),
37+
tool_runtime_api=Mock(),
38+
tool_groups_api=Mock(),
39+
vector_io_api=Mock(),
40+
telemetry_api=telemetry,
41+
persistence_store=kvstore,
42+
created_at="2025-01-01T00:00:00Z",
43+
policy=[],
44+
)
45+
agent.agent_config.client_tools = []
46+
agent.agent_config.max_infer_iters = 5
47+
agent.input_shields = []
48+
agent.output_shields = []
49+
agent.tool_defs = [
50+
ToolDefinition(tool_name="web_search", description="", parameters={}),
51+
ToolDefinition(tool_name="knowledge_search", description="", parameters={}),
52+
]
53+
agent.tool_name_to_args = {}
54+
55+
# Stub tool runtime invoke_tool
56+
async def _mock_invoke_tool(
57+
*args: Any,
58+
tool_name: str | None = None,
59+
kwargs: dict | None = None,
60+
**extra: Any,
61+
):
62+
return ToolInvocationResult(content="Tool execution result")
63+
64+
agent.tool_runtime_api.invoke_tool = _mock_invoke_tool
65+
return agent
66+
67+
return _make
68+
69+
70+
def _chat_stream(tool_name: str | None, content: str = ""):
71+
from llama_stack.apis.common.content_types import (
72+
TextDelta,
73+
ToolCallDelta,
74+
ToolCallParseStatus,
75+
)
76+
from llama_stack.apis.inference import (
77+
ChatCompletionResponseEvent,
78+
ChatCompletionResponseEventType,
79+
ChatCompletionResponseStreamChunk,
80+
StopReason,
81+
)
82+
from llama_stack.models.llama.datatypes import ToolCall
83+
84+
async def gen():
85+
# Start
86+
yield ChatCompletionResponseStreamChunk(
87+
event=ChatCompletionResponseEvent(
88+
event_type=ChatCompletionResponseEventType.start,
89+
delta=TextDelta(text=""),
90+
)
91+
)
92+
93+
# Content
94+
if content:
95+
yield ChatCompletionResponseStreamChunk(
96+
event=ChatCompletionResponseEvent(
97+
event_type=ChatCompletionResponseEventType.progress,
98+
delta=TextDelta(text=content),
99+
)
100+
)
101+
102+
# Tool call if specified
103+
if tool_name:
104+
yield ChatCompletionResponseStreamChunk(
105+
event=ChatCompletionResponseEvent(
106+
event_type=ChatCompletionResponseEventType.progress,
107+
delta=ToolCallDelta(
108+
tool_call=ToolCall(call_id="call_0", tool_name=tool_name, arguments={}),
109+
parse_status=ToolCallParseStatus.succeeded,
110+
),
111+
)
112+
)
113+
114+
# Complete
115+
yield ChatCompletionResponseStreamChunk(
116+
event=ChatCompletionResponseEvent(
117+
event_type=ChatCompletionResponseEventType.complete,
118+
delta=TextDelta(text=""),
119+
stop_reason=StopReason.end_of_turn,
120+
)
121+
)
122+
123+
return gen()
124+
125+
126+
@pytest.fixture
127+
async def telemetry(tmp_path: Path) -> AsyncGenerator[TelemetryAdapter, None]:
128+
db_path = tmp_path / "trace_store.db"
129+
cfg = TelemetryConfig(
130+
sinks=[TelemetrySink.CONSOLE, TelemetrySink.SQLITE],
131+
sqlite_db_path=str(db_path),
132+
)
133+
telemetry = TelemetryAdapter(cfg, deps={})
134+
telemetry_tracing.setup_logger(telemetry)
135+
try:
136+
yield telemetry
137+
finally:
138+
await telemetry.shutdown()
139+
140+
141+
@pytest.fixture
142+
async def kvstore(tmp_path: Path) -> SqliteKVStoreImpl:
143+
kv_path = tmp_path / "agent_kvstore.db"
144+
kv = SqliteKVStoreImpl(SqliteKVStoreConfig(db_path=str(kv_path)))
145+
await kv.initialize()
146+
return kv
147+
148+
149+
@pytest.fixture
150+
def span_patch():
151+
with (
152+
patch("llama_stack.providers.inline.agents.meta_reference.agent_instance.get_current_span") as mock_span,
153+
patch(
154+
"llama_stack.providers.utils.telemetry.tracing.generate_span_id",
155+
return_value="0000000000000abc",
156+
),
157+
):
158+
mock_span.return_value = Mock(get_span_context=Mock(return_value=Mock(trace_id=0x123, span_id=0xABC)))
159+
yield
160+
161+
162+
@pytest.fixture
163+
def make_completion_fn() -> Callable[[str | None, str], Callable]:
164+
def _factory(tool_name: str | None = None, content: str = "") -> Callable:
165+
async def chat_completion(*args: Any, **kwargs: Any):
166+
return _chat_stream(tool_name, content)
167+
168+
return chat_completion
169+
170+
return _factory

0 commit comments

Comments
 (0)