Skip to content

Commit f800a99

Browse files
committed
feat(weave): feedback stats query
1 parent 5bb04e3 commit f800a99

File tree

10 files changed

+1288
-20
lines changed

10 files changed

+1288
-20
lines changed

tests/trace_server/query_builder/test_feedback_stats.py

Lines changed: 459 additions & 0 deletions
Large diffs are not rendered by default.

weave/trace_server/calls_query_builder/stats_query_base.py

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -39,13 +39,21 @@ class StatsQueryTimeBounds:
3939

4040

4141
@dataclass(frozen=True)
42-
class StatsQueryBuildResult:
42+
class SqlQueryResult:
43+
"""Base result for parameterized SQL queries."""
44+
4345
sql: str
4446
columns: list[str]
4547
parameters: dict[str, Any]
46-
granularity_seconds: int
47-
start: datetime.datetime
48-
end: datetime.datetime
48+
49+
50+
@dataclass(frozen=True)
51+
class StatsQueryBuildResult(SqlQueryResult):
52+
"""Query result with time-bucketed granularity metadata."""
53+
54+
granularity_seconds: int = 0
55+
start: datetime.datetime = datetime.datetime.min
56+
end: datetime.datetime = datetime.datetime.min
4957

5058

5159
def auto_select_granularity_seconds(delta: datetime.timedelta) -> int:
@@ -142,6 +150,10 @@ def aggregation_selects_for_metric(
142150
results.append((f"maxOrNull({col})", f"max_{metric}"))
143151
elif agg == AggregationType.COUNT:
144152
results.append((f"countOrNull({col})", f"count_{metric}"))
153+
elif agg == AggregationType.COUNT_TRUE:
154+
results.append((f"countIf({col} = 1)", f"count_true_{metric}"))
155+
elif agg == AggregationType.COUNT_FALSE:
156+
results.append((f"countIf({col} = 0)", f"count_false_{metric}"))
145157
else:
146158
raise ValueError(f"Unsupported aggregation type: {agg}")
147159

weave/trace_server/clickhouse_trace_server_batched.py

Lines changed: 109 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -128,6 +128,14 @@
128128
validate_feedback_create_req,
129129
validate_feedback_purge_req,
130130
)
131+
from weave.trace_server.feedback_payload_schema import (
132+
build_feedback_payload_sample_query,
133+
discover_payload_schema,
134+
)
135+
from weave.trace_server.feedback_stats_query_builder import (
136+
build_feedback_stats_query,
137+
build_feedback_stats_window_query,
138+
)
131139
from weave.trace_server.file_storage import (
132140
FileStorageClient,
133141
FileStorageReadError,
@@ -147,7 +155,6 @@
147155
_build_choices_array,
148156
_build_completion_response,
149157
get_custom_provider_info,
150-
lite_llm_completion,
151158
lite_llm_completion_stream,
152159
resolve_and_apply_prompt,
153160
)
@@ -223,6 +230,43 @@
223230
# num_pools: Number of distinct connection pools (for different hosts/configs)
224231
_CH_POOL_MANAGER = get_pool_manager(maxsize=50, num_pools=2)
225232

233+
# Aggregation prefixes for window_stats column parsing, ordered longest-first so that
234+
# "count_true" and "count_false" are matched before the shorter "count" prefix.
235+
_WINDOW_STAT_AGG_PREFIXES = (
236+
"count_true",
237+
"count_false",
238+
"avg",
239+
"sum",
240+
"min",
241+
"max",
242+
"count",
243+
)
244+
245+
246+
def _parse_window_stat_col(col: str) -> tuple[str, str] | None:
247+
"""Parse a window-stats column alias into (agg_key, metric_slug).
248+
249+
Column aliases are produced by aggregation_selects_for_metric, e.g.:
250+
avg_output_score → ("avg", "output_score")
251+
count_true_output_score → ("count_true", "output_score")
252+
p95_output_score → ("p95", "output_score")
253+
254+
Returns None if the column does not match any known pattern.
255+
"""
256+
for prefix in _WINDOW_STAT_AGG_PREFIXES:
257+
if col.startswith(prefix + "_"):
258+
return prefix, col[len(prefix) + 1 :]
259+
# Percentile columns: p5_slug, p95_slug, p99_slug, etc.
260+
if (
261+
"_" in col
262+
and col[0] == "p"
263+
and col[1 : col.index("_")].replace(".", "").isdigit()
264+
):
265+
idx = col.index("_")
266+
return col[:idx], col[idx + 1 :]
267+
return None
268+
269+
226270
# Precomputed list of (column_index, field_name) for every sentinel field that appears
227271
# in ALL_CALL_COMPLETE_INSERT_COLUMNS. Used by _insert_call_complete_batch to enforce
228272
# sentinel conversion as a last line of defense — preventing "Invalid None value in
@@ -1138,6 +1182,68 @@ def call_stats(self, req: tsi.CallStatsReq) -> tsi.CallStatsRes:
11381182
call_buckets=call_buckets,
11391183
)
11401184

1185+
def feedback_stats(self, req: tsi.FeedbackStatsReq) -> tsi.FeedbackStatsRes:
1186+
"""Return aggregated feedback statistics over time buckets.
1187+
1188+
Extracts numeric values from payload_dump via json_path and aggregates
1189+
by time bucket. Filters by project_id, optional feedback_type and trigger_ref.
1190+
Also includes window-level stats (min, max, avg, percentiles) over the full range.
1191+
"""
1192+
end = req.end or datetime.datetime.now(datetime.timezone.utc)
1193+
if not req.metrics:
1194+
return tsi.FeedbackStatsRes(
1195+
start=req.start,
1196+
end=end,
1197+
granularity=3600,
1198+
timezone=req.timezone or "UTC",
1199+
buckets=[],
1200+
)
1201+
pb = ParamBuilder()
1202+
query_result = build_feedback_stats_query(req, pb)
1203+
result = self._query(query_result.sql, query_result.parameters)
1204+
buckets = rows_to_bucket_dicts(query_result.columns, result.result_rows)
1205+
1206+
window_stats: dict[str, dict[str, float | None]] | None = None
1207+
pb_window = ParamBuilder()
1208+
window_query = build_feedback_stats_window_query(req, pb_window)
1209+
if window_query is not None:
1210+
window_result = self._query(window_query.sql, window_query.parameters)
1211+
if window_result.result_rows:
1212+
row = window_result.result_rows[0]
1213+
raw_stats: dict[str, dict[str, float | None]] = {}
1214+
for idx, col in enumerate(window_query.columns):
1215+
if idx >= len(row):
1216+
break
1217+
parsed = _parse_window_stat_col(col)
1218+
if parsed is None:
1219+
continue
1220+
key, slug = parsed
1221+
val = row[idx]
1222+
raw_stats.setdefault(slug, {})[key] = (
1223+
float(val) if val is not None else None
1224+
)
1225+
window_stats = raw_stats
1226+
1227+
return tsi.FeedbackStatsRes(
1228+
start=query_result.start,
1229+
end=query_result.end,
1230+
granularity=query_result.granularity_seconds,
1231+
timezone=req.timezone or "UTC",
1232+
buckets=buckets,
1233+
window_stats=window_stats,
1234+
)
1235+
1236+
def feedback_payload_schema(
1237+
self, req: tsi.FeedbackPayloadSchemaReq
1238+
) -> tsi.FeedbackPayloadSchemaRes:
1239+
"""Discover feedback payload schema from sample rows."""
1240+
pb = ParamBuilder()
1241+
sql, params = build_feedback_payload_sample_query(req, pb)
1242+
result = self._query(sql, params)
1243+
payload_strs = [row[0] for row in result.result_rows if row and row[0]]
1244+
paths = discover_payload_schema(payload_strs)
1245+
return tsi.FeedbackPayloadSchemaRes(paths=paths)
1246+
11411247
@ddtrace.tracer.wrap(name="clickhouse_trace_server_batched.trace_usage")
11421248
def trace_usage(self, req: tsi.TraceUsageReq) -> tsi.TraceUsageRes:
11431249
"""Compute per-call usage for a trace, with descendant rollup.
@@ -5549,8 +5655,8 @@ def completions_create(
55495655

55505656
# Build summary with usage info if available
55515657
summary: tsi.SummaryInsertMap = {}
5552-
if "usage" in res.response:
5553-
summary["usage"] = {model_name: res.response["usage"]}
5658+
# # if "usage" in res.response:
5659+
# summary["usage"] = {model_name: res.response["usage"]}
55545660

55555661
# Check for exception
55565662
exception = res.response.get("error")

weave/trace_server/external_to_internal_trace_server_adapter.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -390,6 +390,16 @@ def feedback_replace(self, req: tsi.FeedbackReplaceReq) -> tsi.FeedbackReplaceRe
390390
res.wb_user_id = original_user_id
391391
return res
392392

393+
def feedback_stats(self, req: tsi.FeedbackStatsReq) -> tsi.FeedbackStatsRes:
394+
req.project_id = self._idc.ext_to_int_project_id(req.project_id)
395+
return self._ref_apply(self._internal_trace_server.feedback_stats, req)
396+
397+
def feedback_payload_schema(
398+
self, req: tsi.FeedbackPayloadSchemaReq
399+
) -> tsi.FeedbackPayloadSchemaRes:
400+
req.project_id = self._idc.ext_to_int_project_id(req.project_id)
401+
return self._ref_apply(self._internal_trace_server.feedback_payload_schema, req)
402+
393403
def cost_create(self, req: tsi.CostCreateReq) -> tsi.CostCreateRes:
394404
req.project_id = self._idc.ext_to_int_project_id(req.project_id)
395405
return self._ref_apply(self._internal_trace_server.cost_create, req)
Lines changed: 152 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,152 @@
1+
"""Feedback payload schema discovery from sample rows."""
2+
3+
from __future__ import annotations
4+
5+
import datetime
6+
import json
7+
import logging
8+
from collections import defaultdict
9+
from typing import Any
10+
11+
from weave.trace_server.calls_query_builder.utils import param_slot, safely_format_sql
12+
from weave.trace_server.feedback_stats_query_builder import (
13+
JSON_PATH_PATTERN,
14+
trigger_ref_where_clause,
15+
)
16+
from weave.trace_server.orm import ParamBuilder
17+
from weave.trace_server.trace_server_interface import (
18+
FeedbackPayloadPath,
19+
FeedbackPayloadSchemaReq,
20+
)
21+
22+
logger = logging.getLogger(__name__)
23+
24+
25+
def _discover_paths(obj: Any, prefix: str = "") -> dict[str, set[type]]:
26+
"""Recursively discover leaf paths and collect value types.
27+
28+
Returns:
29+
Mapping from dot path to set of Python types seen at that path.
30+
"""
31+
out: dict[str, set[type]] = defaultdict(set)
32+
if obj is None:
33+
return out
34+
if isinstance(obj, dict):
35+
for k, v in obj.items():
36+
if not isinstance(k, str) or "." in k or not k.strip():
37+
continue
38+
path = f"{prefix}.{k}" if prefix else k
39+
if isinstance(v, (dict, list)) and v is not None:
40+
out.update(_discover_paths(v, path))
41+
else:
42+
out[path].add(type(v))
43+
return out
44+
if isinstance(obj, list):
45+
for i, v in enumerate(obj):
46+
if isinstance(v, (dict, list)) and v is not None:
47+
out.update(_discover_paths(v, f"{prefix}[{i}]"))
48+
else:
49+
out[prefix].add(type(v))
50+
return out
51+
out[prefix].add(type(obj))
52+
return out
53+
54+
55+
def _infer_value_type(types_seen: set[type]) -> str:
56+
"""Infer value_type from observed Python types."""
57+
if not types_seen:
58+
return "numeric"
59+
# bool must be checked before int (bool is subclass of int)
60+
if types_seen <= {bool, type(None)}:
61+
return "boolean"
62+
if types_seen <= {int, float, type(None)}:
63+
return "numeric"
64+
return "categorical"
65+
66+
67+
def discover_payload_schema(payload_strs: list[str]) -> list[FeedbackPayloadPath]:
68+
"""Discover schema from raw payload JSON strings.
69+
70+
Parses each string as JSON, recursively discovers leaf paths, infers
71+
value_type from observed types, and returns unique paths.
72+
73+
Args:
74+
payload_strs: List of JSON strings (payload_dump from feedback rows).
75+
76+
Returns:
77+
Sorted list of FeedbackPayloadPath, deduplicated by json_path.
78+
79+
Examples:
80+
>>> discover_payload_schema(['{"output": {"score": 0.9}}'])
81+
[FeedbackPayloadPath(json_path='output.score', value_type='numeric')]
82+
"""
83+
path_to_types: dict[str, set[type]] = defaultdict(set)
84+
for s in payload_strs:
85+
if not s or not s.strip():
86+
continue
87+
try:
88+
obj = json.loads(s)
89+
except json.JSONDecodeError:
90+
continue
91+
for path, types in _discover_paths(obj).items():
92+
# Skip array-index paths like "a[0]" for schema output
93+
if "[" in path:
94+
continue
95+
# Skip paths with chars not allowed by feedback_stats (e.g. spaces)
96+
if not JSON_PATH_PATTERN.match(path):
97+
continue
98+
path_to_types[path].update(types)
99+
100+
result: list[FeedbackPayloadPath] = []
101+
for path in sorted(path_to_types.keys()):
102+
value_type = _infer_value_type(path_to_types[path])
103+
result.append(FeedbackPayloadPath(json_path=path, value_type=value_type))
104+
return result
105+
106+
107+
def build_feedback_payload_sample_query(
108+
req: FeedbackPayloadSchemaReq,
109+
pb: ParamBuilder,
110+
) -> tuple[str, dict[str, Any]]:
111+
"""Build parameterized ClickHouse SQL to fetch sample payload_dump.
112+
113+
Uses same filters as feedback_stats (project_id, created_at, feedback_type,
114+
trigger_ref). Returns one payload per unique trigger_ref (most recent per
115+
ref), since each trigger_ref has a unique payload schema.
116+
"""
117+
now_utc = datetime.datetime.now(datetime.timezone.utc)
118+
start = req.start
119+
end = req.end if req.end is not None else now_utc
120+
limit = req.sample_limit
121+
122+
project_param = pb.add_param(req.project_id)
123+
start_epoch = start.replace(tzinfo=datetime.timezone.utc).timestamp()
124+
end_epoch = end.replace(tzinfo=datetime.timezone.utc).timestamp()
125+
start_param = pb.add_param(start_epoch)
126+
end_param = pb.add_param(end_epoch)
127+
limit_param = pb.add_param(limit)
128+
129+
where_clauses: list[str] = [
130+
f"project_id = {param_slot(project_param, 'String')}",
131+
f"created_at >= toDateTime({param_slot(start_param, 'Float64')}, 'UTC')",
132+
f"created_at < toDateTime({param_slot(end_param, 'Float64')}, 'UTC')",
133+
"payload_dump != ''",
134+
"payload_dump IS NOT NULL",
135+
]
136+
if req.feedback_type is not None:
137+
feedback_type_param = pb.add_param(req.feedback_type)
138+
where_clauses.append(
139+
f"feedback_type = {param_slot(feedback_type_param, 'String')}"
140+
)
141+
if req.trigger_ref is not None:
142+
where_clauses.append(trigger_ref_where_clause(req.trigger_ref, pb))
143+
where_sql = " AND ".join(where_clauses)
144+
145+
raw_sql = f"""
146+
SELECT argMax(payload_dump, created_at) AS payload_sample
147+
FROM feedback
148+
WHERE {where_sql}
149+
GROUP BY trigger_ref
150+
LIMIT {param_slot(limit_param, "Int64")}
151+
"""
152+
return safely_format_sql(raw_sql, logger), pb.get_params()

0 commit comments

Comments
 (0)