diff --git a/src/openlayer/lib/integrations/langchain_callback.py b/src/openlayer/lib/integrations/langchain_callback.py index f907beb1..10ff3982 100644 --- a/src/openlayer/lib/integrations/langchain_callback.py +++ b/src/openlayer/lib/integrations/langchain_callback.py @@ -2,20 +2,21 @@ # pylint: disable=unused-argument import time -from typing import Any, Dict, List, Optional, Union, TYPE_CHECKING, Callable +from typing import Any, Dict, List, Optional, Union, Callable from uuid import UUID try: - from langchain import schema as langchain_schema - from langchain.callbacks.base import BaseCallbackHandler, AsyncCallbackHandler + try: + from langchain_core import messages as langchain_schema + from langchain_core.callbacks.base import BaseCallbackHandler, AsyncCallbackHandler + except ImportError: + from langchain import schema as langchain_schema + from langchain.callbacks.base import BaseCallbackHandler, AsyncCallbackHandler HAVE_LANGCHAIN = True except ImportError: HAVE_LANGCHAIN = False -if TYPE_CHECKING: - from langchain import schema as langchain_schema - from langchain.callbacks.base import BaseCallbackHandler, AsyncCallbackHandler from ..tracing import tracer, steps, traces, enums from .. import utils @@ -50,6 +51,8 @@ def __init__(self, **kwargs: Any) -> None: self.metadata: Dict[str, Any] = kwargs or {} self.steps: Dict[UUID, steps.Step] = {} self.root_steps: set[UUID] = set() # Track which steps are root + # Track standalone traces (consistent with async handler) + self._traces_by_root: Dict[UUID, traces.Trace] = {} # Extract inference_id from kwargs if provided self._inference_id = kwargs.get("inference_id") # Extract metadata_transformer from kwargs if provided @@ -94,17 +97,17 @@ def _start_step( current_trace = tracer.get_current_trace() if current_step is not None: - # We're inside a @trace() decorated function - add as nested step + # We're inside an existing step context - add as nested current_step.add_nested_step(step) elif current_trace is not None: - # There's an existing trace but no current step + # Existing trace but no current step - add to trace current_trace.add_step(step) + # Don't track in _traces_by_root since we're using external trace else: - # No existing trace - create new one (standalone mode) - current_trace = traces.Trace() - tracer._current_trace.set(current_trace) - tracer._rag_context.set(None) - current_trace.add_step(step) + # No existing context - create standalone trace + trace = traces.Trace() + trace.add_step(step) + self._traces_by_root[run_id] = trace # Track root steps (those without parent_run_id) if parent_run_id is None: @@ -151,23 +154,22 @@ def _end_step( if hasattr(step, key): setattr(step, key, value) - # Only upload trace if this was a root step and we're not in a @trace() context - if is_root_step and tracer.get_current_step() is None: - self._process_and_upload_trace(step) + # Only upload if this is a standalone trace (not integrated with external trace) + # If current_step is set, we're part of a larger trace and shouldn't upload + if is_root_step and run_id in self._traces_by_root and tracer.get_current_step() is None: + trace = self._traces_by_root.pop(run_id) + self._process_and_upload_trace(trace) - def _process_and_upload_trace(self, root_step: steps.Step) -> None: + def _process_and_upload_trace(self, trace: traces.Trace) -> None: """Process and upload the completed trace (only for standalone root steps).""" - current_trace = tracer.get_current_trace() - if not current_trace: + if not trace: return # Convert all LangChain objects in the trace once at the end - self._convert_step_objects_recursively(root_step) - for step in current_trace.steps: - if step != root_step: # Avoid converting root_step twice - self._convert_step_objects_recursively(step) + for step in trace.steps: + self._convert_step_objects_recursively(step) - trace_data, input_variable_names = tracer.post_process_trace(current_trace) + trace_data, input_variable_names = tracer.post_process_trace(trace) config = dict( tracer.ConfigLlmData( @@ -1043,6 +1045,10 @@ def __init__( self._ignore_agent = ignore_agent # For async: manage our own trace mapping since context vars are unreliable self._traces_by_root: Dict[UUID, traces.Trace] = {} + # Detect if an external trace context exists at initialization time + # If true, we'll create standalone traces for external system integration + # instead of uploading them independently + self._has_external_trace: bool = tracer.get_current_trace() is not None @property def ignore_llm(self) -> bool: @@ -1098,15 +1104,37 @@ def _start_step( parent_step = self.steps[parent_run_id] parent_step.add_nested_step(step) else: - # This is a root step - create a new trace - trace = traces.Trace() - trace.add_step(step) - self._traces_by_root[run_id] = trace - self.root_steps.add(run_id) - - # Override step ID with custom inference_id if provided - if self._inference_id is not None: - step.id = self._inference_id + # Check if we're in an existing trace context via ContextVars + current_step = tracer.get_current_step() + current_trace = tracer.get_current_trace() + + if current_step is not None: + # We're inside an existing step context - add as nested + current_step.add_nested_step(step) + elif current_trace is not None: + # Have trace but no current step + # If it's an external trace, we should NOT add at root - external system will integrate + # If it's a ContextVar trace with no current step, add to trace + if not self._has_external_trace: + # ContextVar-detected trace - add directly + current_trace.add_step(step) + else: + # External trace without current step - create temp standalone for later integration + trace = traces.Trace() + trace.add_step(step) + self._traces_by_root[run_id] = trace + else: + # No existing context - create standalone trace + trace = traces.Trace() + trace.add_step(step) + self._traces_by_root[run_id] = trace + + # Track root steps + if parent_run_id is None: + self.root_steps.add(run_id) + # Override step ID with custom inference_id if provided + if self._inference_id is not None: + step.id = self._inference_id self.steps[run_id] = step return step @@ -1146,8 +1174,11 @@ def _end_step( if hasattr(step, key): setattr(step, key, value) - # If this is a root step, process and upload the trace - if is_root_step and run_id in self._traces_by_root: + # Only upload if this is a standalone trace (not integrated with external trace) + has_standalone_trace = run_id in self._traces_by_root + + # Only upload if: root step + has standalone trace + not part of external trace + if is_root_step and has_standalone_trace and not self._has_external_trace: trace = self._traces_by_root.pop(run_id) self._process_and_upload_async_trace(trace)