Skip to content

Commit 404f704

Browse files
authored
feat: Add track_error to mirror track_success (#33)
Additionally, emit new `$ld:ai:generation:(success|error)` events on success or failure.
1 parent 80e1845 commit 404f704

File tree

2 files changed

+185
-15
lines changed

2 files changed

+185
-15
lines changed

ldai/testing/test_tracker.py

Lines changed: 136 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
from time import sleep
12
from unittest.mock import MagicMock, call
23

34
import pytest
@@ -60,6 +61,43 @@ def test_tracks_duration(client: LDClient):
6061
assert tracker.get_summary().duration == 100
6162

6263

64+
def test_tracks_duration_of(client: LDClient):
65+
context = Context.create('user-key')
66+
tracker = LDAIConfigTracker(client, "variation-key", "config-key", context)
67+
tracker.track_duration_of(lambda: sleep(0.01))
68+
69+
calls = client.track.mock_calls # type: ignore
70+
71+
assert len(calls) == 1
72+
assert calls[0].args[0] == '$ld:ai:duration:total'
73+
assert calls[0].args[1] == context
74+
assert calls[0].args[2] == {'variationKey': 'variation-key', 'configKey': 'config-key'}
75+
assert calls[0].args[3] == pytest.approx(10, rel=10)
76+
77+
78+
def test_tracks_duration_of_with_exception(client: LDClient):
79+
context = Context.create('user-key')
80+
tracker = LDAIConfigTracker(client, "variation-key", "config-key", context)
81+
82+
def sleep_and_throw():
83+
sleep(0.01)
84+
raise ValueError("Something went wrong")
85+
86+
try:
87+
tracker.track_duration_of(sleep_and_throw)
88+
assert False, "Should have thrown an exception"
89+
except ValueError:
90+
pass
91+
92+
calls = client.track.mock_calls # type: ignore
93+
94+
assert len(calls) == 1
95+
assert calls[0].args[0] == '$ld:ai:duration:total'
96+
assert calls[0].args[1] == context
97+
assert calls[0].args[2] == {'variationKey': 'variation-key', 'configKey': 'config-key'}
98+
assert calls[0].args[3] == pytest.approx(10, rel=10)
99+
100+
63101
def test_tracks_token_usage(client: LDClient):
64102
context = Context.create('user-key')
65103
tracker = LDAIConfigTracker(client, "variation-key", "config-key", context)
@@ -97,6 +135,7 @@ def test_tracks_bedrock_metrics(client: LDClient):
97135

98136
calls = [
99137
call('$ld:ai:generation', context, {'variationKey': 'variation-key', 'configKey': 'config-key'}, 1),
138+
call('$ld:ai:generation:success', context, {'variationKey': 'variation-key', 'configKey': 'config-key'}, 1),
100139
call('$ld:ai:duration:total', context, {'variationKey': 'variation-key', 'configKey': 'config-key'}, 50),
101140
call('$ld:ai:tokens:total', context, {'variationKey': 'variation-key', 'configKey': 'config-key'}, 330),
102141
call('$ld:ai:tokens:input', context, {'variationKey': 'variation-key', 'configKey': 'config-key'}, 220),
@@ -110,6 +149,39 @@ def test_tracks_bedrock_metrics(client: LDClient):
110149
assert tracker.get_summary().usage == TokenUsage(330, 220, 110)
111150

112151

152+
def test_tracks_bedrock_metrics_with_error(client: LDClient):
153+
context = Context.create('user-key')
154+
tracker = LDAIConfigTracker(client, "variation-key", "config-key", context)
155+
156+
bedrock_result = {
157+
'$metadata': {'httpStatusCode': 500},
158+
'usage': {
159+
'totalTokens': 330,
160+
'inputTokens': 220,
161+
'outputTokens': 110,
162+
},
163+
'metrics': {
164+
'latencyMs': 50,
165+
}
166+
}
167+
tracker.track_bedrock_converse_metrics(bedrock_result)
168+
169+
calls = [
170+
call('$ld:ai:generation', context, {'variationKey': 'variation-key', 'configKey': 'config-key'}, 1),
171+
call('$ld:ai:generation:error', context, {'variationKey': 'variation-key', 'configKey': 'config-key'}, 1),
172+
call('$ld:ai:duration:total', context, {'variationKey': 'variation-key', 'configKey': 'config-key'}, 50),
173+
call('$ld:ai:tokens:total', context, {'variationKey': 'variation-key', 'configKey': 'config-key'}, 330),
174+
call('$ld:ai:tokens:input', context, {'variationKey': 'variation-key', 'configKey': 'config-key'}, 220),
175+
call('$ld:ai:tokens:output', context, {'variationKey': 'variation-key', 'configKey': 'config-key'}, 110),
176+
]
177+
178+
client.track.assert_has_calls(calls) # type: ignore
179+
180+
assert tracker.get_summary().success is False
181+
assert tracker.get_summary().duration == 50
182+
assert tracker.get_summary().usage == TokenUsage(330, 220, 110)
183+
184+
113185
def test_tracks_openai_metrics(client: LDClient):
114186
context = Context.create('user-key')
115187
tracker = LDAIConfigTracker(client, "variation-key", "config-key", context)
@@ -129,6 +201,8 @@ def to_dict(self):
129201
tracker.track_openai_metrics(lambda: Result())
130202

131203
calls = [
204+
call('$ld:ai:generation', context, {'variationKey': 'variation-key', 'configKey': 'config-key'}, 1),
205+
call('$ld:ai:generation:success', context, {'variationKey': 'variation-key', 'configKey': 'config-key'}, 1),
132206
call('$ld:ai:tokens:total', context, {'variationKey': 'variation-key', 'configKey': 'config-key'}, 330),
133207
call('$ld:ai:tokens:input', context, {'variationKey': 'variation-key', 'configKey': 'config-key'}, 220),
134208
call('$ld:ai:tokens:output', context, {'variationKey': 'variation-key', 'configKey': 'config-key'}, 110),
@@ -139,6 +213,29 @@ def to_dict(self):
139213
assert tracker.get_summary().usage == TokenUsage(330, 220, 110)
140214

141215

216+
def test_tracks_openai_metrics_with_exception(client: LDClient):
217+
context = Context.create('user-key')
218+
tracker = LDAIConfigTracker(client, "variation-key", "config-key", context)
219+
220+
def raise_exception():
221+
raise ValueError("Something went wrong")
222+
223+
try:
224+
tracker.track_openai_metrics(raise_exception)
225+
assert False, "Should have thrown an exception"
226+
except ValueError:
227+
pass
228+
229+
calls = [
230+
call('$ld:ai:generation', context, {'variationKey': 'variation-key', 'configKey': 'config-key'}, 1),
231+
call('$ld:ai:generation:error', context, {'variationKey': 'variation-key', 'configKey': 'config-key'}, 1),
232+
]
233+
234+
client.track.assert_has_calls(calls, any_order=False) # type: ignore
235+
236+
assert tracker.get_summary().usage is None
237+
238+
142239
@pytest.mark.parametrize(
143240
"kind,label",
144241
[
@@ -166,11 +263,44 @@ def test_tracks_success(client: LDClient):
166263
tracker = LDAIConfigTracker(client, "variation-key", "config-key", context)
167264
tracker.track_success()
168265

169-
client.track.assert_called_with( # type: ignore
170-
'$ld:ai:generation',
171-
context,
172-
{'variationKey': 'variation-key', 'configKey': 'config-key'},
173-
1
174-
)
266+
calls = [
267+
call('$ld:ai:generation', context, {'variationKey': 'variation-key', 'configKey': 'config-key'}, 1),
268+
call('$ld:ai:generation:success', context, {'variationKey': 'variation-key', 'configKey': 'config-key'}, 1),
269+
]
270+
271+
client.track.assert_has_calls(calls) # type: ignore
175272

176273
assert tracker.get_summary().success is True
274+
275+
276+
def test_tracks_error(client: LDClient):
277+
context = Context.create('user-key')
278+
tracker = LDAIConfigTracker(client, "variation-key", "config-key", context)
279+
tracker.track_error()
280+
281+
calls = [
282+
call('$ld:ai:generation', context, {'variationKey': 'variation-key', 'configKey': 'config-key'}, 1),
283+
call('$ld:ai:generation:error', context, {'variationKey': 'variation-key', 'configKey': 'config-key'}, 1),
284+
]
285+
286+
client.track.assert_has_calls(calls) # type: ignore
287+
288+
assert tracker.get_summary().success is False
289+
290+
291+
def test_error_overwrites_success(client: LDClient):
292+
context = Context.create('user-key')
293+
tracker = LDAIConfigTracker(client, "variation-key", "config-key", context)
294+
tracker.track_success()
295+
tracker.track_error()
296+
297+
calls = [
298+
call('$ld:ai:generation', context, {'variationKey': 'variation-key', 'configKey': 'config-key'}, 1),
299+
call('$ld:ai:generation:success', context, {'variationKey': 'variation-key', 'configKey': 'config-key'}, 1),
300+
call('$ld:ai:generation', context, {'variationKey': 'variation-key', 'configKey': 'config-key'}, 1),
301+
call('$ld:ai:generation:error', context, {'variationKey': 'variation-key', 'configKey': 'config-key'}, 1),
302+
]
303+
304+
client.track.assert_has_calls(calls) # type: ignore
305+
306+
assert tracker.get_summary().success is False

ldai/tracker.py

Lines changed: 49 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -106,14 +106,20 @@ def track_duration_of(self, func):
106106
"""
107107
Automatically track the duration of an AI operation.
108108
109+
An exception occurring during the execution of the function will still
110+
track the duration. The exception will be re-thrown.
111+
109112
:param func: Function to track.
110113
:return: Result of the tracked function.
111114
"""
112115
start_time = time.time()
113-
result = func()
114-
end_time = time.time()
115-
duration = int((end_time - start_time) * 1000) # duration in milliseconds
116-
self.track_duration(duration)
116+
try:
117+
result = func()
118+
finally:
119+
end_time = time.time()
120+
duration = int((end_time - start_time) * 1000) # duration in milliseconds
121+
self.track_duration(duration)
122+
117123
return result
118124

119125
def track_feedback(self, feedback: Dict[str, FeedbackKind]) -> None:
@@ -146,32 +152,66 @@ def track_success(self) -> None:
146152
self._ld_client.track(
147153
'$ld:ai:generation', self._context, self.__get_track_data(), 1
148154
)
155+
self._ld_client.track(
156+
'$ld:ai:generation:success', self._context, self.__get_track_data(), 1
157+
)
158+
159+
def track_error(self) -> None:
160+
"""
161+
Track an unsuccessful AI generation attempt.
162+
"""
163+
self._summary._success = False
164+
self._ld_client.track(
165+
'$ld:ai:generation', self._context, self.__get_track_data(), 1
166+
)
167+
self._ld_client.track(
168+
'$ld:ai:generation:error', self._context, self.__get_track_data(), 1
169+
)
149170

150171
def track_openai_metrics(self, func):
151172
"""
152173
Track OpenAI-specific operations.
153174
175+
This function will track the duration of the operation, the token
176+
usage, and the success or error status.
177+
178+
If the provided function throws, then this method will also throw.
179+
180+
In the case the provided function throws, this function will record the
181+
duration and an error.
182+
183+
A failed operation will not have any token usage data.
184+
154185
:param func: Function to track.
155186
:return: Result of the tracked function.
156187
"""
157-
result = self.track_duration_of(func)
158-
if hasattr(result, 'usage') and hasattr(result.usage, 'to_dict'):
159-
self.track_tokens(_openai_to_token_usage(result.usage.to_dict()))
188+
try:
189+
result = self.track_duration_of(func)
190+
self.track_success()
191+
if hasattr(result, 'usage') and hasattr(result.usage, 'to_dict'):
192+
self.track_tokens(_openai_to_token_usage(result.usage.to_dict()))
193+
except Exception:
194+
self.track_error()
195+
raise
196+
160197
return result
161198

162199
def track_bedrock_converse_metrics(self, res: dict) -> dict:
163200
"""
164201
Track AWS Bedrock conversation operations.
165202
203+
204+
This function will track the duration of the operation, the token
205+
usage, and the success or error status.
206+
166207
:param res: Response dictionary from Bedrock.
167208
:return: The original response dictionary.
168209
"""
169210
status_code = res.get('$metadata', {}).get('httpStatusCode', 0)
170211
if status_code == 200:
171212
self.track_success()
172213
elif status_code >= 400:
173-
# Potentially add error tracking in the future.
174-
pass
214+
self.track_error()
175215
if res.get('metrics', {}).get('latencyMs'):
176216
self.track_duration(res['metrics']['latencyMs'])
177217
if res.get('usage'):

0 commit comments

Comments
 (0)