15
15
"""
16
16
17
17
import json
18
+ import io
18
19
19
20
from wrapt import ObjectProxy
21
+ from itertools import tee
20
22
from .stream_body_wrapper import BufferedStreamBody
21
23
from functools import wraps
22
24
from langtrace .trace_attributes import (
43
45
set_span_attributes ,
44
46
set_usage_attributes ,
45
47
)
48
+ from langtrace_python_sdk .utils import set_event_prompt
46
49
47
50
48
51
def converse_stream (original_method , version , tracer ):
@@ -128,7 +131,9 @@ def traced_method(*args, **kwargs):
128
131
response = original_method (* args , ** kwargs )
129
132
130
133
if span .is_recording ():
131
- set_span_streaming_response (span , response )
134
+ stream1 , stream2 = tee (response ["stream" ])
135
+ set_span_streaming_response (span , stream1 )
136
+ response ["stream" ] = stream2
132
137
return response
133
138
134
139
return traced_method
@@ -167,12 +172,29 @@ def traced_method(*args, **kwargs):
167
172
return traced_method
168
173
169
174
175
+ def parse_vendor_and_model_name_from_model_id (model_id ):
176
+ if model_id .startswith ("arn:aws:bedrock:" ):
177
+ # This needs to be in one of the following forms:
178
+ # arn:aws:bedrock:region:account-id:foundation-model/vendor.model-name
179
+ # arn:aws:bedrock:region:account-id:custom-model/vendor.model-name/model-id
180
+ parts = model_id .split ("/" )
181
+ identifiers = parts [1 ].split ("." )
182
+ return identifiers [0 ], identifiers [1 ]
183
+ parts = model_id .split ("." )
184
+ if len (parts ) == 1 :
185
+ return parts [0 ], parts [0 ]
186
+ else :
187
+ return parts [- 2 ], parts [- 1 ]
188
+
189
+
170
190
def patch_invoke_model (original_method , tracer , version ):
171
191
def traced_method (* args , ** kwargs ):
172
192
modelId = kwargs .get ("modelId" )
173
- ( vendor , _ ) = modelId . split ( "." )
193
+ vendor , _ = parse_vendor_and_model_name_from_model_id ( modelId )
174
194
span_attributes = {
175
195
** get_langtrace_attributes (version , vendor , vendor_type = "framework" ),
196
+ SpanAttributes .LLM_PATH : APIS ["INVOKE_MODEL" ]["ENDPOINT" ],
197
+ SpanAttributes .LLM_IS_STREAMING : False ,
176
198
** get_extra_attributes (),
177
199
}
178
200
with tracer .start_as_current_span (
@@ -193,9 +215,11 @@ def patch_invoke_model_with_response_stream(original_method, tracer, version):
193
215
@wraps (original_method )
194
216
def traced_method (* args , ** kwargs ):
195
217
modelId = kwargs .get ("modelId" )
196
- ( vendor , _ ) = modelId . split ( "." )
218
+ vendor , _ = parse_vendor_and_model_name_from_model_id ( modelId )
197
219
span_attributes = {
198
220
** get_langtrace_attributes (version , vendor , vendor_type = "framework" ),
221
+ SpanAttributes .LLM_PATH : APIS ["INVOKE_MODEL_WITH_RESPONSE_STREAM" ]["ENDPOINT" ],
222
+ SpanAttributes .LLM_IS_STREAMING : True ,
199
223
** get_extra_attributes (),
200
224
}
201
225
span = tracer .start_span (
@@ -217,7 +241,7 @@ def handle_streaming_call(span, kwargs, response):
217
241
def stream_finished (response_body ):
218
242
request_body = json .loads (kwargs .get ("body" ))
219
243
220
- ( vendor , model ) = kwargs .get ("modelId" ). split ( "." )
244
+ vendor , model = parse_vendor_and_model_name_from_model_id ( kwargs .get ("modelId" ))
221
245
222
246
set_span_attribute (span , SpanAttributes .LLM_REQUEST_MODEL , model )
223
247
set_span_attribute (span , SpanAttributes .LLM_RESPONSE_MODEL , model )
@@ -241,18 +265,22 @@ def stream_finished(response_body):
241
265
242
266
def handle_call (span , kwargs , response ):
243
267
modelId = kwargs .get ("modelId" )
244
- (vendor , model_name ) = modelId .split ("." )
268
+ vendor , model_name = parse_vendor_and_model_name_from_model_id (modelId )
269
+ read_response_body = response .get ("body" ).read ()
270
+ request_body = json .loads (kwargs .get ("body" ))
271
+ response_body = json .loads (read_response_body )
245
272
response ["body" ] = BufferedStreamBody (
246
- response [ "body" ]. _raw_stream , response [ "body" ]. _content_length
273
+ io . BytesIO ( read_response_body ), len ( read_response_body )
247
274
)
248
- request_body = json .loads (kwargs .get ("body" ))
249
- response_body = json .loads (response .get ("body" ).read ())
250
275
251
276
set_span_attribute (span , SpanAttributes .LLM_RESPONSE_MODEL , modelId )
252
277
set_span_attribute (span , SpanAttributes .LLM_REQUEST_MODEL , modelId )
253
278
254
279
if vendor == "amazon" :
255
- set_amazon_attributes (span , request_body , response_body )
280
+ if model_name .startswith ("titan-embed-text" ):
281
+ set_amazon_embedding_attributes (span , request_body , response_body )
282
+ else :
283
+ set_amazon_attributes (span , request_body , response_body )
256
284
257
285
if vendor == "anthropic" :
258
286
if "prompt" in request_body :
@@ -356,6 +384,27 @@ def set_amazon_attributes(span, request_body, response_body):
356
384
set_event_completion (span , completions )
357
385
358
386
387
+ def set_amazon_embedding_attributes (span , request_body , response_body ):
388
+ input_text = request_body .get ("inputText" )
389
+ set_event_prompt (span , input_text )
390
+
391
+ embeddings = response_body .get ("embedding" , [])
392
+ input_tokens = response_body .get ("inputTextTokenCount" )
393
+ set_usage_attributes (
394
+ span ,
395
+ {
396
+ "input_tokens" : input_tokens ,
397
+ "output" : len (embeddings ),
398
+ },
399
+ )
400
+ set_span_attribute (
401
+ span , SpanAttributes .LLM_REQUEST_MODEL , request_body .get ("modelId" )
402
+ )
403
+ set_span_attribute (
404
+ span , SpanAttributes .LLM_RESPONSE_MODEL , request_body .get ("modelId" )
405
+ )
406
+
407
+
359
408
def set_anthropic_completions_attributes (span , request_body , response_body ):
360
409
set_span_attribute (
361
410
span ,
@@ -442,10 +491,10 @@ def _set_response_attributes(span, kwargs, result):
442
491
)
443
492
444
493
445
- def set_span_streaming_response (span , response ):
494
+ def set_span_streaming_response (span , response_stream ):
446
495
streaming_response = ""
447
496
role = None
448
- for event in response [ "stream" ] :
497
+ for event in response_stream :
449
498
if "messageStart" in event :
450
499
role = event ["messageStart" ]["role" ]
451
500
elif "contentBlockDelta" in event :
@@ -475,13 +524,15 @@ def __init__(
475
524
stream_done_callback = None ,
476
525
):
477
526
super ().__init__ (response )
478
-
479
527
self ._stream_done_callback = stream_done_callback
480
528
self ._accumulating_body = {"generation" : "" }
529
+ self .last_chunk = None
481
530
482
531
def __iter__ (self ):
483
532
for event in self .__wrapped__ :
533
+ # Process the event
484
534
self ._process_event (event )
535
+ # Yield the original event immediately
485
536
yield event
486
537
487
538
def _process_event (self , event ):
@@ -496,7 +547,11 @@ def _process_event(self, event):
496
547
self ._stream_done_callback (decoded_chunk )
497
548
return
498
549
if "generation" in decoded_chunk :
499
- self ._accumulating_body ["generation" ] += decoded_chunk .get ("generation" )
550
+ generation = decoded_chunk .get ("generation" )
551
+ if self .last_chunk == generation :
552
+ return
553
+ self .last_chunk = generation
554
+ self ._accumulating_body ["generation" ] += generation
500
555
501
556
if type == "message_start" :
502
557
self ._accumulating_body = decoded_chunk .get ("message" )
@@ -505,9 +560,11 @@ def _process_event(self, event):
505
560
decoded_chunk .get ("content_block" )
506
561
)
507
562
elif type == "content_block_delta" :
508
- self ._accumulating_body ["content" ][- 1 ]["text" ] += decoded_chunk .get (
509
- "delta"
510
- ).get ("text" )
563
+ text = decoded_chunk .get ("delta" ).get ("text" )
564
+ if self .last_chunk == text :
565
+ return
566
+ self .last_chunk = text
567
+ self ._accumulating_body ["content" ][- 1 ]["text" ] += text
511
568
512
569
elif self .has_finished (type , decoded_chunk ):
513
570
self ._accumulating_body ["invocation_metrics" ] = decoded_chunk .get (
0 commit comments