1
1
import asyncio
2
+ import contextvars
2
3
import inspect
3
4
import logging
4
5
import os
10
11
Dict ,
11
12
Generator ,
12
13
Iterable ,
14
+ List ,
13
15
Optional ,
14
16
Tuple ,
15
17
TypeVar ,
21
23
from opentelemetry .util ._decorator import _AgnosticContextManager
22
24
from typing_extensions import ParamSpec
23
25
24
- from langfuse ._client .environment_variables import (
25
- LANGFUSE_OBSERVE_DECORATOR_IO_CAPTURE_ENABLED ,
26
- )
27
-
28
26
from langfuse ._client .constants import (
29
27
ObservationTypeLiteralNoEvent ,
30
28
get_observation_types_list ,
31
29
)
30
+ from langfuse ._client .environment_variables import (
31
+ LANGFUSE_OBSERVE_DECORATOR_IO_CAPTURE_ENABLED ,
32
+ )
32
33
from langfuse ._client .get_client import _set_current_public_key , get_client
33
34
from langfuse ._client .span import (
34
- LangfuseGeneration ,
35
- LangfuseSpan ,
36
35
LangfuseAgent ,
37
- LangfuseTool ,
38
36
LangfuseChain ,
39
- LangfuseRetriever ,
40
- LangfuseEvaluator ,
41
37
LangfuseEmbedding ,
38
+ LangfuseEvaluator ,
39
+ LangfuseGeneration ,
42
40
LangfuseGuardrail ,
41
+ LangfuseRetriever ,
42
+ LangfuseSpan ,
43
+ LangfuseTool ,
43
44
)
44
45
from langfuse .types import TraceContext
45
46
@@ -468,29 +469,54 @@ def _wrap_sync_generator_result(
468
469
generator : Generator ,
469
470
transform_to_string : Optional [Callable [[Iterable ], str ]] = None ,
470
471
) -> Any :
471
- items = []
472
+ preserved_context = contextvars . copy_context ()
472
473
473
- try :
474
- for item in generator :
475
- items .append (item )
474
+ return _ContextPreservedSyncGeneratorWrapper (
475
+ generator ,
476
+ preserved_context ,
477
+ langfuse_span_or_generation ,
478
+ transform_to_string ,
479
+ )
480
+
481
+ def _wrap_async_generator_result (
482
+ self ,
483
+ langfuse_span_or_generation : Union [
484
+ LangfuseSpan ,
485
+ LangfuseGeneration ,
486
+ LangfuseAgent ,
487
+ LangfuseTool ,
488
+ LangfuseChain ,
489
+ LangfuseRetriever ,
490
+ LangfuseEvaluator ,
491
+ LangfuseEmbedding ,
492
+ LangfuseGuardrail ,
493
+ ],
494
+ generator : AsyncGenerator ,
495
+ transform_to_string : Optional [Callable [[Iterable ], str ]] = None ,
496
+ ) -> Any :
497
+ preserved_context = contextvars .copy_context ()
476
498
477
- yield item
499
+ return _ContextPreservedAsyncGeneratorWrapper (
500
+ generator ,
501
+ preserved_context ,
502
+ langfuse_span_or_generation ,
503
+ transform_to_string ,
504
+ )
478
505
479
- finally :
480
- output : Any = items
481
506
482
- if transform_to_string is not None :
483
- output = transform_to_string (items )
507
+ _decorator = LangfuseDecorator ()
508
+
509
+ observe = _decorator .observe
484
510
485
- elif all (isinstance (item , str ) for item in items ):
486
- output = "" .join (items )
487
511
488
- langfuse_span_or_generation . update ( output = output )
489
- langfuse_span_or_generation . end ()
512
+ class _ContextPreservedSyncGeneratorWrapper :
513
+ """Sync generator wrapper that ensures each iteration runs in preserved context."""
490
514
491
- async def _wrap_async_generator_result (
515
+ def __init__ (
492
516
self ,
493
- langfuse_span_or_generation : Union [
517
+ generator : Generator ,
518
+ context : contextvars .Context ,
519
+ span : Union [
494
520
LangfuseSpan ,
495
521
LangfuseGeneration ,
496
522
LangfuseAgent ,
@@ -501,30 +527,105 @@ async def _wrap_async_generator_result(
501
527
LangfuseEmbedding ,
502
528
LangfuseGuardrail ,
503
529
],
504
- generator : AsyncGenerator ,
505
- transform_to_string : Optional [Callable [[Iterable ], str ]] = None ,
506
- ) -> AsyncGenerator :
507
- items = []
530
+ transform_fn : Optional [Callable [[Iterable ], str ]],
531
+ ) -> None :
532
+ self .generator = generator
533
+ self .context = context
534
+ self .items : List [Any ] = []
535
+ self .span = span
536
+ self .transform_fn = transform_fn
537
+
538
+ def __iter__ (self ) -> "_ContextPreservedSyncGeneratorWrapper" :
539
+ return self
540
+
541
+ def __next__ (self ) -> Any :
542
+ try :
543
+ # Run the generator's __next__ in the preserved context
544
+ item = self .context .run (next , self .generator )
545
+ self .items .append (item )
546
+
547
+ return item
548
+
549
+ except StopIteration :
550
+ # Handle output and span cleanup when generator is exhausted
551
+ output : Any = self .items
552
+
553
+ if self .transform_fn is not None :
554
+ output = self .transform_fn (self .items )
555
+
556
+ elif all (isinstance (item , str ) for item in self .items ):
557
+ output = "" .join (self .items )
558
+
559
+ self .span .update (output = output ).end ()
560
+
561
+ raise # Re-raise StopIteration
562
+
563
+ except Exception as e :
564
+ self .span .update (level = "ERROR" , status_message = str (e )).end ()
508
565
566
+ raise
567
+
568
+
569
+ class _ContextPreservedAsyncGeneratorWrapper :
570
+ """Async generator wrapper that ensures each iteration runs in preserved context."""
571
+
572
+ def __init__ (
573
+ self ,
574
+ generator : AsyncGenerator ,
575
+ context : contextvars .Context ,
576
+ span : Union [
577
+ LangfuseSpan ,
578
+ LangfuseGeneration ,
579
+ LangfuseAgent ,
580
+ LangfuseTool ,
581
+ LangfuseChain ,
582
+ LangfuseRetriever ,
583
+ LangfuseEvaluator ,
584
+ LangfuseEmbedding ,
585
+ LangfuseGuardrail ,
586
+ ],
587
+ transform_fn : Optional [Callable [[Iterable ], str ]],
588
+ ) -> None :
589
+ self .generator = generator
590
+ self .context = context
591
+ self .items : List [Any ] = []
592
+ self .span = span
593
+ self .transform_fn = transform_fn
594
+
595
+ def __aiter__ (self ) -> "_ContextPreservedAsyncGeneratorWrapper" :
596
+ return self
597
+
598
+ async def __anext__ (self ) -> Any :
509
599
try :
510
- async for item in generator :
511
- items .append (item )
600
+ # Run the generator's __anext__ in the preserved context
601
+ try :
602
+ # Python 3.10+ approach with context parameter
603
+ item = await asyncio .create_task (
604
+ self .generator .__anext__ (), # type: ignore
605
+ context = self .context ,
606
+ ) # type: ignore
607
+ except TypeError :
608
+ # Python < 3.10 fallback - context parameter not supported
609
+ item = await self .generator .__anext__ ()
512
610
513
- yield item
611
+ self . items . append ( item )
514
612
515
- finally :
516
- output : Any = items
613
+ return item
517
614
518
- if transform_to_string is not None :
519
- output = transform_to_string (items )
615
+ except StopAsyncIteration :
616
+ # Handle output and span cleanup when generator is exhausted
617
+ output : Any = self .items
520
618
521
- elif all ( isinstance ( item , str ) for item in items ) :
522
- output = "" . join ( items )
619
+ if self . transform_fn is not None :
620
+ output = self . transform_fn ( self . items )
523
621
524
- langfuse_span_or_generation . update ( output = output )
525
- langfuse_span_or_generation . end ( )
622
+ elif all ( isinstance ( item , str ) for item in self . items ):
623
+ output = "" . join ( self . items )
526
624
625
+ self .span .update (output = output ).end ()
527
626
528
- _decorator = LangfuseDecorator ()
627
+ raise # Re-raise StopAsyncIteration
628
+ except Exception as e :
629
+ self .span .update (level = "ERROR" , status_message = str (e )).end ()
529
630
530
- observe = _decorator . observe
631
+ raise
0 commit comments