|
128 | 128 | validate_feedback_create_req, |
129 | 129 | validate_feedback_purge_req, |
130 | 130 | ) |
| 131 | +from weave.trace_server.feedback_payload_schema import ( |
| 132 | + build_feedback_payload_sample_query, |
| 133 | + discover_payload_schema, |
| 134 | +) |
| 135 | +from weave.trace_server.feedback_stats_query_builder import ( |
| 136 | + build_feedback_stats_query, |
| 137 | + build_feedback_stats_window_query, |
| 138 | +) |
131 | 139 | from weave.trace_server.file_storage import ( |
132 | 140 | FileStorageClient, |
133 | 141 | FileStorageReadError, |
|
147 | 155 | _build_choices_array, |
148 | 156 | _build_completion_response, |
149 | 157 | get_custom_provider_info, |
150 | | - lite_llm_completion, |
151 | 158 | lite_llm_completion_stream, |
152 | 159 | resolve_and_apply_prompt, |
153 | 160 | ) |
|
223 | 230 | # num_pools: Number of distinct connection pools (for different hosts/configs) |
224 | 231 | _CH_POOL_MANAGER = get_pool_manager(maxsize=50, num_pools=2) |
225 | 232 |
|
| 233 | +# Aggregation prefixes for window_stats column parsing, ordered longest-first so that |
| 234 | +# "count_true" and "count_false" are matched before the shorter "count" prefix. |
| 235 | +_WINDOW_STAT_AGG_PREFIXES = ( |
| 236 | + "count_true", |
| 237 | + "count_false", |
| 238 | + "avg", |
| 239 | + "sum", |
| 240 | + "min", |
| 241 | + "max", |
| 242 | + "count", |
| 243 | +) |
| 244 | + |
| 245 | + |
| 246 | +def _parse_window_stat_col(col: str) -> tuple[str, str] | None: |
| 247 | + """Parse a window-stats column alias into (agg_key, metric_slug). |
| 248 | +
|
| 249 | + Column aliases are produced by aggregation_selects_for_metric, e.g.: |
| 250 | + avg_output_score → ("avg", "output_score") |
| 251 | + count_true_output_score → ("count_true", "output_score") |
| 252 | + p95_output_score → ("p95", "output_score") |
| 253 | +
|
| 254 | + Returns None if the column does not match any known pattern. |
| 255 | + """ |
| 256 | + for prefix in _WINDOW_STAT_AGG_PREFIXES: |
| 257 | + if col.startswith(prefix + "_"): |
| 258 | + return prefix, col[len(prefix) + 1 :] |
| 259 | + # Percentile columns: p5_slug, p95_slug, p99_slug, etc. |
| 260 | + if ( |
| 261 | + "_" in col |
| 262 | + and col[0] == "p" |
| 263 | + and col[1 : col.index("_")].replace(".", "").isdigit() |
| 264 | + ): |
| 265 | + idx = col.index("_") |
| 266 | + return col[:idx], col[idx + 1 :] |
| 267 | + return None |
| 268 | + |
| 269 | + |
226 | 270 | # Precomputed list of (column_index, field_name) for every sentinel field that appears |
227 | 271 | # in ALL_CALL_COMPLETE_INSERT_COLUMNS. Used by _insert_call_complete_batch to enforce |
228 | 272 | # sentinel conversion as a last line of defense — preventing "Invalid None value in |
@@ -1138,6 +1182,68 @@ def call_stats(self, req: tsi.CallStatsReq) -> tsi.CallStatsRes: |
1138 | 1182 | call_buckets=call_buckets, |
1139 | 1183 | ) |
1140 | 1184 |
|
| 1185 | + def feedback_stats(self, req: tsi.FeedbackStatsReq) -> tsi.FeedbackStatsRes: |
| 1186 | + """Return aggregated feedback statistics over time buckets. |
| 1187 | +
|
| 1188 | + Extracts numeric values from payload_dump via json_path and aggregates |
| 1189 | + by time bucket. Filters by project_id, optional feedback_type and trigger_ref. |
| 1190 | + Also includes window-level stats (min, max, avg, percentiles) over the full range. |
| 1191 | + """ |
| 1192 | + end = req.end or datetime.datetime.now(datetime.timezone.utc) |
| 1193 | + if not req.metrics: |
| 1194 | + return tsi.FeedbackStatsRes( |
| 1195 | + start=req.start, |
| 1196 | + end=end, |
| 1197 | + granularity=3600, |
| 1198 | + timezone=req.timezone or "UTC", |
| 1199 | + buckets=[], |
| 1200 | + ) |
| 1201 | + pb = ParamBuilder() |
| 1202 | + query_result = build_feedback_stats_query(req, pb) |
| 1203 | + result = self._query(query_result.sql, query_result.parameters) |
| 1204 | + buckets = rows_to_bucket_dicts(query_result.columns, result.result_rows) |
| 1205 | + |
| 1206 | + window_stats: dict[str, dict[str, float | None]] | None = None |
| 1207 | + pb_window = ParamBuilder() |
| 1208 | + window_query = build_feedback_stats_window_query(req, pb_window) |
| 1209 | + if window_query is not None: |
| 1210 | + window_result = self._query(window_query.sql, window_query.parameters) |
| 1211 | + if window_result.result_rows: |
| 1212 | + row = window_result.result_rows[0] |
| 1213 | + raw_stats: dict[str, dict[str, float | None]] = {} |
| 1214 | + for idx, col in enumerate(window_query.columns): |
| 1215 | + if idx >= len(row): |
| 1216 | + break |
| 1217 | + parsed = _parse_window_stat_col(col) |
| 1218 | + if parsed is None: |
| 1219 | + continue |
| 1220 | + key, slug = parsed |
| 1221 | + val = row[idx] |
| 1222 | + raw_stats.setdefault(slug, {})[key] = ( |
| 1223 | + float(val) if val is not None else None |
| 1224 | + ) |
| 1225 | + window_stats = raw_stats |
| 1226 | + |
| 1227 | + return tsi.FeedbackStatsRes( |
| 1228 | + start=query_result.start, |
| 1229 | + end=query_result.end, |
| 1230 | + granularity=query_result.granularity_seconds, |
| 1231 | + timezone=req.timezone or "UTC", |
| 1232 | + buckets=buckets, |
| 1233 | + window_stats=window_stats, |
| 1234 | + ) |
| 1235 | + |
| 1236 | + def feedback_payload_schema( |
| 1237 | + self, req: tsi.FeedbackPayloadSchemaReq |
| 1238 | + ) -> tsi.FeedbackPayloadSchemaRes: |
| 1239 | + """Discover feedback payload schema from sample rows.""" |
| 1240 | + pb = ParamBuilder() |
| 1241 | + sql, params = build_feedback_payload_sample_query(req, pb) |
| 1242 | + result = self._query(sql, params) |
| 1243 | + payload_strs = [row[0] for row in result.result_rows if row and row[0]] |
| 1244 | + paths = discover_payload_schema(payload_strs) |
| 1245 | + return tsi.FeedbackPayloadSchemaRes(paths=paths) |
| 1246 | + |
1141 | 1247 | @ddtrace.tracer.wrap(name="clickhouse_trace_server_batched.trace_usage") |
1142 | 1248 | def trace_usage(self, req: tsi.TraceUsageReq) -> tsi.TraceUsageRes: |
1143 | 1249 | """Compute per-call usage for a trace, with descendant rollup. |
@@ -5549,8 +5655,8 @@ def completions_create( |
5549 | 5655 |
|
5550 | 5656 | # Build summary with usage info if available |
5551 | 5657 | summary: tsi.SummaryInsertMap = {} |
5552 | | - if "usage" in res.response: |
5553 | | - summary["usage"] = {model_name: res.response["usage"]} |
| 5658 | + # # if "usage" in res.response: |
| 5659 | + # summary["usage"] = {model_name: res.response["usage"]} |
5554 | 5660 |
|
5555 | 5661 | # Check for exception |
5556 | 5662 | exception = res.response.get("error") |
|
0 commit comments