Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
101 changes: 66 additions & 35 deletions src/openlayer/lib/integrations/langchain_callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,20 +2,21 @@

# pylint: disable=unused-argument
import time
from typing import Any, Dict, List, Optional, Union, TYPE_CHECKING, Callable
from typing import Any, Dict, List, Optional, Union, Callable
from uuid import UUID

try:
from langchain import schema as langchain_schema
from langchain.callbacks.base import BaseCallbackHandler, AsyncCallbackHandler
try:
from langchain_core import messages as langchain_schema
from langchain_core.callbacks.base import BaseCallbackHandler, AsyncCallbackHandler
except ImportError:
from langchain import schema as langchain_schema
from langchain.callbacks.base import BaseCallbackHandler, AsyncCallbackHandler

HAVE_LANGCHAIN = True
except ImportError:
HAVE_LANGCHAIN = False

if TYPE_CHECKING:
from langchain import schema as langchain_schema
from langchain.callbacks.base import BaseCallbackHandler, AsyncCallbackHandler

from ..tracing import tracer, steps, traces, enums
from .. import utils
Expand Down Expand Up @@ -50,6 +51,8 @@ def __init__(self, **kwargs: Any) -> None:
self.metadata: Dict[str, Any] = kwargs or {}
self.steps: Dict[UUID, steps.Step] = {}
self.root_steps: set[UUID] = set() # Track which steps are root
# Track standalone traces (consistent with async handler)
self._traces_by_root: Dict[UUID, traces.Trace] = {}
# Extract inference_id from kwargs if provided
self._inference_id = kwargs.get("inference_id")
# Extract metadata_transformer from kwargs if provided
Expand Down Expand Up @@ -94,17 +97,17 @@ def _start_step(
current_trace = tracer.get_current_trace()

if current_step is not None:
# We're inside a @trace() decorated function - add as nested step
# We're inside an existing step context - add as nested
current_step.add_nested_step(step)
elif current_trace is not None:
# There's an existing trace but no current step
# Existing trace but no current step - add to trace
current_trace.add_step(step)
# Don't track in _traces_by_root since we're using external trace
else:
# No existing trace - create new one (standalone mode)
current_trace = traces.Trace()
tracer._current_trace.set(current_trace)
tracer._rag_context.set(None)
current_trace.add_step(step)
# No existing context - create standalone trace
trace = traces.Trace()
trace.add_step(step)
self._traces_by_root[run_id] = trace

# Track root steps (those without parent_run_id)
if parent_run_id is None:
Expand Down Expand Up @@ -151,23 +154,22 @@ def _end_step(
if hasattr(step, key):
setattr(step, key, value)

# Only upload trace if this was a root step and we're not in a @trace() context
if is_root_step and tracer.get_current_step() is None:
self._process_and_upload_trace(step)
# Only upload if this is a standalone trace (not integrated with external trace)
# If current_step is set, we're part of a larger trace and shouldn't upload
if is_root_step and run_id in self._traces_by_root and tracer.get_current_step() is None:
trace = self._traces_by_root.pop(run_id)
self._process_and_upload_trace(trace)

def _process_and_upload_trace(self, root_step: steps.Step) -> None:
def _process_and_upload_trace(self, trace: traces.Trace) -> None:
"""Process and upload the completed trace (only for standalone root steps)."""
current_trace = tracer.get_current_trace()
if not current_trace:
if not trace:
return

# Convert all LangChain objects in the trace once at the end
self._convert_step_objects_recursively(root_step)
for step in current_trace.steps:
if step != root_step: # Avoid converting root_step twice
self._convert_step_objects_recursively(step)
for step in trace.steps:
self._convert_step_objects_recursively(step)

trace_data, input_variable_names = tracer.post_process_trace(current_trace)
trace_data, input_variable_names = tracer.post_process_trace(trace)

config = dict(
tracer.ConfigLlmData(
Expand Down Expand Up @@ -1043,6 +1045,10 @@ def __init__(
self._ignore_agent = ignore_agent
# For async: manage our own trace mapping since context vars are unreliable
self._traces_by_root: Dict[UUID, traces.Trace] = {}
# Detect if an external trace context exists at initialization time
# If true, we'll create standalone traces for external system integration
# instead of uploading them independently
self._has_external_trace: bool = tracer.get_current_trace() is not None

@property
def ignore_llm(self) -> bool:
Expand Down Expand Up @@ -1098,15 +1104,37 @@ def _start_step(
parent_step = self.steps[parent_run_id]
parent_step.add_nested_step(step)
else:
# This is a root step - create a new trace
trace = traces.Trace()
trace.add_step(step)
self._traces_by_root[run_id] = trace
self.root_steps.add(run_id)

# Override step ID with custom inference_id if provided
if self._inference_id is not None:
step.id = self._inference_id
# Check if we're in an existing trace context via ContextVars
current_step = tracer.get_current_step()
current_trace = tracer.get_current_trace()

if current_step is not None:
# We're inside an existing step context - add as nested
current_step.add_nested_step(step)
elif current_trace is not None:
# Have trace but no current step
# If it's an external trace, we should NOT add at root - external system will integrate
# If it's a ContextVar trace with no current step, add to trace
if not self._has_external_trace:
# ContextVar-detected trace - add directly
current_trace.add_step(step)
else:
# External trace without current step - create temp standalone for later integration
trace = traces.Trace()
trace.add_step(step)
self._traces_by_root[run_id] = trace
else:
# No existing context - create standalone trace
trace = traces.Trace()
trace.add_step(step)
self._traces_by_root[run_id] = trace

# Track root steps
if parent_run_id is None:
self.root_steps.add(run_id)
# Override step ID with custom inference_id if provided
if self._inference_id is not None:
step.id = self._inference_id

self.steps[run_id] = step
return step
Expand Down Expand Up @@ -1146,8 +1174,11 @@ def _end_step(
if hasattr(step, key):
setattr(step, key, value)

# If this is a root step, process and upload the trace
if is_root_step and run_id in self._traces_by_root:
# Only upload if this is a standalone trace (not integrated with external trace)
has_standalone_trace = run_id in self._traces_by_root

# Only upload if: root step + has standalone trace + not part of external trace
if is_root_step and has_standalone_trace and not self._has_external_trace:
trace = self._traces_by_root.pop(run_id)
self._process_and_upload_async_trace(trace)

Expand Down