Skip to content

Commit d02d940

Browse files
authored
fix(observe): handle generator context propagation (#1383)
1 parent 3ce7abe commit d02d940

File tree

2 files changed

+423
-41
lines changed

2 files changed

+423
-41
lines changed

langfuse/_client/observe.py

Lines changed: 142 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import asyncio
2+
import contextvars
23
import inspect
34
import logging
45
import os
@@ -10,6 +11,7 @@
1011
Dict,
1112
Generator,
1213
Iterable,
14+
List,
1315
Optional,
1416
Tuple,
1517
TypeVar,
@@ -21,25 +23,24 @@
2123
from opentelemetry.util._decorator import _AgnosticContextManager
2224
from typing_extensions import ParamSpec
2325

24-
from langfuse._client.environment_variables import (
25-
LANGFUSE_OBSERVE_DECORATOR_IO_CAPTURE_ENABLED,
26-
)
27-
2826
from langfuse._client.constants import (
2927
ObservationTypeLiteralNoEvent,
3028
get_observation_types_list,
3129
)
30+
from langfuse._client.environment_variables import (
31+
LANGFUSE_OBSERVE_DECORATOR_IO_CAPTURE_ENABLED,
32+
)
3233
from langfuse._client.get_client import _set_current_public_key, get_client
3334
from langfuse._client.span import (
34-
LangfuseGeneration,
35-
LangfuseSpan,
3635
LangfuseAgent,
37-
LangfuseTool,
3836
LangfuseChain,
39-
LangfuseRetriever,
40-
LangfuseEvaluator,
4137
LangfuseEmbedding,
38+
LangfuseEvaluator,
39+
LangfuseGeneration,
4240
LangfuseGuardrail,
41+
LangfuseRetriever,
42+
LangfuseSpan,
43+
LangfuseTool,
4344
)
4445
from langfuse.types import TraceContext
4546

@@ -468,29 +469,54 @@ def _wrap_sync_generator_result(
468469
generator: Generator,
469470
transform_to_string: Optional[Callable[[Iterable], str]] = None,
470471
) -> Any:
471-
items = []
472+
preserved_context = contextvars.copy_context()
472473

473-
try:
474-
for item in generator:
475-
items.append(item)
474+
return _ContextPreservedSyncGeneratorWrapper(
475+
generator,
476+
preserved_context,
477+
langfuse_span_or_generation,
478+
transform_to_string,
479+
)
480+
481+
def _wrap_async_generator_result(
482+
self,
483+
langfuse_span_or_generation: Union[
484+
LangfuseSpan,
485+
LangfuseGeneration,
486+
LangfuseAgent,
487+
LangfuseTool,
488+
LangfuseChain,
489+
LangfuseRetriever,
490+
LangfuseEvaluator,
491+
LangfuseEmbedding,
492+
LangfuseGuardrail,
493+
],
494+
generator: AsyncGenerator,
495+
transform_to_string: Optional[Callable[[Iterable], str]] = None,
496+
) -> Any:
497+
preserved_context = contextvars.copy_context()
476498

477-
yield item
499+
return _ContextPreservedAsyncGeneratorWrapper(
500+
generator,
501+
preserved_context,
502+
langfuse_span_or_generation,
503+
transform_to_string,
504+
)
478505

479-
finally:
480-
output: Any = items
481506

482-
if transform_to_string is not None:
483-
output = transform_to_string(items)
507+
_decorator = LangfuseDecorator()
508+
509+
observe = _decorator.observe
484510

485-
elif all(isinstance(item, str) for item in items):
486-
output = "".join(items)
487511

488-
langfuse_span_or_generation.update(output=output)
489-
langfuse_span_or_generation.end()
512+
class _ContextPreservedSyncGeneratorWrapper:
513+
"""Sync generator wrapper that ensures each iteration runs in preserved context."""
490514

491-
async def _wrap_async_generator_result(
515+
def __init__(
492516
self,
493-
langfuse_span_or_generation: Union[
517+
generator: Generator,
518+
context: contextvars.Context,
519+
span: Union[
494520
LangfuseSpan,
495521
LangfuseGeneration,
496522
LangfuseAgent,
@@ -501,30 +527,105 @@ async def _wrap_async_generator_result(
501527
LangfuseEmbedding,
502528
LangfuseGuardrail,
503529
],
504-
generator: AsyncGenerator,
505-
transform_to_string: Optional[Callable[[Iterable], str]] = None,
506-
) -> AsyncGenerator:
507-
items = []
530+
transform_fn: Optional[Callable[[Iterable], str]],
531+
) -> None:
532+
self.generator = generator
533+
self.context = context
534+
self.items: List[Any] = []
535+
self.span = span
536+
self.transform_fn = transform_fn
537+
538+
def __iter__(self) -> "_ContextPreservedSyncGeneratorWrapper":
539+
return self
540+
541+
def __next__(self) -> Any:
542+
try:
543+
# Run the generator's __next__ in the preserved context
544+
item = self.context.run(next, self.generator)
545+
self.items.append(item)
546+
547+
return item
548+
549+
except StopIteration:
550+
# Handle output and span cleanup when generator is exhausted
551+
output: Any = self.items
552+
553+
if self.transform_fn is not None:
554+
output = self.transform_fn(self.items)
555+
556+
elif all(isinstance(item, str) for item in self.items):
557+
output = "".join(self.items)
558+
559+
self.span.update(output=output).end()
560+
561+
raise # Re-raise StopIteration
562+
563+
except Exception as e:
564+
self.span.update(level="ERROR", status_message=str(e)).end()
508565

566+
raise
567+
568+
569+
class _ContextPreservedAsyncGeneratorWrapper:
570+
"""Async generator wrapper that ensures each iteration runs in preserved context."""
571+
572+
def __init__(
573+
self,
574+
generator: AsyncGenerator,
575+
context: contextvars.Context,
576+
span: Union[
577+
LangfuseSpan,
578+
LangfuseGeneration,
579+
LangfuseAgent,
580+
LangfuseTool,
581+
LangfuseChain,
582+
LangfuseRetriever,
583+
LangfuseEvaluator,
584+
LangfuseEmbedding,
585+
LangfuseGuardrail,
586+
],
587+
transform_fn: Optional[Callable[[Iterable], str]],
588+
) -> None:
589+
self.generator = generator
590+
self.context = context
591+
self.items: List[Any] = []
592+
self.span = span
593+
self.transform_fn = transform_fn
594+
595+
def __aiter__(self) -> "_ContextPreservedAsyncGeneratorWrapper":
596+
return self
597+
598+
async def __anext__(self) -> Any:
509599
try:
510-
async for item in generator:
511-
items.append(item)
600+
# Run the generator's __anext__ in the preserved context
601+
try:
602+
# Python 3.10+ approach with context parameter
603+
item = await asyncio.create_task(
604+
self.generator.__anext__(), # type: ignore
605+
context=self.context,
606+
) # type: ignore
607+
except TypeError:
608+
# Python < 3.10 fallback - context parameter not supported
609+
item = await self.generator.__anext__()
512610

513-
yield item
611+
self.items.append(item)
514612

515-
finally:
516-
output: Any = items
613+
return item
517614

518-
if transform_to_string is not None:
519-
output = transform_to_string(items)
615+
except StopAsyncIteration:
616+
# Handle output and span cleanup when generator is exhausted
617+
output: Any = self.items
520618

521-
elif all(isinstance(item, str) for item in items):
522-
output = "".join(items)
619+
if self.transform_fn is not None:
620+
output = self.transform_fn(self.items)
523621

524-
langfuse_span_or_generation.update(output=output)
525-
langfuse_span_or_generation.end()
622+
elif all(isinstance(item, str) for item in self.items):
623+
output = "".join(self.items)
526624

625+
self.span.update(output=output).end()
527626

528-
_decorator = LangfuseDecorator()
627+
raise # Re-raise StopAsyncIteration
628+
except Exception as e:
629+
self.span.update(level="ERROR", status_message=str(e)).end()
529630

530-
observe = _decorator.observe
631+
raise

0 commit comments

Comments
 (0)