diff --git a/nemoguardrails/actions/llm/utils.py b/nemoguardrails/actions/llm/utils.py index e58a1aba5..b1163081b 100644 --- a/nemoguardrails/actions/llm/utils.py +++ b/nemoguardrails/actions/llm/utils.py @@ -66,6 +66,8 @@ def _infer_model_name(llm: BaseLanguageModel): async def llm_call( llm: BaseLanguageModel, prompt: Union[str, List[dict]], + model_name: Optional[str] = None, + model_provider: Optional[str] = None, stop: Optional[List[str]] = None, custom_callback_handlers: Optional[List[AsyncCallbackHandler]] = None, ) -> str: @@ -76,7 +78,8 @@ async def llm_call( llm_call_info = LLMCallInfo() llm_call_info_var.set(llm_call_info) - llm_call_info.llm_model_name = _infer_model_name(llm) + llm_call_info.llm_model_name = model_name or _infer_model_name(llm) + llm_call_info.llm_provider_name = model_provider if custom_callback_handlers and custom_callback_handlers != [None]: all_callbacks = BaseCallbackManager( @@ -172,15 +175,15 @@ def get_colang_history( history += f'user "{event["text"]}"\n' elif event["type"] == "UserIntent": if include_texts: - history += f' {event["intent"]}\n' + history += f" {event['intent']}\n" else: - history += f'user {event["intent"]}\n' + history += f"user {event['intent']}\n" elif event["type"] == "BotIntent": # If we have instructions, we add them before the bot message. # But we only do that for the last bot message. if "instructions" in event and idx == last_bot_intent_idx: history += f"# {event['instructions']}\n" - history += f'bot {event["intent"]}\n' + history += f"bot {event['intent']}\n" elif event["type"] == "StartUtteranceBotAction" and include_texts: history += f' "{event["script"]}"\n' # We skip system actions from this log @@ -349,9 +352,9 @@ def flow_to_colang(flow: Union[dict, Flow]) -> str: if "_type" not in element: raise Exception("bla") if element["_type"] == "UserIntent": - colang_flow += f'user {element["intent_name"]}\n' + colang_flow += f"user {element['intent_name']}\n" elif element["_type"] == "run_action" and element["action_name"] == "utter": - colang_flow += f'bot {element["action_params"]["value"]}\n' + colang_flow += f"bot {element['action_params']['value']}\n" return colang_flow diff --git a/nemoguardrails/logging/explain.py b/nemoguardrails/logging/explain.py index f6e3b5bc0..d9c282d15 100644 --- a/nemoguardrails/logging/explain.py +++ b/nemoguardrails/logging/explain.py @@ -59,6 +59,10 @@ class LLMCallInfo(LLMCallSummary): default="unknown", description="The name of the model use for the LLM call.", ) + llm_provider_name: Optional[str] = Field( + default="unknown", + description="The provider of the model used for the LLM call, e.g. 'openai', 'nvidia'.", + ) class ExplainInfo(BaseModel): @@ -100,7 +104,7 @@ def print_llm_calls_summary(self): for i in range(len(self.llm_calls)): llm_call = self.llm_calls[i] msg = ( - f"{i+1}. Task `{llm_call.task}` took {llm_call.duration:.2f} seconds " + f"{i + 1}. Task `{llm_call.task}` took {llm_call.duration:.2f} seconds " + ( f"and used {llm_call.total_tokens} tokens." if total_tokens diff --git a/nemoguardrails/rails/llm/config.py b/nemoguardrails/rails/llm/config.py index ffdd10220..c43711fd2 100644 --- a/nemoguardrails/rails/llm/config.py +++ b/nemoguardrails/rails/llm/config.py @@ -364,6 +364,18 @@ class TracingConfig(BaseModel): default_factory=lambda: [LogAdapterConfig()], description="The list of tracing adapters to use. If not specified, the default adapters are used.", ) + span_format: str = Field( + default="opentelemetry", + description="The span format to use. Options are 'legacy' (simple metrics) or 'opentelemetry' (OpenTelemetry semantic conventions).", + ) + enable_content_capture: bool = Field( + default=False, + description=( + "Capture prompts and responses (user/assistant/tool message content) in tracing/telemetry events. " + "Disabled by default for privacy and alignment with OpenTelemetry GenAI semantic conventions. " + "WARNING: Enabling this may include PII and sensitive data in your telemetry backend." + ), + ) class EmbeddingsCacheConfig(BaseModel): diff --git a/nemoguardrails/rails/llm/llmrails.py b/nemoguardrails/rails/llm/llmrails.py index 98ccd2dea..97f2d33c6 100644 --- a/nemoguardrails/rails/llm/llmrails.py +++ b/nemoguardrails/rails/llm/llmrails.py @@ -240,6 +240,8 @@ def __init__( from nemoguardrails.tracing import create_log_adapters self._log_adapters = create_log_adapters(config.tracing) + else: + self._log_adapters = None # We run some additional checks on the config self._validate_config() @@ -1149,9 +1151,19 @@ async def generate_async( # lazy import to avoid circular dependency from nemoguardrails.tracing import Tracer - # Create a Tracer instance with instantiated adapters + span_format = getattr( + self.config.tracing, "span_format", "opentelemetry" + ) + enable_content_capture = getattr( + self.config.tracing, "enable_content_capture", False + ) + # Create a Tracer instance with instantiated adapters and span configuration tracer = Tracer( - input=messages, response=res, adapters=self._log_adapters + input=messages, + response=res, + adapters=self._log_adapters, + span_format=span_format, + enable_content_capture=enable_content_capture, ) await tracer.export_async() diff --git a/nemoguardrails/tracing/__init__.py b/nemoguardrails/tracing/__init__.py index d99d29e56..69492c40d 100644 --- a/nemoguardrails/tracing/__init__.py +++ b/nemoguardrails/tracing/__init__.py @@ -13,4 +13,24 @@ # See the License for the specific language governing permissions and # limitations under the License. -from .tracer import InteractionLog, Tracer, create_log_adapters +from .interaction_types import InteractionLog, InteractionOutput +from .span_extractors import ( + SpanExtractor, + SpanExtractorV1, + SpanExtractorV2, + create_span_extractor, +) +from .spans import SpanEvent, SpanLegacy, SpanOpentelemetry +from .tracer import Tracer, create_log_adapters + +___all__ = [ + SpanExtractor, + SpanExtractorV1, + SpanExtractorV2, + create_span_extractor, + Tracer, + create_log_adapters, + SpanEvent, + SpanLegacy, + SpanOpentelemetry, +] diff --git a/nemoguardrails/tracing/adapters/base.py b/nemoguardrails/tracing/adapters/base.py index 6c355b0f3..5b4a2ad04 100644 --- a/nemoguardrails/tracing/adapters/base.py +++ b/nemoguardrails/tracing/adapters/base.py @@ -16,7 +16,7 @@ from abc import ABC, abstractmethod from typing import Optional -from nemoguardrails.eval.models import InteractionLog +from nemoguardrails.tracing.interaction_types import InteractionLog class InteractionLogAdapter(ABC): diff --git a/nemoguardrails/tracing/adapters/filesystem.py b/nemoguardrails/tracing/adapters/filesystem.py index 3e99398b8..bd6c967e1 100644 --- a/nemoguardrails/tracing/adapters/filesystem.py +++ b/nemoguardrails/tracing/adapters/filesystem.py @@ -24,6 +24,10 @@ from nemoguardrails.tracing import InteractionLog from nemoguardrails.tracing.adapters.base import InteractionLogAdapter +from nemoguardrails.tracing.span_formatting import ( + format_span_for_filesystem, + get_schema_version_for_filesystem, +) class FileSystemAdapter(InteractionLogAdapter): @@ -38,56 +42,46 @@ def __init__(self, filepath: Optional[str] = None): def transform(self, interaction_log: "InteractionLog"): """Transforms the InteractionLog into a JSON string.""" - spans = [] - - for span_data in interaction_log.trace: - span_dict = { - "name": span_data.name, - "span_id": span_data.span_id, - "parent_id": span_data.parent_id, - "trace_id": interaction_log.id, - "start_time": span_data.start_time, - "end_time": span_data.end_time, - "duration": span_data.duration, - "metrics": span_data.metrics, - } - spans.append(span_dict) + spans = [ + format_span_for_filesystem(span_data) for span_data in interaction_log.trace + ] + + if not interaction_log.trace: + schema_version = None + else: + schema_version = get_schema_version_for_filesystem(interaction_log.trace[0]) log_dict = { + "schema_version": schema_version, "trace_id": interaction_log.id, "spans": spans, } - with open(self.filepath, "a") as f: - f.write(json.dumps(log_dict, indent=2) + "\n") + with open(self.filepath, "a", encoding="utf-8") as f: + f.write(json.dumps(log_dict) + "\n") async def transform_async(self, interaction_log: "InteractionLog"): try: import aiofiles except ImportError: raise ImportError( - "aiofiles is required for async file writing. Please install it using `pip install aiofiles" + "aiofiles is required for async file writing. Please install it using `pip install aiofiles`" ) - spans = [] - - for span_data in interaction_log.trace: - span_dict = { - "name": span_data.name, - "span_id": span_data.span_id, - "parent_id": span_data.parent_id, - "trace_id": interaction_log.id, - "start_time": span_data.start_time, - "end_time": span_data.end_time, - "duration": span_data.duration, - "metrics": span_data.metrics, - } - spans.append(span_dict) + spans = [ + format_span_for_filesystem(span_data) for span_data in interaction_log.trace + ] + + if not interaction_log.trace: + schema_version = None + else: + schema_version = get_schema_version_for_filesystem(interaction_log.trace[0]) log_dict = { + "schema_version": schema_version, "trace_id": interaction_log.id, "spans": spans, } - async with aiofiles.open(self.filepath, "a") as f: - await f.write(json.dumps(log_dict, indent=2) + "\n") + async with aiofiles.open(self.filepath, "a", encoding="utf-8") as f: + await f.write(json.dumps(log_dict) + "\n") diff --git a/nemoguardrails/tracing/adapters/opentelemetry.py b/nemoguardrails/tracing/adapters/opentelemetry.py index 6044b3cfe..00456954c 100644 --- a/nemoguardrails/tracing/adapters/opentelemetry.py +++ b/nemoguardrails/tracing/adapters/opentelemetry.py @@ -55,13 +55,13 @@ import warnings from importlib.metadata import version -from typing import TYPE_CHECKING, Optional, Type +from typing import TYPE_CHECKING, Any, Dict if TYPE_CHECKING: from nemoguardrails.tracing import InteractionLog try: - from opentelemetry import trace - from opentelemetry.trace import NoOpTracerProvider + from opentelemetry import trace # type: ignore + from opentelemetry.trace import NoOpTracerProvider # type: ignore except ImportError: raise ImportError( @@ -70,34 +70,7 @@ ) from nemoguardrails.tracing.adapters.base import InteractionLogAdapter - -# DEPRECATED: global dictionary to store registered exporters -# will be removed in v0.16.0 -_exporter_name_cls_map: dict[str, Type] = {} - - -def register_otel_exporter(name: str, exporter_cls: Type): - """Register a new exporter. - - Args: - name: The name to register the exporter under. - exporter_cls: The exporter class to register. - - Deprecated: - This function is deprecated and will be removed in version 0.16.0. - Please configure OpenTelemetry exporters directly in your application code. - See the migration guide at: - https://github.com/NVIDIA/NeMo-Guardrails/blob/main/examples/configs/tracing/README.md#migration-guide - """ - warnings.warn( - "register_otel_exporter is deprecated and will be removed in version 0.16.0. " - "Please configure OpenTelemetry exporters directly in your application code. " - "See the migration guide at: " - "https://github.com/NVIDIA/NeMo-Guardrails/blob/develop/examples/configs/tracing/README.md#migration-guide", - DeprecationWarning, - stacklevel=2, - ) - _exporter_name_cls_map[name] = exporter_cls +from nemoguardrails.tracing.span_formatting import extract_span_attributes class OpenTelemetryAdapter(InteractionLogAdapter): @@ -114,40 +87,20 @@ class OpenTelemetryAdapter(InteractionLogAdapter): def __init__( self, service_name: str = "nemo_guardrails", - **kwargs, ): """ Initialize the OpenTelemetry adapter. Args: service_name: Service name for instrumentation scope (not used for resource) - **kwargs: Additional arguments (for backward compatibility) Note: Applications must configure the OpenTelemetry SDK before using this adapter. The adapter will use the globally configured tracer provider. """ - # check for deprecated parameters and warn users - deprecated_params = [ - "exporter", - "exporter_cls", - "resource_attributes", - "span_processor", - ] - used_deprecated = [param for param in deprecated_params if param in kwargs] - - if used_deprecated: - warnings.warn( - f"OpenTelemetry configuration parameters {used_deprecated} in YAML/config are deprecated " - "and will be ignored. Please configure OpenTelemetry in your application code. " - "See the migration guide at: " - "https://github.com/NVIDIA/NeMo-Guardrails/blob/main/examples/configs/tracing/README.md#migration-guide", - DeprecationWarning, - stacklevel=2, - ) # validate that OpenTelemetry is properly configured - provider = trace.get_tracer_provider() + provider = trace.get_tracer_provider() # type: ignore if provider is None or isinstance(provider, NoOpTracerProvider): warnings.warn( "No OpenTelemetry TracerProvider configured. Traces will not be exported. " @@ -158,7 +111,7 @@ def __init__( stacklevel=2, ) - self.tracer = trace.get_tracer( + self.tracer = trace.get_tracer( # type: ignore service_name, instrumenting_library_version=version("nemoguardrails"), schema_url="https://opentelemetry.io/schemas/1.26.0", @@ -166,10 +119,16 @@ def __init__( def transform(self, interaction_log: "InteractionLog"): """Transforms the InteractionLog into OpenTelemetry spans.""" - spans = {} + # get the actual interaction start time from the first rail + # all span times are relative offsets from this timestamp + base_time_ns = _get_base_time_ns(interaction_log) + + spans: Dict[str, Any] = {} for span_data in interaction_log.trace: - parent_span = spans.get(span_data.parent_id) + parent_span = ( + spans.get(span_data.parent_id) if span_data.parent_id else None + ) parent_context = ( trace.set_span_in_context(parent_span) if parent_span else None ) @@ -178,14 +137,21 @@ def transform(self, interaction_log: "InteractionLog"): span_data, parent_context, spans, - interaction_log.id, # trace_id + base_time_ns, ) async def transform_async(self, interaction_log: "InteractionLog"): """Transforms the InteractionLog into OpenTelemetry spans asynchronously.""" - spans = {} + # get the actual interaction start time from the first rail + # all span times are relative offsets from this timestamp + base_time_ns = _get_base_time_ns(interaction_log) + + spans: Dict[str, Any] = {} + for span_data in interaction_log.trace: - parent_span = spans.get(span_data.parent_id) + parent_span = ( + spans.get(span_data.parent_id) if span_data.parent_id else None + ) parent_context = ( trace.set_span_in_context(parent_span) if parent_span else None ) @@ -193,7 +159,7 @@ async def transform_async(self, interaction_log: "InteractionLog"): span_data, parent_context, spans, - interaction_log.id, # trace_id + base_time_ns, ) def _create_span( @@ -201,19 +167,91 @@ def _create_span( span_data, parent_context, spans, - trace_id, + base_time_ns, ): - with self.tracer.start_as_current_span( + """Create OTel span from a span. + + This is a pure API bridge - all semantic attributes are extracted + by the formatting function. We only handle: + 1. Timestamp conversion (relative to absolute) + 2. Span kind mapping (string to enum) + 3. API calls to create spans and events + """ + # convert relative times to absolute timestamps + # the span times are relative offsets from the start of the trace + # base_time_ns represents the start time of the trace + # we simply add the relative offsets to get absolute times + relative_start_ns = int(span_data.start_time * 1_000_000_000) + relative_end_ns = int(span_data.end_time * 1_000_000_000) + + start_time_ns = base_time_ns + relative_start_ns + end_time_ns = base_time_ns + relative_end_ns + + attributes = extract_span_attributes(span_data) + + from opentelemetry.trace import SpanKind as OTelSpanKind + + span_kind_map = { + "server": OTelSpanKind.SERVER, + "client": OTelSpanKind.CLIENT, + "internal": OTelSpanKind.INTERNAL, + } + + span_kind_str = attributes.get("span.kind", "internal") + otel_span_kind = span_kind_map.get(span_kind_str, OTelSpanKind.INTERNAL) + + span = self.tracer.start_span( span_data.name, context=parent_context, - ) as span: - for key, value in span_data.metrics.items(): + start_time=start_time_ns, + kind=otel_span_kind, + ) + + if attributes: + for key, value in attributes.items(): + if key == "span.kind": + continue span.set_attribute(key, value) - span.set_attribute("span_id", span_data.span_id) - span.set_attribute("trace_id", trace_id) - span.set_attribute("start_time", span_data.start_time) - span.set_attribute("end_time", span_data.end_time) - span.set_attribute("duration", span_data.duration) + if hasattr(span_data, "events") and span_data.events: + for event in span_data.events: + relative_event_ns = int(event.timestamp * 1_000_000_000) + event_time_ns = base_time_ns + relative_event_ns + + event_attrs = event.attributes.copy() if event.attributes else {} + + if event.body and isinstance(event.body, dict): + # merge body content into attributes for OTel compatibility + # (OTel events don't have separate body, just attributes) + for body_key, body_value in event.body.items(): + if body_key not in event_attrs: + event_attrs[body_key] = body_value + + span.add_event( + name=event.name, attributes=event_attrs, timestamp=event_time_ns + ) + + spans[span_data.span_id] = span + + span.end(end_time=end_time_ns) + + +def _get_base_time_ns(interaction_log: InteractionLog) -> int: + """Get the base time in nanoseconds for tracing spans. + + Args: + interaction_log: The interaction log containing rail timing information + + Returns: + Base time in nanoseconds, either from the first activated rail or current time + """ + if ( + interaction_log.activated_rails + and interaction_log.activated_rails[0].started_at + ): + return int(interaction_log.activated_rails[0].started_at * 1_000_000_000) + else: + # This shouldn't happen in normal operation, but provide a fallback + import time - spans[span_data.span_id] = span + return time.time_ns() diff --git a/nemoguardrails/tracing/constants.py b/nemoguardrails/tracing/constants.py new file mode 100644 index 000000000..3e0bf3179 --- /dev/null +++ b/nemoguardrails/tracing/constants.py @@ -0,0 +1,211 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""OpenTelemetry constants and semantic conventions for NeMo Guardrails.""" + + +class SpanKind: + """String constants for span kinds.""" + + SERVER = "server" + CLIENT = "client" + INTERNAL = "internal" + + +class SpanTypes: + """Internal span type identifiers used in span mapping. + + These are internal identifiers used to categorize spans before mapping + to actual span names. They represent the type of operation being traced. + + Note: 'llm_call' maps to various GenAI semantic convention span types + like inference (gen_ai.inference.client), embeddings, etc. + """ + + # NeMo Guardrails-specific internal types + INTERACTION = "interaction" # Entry point to guardrails + RAIL = "rail" # Rail execution + ACTION = "action" # Action execution + + # GenAI-related type (maps to official semantic conventions) + LLM_CALL = "llm_call" # maps to gen_ai.inference.client + + # NOTE: might use more specific types in the future + # could add more specific types that align with semantic conventions: + # INFERENCE = "inference" # for gen_ai.inference.client spans + # EMBEDDING = "embedding" # for gen_ai.embeddings.client spans + + +class SpanNamePatterns: + """Patterns used for identifying span types from span names.""" + + # patterns that indicate SERVER spans + INTERACTION = "interaction" + GUARDRAILS_REQUEST_PATTERN = "guardrails.request" + + # patterns that indicate CLIENT spans + GEN_AI_PREFIX = "gen_ai." + LLM = "llm" + COMPLETION = "completion" + + +class SystemConstants: + """System-level constants for NeMo Guardrails.""" + + SYSTEM_NAME = "nemo-guardrails" + UNKNOWN = "unknown" + + +class GenAIAttributes: + """GenAI semantic convention attributes following the draft specification. + + Note: These are based on the experimental OpenTelemetry GenAI semantic conventions + since they are not yet available in the stable semantic conventions package. + + See: https://opentelemetry.io/docs/specs/semconv/gen-ai/ + """ + + GEN_AI_SYSTEM = "gen_ai.system" # @deprecated + + GEN_AI_PROVIDER_NAME = "gen_ai.provider.name" + GEN_AI_OPERATION_NAME = "gen_ai.operation.name" + + GEN_AI_REQUEST_MODEL = "gen_ai.request.model" + GEN_AI_REQUEST_MAX_TOKENS = "gen_ai.request.max_tokens" + GEN_AI_REQUEST_TEMPERATURE = "gen_ai.request.temperature" + GEN_AI_REQUEST_TOP_P = "gen_ai.request.top_p" + GEN_AI_REQUEST_TOP_K = "gen_ai.request.top_k" + GEN_AI_REQUEST_FREQUENCY_PENALTY = "gen_ai.request.frequency_penalty" + GEN_AI_REQUEST_PRESENCE_PENALTY = "gen_ai.request.presence_penalty" + GEN_AI_REQUEST_STOP_SEQUENCES = "gen_ai.request.stop_sequences" + + GEN_AI_RESPONSE_MODEL = "gen_ai.response.model" + GEN_AI_RESPONSE_ID = "gen_ai.response.id" + GEN_AI_RESPONSE_FINISH_REASONS = "gen_ai.response.finish_reasons" + + GEN_AI_USAGE_INPUT_TOKENS = "gen_ai.usage.input_tokens" + GEN_AI_USAGE_OUTPUT_TOKENS = "gen_ai.usage.output_tokens" + GEN_AI_USAGE_TOTAL_TOKENS = "gen_ai.usage.total_tokens" + + +class CommonAttributes: + """Common OpenTelemetry attributes used across spans.""" + + SPAN_KIND = "span.kind" + + +class GuardrailsAttributes: + """NeMo Guardrails-specific attributes for spans.""" + + # rail attributes + RAIL_TYPE = "rail.type" + RAIL_NAME = "rail.name" + RAIL_STOP = "rail.stop" + RAIL_DECISIONS = "rail.decisions" + + # action attributes + ACTION_NAME = "action.name" + ACTION_HAS_LLM_CALLS = "action.has_llm_calls" + ACTION_LLM_CALLS_COUNT = "action.llm_calls_count" + ACTION_PARAM_PREFIX = "action.param." # For dynamic action parameters + + +class SpanNames: + """Standard span names following OpenTelemetry GenAI semantic conventions. + + Based on: https://opentelemetry.io/docs/specs/semconv/gen-ai/gen-ai-spans/ + + IMPORTANT: Span names must be low cardinality to avoid performance issues. + Variable/high cardinality data (like specific rail types, model names, etc.) + should go in attributes instead of the span name. + """ + + # server spans (entry points); NeMo Guardrails specific + GUARDRAILS_REQUEST = "guardrails.request" # Entry point for guardrails processing + + # internal spans; NeMo Guardrails specific + GUARDRAILS_RAIL = "guardrails.rail" # Use attributes for rail type/name + GUARDRAILS_ACTION = "guardrails.action" # Use attributes for action name + + # client spans (LLM calls), following official GenAI semantic conventions + # "Span name SHOULD be `{gen_ai.operation.name} {gen_ai.request.model}`" + # since model names are high cardinality, we'll build these dynamically + # these are fallback operation names when model is unknown + GEN_AI_COMPLETION = "completion" + GEN_AI_CHAT = "chat" + GEN_AI_EMBEDDING = "embedding" + + +class OperationNames: + """Standard operation names for GenAI semantic conventions. + + Note: This only defines standard LLM operations. Custom actions and tasks + should be passed through as-is since they are dynamic and user-defined. + """ + + # standard LLM operations (from GenAI semantic conventions) + COMPLETION = "completion" + CHAT = "chat" + EMBEDDING = "embedding" + + # default operation for guardrails interactions + GUARDRAILS = "guardrails" + + +class EventNames: + """Standard event names for OpenTelemetry GenAI semantic conventions. + + Based on official spec at: + https://github.com/open-telemetry/semantic-conventions/blob/main/model/gen-ai/events.yaml + """ + + GEN_AI_SYSTEM_MESSAGE = "gen_ai.system.message" + GEN_AI_USER_MESSAGE = "gen_ai.user.message" + GEN_AI_ASSISTANT_MESSAGE = "gen_ai.assistant.message" + # GEN_AI_TOOL_MESSAGE = "gen_ai.tool.message" + + GEN_AI_CHOICE = "gen_ai.choice" + + GEN_AI_CONTENT_PROMPT = "gen_ai.content.prompt" # @deprecated ; use GEN_AI_USER_MESSAGE instead, as we are still using text completions we should use it! + GEN_AI_CONTENT_COMPLETION = "gen_ai.content.completion" # @deprecated ; use GEN_AI_ASSISTANT_MESSAGE, but as we are still using text completions we should use it! + + +class GuardrailsEventNames: + """NeMo Guardrails-specific event names (not OTel GenAI conventions). + + These events represent internal guardrails state changes, not LLM API calls. + They use a guardrails-specific namespace to avoid confusion with OTel GenAI semantic conventions. + """ + + UTTERANCE_USER_FINISHED = "guardrails.utterance.user.finished" + UTTERANCE_BOT_STARTED = "guardrails.utterance.bot.started" + UTTERANCE_BOT_FINISHED = "guardrails.utterance.bot.finished" + + USER_MESSAGE = "guardrails.user_message" + + +class GuardrailsEventTypes: + """NeMo Guardrails internal event type constants. + + These are the type values from internal guardrails events. + """ + + UTTERANCE_USER_ACTION_FINISHED = "UtteranceUserActionFinished" + USER_MESSAGE = "UserMessage" + + START_UTTERANCE_BOT_ACTION = "StartUtteranceBotAction" + UTTERANCE_BOT_ACTION_FINISHED = "UtteranceBotActionFinished" + + SYSTEM_MESSAGE = "SystemMessage" diff --git a/nemoguardrails/tracing/interaction_types.py b/nemoguardrails/tracing/interaction_types.py new file mode 100644 index 000000000..51f77bdbd --- /dev/null +++ b/nemoguardrails/tracing/interaction_types.py @@ -0,0 +1,83 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Core models for the tracing system.""" + +from typing import Any, List, Optional, Union + +from pydantic import BaseModel, Field + +from nemoguardrails.rails.llm.options import ActivatedRail, GenerationLog +from nemoguardrails.tracing.span_extractors import SpanExtractor, create_span_extractor +from nemoguardrails.tracing.spans import SpanLegacy, SpanOpentelemetry + + +class InteractionLog(BaseModel): + """Detailed log about the execution of an interaction.""" + + id: str = Field(description="A human readable id of the interaction.") + + activated_rails: List[ActivatedRail] = Field( + default_factory=list, description="Details about the activated rails." + ) + events: List[dict] = Field( + default_factory=list, + description="The full list of events recorded during the interaction.", + ) + trace: List[Union[SpanLegacy, SpanOpentelemetry]] = Field( + default_factory=list, description="Detailed information about the execution." + ) + + +class InteractionOutput(BaseModel): + """Simple model for interaction output used in tracer.""" + + id: str = Field(description="A human readable id of the interaction.") + input: Any = Field(description="The input for the interaction.") + output: Optional[Any] = Field( + default=None, description="The output of the interaction." + ) + + +def extract_interaction_log( + interaction_output: InteractionOutput, + generation_log: GenerationLog, + span_format: str = "opentelemetry", + enable_content_capture: bool = False, +) -> InteractionLog: + """Extracts an `InteractionLog` object from an `GenerationLog` object. + + Args: + interaction_output: The interaction output + generation_log: The generation log + span_format: Span format to use ("legacy" or "opentelemetry") + enable_content_capture: Whether to include content in trace events + """ + internal_events = generation_log.internal_events + + span_extractor: SpanExtractor = create_span_extractor( + span_format=span_format, + events=internal_events, + enable_content_capture=enable_content_capture, + ) + + spans = span_extractor.extract_spans(generation_log.activated_rails) + + return InteractionLog( + id=interaction_output.id, + activated_rails=generation_log.activated_rails, + events=generation_log.internal_events, + trace=spans, + ) diff --git a/nemoguardrails/tracing/span_extractors.py b/nemoguardrails/tracing/span_extractors.py new file mode 100644 index 000000000..637f754f9 --- /dev/null +++ b/nemoguardrails/tracing/span_extractors.py @@ -0,0 +1,482 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Span extraction logic for different span versions.""" + +from abc import ABC, abstractmethod +from typing import List, Optional, Union + +from nemoguardrails.rails.llm.options import ActivatedRail +from nemoguardrails.tracing.constants import ( + EventNames, + GuardrailsEventNames, + GuardrailsEventTypes, + OperationNames, + SpanNames, + SpanTypes, + SystemConstants, +) +from nemoguardrails.tracing.spans import ( + ActionSpan, + InteractionSpan, + LLMSpan, + RailSpan, + SpanEvent, + SpanLegacy, + SpanOpentelemetry, + TypedSpan, +) +from nemoguardrails.utils import new_uuid + + +class SpanExtractor(ABC): + """Base class for span extractors.""" + + @abstractmethod + def extract_spans( + self, activated_rails: List[ActivatedRail] + ) -> List[Union[SpanLegacy, SpanOpentelemetry]]: + """Extract spans from activated rails.""" + ... + + +class SpanExtractorV1(SpanExtractor): + """Extract v1 spans (legacy format).""" + + def extract_spans( + self, activated_rails: List[ActivatedRail] + ) -> List[Union[SpanLegacy, SpanOpentelemetry]]: + """Extract v1 spans from activated rails.""" + spans: List[SpanLegacy] = [] + if not activated_rails: + return spans + + ref_time = activated_rails[0].started_at or 0.0 + + # Create interaction span + interaction_span = SpanLegacy( + span_id=new_uuid(), + name=SpanTypes.INTERACTION, # V1 uses legacy naming + start_time=(activated_rails[0].started_at or 0.0) - ref_time, + end_time=(activated_rails[-1].finished_at or 0.0) - ref_time, + duration=(activated_rails[-1].finished_at or 0.0) + - (activated_rails[0].started_at or 0.0), + ) + + interaction_span.metrics.update( + { + "interaction_total": 1, + "interaction_seconds_avg": interaction_span.duration, + "interaction_seconds_total": interaction_span.duration, + } + ) + spans.append(interaction_span) + + # Process rails and actions + for activated_rail in activated_rails: + rail_span = SpanLegacy( + span_id=new_uuid(), + name="rail: " + activated_rail.name, + parent_id=interaction_span.span_id, + start_time=(activated_rail.started_at or 0.0) - ref_time, + end_time=(activated_rail.finished_at or 0.0) - ref_time, + duration=activated_rail.duration or 0.0, + ) + spans.append(rail_span) + + for action in activated_rail.executed_actions: + action_span = SpanLegacy( + span_id=new_uuid(), + name="action: " + action.action_name, + parent_id=rail_span.span_id, + start_time=(action.started_at or 0.0) - ref_time, + end_time=(action.finished_at or 0.0) - ref_time, + duration=action.duration or 0.0, + ) + + base_metric_name = f"action_{action.action_name}" + action_span.metrics.update( + { + f"{base_metric_name}_total": 1, + f"{base_metric_name}_seconds_avg": action.duration or 0.0, + f"{base_metric_name}_seconds_total": action.duration or 0.0, + } + ) + spans.append(action_span) + + # Process LLM calls + for llm_call in action.llm_calls: + model_name = llm_call.llm_model_name or SystemConstants.UNKNOWN + llm_span = SpanLegacy( + span_id=new_uuid(), + name="LLM: " + model_name, + parent_id=action_span.span_id, + start_time=(llm_call.started_at or 0.0) - ref_time, + end_time=(llm_call.finished_at or 0.0) - ref_time, + duration=llm_call.duration or 0.0, + ) + + base_metric_name = f"llm_call_{model_name.replace('/', '_')}" + llm_span.metrics.update( + { + f"{base_metric_name}_total": 1, + f"{base_metric_name}_seconds_avg": llm_call.duration or 0.0, + f"{base_metric_name}_seconds_total": llm_call.duration + or 0.0, + f"{base_metric_name}_prompt_tokens_total": llm_call.prompt_tokens + or 0, + f"{base_metric_name}_completion_tokens_total": llm_call.completion_tokens + or 0, + f"{base_metric_name}_tokens_total": llm_call.total_tokens + or 0, + } + ) + spans.append(llm_span) + + return spans + + +class SpanExtractorV2(SpanExtractor): + """Extract v2 spans with OpenTelemetry semantic conventions.""" + + def __init__( + self, events: Optional[List[dict]] = None, enable_content_capture: bool = False + ): + """Initialize with optional events for extracting user/bot messages. + + Args: + events: Internal events from InteractionLog + enable_content_capture: Whether to include potentially sensitive content in events + """ + self.internal_events = events or [] + self.enable_content_capture = enable_content_capture + + def extract_spans( + self, activated_rails: List[ActivatedRail] + ) -> List[Union[SpanLegacy, SpanOpentelemetry, TypedSpan]]: + """Extract v2 spans from activated rails with OpenTelemetry attributes.""" + spans: List[TypedSpan] = [] + ref_time = activated_rails[0].started_at or 0.0 + + interaction_span = InteractionSpan( + span_id=new_uuid(), + name=SpanNames.GUARDRAILS_REQUEST, + start_time=(activated_rails[0].started_at or 0.0) - ref_time, + end_time=(activated_rails[-1].finished_at or 0.0) - ref_time, + duration=(activated_rails[-1].finished_at or 0.0) + - (activated_rails[0].started_at or 0.0), + operation_name=OperationNames.GUARDRAILS, + service_name=SystemConstants.SYSTEM_NAME, + ) + spans.append(interaction_span) + + for activated_rail in activated_rails: + # Create typed RailSpan + rail_span = RailSpan( + span_id=new_uuid(), + name=SpanNames.GUARDRAILS_RAIL, # Low-cardinality name + parent_id=interaction_span.span_id, + start_time=(activated_rail.started_at or 0.0) - ref_time, + end_time=(activated_rail.finished_at or 0.0) - ref_time, + duration=activated_rail.duration or 0.0, + rail_type=activated_rail.type, + rail_name=activated_rail.name, + rail_stop=( + activated_rail.stop if activated_rail.stop is not None else None + ), + rail_decisions=( + activated_rail.decisions if activated_rail.decisions else None + ), + ) + spans.append(rail_span) + + for action in activated_rail.executed_actions: + # Create typed ActionSpan + action_span = ActionSpan( + span_id=new_uuid(), + name=SpanNames.GUARDRAILS_ACTION, + parent_id=rail_span.span_id, + start_time=(action.started_at or 0.0) - ref_time, + end_time=(action.finished_at or 0.0) - ref_time, + duration=action.duration or 0.0, + action_name=action.action_name, + has_llm_calls=len(action.llm_calls) > 0, + llm_calls_count=len(action.llm_calls), + action_params={ + k: v + for k, v in (action.action_params or {}).items() + if isinstance(v, (str, int, float, bool)) + }, + error=True if hasattr(action, "error") and action.error else None, + error_type=( + type(action.error).__name__ + if hasattr(action, "error") and action.error + else None + ), + error_message=( + str(action.error) + if hasattr(action, "error") and action.error + else None + ), + ) + spans.append(action_span) + + for llm_call in action.llm_calls: + model_name = llm_call.llm_model_name or SystemConstants.UNKNOWN + + provider_name = ( + llm_call.llm_provider_name or SystemConstants.UNKNOWN + ) + + # use the specific task name as operation name (custom operation) + # this provides better observability for NeMo Guardrails specific tasks + operation_name = llm_call.task or OperationNames.COMPLETION + + # follow OpenTelemetry convention: span name = "{operation} {model}" + span_name = f"{operation_name} {model_name}" + + # extract request parameters from raw_response if available + temperature = None + max_tokens = None + top_p = None + response_id = None + finish_reasons = None + + if llm_call.raw_response: + response_id = llm_call.raw_response.get("id") + finish_reasons = self._extract_finish_reasons( + llm_call.raw_response + ) + temperature = llm_call.raw_response.get("temperature") + max_tokens = llm_call.raw_response.get("max_tokens") + top_p = llm_call.raw_response.get("top_p") + + llm_span = LLMSpan( + span_id=new_uuid(), + name=span_name, + parent_id=action_span.span_id, + start_time=(llm_call.started_at or 0.0) - ref_time, + end_time=(llm_call.finished_at or 0.0) - ref_time, + duration=llm_call.duration or 0.0, + provider_name=provider_name, + request_model=model_name, + response_model=model_name, + operation_name=operation_name, + usage_input_tokens=llm_call.prompt_tokens, + usage_output_tokens=llm_call.completion_tokens, + usage_total_tokens=llm_call.total_tokens, + temperature=temperature, + max_tokens=max_tokens, + top_p=top_p, + response_id=response_id, + response_finish_reasons=finish_reasons, + # TODO: add error to LLMCallInfo for future release + # error=( + # True + # if hasattr(llm_call, "error") and llm_call.error + # else None + # ), + # error_type=( + # type(llm_call.error).__name__ + # if hasattr(llm_call, "error") and llm_call.error + # else None + # ), + # error_message=( + # str(llm_call.error) + # if hasattr(llm_call, "error") and llm_call.error + # else None + # ), + ) + + llm_events = self._extract_llm_events(llm_call, llm_span.start_time) + llm_span.events.extend(llm_events) + + spans.append(llm_span) + + # Add conversation events to the interaction span + if self.internal_events: + interaction_events = self._extract_conversation_events(ref_time) + interaction_span.events.extend(interaction_events) + + return spans + + def _extract_llm_events(self, llm_call, start_time: float) -> List[SpanEvent]: + """Extract OpenTelemetry GenAI message events from an LLM call.""" + events = [] + + # TODO: Update to use newer gen_ai.user.message and gen_ai.assistant.message events + # Currently using deprecated gen_ai.content.prompt and gen_ai.content.completion for simplicity + if llm_call.prompt: + # per OTel spec: content should NOT be captured by default + body = {"content": llm_call.prompt} if self.enable_content_capture else {} + events.append( + SpanEvent( + name=EventNames.GEN_AI_CONTENT_PROMPT, + timestamp=start_time, + body=body, + ) + ) + + if llm_call.completion: + # per OTel spec: content should NOT be captured by default + body = ( + {"content": llm_call.completion} if self.enable_content_capture else {} + ) + events.append( + SpanEvent( + name=EventNames.GEN_AI_CONTENT_COMPLETION, + timestamp=start_time + (llm_call.duration or 0), + body=body, + ) + ) + + return events + + def _extract_conversation_events(self, ref_time: float) -> List[SpanEvent]: + """Extract guardrails-specific conversation events from internal events. + + NOTE: These are NeMo Guardrails internal events, NOT OpenTelemetry GenAI events. + We use guardrails-specific namespacing to avoid confusion with OTel GenAI semantic conventions. + """ + events = [] + + for event in self.internal_events: + event_type = event.get("type", "") + body = dict() + event_timestamp = self._get_event_timestamp(event, ref_time) + + if event_type == GuardrailsEventTypes.UTTERANCE_USER_ACTION_FINISHED: + if self.enable_content_capture: + body["content"] = event.get("final_transcript", "") + body["type"] = event_type + events.append( + SpanEvent( + name=GuardrailsEventNames.UTTERANCE_USER_FINISHED, + timestamp=event_timestamp, + body=body, + ) + ) + + elif event_type == GuardrailsEventTypes.USER_MESSAGE: + if self.enable_content_capture: + body["content"] = event.get("text", "") + body["type"] = event_type + events.append( + SpanEvent( + name=GuardrailsEventNames.USER_MESSAGE, + timestamp=event_timestamp, + body=body, + ) + ) + + elif event_type == GuardrailsEventTypes.START_UTTERANCE_BOT_ACTION: + if self.enable_content_capture: + body["content"] = event.get("script", "") + body["type"] = event_type + events.append( + SpanEvent( + name=GuardrailsEventNames.UTTERANCE_BOT_STARTED, + timestamp=event_timestamp, + body=body, + ) + ) + elif event_type == GuardrailsEventTypes.UTTERANCE_BOT_ACTION_FINISHED: + if self.enable_content_capture: + body["content"] = event.get("final_script", "") + body["type"] = event_type + body["is_success"] = event.get("is_success", True) + events.append( + SpanEvent( + name=GuardrailsEventNames.UTTERANCE_BOT_FINISHED, + timestamp=event_timestamp, + body=body, + ) + ) + + return events + + def _get_event_timestamp(self, event: dict, ref_time: float) -> float: + """Extract timestamp from event or use reference time. + + Args: + event: The internal event dictionary + ref_time: Reference time to use as fallback (trace start time) + + Returns: + Timestamp in seconds relative to trace start + """ + event_created_at = event.get("event_created_at") + if event_created_at: + try: + from datetime import datetime + + dt = datetime.fromisoformat(event_created_at) + absolute_timestamp = dt.timestamp() + return absolute_timestamp - ref_time + except (ValueError, AttributeError): + pass + + # fallback: use reference time (event at start of trace) + return 0.0 + + def _extract_finish_reasons(self, raw_response: dict) -> Optional[List[str]]: + """Extract finish reasons from raw LLM response.""" + if not raw_response: + return None + + finish_reasons = [] + + if "finish_reason" in raw_response: + finish_reasons.append(raw_response["finish_reason"]) + + if not finish_reasons and raw_response: + finish_reasons = ["stop"] + + return finish_reasons if finish_reasons else None + + +from nemoguardrails.tracing.span_format import SpanFormat, validate_span_format + + +def create_span_extractor( + span_format: str = "legacy", + events: Optional[List[dict]] = None, + enable_content_capture: bool = True, +) -> SpanExtractor: + """Create a span extractor based on format and configuration. + + Args: + span_format: Format of span extractor ('legacy' or 'opentelemetry') + events: Internal events for OpenTelemetry extractor + enable_content_capture: Whether to capture content in spans + + Returns: + Configured span extractor instance + + Raises: + ValueError: If span_format is not supported + """ + format_enum = validate_span_format(span_format) + + if format_enum == SpanFormat.LEGACY: + return SpanExtractorV1() # TODO: Rename to SpanExtractorLegacy + elif format_enum == SpanFormat.OPENTELEMETRY: + return SpanExtractorV2( # TODO: Rename to SpanExtractorOTel + events=events, + enable_content_capture=enable_content_capture, + ) + else: + # This should never happen due to validation, but keeps type checker happy + raise ValueError(f"Unknown span format: {span_format}") diff --git a/nemoguardrails/tracing/span_format.py b/nemoguardrails/tracing/span_format.py new file mode 100644 index 000000000..d524c127a --- /dev/null +++ b/nemoguardrails/tracing/span_format.py @@ -0,0 +1,85 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Span format definitions for NeMo Guardrails tracing.""" + +from enum import Enum +from typing import Literal, Union + + +class SpanFormat(str, Enum): + """Supported span formats for tracing. + + Inherits from str to allow direct string comparison and JSON serialization. + """ + + # legacy structure with metrics dictionary (simple, minimal overhead) + LEGACY = "legacy" + + # OpenTelemetry Semantic Conventions compliant format + # see https://opentelemetry.io/docs/specs/semconv/gen-ai/gen-ai-events/ + OPENTELEMETRY = "opentelemetry" + + @classmethod + def from_string(cls, value: str) -> "SpanFormat": + """Create SpanFormat from string value. + + Args: + value: String representation of span format + + Returns: + SpanFormat enum value + + Raises: + ValueError: If value is not a valid span format + """ + try: + return cls(value.lower()) + except ValueError: + valid_formats = [f.value for f in cls] + raise ValueError( + f"Invalid span format: '{value}'. " + f"Valid formats are: {', '.join(valid_formats)}" + ) + + def __str__(self) -> str: + """Return string value for use in configs.""" + return self.value + + +# Type alias for function signatures +SpanFormatType = Union[SpanFormat, Literal["legacy", "opentelemetry"], str] + + +def validate_span_format(value: SpanFormatType) -> SpanFormat: + """Validate and convert span format value to SpanFormat enum. + + Args: + value: Span format as enum, literal, or string + + Returns: + Validated SpanFormat enum value + + Raises: + ValueError: If value is not a valid span format + """ + if isinstance(value, SpanFormat): + return value + elif isinstance(value, str): + return SpanFormat.from_string(value) + else: + raise TypeError( + f"Span format must be a string or SpanFormat enum, got {type(value)}" + ) diff --git a/nemoguardrails/tracing/span_formatting.py b/nemoguardrails/tracing/span_formatting.py new file mode 100644 index 000000000..1350171ba --- /dev/null +++ b/nemoguardrails/tracing/span_formatting.py @@ -0,0 +1,107 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Simple span formatting functions for different output formats.""" + +from typing import Any, Dict + +from nemoguardrails.tracing.spans import SpanLegacy, is_opentelemetry_span + + +def get_schema_version_for_filesystem(span) -> str: + """Return the schema version string based on the span type.""" + if isinstance(span, SpanLegacy): + return "1.0" + if is_opentelemetry_span(span): + return "2.0" + raise ValueError(f"Unknown span type: {type(span).__name__}.") + + +def format_span_for_filesystem(span) -> Dict[str, Any]: + """Format any span type for JSON filesystem storage. + + Args: + span: Either SpanLegacy or typed span (InteractionSpan, RailSpan, etc.) + + Returns: + Dictionary with all span data for JSON serialization + """ + if not isinstance(span, SpanLegacy) and not is_opentelemetry_span(span): + raise ValueError( + f"Unknown span type: {type(span).__name__}. " + f"Only SpanLegacy and typed spans are supported." + ) + + result = { + "name": span.name, + "span_id": span.span_id, + "parent_id": span.parent_id, + "start_time": span.start_time, + "end_time": span.end_time, + "duration": span.duration, + "span_type": span.__class__.__name__, + } + + if isinstance(span, SpanLegacy): + if hasattr(span, "metrics") and span.metrics: + result["metrics"] = span.metrics + + else: # is_typed_span(span) + result["span_kind"] = span.span_kind + result["attributes"] = span.to_otel_attributes() + + if hasattr(span, "events") and span.events: + result["events"] = [ + { + "name": event.name, + "timestamp": event.timestamp, + "attributes": event.attributes, + } + for event in span.events + ] + + if hasattr(span, "error") and span.error: + result["error"] = { + "occurred": span.error, + "type": getattr(span, "error_type", None), + "message": getattr(span, "error_message", None), + } + + if hasattr(span, "custom_attributes") and span.custom_attributes: + result["custom_attributes"] = span.custom_attributes + + return result + + +def extract_span_attributes(span) -> Dict[str, Any]: + """Extract OpenTelemetry attributes from any span type. + + Args: + span: Either SpanLegacy or typed span + + Returns: + Dictionary of OpenTelemetry attributes + """ + if isinstance(span, SpanLegacy): + return span.metrics.copy() if hasattr(span, "metrics") and span.metrics else {} + + elif is_opentelemetry_span(span): + return span.to_otel_attributes() + + else: + raise ValueError( + f"Unknown span type: {type(span).__name__}. " + f"Only SpanLegacy and typed spans are supported." + ) diff --git a/nemoguardrails/tracing/spans.py b/nemoguardrails/tracing/spans.py new file mode 100644 index 000000000..fb89fb394 --- /dev/null +++ b/nemoguardrails/tracing/spans.py @@ -0,0 +1,341 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Span models for NeMo Guardrails tracing system.""" + +from abc import ABC, abstractmethod +from enum import Enum +from typing import Any, Dict, List, Optional, Union + +from pydantic import BaseModel, Field + +from nemoguardrails.tracing.constants import ( + CommonAttributes, + GenAIAttributes, + GuardrailsAttributes, +) + + +class SpanKind(str, Enum): + SERVER = "server" + CLIENT = "client" + INTERNAL = "internal" + + +class SpanEvent(BaseModel): + """Event that can be attached to a span.""" + + name: str = Field(description="Event name (e.g., 'gen_ai.user.message')") + timestamp: float = Field(description="Timestamp when the event occurred (relative)") + attributes: Dict[str, Any] = Field( + default_factory=dict, description="Event attributes" + ) + body: Optional[Dict[str, Any]] = Field( + default=None, description="Event body for structured data" + ) + + +class SpanLegacy(BaseModel): + """Simple span model (v1) for basic tracing.""" + + span_id: str = Field(description="The id of the span.") + name: str = Field(description="A human-readable name for the span.") + parent_id: Optional[str] = Field( + default=None, description="The id of the parent span." + ) + resource_id: Optional[str] = Field( + default=None, description="The id of the resource." + ) + start_time: float = Field(description="The start time of the span.") + end_time: float = Field(description="The end time of the span.") + duration: float = Field(description="The duration of the span in seconds.") + metrics: Dict[str, Union[int, float]] = Field( + default_factory=dict, description="The metrics recorded during the span." + ) + + +class BaseSpan(BaseModel, ABC): + """Base span with common fields across all span types.""" + + span_id: str = Field(description="Unique identifier for this span") + name: str = Field(description="Human-readable name for the span") + parent_id: Optional[str] = Field(default=None, description="ID of the parent span") + + start_time: float = Field( + description="Start time relative to trace start (seconds)" + ) + end_time: float = Field(description="End time relative to trace start (seconds)") + duration: float = Field(description="Duration of the span in seconds") + + span_kind: SpanKind = Field(description="OpenTelemetry span kind") + + events: List[SpanEvent] = Field( + default_factory=list, + description="Events attached to this span following OpenTelemetry conventions", + ) + + error: Optional[bool] = Field(default=None, description="Whether an error occurred") + error_type: Optional[str] = Field( + default=None, description="Type of error (e.g., exception class name)" + ) + error_message: Optional[str] = Field( + default=None, description="Error message or description" + ) + + custom_attributes: Dict[str, Any] = Field( + default_factory=dict, + description="Additional custom attributes not covered by typed fields", + ) + + @abstractmethod + def to_otel_attributes(self) -> Dict[str, Any]: + """Convert typed fields to legacy OpenTelemetry attributes dictionary. + + Returns: + Dict containing OTel semantic convention attributes. + """ + pass + + def _base_attributes(self) -> Dict[str, Any]: + """Get common attributes for all span types.""" + attributes = { + CommonAttributes.SPAN_KIND: self.span_kind, + } + + # TODO: for future release, consider adding: + # if self.error is not None: + # attributes["error"] = self.error + # if self.error_type is not None: + # attributes["error.type"] = self.error_type + # if self.error_message is not None: + # attributes["error.message"] = self.error_message + + attributes.update(self.custom_attributes) + + return attributes + + +class InteractionSpan(BaseSpan): + """Top-level span for a guardrails interaction (server span).""" + + span_kind: SpanKind = SpanKind.SERVER + + operation_name: str = Field( + default="guardrails", description="Operation name for this interaction" + ) + service_name: str = Field(default="nemo_guardrails", description="Service name") + + user_id: Optional[str] = Field(default=None, description="User identifier") + session_id: Optional[str] = Field(default=None, description="Session identifier") + request_id: Optional[str] = Field(default=None, description="Request identifier") + + def to_otel_attributes(self) -> Dict[str, Any]: + """Convert to OTel attributes.""" + attributes = self._base_attributes() + + attributes[GenAIAttributes.GEN_AI_OPERATION_NAME] = self.operation_name + attributes["service.name"] = self.service_name + + if self.user_id is not None: + attributes["user.id"] = self.user_id + if self.session_id is not None: + attributes["session.id"] = self.session_id + if self.request_id is not None: + attributes["request.id"] = self.request_id + + return attributes + + +class RailSpan(BaseSpan): + """Span for a guardrail execution (internal span).""" + + span_kind: SpanKind = SpanKind.INTERNAL + # rail-specific attributes + rail_type: str = Field(description="Type of rail (e.g., input, output, dialog)") + rail_name: str = Field(description="Name of the rail (e.g., check_jailbreak)") + rail_stop: Optional[bool] = Field( + default=None, description="Whether the rail stopped execution" + ) + rail_decisions: Optional[List[str]] = Field( + default=None, description="Decisions made by the rail" + ) + + def to_otel_attributes(self) -> Dict[str, Any]: + """Convert to OTel attributes.""" + attributes = self._base_attributes() + + attributes[GuardrailsAttributes.RAIL_TYPE] = self.rail_type + attributes[GuardrailsAttributes.RAIL_NAME] = self.rail_name + + if self.rail_stop is not None: + attributes[GuardrailsAttributes.RAIL_STOP] = self.rail_stop + if self.rail_decisions is not None: + attributes[GuardrailsAttributes.RAIL_DECISIONS] = self.rail_decisions + + return attributes + + +class ActionSpan(BaseSpan): + """Span for an action execution (internal span).""" + + span_kind: SpanKind = SpanKind.INTERNAL + # action-specific attributes + action_name: str = Field(description="Name of the action being executed") + action_params: Dict[str, Any] = Field( + default_factory=dict, description="Parameters passed to the action" + ) + has_llm_calls: bool = Field( + default=False, description="Whether this action made LLM calls" + ) + llm_calls_count: int = Field( + default=0, description="Number of LLM calls made by this action" + ) + + def to_otel_attributes(self) -> Dict[str, Any]: + """Convert to OTel attributes.""" + attributes = self._base_attributes() + + attributes[GuardrailsAttributes.ACTION_NAME] = self.action_name + attributes[GuardrailsAttributes.ACTION_HAS_LLM_CALLS] = self.has_llm_calls + attributes[GuardrailsAttributes.ACTION_LLM_CALLS_COUNT] = self.llm_calls_count + + # add action parameters as individual attributes + for param_name, param_value in self.action_params.items(): + if isinstance(param_value, (str, int, float, bool)): + attributes[ + f"{GuardrailsAttributes.ACTION_PARAM_PREFIX}{param_name}" + ] = param_value + + return attributes + + +class LLMSpan(BaseSpan): + """Span for an LLM API call (client span).""" + + span_kind: SpanKind = SpanKind.CLIENT + provider_name: str = Field( + description="LLM provider name (e.g., openai, anthropic)" + ) + request_model: str = Field(description="Model requested (e.g., gpt-4)") + response_model: str = Field( + description="Model that responded (usually same as request_model)" + ) + operation_name: str = Field( + description="Operation name (e.g., chat.completions, embeddings)" + ) + + usage_input_tokens: Optional[int] = Field( + default=None, description="Number of input tokens" + ) + usage_output_tokens: Optional[int] = Field( + default=None, description="Number of output tokens" + ) + usage_total_tokens: Optional[int] = Field( + default=None, description="Total number of tokens" + ) + + # Request parameters + temperature: Optional[float] = Field( + default=None, description="Temperature parameter" + ) + max_tokens: Optional[int] = Field( + default=None, description="Maximum tokens to generate" + ) + top_p: Optional[float] = Field(default=None, description="Top-p parameter") + top_k: Optional[int] = Field(default=None, description="Top-k parameter") + frequency_penalty: Optional[float] = Field( + default=None, description="Frequency penalty" + ) + presence_penalty: Optional[float] = Field( + default=None, description="Presence penalty" + ) + stop_sequences: Optional[List[str]] = Field( + default=None, description="Stop sequences" + ) + + response_id: Optional[str] = Field(default=None, description="Response identifier") + response_finish_reasons: Optional[List[str]] = Field( + default=None, description="Finish reasons for each choice" + ) + + def to_otel_attributes(self) -> Dict[str, Any]: + """Convert to OTel attributes.""" + attributes = self._base_attributes() + + attributes[GenAIAttributes.GEN_AI_PROVIDER_NAME] = self.provider_name + attributes[GenAIAttributes.GEN_AI_REQUEST_MODEL] = self.request_model + attributes[GenAIAttributes.GEN_AI_RESPONSE_MODEL] = self.response_model + attributes[GenAIAttributes.GEN_AI_OPERATION_NAME] = self.operation_name + + if self.usage_input_tokens is not None: + attributes[ + GenAIAttributes.GEN_AI_USAGE_INPUT_TOKENS + ] = self.usage_input_tokens + if self.usage_output_tokens is not None: + attributes[ + GenAIAttributes.GEN_AI_USAGE_OUTPUT_TOKENS + ] = self.usage_output_tokens + if self.usage_total_tokens is not None: + attributes[ + GenAIAttributes.GEN_AI_USAGE_TOTAL_TOKENS + ] = self.usage_total_tokens + + if self.temperature is not None: + attributes[GenAIAttributes.GEN_AI_REQUEST_TEMPERATURE] = self.temperature + if self.max_tokens is not None: + attributes[GenAIAttributes.GEN_AI_REQUEST_MAX_TOKENS] = self.max_tokens + if self.top_p is not None: + attributes[GenAIAttributes.GEN_AI_REQUEST_TOP_P] = self.top_p + if self.top_k is not None: + attributes[GenAIAttributes.GEN_AI_REQUEST_TOP_K] = self.top_k + if self.frequency_penalty is not None: + attributes[ + GenAIAttributes.GEN_AI_REQUEST_FREQUENCY_PENALTY + ] = self.frequency_penalty + if self.presence_penalty is not None: + attributes[ + GenAIAttributes.GEN_AI_REQUEST_PRESENCE_PENALTY + ] = self.presence_penalty + if self.stop_sequences is not None: + attributes[ + GenAIAttributes.GEN_AI_REQUEST_STOP_SEQUENCES + ] = self.stop_sequences + + if self.response_id is not None: + attributes[GenAIAttributes.GEN_AI_RESPONSE_ID] = self.response_id + if self.response_finish_reasons is not None: + attributes[ + GenAIAttributes.GEN_AI_RESPONSE_FINISH_REASONS + ] = self.response_finish_reasons + + return attributes + + +TypedSpan = Union[InteractionSpan, RailSpan, ActionSpan, LLMSpan] + +SpanOpentelemetry = TypedSpan + + +def is_opentelemetry_span(span: Any) -> bool: + """Check if an object is a typed span (V2). + + Args: + span: Object to check + + Returns: + True if the object is a typed span, False otherwise + """ + # Python 3.9 compatibility: cannot use isinstance with Union types + return isinstance(span, (InteractionSpan, RailSpan, ActionSpan, LLMSpan)) diff --git a/nemoguardrails/tracing/tracer.py b/nemoguardrails/tracing/tracer.py index 5ad59d5dd..b00c822cf 100644 --- a/nemoguardrails/tracing/tracer.py +++ b/nemoguardrails/tracing/tracer.py @@ -18,12 +18,15 @@ from contextlib import AsyncExitStack from typing import List, Optional -from nemoguardrails.eval.eval import _extract_interaction_log -from nemoguardrails.eval.models import InteractionLog, InteractionOutput from nemoguardrails.rails.llm.config import TracingConfig from nemoguardrails.rails.llm.options import GenerationLog, GenerationResponse from nemoguardrails.tracing.adapters.base import InteractionLogAdapter from nemoguardrails.tracing.adapters.registry import LogAdapterRegistry +from nemoguardrails.tracing.interaction_types import ( + InteractionLog, + InteractionOutput, + extract_interaction_log, +) def new_uuid() -> str: @@ -36,6 +39,8 @@ def __init__( input, response: GenerationResponse, adapters: Optional[List[InteractionLogAdapter]] = None, + span_format: str = "opentelemetry", + enable_content_capture: bool = False, ): self._interaction_output = InteractionOutput( id=new_uuid(), input=input[-1]["content"], output=response.response @@ -46,6 +51,8 @@ def __init__( raise RuntimeError("Generation log is missing.") self.adapters = adapters or [] + self._span_format = span_format + self._enable_content_capture = enable_content_capture def generate_interaction_log( self, @@ -59,7 +66,12 @@ def generate_interaction_log( if generation_log is None: generation_log = self._generation_log - interaction_log = _extract_interaction_log(interaction_output, generation_log) + interaction_log = extract_interaction_log( + interaction_output, + generation_log, + span_format=self._span_format, + enable_content_capture=self._enable_content_capture, + ) return interaction_log def add_adapter(self, adapter: InteractionLogAdapter): diff --git a/tests/test_tracing_adapters_filesystem.py b/tests/test_tracing_adapters_filesystem.py deleted file mode 100644 index df4a470c9..000000000 --- a/tests/test_tracing_adapters_filesystem.py +++ /dev/null @@ -1,111 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import asyncio -import importlib -import json -import os -import tempfile -import unittest -from unittest.mock import MagicMock - -from nemoguardrails.eval.models import Span -from nemoguardrails.tracing import InteractionLog -from nemoguardrails.tracing.adapters.filesystem import FileSystemAdapter - - -class TestFileSystemAdapter(unittest.TestCase): - def setUp(self): - # creating a temporary directory - self.temp_dir = tempfile.TemporaryDirectory() - self.filepath = os.path.join(self.temp_dir.name, "trace.jsonl") - - def tearDown(self): - self.temp_dir.cleanup() - - def test_initialization_default_path(self): - adapter = FileSystemAdapter() - self.assertEqual(adapter.filepath, "./.traces/trace.jsonl") - - def test_initialization_custom_path(self): - adapter = FileSystemAdapter(filepath=self.filepath) - self.assertEqual(adapter.filepath, self.filepath) - self.assertTrue(os.path.exists(os.path.dirname(self.filepath))) - - def test_transform(self): - adapter = FileSystemAdapter(filepath=self.filepath) - - # Mock the InteractionLog - interaction_log = InteractionLog( - id="test_id", - activated_rails=[], - events=[], - trace=[ - Span( - name="test_span", - span_id="span_1", - parent_id=None, - start_time=0.0, - end_time=1.0, - duration=1.0, - metrics={}, - ) - ], - ) - - adapter.transform(interaction_log) - - with open(self.filepath, "r") as f: - content = f.read() - log_dict = json.loads(content.strip()) - self.assertEqual(log_dict["trace_id"], "test_id") - self.assertEqual(len(log_dict["spans"]), 1) - self.assertEqual(log_dict["spans"][0]["name"], "test_span") - - @unittest.skipIf( - importlib.util.find_spec("aiofiles") is None, "aiofiles is not installed" - ) - def test_transform_async(self): - async def run_test(): - adapter = FileSystemAdapter(filepath=self.filepath) - - # Mock the InteractionLog - interaction_log = InteractionLog( - id="test_id", - activated_rails=[], - events=[], - trace=[ - Span( - name="test_span", - span_id="span_1", - parent_id=None, - start_time=0.0, - end_time=1.0, - duration=1.0, - metrics={}, - ) - ], - ) - - await adapter.transform_async(interaction_log) - - with open(self.filepath, "r") as f: - content = f.read() - log_dict = json.loads(content.strip()) - self.assertEqual(log_dict["trace_id"], "test_id") - self.assertEqual(len(log_dict["spans"]), 1) - self.assertEqual(log_dict["spans"][0]["name"], "test_span") - - asyncio.run(run_test()) diff --git a/tests/test_tracing_adapters_opentelemetry.py b/tests/test_tracing_adapters_opentelemetry.py deleted file mode 100644 index ee1a5a667..000000000 --- a/tests/test_tracing_adapters_opentelemetry.py +++ /dev/null @@ -1,366 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import asyncio -import unittest -import warnings -from importlib.metadata import version -from unittest.mock import MagicMock, patch - -# TODO: check to see if we can add it as a dependency -# but now we try to import opentelemetry and set a flag if it's not available -try: - from opentelemetry.sdk.trace import TracerProvider - from opentelemetry.trace import NoOpTracerProvider - - from nemoguardrails.tracing.adapters.opentelemetry import OpenTelemetryAdapter - - OPENTELEMETRY_AVAILABLE = True -except ImportError: - OPENTELEMETRY_AVAILABLE = False - -from nemoguardrails.eval.models import Span -from nemoguardrails.tracing import InteractionLog - - -@unittest.skipIf(not OPENTELEMETRY_AVAILABLE, "opentelemetry is not available") -class TestOpenTelemetryAdapter(unittest.TestCase): - def setUp(self): - # Set up a mock tracer provider for testing - self.mock_tracer_provider = MagicMock(spec=TracerProvider) - self.mock_tracer = MagicMock() - self.mock_tracer_provider.get_tracer.return_value = self.mock_tracer - - # Patch the global tracer provider - patcher_get_tracer_provider = patch("opentelemetry.trace.get_tracer_provider") - self.mock_get_tracer_provider = patcher_get_tracer_provider.start() - self.mock_get_tracer_provider.return_value = self.mock_tracer_provider - self.addCleanup(patcher_get_tracer_provider.stop) - - # Patch get_tracer to return our mock - patcher_get_tracer = patch("opentelemetry.trace.get_tracer") - self.mock_get_tracer = patcher_get_tracer.start() - self.mock_get_tracer.return_value = self.mock_tracer - self.addCleanup(patcher_get_tracer.stop) - - # Get the actual version for testing - self.actual_version = version("nemoguardrails") - - # Create the adapter - it should now use the global tracer - self.adapter = OpenTelemetryAdapter() - - def test_initialization(self): - """Test that the adapter initializes correctly using the global tracer.""" - - self.mock_get_tracer.assert_called_once_with( - "nemo_guardrails", - instrumenting_library_version=self.actual_version, - schema_url="https://opentelemetry.io/schemas/1.26.0", - ) - # Verify that the adapter has the mock tracer - self.assertEqual(self.adapter.tracer, self.mock_tracer) - - def test_transform(self): - """Test that transform creates spans correctly.""" - interaction_log = InteractionLog( - id="test_id", - activated_rails=[], - events=[], - trace=[ - Span( - name="test_span", - span_id="span_1", - parent_id=None, - start_time=0.0, - end_time=1.0, - duration=1.0, - metrics={"key": 123}, - ) - ], - ) - - self.adapter.transform(interaction_log) - - # Verify that start_as_current_span was called - self.mock_tracer.start_as_current_span.assert_called_once_with( - "test_span", - context=None, - ) - - # We retrieve the mock span instance here - span_instance = ( - self.mock_tracer.start_as_current_span.return_value.__enter__.return_value - ) - - # Verify span attributes were set - span_instance.set_attribute.assert_any_call("key", 123) - span_instance.set_attribute.assert_any_call("span_id", "span_1") - span_instance.set_attribute.assert_any_call("trace_id", "test_id") - span_instance.set_attribute.assert_any_call("start_time", 0.0) - span_instance.set_attribute.assert_any_call("end_time", 1.0) - span_instance.set_attribute.assert_any_call("duration", 1.0) - - def test_transform_span_attributes_various_types(self): - """Test that different attribute types are handled correctly.""" - interaction_log = InteractionLog( - id="test_id", - activated_rails=[], - events=[], - trace=[ - Span( - name="test_span", - span_id="span_1", - parent_id=None, - start_time=0.0, - end_time=1.0, - duration=1.0, - metrics={ - "int_key": 42, - "float_key": 3.14, - "str_key": 123, # Changed to a numeric value - "bool_key": 1, # Changed to a numeric value - }, - ) - ], - ) - - self.adapter.transform(interaction_log) - - span_instance = ( - self.mock_tracer.start_as_current_span.return_value.__enter__.return_value - ) - - span_instance.set_attribute.assert_any_call("int_key", 42) - span_instance.set_attribute.assert_any_call("float_key", 3.14) - span_instance.set_attribute.assert_any_call("str_key", 123) - span_instance.set_attribute.assert_any_call("bool_key", 1) - span_instance.set_attribute.assert_any_call("span_id", "span_1") - span_instance.set_attribute.assert_any_call("trace_id", "test_id") - span_instance.set_attribute.assert_any_call("start_time", 0.0) - span_instance.set_attribute.assert_any_call("end_time", 1.0) - span_instance.set_attribute.assert_any_call("duration", 1.0) - - def test_transform_with_empty_trace(self): - """Test transform with empty trace.""" - interaction_log = InteractionLog( - id="test_id", - activated_rails=[], - events=[], - trace=[], - ) - - self.adapter.transform(interaction_log) - - self.mock_tracer.start_as_current_span.assert_not_called() - - def test_transform_with_tracer_failure(self): - """Test transform when tracer fails.""" - self.mock_tracer.start_as_current_span.side_effect = Exception("Tracer failure") - - interaction_log = InteractionLog( - id="test_id", - activated_rails=[], - events=[], - trace=[ - Span( - name="test_span", - span_id="span_1", - parent_id=None, - start_time=0.0, - end_time=1.0, - duration=1.0, - metrics={"key": 123}, - ) - ], - ) - - with self.assertRaises(Exception) as context: - self.adapter.transform(interaction_log) - - self.assertIn("Tracer failure", str(context.exception)) - - def test_transform_async(self): - """Test async transform functionality.""" - - async def run_test(): - interaction_log = InteractionLog( - id="test_id", - activated_rails=[], - events=[], - trace=[ - Span( - name="test_span", - span_id="span_1", - parent_id=None, - start_time=0.0, - end_time=1.0, - duration=1.0, - metrics={"key": 123}, - ) - ], - ) - - await self.adapter.transform_async(interaction_log) - - self.mock_tracer.start_as_current_span.assert_called_once_with( - "test_span", - context=None, - ) - - # We retrieve the mock span instance here - span_instance = ( - self.mock_tracer.start_as_current_span.return_value.__enter__.return_value - ) - - span_instance.set_attribute.assert_any_call("key", 123) - span_instance.set_attribute.assert_any_call("span_id", "span_1") - span_instance.set_attribute.assert_any_call("trace_id", "test_id") - span_instance.set_attribute.assert_any_call("start_time", 0.0) - span_instance.set_attribute.assert_any_call("end_time", 1.0) - span_instance.set_attribute.assert_any_call("duration", 1.0) - - asyncio.run(run_test()) - - def test_transform_async_with_empty_trace(self): - """Test async transform with empty trace.""" - - async def run_test(): - interaction_log = InteractionLog( - id="test_id", - activated_rails=[], - events=[], - trace=[], - ) - - await self.adapter.transform_async(interaction_log) - - self.mock_tracer.start_as_current_span.assert_not_called() - - asyncio.run(run_test()) - - def test_transform_async_with_tracer_failure(self): - """Test async transform when tracer fails.""" - self.mock_tracer.start_as_current_span.side_effect = Exception("Tracer failure") - - async def run_test(): - interaction_log = InteractionLog( - id="test_id", - activated_rails=[], - events=[], - trace=[ - Span( - name="test_span", - span_id="span_1", - parent_id=None, - start_time=0.0, - end_time=1.0, - duration=1.0, - metrics={"key": 123}, - ) - ], - ) - - with self.assertRaises(Exception) as context: - await self.adapter.transform_async(interaction_log) - - self.assertIn("Tracer failure", str(context.exception)) - - asyncio.run(run_test()) - - def test_backward_compatibility_with_old_config(self): - """Test that old configuration parameters are still accepted.""" - # This should not fail even if old parameters are passed - adapter = OpenTelemetryAdapter( - service_name="test_service", - exporter="console", # this should be ignored gracefully - resource_attributes={"test": "value"}, # this should be ignored gracefully - ) - - # Should still create the adapter successfully - self.assertIsInstance(adapter, OpenTelemetryAdapter) - self.assertEqual(adapter.tracer, self.mock_tracer) - - def test_deprecation_warning_for_old_parameters(self): - """Test that deprecation warnings are raised for old configuration parameters.""" - - with warnings.catch_warnings(record=True) as w: - warnings.simplefilter("always") - - # adapter with deprecated parameters - _adapter = OpenTelemetryAdapter( - service_name="test_service", - exporter="console", - resource_attributes={"test": "value"}, - span_processor=MagicMock(), - ) - - # deprecation warning is issued - self.assertEqual(len(w), 1) - self.assertTrue(issubclass(w[0].category, DeprecationWarning)) - self.assertIn("deprecated", str(w[0].message)) - self.assertIn("exporter", str(w[0].message)) - self.assertIn("resource_attributes", str(w[0].message)) - self.assertIn("span_processor", str(w[0].message)) - - def test_no_op_tracer_provider_warning(self): - """Test that a warning is issued when NoOpTracerProvider is detected.""" - - with patch("opentelemetry.trace.get_tracer_provider") as mock_get_provider: - mock_get_provider.return_value = NoOpTracerProvider() - - with warnings.catch_warnings(record=True) as w: - warnings.simplefilter("always") - - _adapter = OpenTelemetryAdapter() - - self.assertEqual(len(w), 1) - self.assertTrue(issubclass(w[0].category, UserWarning)) - self.assertIn( - "No OpenTelemetry TracerProvider configured", str(w[0].message) - ) - self.assertIn("Traces will not be exported", str(w[0].message)) - - def test_no_warnings_with_proper_configuration(self): - """Test that no warnings are issued when properly configured.""" - with warnings.catch_warnings(record=True) as w: - warnings.simplefilter("always") - - # adapter without deprecated parameters - _adapter = OpenTelemetryAdapter(service_name="test_service") - - # no warnings is issued - self.assertEqual(len(w), 0) - - def test_register_otel_exporter_deprecation(self): - """Test that register_otel_exporter shows deprecation warning.""" - from nemoguardrails.tracing.adapters.opentelemetry import register_otel_exporter - - with warnings.catch_warnings(record=True) as w: - warnings.simplefilter("always") - - mock_exporter_cls = MagicMock() - - register_otel_exporter("test-exporter", mock_exporter_cls) - - self.assertEqual(len(w), 1) - self.assertTrue(issubclass(w[0].category, DeprecationWarning)) - self.assertIn("register_otel_exporter is deprecated", str(w[0].message)) - self.assertIn("0.16.0", str(w[0].message)) - - from nemoguardrails.tracing.adapters.opentelemetry import ( - _exporter_name_cls_map, - ) - - self.assertEqual(_exporter_name_cls_map["test-exporter"], mock_exporter_cls) diff --git a/tests/tracing/adapters/test_filesystem.py b/tests/tracing/adapters/test_filesystem.py new file mode 100644 index 000000000..b0c2d9659 --- /dev/null +++ b/tests/tracing/adapters/test_filesystem.py @@ -0,0 +1,442 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import asyncio +import importlib +import json +import os +import tempfile +import unittest + +from nemoguardrails.tracing import InteractionLog, SpanLegacy +from nemoguardrails.tracing.adapters.filesystem import FileSystemAdapter +from nemoguardrails.tracing.spans import ( + ActionSpan, + InteractionSpan, + LLMSpan, + RailSpan, + SpanEvent, +) + + +class TestFileSystemAdapter(unittest.TestCase): + def setUp(self): + # creating a temporary directory + self.temp_dir = tempfile.TemporaryDirectory() + self.filepath = os.path.join(self.temp_dir.name, "trace.jsonl") + + def tearDown(self): + self.temp_dir.cleanup() + + def test_initialization_default_path(self): + adapter = FileSystemAdapter() + self.assertEqual(adapter.filepath, "./.traces/trace.jsonl") + + def test_initialization_custom_path(self): + adapter = FileSystemAdapter(filepath=self.filepath) + self.assertEqual(adapter.filepath, self.filepath) + self.assertTrue(os.path.exists(os.path.dirname(self.filepath))) + + def test_transform(self): + adapter = FileSystemAdapter(filepath=self.filepath) + + # Mock the InteractionLog + interaction_log = InteractionLog( + id="test_id", + activated_rails=[], + events=[], + trace=[ + SpanLegacy( + name="test_span", + span_id="span_1", + parent_id=None, + start_time=0.0, + end_time=1.0, + duration=1.0, + metrics={}, + ) + ], + ) + + adapter.transform(interaction_log) + + with open(self.filepath, "r") as f: + content = f.read() + log_dict = json.loads(content.strip()) + self.assertEqual(log_dict["trace_id"], "test_id") + self.assertEqual(len(log_dict["spans"]), 1) + self.assertEqual(log_dict["spans"][0]["name"], "test_span") + + @unittest.skipIf( + importlib.util.find_spec("aiofiles") is None, "aiofiles is not installed" + ) + def test_transform_async(self): + async def run_test(): + adapter = FileSystemAdapter(filepath=self.filepath) + + # Mock the InteractionLog + interaction_log = InteractionLog( + id="test_id", + activated_rails=[], + events=[], + trace=[ + SpanLegacy( + name="test_span", + span_id="span_1", + parent_id=None, + start_time=0.0, + end_time=1.0, + duration=1.0, + metrics={}, + ) + ], + ) + + await adapter.transform_async(interaction_log) + + with open(self.filepath, "r") as f: + content = f.read() + log_dict = json.loads(content.strip()) + self.assertEqual(log_dict["trace_id"], "test_id") + self.assertEqual(len(log_dict["spans"]), 1) + self.assertEqual(log_dict["spans"][0]["name"], "test_span") + + asyncio.run(run_test()) + + def test_schema_version(self): + adapter = FileSystemAdapter(filepath=self.filepath) + interaction_log = InteractionLog( + id="test_id", + activated_rails=[], + events=[], + trace=[ + SpanLegacy( + name="test_span", + span_id="span_1", + parent_id=None, + start_time=0.0, + end_time=1.0, + duration=1.0, + metrics={}, + ) + ], + ) + adapter.transform(interaction_log) + + with open(self.filepath, "r") as f: + content = f.read() + log_dict = json.loads(content.strip()) + self.assertEqual(log_dict["schema_version"], "1.0") + + def test_span_legacy_with_metrics(self): + adapter = FileSystemAdapter(filepath=self.filepath) + interaction_log = InteractionLog( + id="test_trace", + activated_rails=[], + events=[], + trace=[ + SpanLegacy( + name="llm_call", + span_id="span_1", + parent_id=None, + start_time=0.0, + end_time=1.5, + duration=1.5, + metrics={ + "input_tokens": 10, + "output_tokens": 20, + "total_tokens": 30, + }, + ) + ], + ) + adapter.transform(interaction_log) + + with open(self.filepath, "r") as f: + content = f.read() + log_dict = json.loads(content.strip()) + span = log_dict["spans"][0] + self.assertEqual(span["span_type"], "SpanLegacy") + self.assertIn("metrics", span) + self.assertEqual(span["metrics"]["input_tokens"], 10) + self.assertEqual(span["metrics"]["output_tokens"], 20) + self.assertEqual(span["metrics"]["total_tokens"], 30) + + def test_interaction_span_with_events(self): + adapter = FileSystemAdapter(filepath=self.filepath) + events = [ + SpanEvent( + name="gen_ai.content.prompt", + timestamp=0.1, + attributes={"gen_ai.prompt": "Hello, how are you?"}, + ), + SpanEvent( + name="gen_ai.content.completion", + timestamp=1.9, + attributes={"gen_ai.completion": "I'm doing well, thank you!"}, + ), + ] + interaction_log = InteractionLog( + id="test_trace", + activated_rails=[], + events=[], + trace=[ + InteractionSpan( + name="interaction", + span_id="span_1", + parent_id=None, + start_time=0.0, + end_time=2.0, + duration=2.0, + span_kind="server", + request_model="gpt-4", + events=events, + ) + ], + ) + adapter.transform(interaction_log) + + with open(self.filepath, "r") as f: + content = f.read() + log_dict = json.loads(content.strip()) + span = log_dict["spans"][0] + self.assertEqual(span["span_type"], "InteractionSpan") + self.assertEqual(span["span_kind"], "server") + self.assertIn("events", span) + self.assertEqual(len(span["events"]), 2) + self.assertEqual(span["events"][0]["name"], "gen_ai.content.prompt") + self.assertEqual(span["events"][0]["timestamp"], 0.1) + self.assertIn("attributes", span) + self.assertIn("gen_ai.operation.name", span["attributes"]) + + def test_rail_span_with_attributes(self): + adapter = FileSystemAdapter(filepath=self.filepath) + interaction_log = InteractionLog( + id="test_trace", + activated_rails=[], + events=[], + trace=[ + RailSpan( + name="check_jailbreak", + span_id="span_1", + parent_id="parent_span", + start_time=0.5, + end_time=1.0, + duration=0.5, + span_kind="internal", + rail_type="input", + rail_name="check_jailbreak", + rail_stop=False, + rail_decisions=["allow"], + ) + ], + ) + adapter.transform(interaction_log) + + with open(self.filepath, "r") as f: + content = f.read() + log_dict = json.loads(content.strip()) + span = log_dict["spans"][0] + self.assertEqual(span["span_type"], "RailSpan") + self.assertEqual(span["span_kind"], "internal") + self.assertEqual(span["parent_id"], "parent_span") + self.assertIn("attributes", span) + self.assertEqual(span["attributes"]["rail.type"], "input") + self.assertEqual(span["attributes"]["rail.name"], "check_jailbreak") + self.assertEqual(span["attributes"]["rail.stop"], False) + self.assertEqual(span["attributes"]["rail.decisions"], ["allow"]) + + def test_action_span_with_error(self): + adapter = FileSystemAdapter(filepath=self.filepath) + interaction_log = InteractionLog( + id="test_trace", + activated_rails=[], + events=[], + trace=[ + ActionSpan( + name="execute_action", + span_id="span_1", + parent_id=None, + start_time=0.0, + end_time=0.5, + duration=0.5, + span_kind="internal", + action_name="fetch_data", + action_params={"url": "https://api.example.com"}, + error=True, + error_type="ConnectionError", + error_message="Failed to connect to API", + ) + ], + ) + adapter.transform(interaction_log) + + with open(self.filepath, "r") as f: + content = f.read() + log_dict = json.loads(content.strip()) + span = log_dict["spans"][0] + self.assertEqual(span["span_type"], "ActionSpan") + self.assertIn("error", span) + self.assertEqual(span["error"]["occurred"], True) + self.assertEqual(span["error"]["type"], "ConnectionError") + self.assertEqual(span["error"]["message"], "Failed to connect to API") + self.assertIn("attributes", span) + self.assertEqual(span["attributes"]["action.name"], "fetch_data") + + def test_llm_span_with_custom_attributes(self): + adapter = FileSystemAdapter(filepath=self.filepath) + interaction_log = InteractionLog( + id="test_trace", + activated_rails=[], + events=[], + trace=[ + LLMSpan( + name="llm_api_call", + span_id="span_1", + parent_id=None, + start_time=0.0, + end_time=1.0, + duration=1.0, + span_kind="client", + provider_name="openai", + operation_name="chat.completions", + request_model="gpt-4", + temperature=0.7, + response_model="gpt-4-0613", + usage_input_tokens=50, + usage_output_tokens=100, + custom_attributes={"custom_key": "custom_value"}, + ) + ], + ) + adapter.transform(interaction_log) + + with open(self.filepath, "r") as f: + content = f.read() + log_dict = json.loads(content.strip()) + span = log_dict["spans"][0] + self.assertEqual(span["span_type"], "LLMSpan") + self.assertEqual(span["span_kind"], "client") + self.assertIn("attributes", span) + self.assertEqual(span["attributes"]["gen_ai.request.model"], "gpt-4") + self.assertEqual(span["attributes"]["gen_ai.request.temperature"], 0.7) + self.assertEqual(span["attributes"]["gen_ai.response.model"], "gpt-4-0613") + self.assertEqual(span["attributes"]["gen_ai.usage.input_tokens"], 50) + self.assertEqual(span["attributes"]["gen_ai.usage.output_tokens"], 100) + self.assertIn("custom_attributes", span) + self.assertEqual(span["custom_attributes"]["custom_key"], "custom_value") + + def test_mixed_span_types(self): + adapter = FileSystemAdapter(filepath=self.filepath) + interaction_log = InteractionLog( + id="test_mixed", + activated_rails=[], + events=[], + trace=[ + InteractionSpan( + name="interaction", + span_id="span_1", + parent_id=None, + start_time=0.0, + end_time=3.0, + duration=3.0, + span_kind="server", + request_model="gpt-4", + ), + RailSpan( + name="check_jailbreak", + span_id="span_2", + parent_id="span_1", + start_time=0.5, + end_time=1.0, + duration=0.5, + span_kind="internal", + rail_type="input", + rail_name="check_jailbreak", + rail_stop=False, + ), + SpanLegacy( + name="legacy_span", + span_id="span_3", + parent_id="span_1", + start_time=1.5, + end_time=2.5, + duration=1.0, + metrics={"tokens": 25}, + ), + ], + ) + adapter.transform(interaction_log) + + with open(self.filepath, "r") as f: + content = f.read() + log_dict = json.loads(content.strip()) + self.assertEqual(len(log_dict["spans"]), 3) + + self.assertEqual(log_dict["spans"][0]["span_type"], "InteractionSpan") + self.assertIn("span_kind", log_dict["spans"][0]) + self.assertIn("attributes", log_dict["spans"][0]) + + self.assertEqual(log_dict["spans"][1]["span_type"], "RailSpan") + self.assertEqual(log_dict["spans"][1]["parent_id"], "span_1") + + self.assertEqual(log_dict["spans"][2]["span_type"], "SpanLegacy") + self.assertIn("metrics", log_dict["spans"][2]) + self.assertNotIn("span_kind", log_dict["spans"][2]) + + @unittest.skipIf( + importlib.util.find_spec("aiofiles") is None, "aiofiles is not installed" + ) + def test_transform_async_with_otel_spans(self): + async def run_test(): + adapter = FileSystemAdapter(filepath=self.filepath) + interaction_log = InteractionLog( + id="test_async_otel", + activated_rails=[], + events=[], + trace=[ + InteractionSpan( + name="interaction", + span_id="span_1", + parent_id=None, + start_time=0.0, + end_time=2.0, + duration=2.0, + span_kind="server", + request_model="gpt-4", + events=[ + SpanEvent( + name="test_event", + timestamp=1.0, + attributes={"key": "value"}, + ) + ], + ) + ], + ) + + await adapter.transform_async(interaction_log) + + with open(self.filepath, "r") as f: + content = f.read() + log_dict = json.loads(content.strip()) + self.assertEqual(log_dict["schema_version"], "2.0") + self.assertEqual(log_dict["trace_id"], "test_async_otel") + span = log_dict["spans"][0] + self.assertEqual(span["span_type"], "InteractionSpan") + self.assertIn("events", span) + self.assertEqual(len(span["events"]), 1) + + asyncio.run(run_test()) diff --git a/tests/tracing/adapters/test_opentelemetry.py b/tests/tracing/adapters/test_opentelemetry.py new file mode 100644 index 000000000..f6c1405dc --- /dev/null +++ b/tests/tracing/adapters/test_opentelemetry.py @@ -0,0 +1,464 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import asyncio +import unittest +import warnings +from importlib.metadata import version +from unittest.mock import MagicMock, patch + +from opentelemetry.sdk.trace import TracerProvider +from opentelemetry.trace import NoOpTracerProvider + +from nemoguardrails.tracing import ( + InteractionLog, + SpanEvent, + SpanLegacy, + SpanOpentelemetry, +) +from nemoguardrails.tracing.adapters.opentelemetry import OpenTelemetryAdapter + + +class TestOpenTelemetryAdapter(unittest.TestCase): + def setUp(self): + # Set up a mock tracer provider for testing + self.mock_tracer_provider = MagicMock(spec=TracerProvider) + self.mock_tracer = MagicMock() + self.mock_tracer_provider.get_tracer.return_value = self.mock_tracer + + # Patch the global tracer provider + patcher_get_tracer_provider = patch("opentelemetry.trace.get_tracer_provider") + self.mock_get_tracer_provider = patcher_get_tracer_provider.start() + self.mock_get_tracer_provider.return_value = self.mock_tracer_provider + self.addCleanup(patcher_get_tracer_provider.stop) + + # Patch get_tracer to return our mock + patcher_get_tracer = patch("opentelemetry.trace.get_tracer") + self.mock_get_tracer = patcher_get_tracer.start() + self.mock_get_tracer.return_value = self.mock_tracer + self.addCleanup(patcher_get_tracer.stop) + + # Get the actual version for testing + self.actual_version = version("nemoguardrails") + + # Create the adapter - it should now use the global tracer + self.adapter = OpenTelemetryAdapter() + + def test_initialization(self): + """Test that the adapter initializes correctly using the global tracer.""" + + self.mock_get_tracer.assert_called_once_with( + "nemo_guardrails", + instrumenting_library_version=self.actual_version, + schema_url="https://opentelemetry.io/schemas/1.26.0", + ) + # Verify that the adapter has the mock tracer + self.assertEqual(self.adapter.tracer, self.mock_tracer) + + def test_transform(self): + """Test that transform creates spans correctly with proper timing.""" + mock_span = MagicMock() + self.mock_tracer.start_span.return_value = mock_span + + interaction_log = InteractionLog( + id="test_id", + activated_rails=[], + events=[], + trace=[ + SpanLegacy( + name="test_span", + span_id="span_1", + parent_id=None, + start_time=1234567890.5, # historical timestamp + end_time=1234567891.5, # historical timestamp + duration=1.0, + metrics={"key": 123}, + ) + ], + ) + + self.adapter.transform(interaction_log) + + # Verify that start_span was called with proper timing (not start_as_current_span) + call_args = self.mock_tracer.start_span.call_args + self.assertEqual(call_args[0][0], "test_span") # name + self.assertEqual(call_args[1]["context"], None) # no parent context + # Verify start_time is a reasonable absolute timestamp in nanoseconds + start_time_ns = call_args[1]["start_time"] + self.assertIsInstance(start_time_ns, int) + self.assertGreater( + start_time_ns, 1e15 + ) # Should be realistic Unix timestamp in ns + + # V1 span metrics are set directly without prefix + mock_span.set_attribute.assert_any_call("key", 123) + # The adapter no longer sets intrinsic IDs as attributes + # (span_id, trace_id, duration are intrinsic to OTel spans) + + # Verify span was ended with correct end time + end_call_args = mock_span.end.call_args + end_time_ns = end_call_args[1]["end_time"] + self.assertIsInstance(end_time_ns, int) + self.assertGreater(end_time_ns, start_time_ns) # End should be after start + # Verify duration is approximately correct (allowing for conversion precision) + duration_ns = end_time_ns - start_time_ns + expected_duration_ns = int(1.0 * 1_000_000_000) # 1 second + self.assertAlmostEqual( + duration_ns, expected_duration_ns, delta=1000000 + ) # 1ms tolerance + + def test_transform_span_attributes_various_types(self): + """Test that different attribute types are handled correctly.""" + mock_span = MagicMock() + self.mock_tracer.start_span.return_value = mock_span + + interaction_log = InteractionLog( + id="test_id", + activated_rails=[], + events=[], + trace=[ + SpanLegacy( + name="test_span", + span_id="span_1", + parent_id=None, + start_time=1234567890.0, + end_time=1234567891.0, + duration=1.0, + metrics={ + "int_key": 42, + "float_key": 3.14, + "str_key": 123, # Changed to a numeric value + "bool_key": 1, # Changed to a numeric value + }, + ) + ], + ) + + self.adapter.transform(interaction_log) + + mock_span.set_attribute.assert_any_call("int_key", 42) + mock_span.set_attribute.assert_any_call("float_key", 3.14) + mock_span.set_attribute.assert_any_call("str_key", 123) + mock_span.set_attribute.assert_any_call("bool_key", 1) + # The adapter no longer sets intrinsic IDs as attributes + # (span_id, trace_id, duration are intrinsic to OTel spans) + # Verify span was ended + mock_span.end.assert_called_once() + end_call_args = mock_span.end.call_args + self.assertIn("end_time", end_call_args[1]) + self.assertIsInstance(end_call_args[1]["end_time"], int) + + def test_transform_with_empty_trace(self): + """Test transform with empty trace.""" + interaction_log = InteractionLog( + id="test_id", + activated_rails=[], + events=[], + trace=[], + ) + + self.adapter.transform(interaction_log) + + self.mock_tracer.start_span.assert_not_called() + + def test_transform_with_tracer_failure(self): + """Test transform when tracer fails.""" + self.mock_tracer.start_span.side_effect = Exception("Tracer failure") + + interaction_log = InteractionLog( + id="test_id", + activated_rails=[], + events=[], + trace=[ + SpanLegacy( + name="test_span", + span_id="span_1", + parent_id=None, + start_time=1234567890.0, + end_time=1234567891.0, + duration=1.0, + metrics={"key": 123}, + ) + ], + ) + + with self.assertRaises(Exception) as context: + self.adapter.transform(interaction_log) + + self.assertIn("Tracer failure", str(context.exception)) + + def test_transform_with_parent_child_relationships(self): + """Test that parent-child relationships are preserved with correct timing.""" + parent_mock_span = MagicMock() + child_mock_span = MagicMock() + self.mock_tracer.start_span.side_effect = [parent_mock_span, child_mock_span] + + interaction_log = InteractionLog( + id="test_id", + activated_rails=[], + events=[], + trace=[ + SpanLegacy( + name="parent_span", + span_id="span_1", + parent_id=None, + start_time=1234567890.0, + end_time=1234567892.0, + duration=2.0, + metrics={"parent_key": 1}, + ), + SpanLegacy( + name="child_span", + span_id="span_2", + parent_id="span_1", + start_time=1234567890.5, # child starts after parent + end_time=1234567891.5, # child ends before parent + duration=1.0, + metrics={"child_key": 2}, + ), + ], + ) + + with patch( + "opentelemetry.trace.set_span_in_context" + ) as mock_set_span_in_context: + mock_set_span_in_context.return_value = "parent_context" + + self.adapter.transform(interaction_log) + + # verify parent span created first with no context + self.assertEqual(self.mock_tracer.start_span.call_count, 2) + first_call = self.mock_tracer.start_span.call_args_list[0] + self.assertEqual(first_call[0][0], "parent_span") # name + self.assertEqual(first_call[1]["context"], None) # no parent context + # Verify start_time is a reasonable absolute timestamp + start_time_ns = first_call[1]["start_time"] + self.assertIsInstance(start_time_ns, int) + self.assertGreater( + start_time_ns, 1e15 + ) # Should be realistic Unix timestamp in ns + + # verify child span created with parent context + second_call = self.mock_tracer.start_span.call_args_list[1] + self.assertEqual(second_call[0][0], "child_span") # name + self.assertEqual( + second_call[1]["context"], "parent_context" + ) # parent context + # Verify child start_time is also a reasonable absolute timestamp + child_start_time_ns = second_call[1]["start_time"] + self.assertIsInstance(child_start_time_ns, int) + self.assertGreater( + child_start_time_ns, 1e15 + ) # Should be realistic Unix timestamp in ns + + # verify parent context was set correctly + mock_set_span_in_context.assert_called_once_with(parent_mock_span) + + # verify both spans ended with reasonable times + parent_mock_span.end.assert_called_once() + child_mock_span.end.assert_called_once() + parent_end_time = parent_mock_span.end.call_args[1]["end_time"] + child_end_time = child_mock_span.end.call_args[1]["end_time"] + self.assertIsInstance(parent_end_time, int) + self.assertIsInstance(child_end_time, int) + self.assertGreater(parent_end_time, 1e15) + self.assertGreater(child_end_time, 1e15) + + def test_transform_async(self): + """Test async transform functionality.""" + + async def run_test(): + mock_span = MagicMock() + self.mock_tracer.start_span.return_value = mock_span + + interaction_log = InteractionLog( + id="test_id", + activated_rails=[], + events=[], + trace=[ + SpanLegacy( + name="test_span", + span_id="span_1", + parent_id=None, + start_time=1234567890.5, + end_time=1234567891.5, + duration=1.0, + metrics={"key": 123}, + ) + ], + ) + + await self.adapter.transform_async(interaction_log) + + call_args = self.mock_tracer.start_span.call_args + self.assertEqual(call_args[0][0], "test_span") + self.assertEqual(call_args[1]["context"], None) + # Verify start_time is reasonable + self.assertIsInstance(call_args[1]["start_time"], int) + self.assertGreater(call_args[1]["start_time"], 1e15) + + mock_span.set_attribute.assert_any_call("key", 123) + # The adapter no longer sets intrinsic IDs as attributes + # (span_id, trace_id, duration are intrinsic to OTel spans) + mock_span.end.assert_called_once() + self.assertIn("end_time", mock_span.end.call_args[1]) + self.assertIsInstance(mock_span.end.call_args[1]["end_time"], int) + + asyncio.run(run_test()) + + def test_transform_async_with_empty_trace(self): + """Test async transform with empty trace.""" + + async def run_test(): + interaction_log = InteractionLog( + id="test_id", + activated_rails=[], + events=[], + trace=[], + ) + + await self.adapter.transform_async(interaction_log) + + self.mock_tracer.start_span.assert_not_called() + + asyncio.run(run_test()) + + def test_transform_async_with_tracer_failure(self): + """Test async transform when tracer fails.""" + self.mock_tracer.start_span.side_effect = Exception("Tracer failure") + + async def run_test(): + interaction_log = InteractionLog( + id="test_id", + activated_rails=[], + events=[], + trace=[ + SpanLegacy( + name="test_span", + span_id="span_1", + parent_id=None, + start_time=1234567890.0, + end_time=1234567891.0, + duration=1.0, + metrics={"key": 123}, + ) + ], + ) + + with self.assertRaises(Exception) as context: + await self.adapter.transform_async(interaction_log) + + self.assertIn("Tracer failure", str(context.exception)) + + asyncio.run(run_test()) + + def test_no_op_tracer_provider_warning(self): + """Test that a warning is issued when NoOpTracerProvider is detected.""" + + with patch("opentelemetry.trace.get_tracer_provider") as mock_get_provider: + mock_get_provider.return_value = NoOpTracerProvider() + + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter("always") + + _adapter = OpenTelemetryAdapter() + + self.assertEqual(len(w), 1) + self.assertTrue(issubclass(w[0].category, UserWarning)) + self.assertIn( + "No OpenTelemetry TracerProvider configured", str(w[0].message) + ) + self.assertIn("Traces will not be exported", str(w[0].message)) + + def test_no_warnings_with_proper_configuration(self): + """Test that no warnings are issued when properly configured.""" + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter("always") + + # adapter without deprecated parameters + _adapter = OpenTelemetryAdapter(service_name="test_service") + + # no warnings is issued + self.assertEqual(len(w), 0) + + def test_v1_spans_unique_timestamps_regression(self): + """Regression test: V1 spans should have unique timestamps. + + This test ensures the timestamp bug is fixed for V1 spans. + With the bug, all spans would have the same end_time_ns. + """ + created_spans = [] + + def track_span(*args, **kwargs): + span = MagicMock() + created_spans.append(span) + return span + + self.mock_tracer.start_span.side_effect = track_span + + # Create multiple V1 spans with different end times + spans = [] + for i in range(5): + spans.append( + SpanLegacy( + name=f"v1_span_{i}", + span_id=str(i), + start_time=float(i * 0.1), # 0, 0.1, 0.2, 0.3, 0.4 + end_time=float(0.5 + i * 0.2), # 0.5, 0.7, 0.9, 1.1, 1.3 + duration=float(0.5 + i * 0.2 - i * 0.1), + metrics={"index": i}, + ) + ) + + interaction_log = InteractionLog( + id="v1_regression_test", + activated_rails=[], + events=[], + trace=spans, + ) + + # Use fixed time for predictable results + import time + + with patch("time.time_ns", return_value=8000000000_000_000_000): + self.adapter.transform(interaction_log) + + # Extract all end times + end_times = [] + for span_mock in created_spans: + end_time = span_mock.end.call_args[1]["end_time"] + end_times.append(end_time) + + # CRITICAL: All end times MUST be different + unique_end_times = set(end_times) + self.assertEqual( + len(unique_end_times), + 5, + f"REGRESSION DETECTED: All V1 span end times should be unique! " + f"Got {len(unique_end_times)} unique values from {end_times}. " + f"The timestamp calculation bug has regressed.", + ) + + # Verify expected values + base_ns = 8000000000_000_000_000 + expected_end_times = [ + base_ns + int(0.5 * 1_000_000_000), + base_ns + int(0.7 * 1_000_000_000), + base_ns + int(0.9 * 1_000_000_000), + base_ns + int(1.1 * 1_000_000_000), + base_ns + int(1.3 * 1_000_000_000), + ] + + self.assertEqual(end_times, expected_end_times) diff --git a/tests/tracing/adapters/test_opentelemetry_v2.py b/tests/tracing/adapters/test_opentelemetry_v2.py new file mode 100644 index 000000000..fae39b129 --- /dev/null +++ b/tests/tracing/adapters/test_opentelemetry_v2.py @@ -0,0 +1,519 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +from unittest.mock import MagicMock, patch + +from nemoguardrails.tracing import ( + InteractionLog, + SpanEvent, + SpanLegacy, + SpanOpentelemetry, +) +from nemoguardrails.tracing.adapters.opentelemetry import OpenTelemetryAdapter +from nemoguardrails.tracing.spans import InteractionSpan, LLMSpan + + +class TestOpenTelemetryAdapterV2(unittest.TestCase): + """Test OpenTelemetryAdapter handling of v2 spans.""" + + def setUp(self): + """Set up test fixtures.""" + # Mock the tracer + self.mock_tracer = MagicMock() + self.mock_tracer_provider = MagicMock() + self.mock_tracer_provider.get_tracer.return_value = self.mock_tracer + + # Patch trace.get_tracer_provider + patcher = patch("opentelemetry.trace.get_tracer_provider") + self.mock_get_tracer_provider = patcher.start() + self.mock_get_tracer_provider.return_value = self.mock_tracer_provider + self.addCleanup(patcher.stop) + + self.adapter = OpenTelemetryAdapter() + + def test_v1_span_compatibility(self): + """Test that v1 spans still work correctly.""" + mock_span = MagicMock() + self.mock_tracer.start_span.return_value = mock_span + + v1_span = SpanLegacy( + name="test_v1", + span_id="v1_123", + start_time=0.0, + end_time=1.0, + duration=1.0, + metrics={"metric1": 42}, + ) + + interaction_log = InteractionLog( + id="test_v1_log", activated_rails=[], events=[], trace=[v1_span] + ) + + self.adapter.transform(interaction_log) + + # Verify span was created + self.mock_tracer.start_span.assert_called_once() + + # Verify metrics were set as attributes without prefix + mock_span.set_attribute.assert_any_call("metric1", 42) + + # Should not try to add events + mock_span.add_event.assert_not_called() + + def test_v2_span_attributes(self): + """Test that v2 span attributes are properly handled.""" + mock_span = MagicMock() + self.mock_tracer.start_span.return_value = mock_span + + from nemoguardrails.tracing.spans import LLMSpan + + v2_span = LLMSpan( + name="LLM: gpt-4", + span_id="v2_123", + start_time=0.0, + end_time=2.0, + duration=2.0, + provider_name="openai", + request_model="gpt-4", + response_model="gpt-4", + operation_name="chat.completions", + usage_total_tokens=150, + custom_attributes={ + "rail.decisions": ["continue", "allow"], # List attribute in custom + }, + ) + + interaction_log = InteractionLog( + id="test_v2_log", activated_rails=[], events=[], trace=[v2_span] + ) + + self.adapter.transform(interaction_log) + + # Verify OpenTelemetry attributes were set + mock_span.set_attribute.assert_any_call("gen_ai.provider.name", "openai") + mock_span.set_attribute.assert_any_call("gen_ai.request.model", "gpt-4") + mock_span.set_attribute.assert_any_call("gen_ai.usage.total_tokens", 150) + + # Verify list was passed directly + # Note: OTel Python SDK automatically converts lists to strings + mock_span.set_attribute.assert_any_call("rail.decisions", ["continue", "allow"]) + + def test_v2_span_events(self): + """Test that v2 span events are properly added.""" + mock_span = MagicMock() + self.mock_tracer.start_span.return_value = mock_span + + events = [ + SpanEvent( + name="gen_ai.content.prompt", + timestamp=0.5, + body={"content": "What is AI?"}, + ), + SpanEvent( + name="gen_ai.content.completion", + timestamp=1.5, + body={"content": "AI stands for Artificial Intelligence..."}, + ), + SpanEvent( + name="gen_ai.choice", + timestamp=1.6, + body={"finish_reason": "stop", "index": 0}, + ), + ] + + v2_span = LLMSpan( + name="LLM: gpt-4", + span_id="v2_events", + start_time=0.0, + end_time=2.0, + duration=2.0, + provider_name="openai", + request_model="gpt-4", + response_model="gpt-4", + operation_name="chat.completions", + events=events, + ) + + interaction_log = InteractionLog( + id="test_events", activated_rails=[], events=[], trace=[v2_span] + ) + + self.adapter.transform(interaction_log) + + # Verify events were added + self.assertEqual(mock_span.add_event.call_count, 3) + + # Check first event (prompt) + call_args = mock_span.add_event.call_args_list[0] + self.assertEqual(call_args[1]["name"], "gen_ai.content.prompt") + # In new implementation, body content is merged directly into attributes + self.assertIn("content", call_args[1]["attributes"]) + self.assertEqual(call_args[1]["attributes"]["content"], "What is AI?") + + # Check choice event has finish reason + call_args = mock_span.add_event.call_args_list[2] + self.assertEqual(call_args[1]["name"], "gen_ai.choice") + # In new implementation, body fields are merged directly into attributes + self.assertIn("finish_reason", call_args[1]["attributes"]) + + def test_v2_span_metrics(self): + """Test that v2 span token usage is properly recorded as attributes.""" + mock_span = MagicMock() + self.mock_tracer.start_span.return_value = mock_span + + # In the new implementation, token usage is in attributes, not otel_metrics + v2_span = LLMSpan( + name="completion gpt-4", # Following new naming convention + span_id="v2_metrics", + start_time=0.0, + end_time=2.0, + duration=2.0, + provider_name="openai", + request_model="gpt-4", + response_model="gpt-4", + operation_name="completion", + usage_input_tokens=50, + usage_output_tokens=100, + usage_total_tokens=150, + ) + + interaction_log = InteractionLog( + id="test_metrics", activated_rails=[], events=[], trace=[v2_span] + ) + + self.adapter.transform(interaction_log) + + # Verify token usage is recorded as standard attributes per OpenTelemetry GenAI conventions + mock_span.set_attribute.assert_any_call("gen_ai.usage.input_tokens", 50) + mock_span.set_attribute.assert_any_call("gen_ai.usage.output_tokens", 100) + mock_span.set_attribute.assert_any_call("gen_ai.usage.total_tokens", 150) + mock_span.set_attribute.assert_any_call("gen_ai.provider.name", "openai") + mock_span.set_attribute.assert_any_call("gen_ai.request.model", "gpt-4") + + def test_mixed_v1_v2_spans(self): + """Test handling of mixed v1 and v2 spans in the same trace.""" + mock_span = MagicMock() + self.mock_tracer.start_span.return_value = mock_span + + v1_span = SpanLegacy( + name="action: check_input", + span_id="v1_span", + start_time=0.0, + end_time=0.5, + duration=0.5, + metrics={"action_total": 1}, # Will be set as action_total (no prefix) + ) + + v2_span = LLMSpan( + name="LLM: gpt-4", + span_id="v2_span", + parent_id="v1_span", + start_time=0.1, + end_time=0.4, + duration=0.3, + provider_name="openai", + request_model="gpt-4", + response_model="gpt-4", + operation_name="chat.completions", + events=[ + SpanEvent( + name="gen_ai.content.prompt", + timestamp=0.1, + body={"content": "test"}, + ) + ], + ) + + interaction_log = InteractionLog( + id="test_mixed", activated_rails=[], events=[], trace=[v1_span, v2_span] + ) + + self.adapter.transform(interaction_log) + + # Verify both spans were created + self.assertEqual(self.mock_tracer.start_span.call_count, 2) + + # Verify v2 span had events added (v1 should not) + # Only the second span should have events + event_calls = [call for call in mock_span.add_event.call_args_list] + self.assertEqual(len(event_calls), 1) # Only v2 span has events + + def test_event_content_passthrough(self): + """Test that event content is passed through as-is by the adapter.""" + mock_span = MagicMock() + self.mock_tracer.start_span.return_value = mock_span + + from nemoguardrails.tracing.spans import InteractionSpan + + long_content = "x" * 2000 + + v2_span = InteractionSpan( + name="test", + span_id="truncate_test", + start_time=0.0, + end_time=1.0, + duration=1.0, + events=[ + SpanEvent( + name="gen_ai.content.prompt", + timestamp=0.5, + body={"content": long_content}, + ) + ], + ) + + interaction_log = InteractionLog( + id="test_truncate", activated_rails=[], events=[], trace=[v2_span] + ) + + self.adapter.transform(interaction_log) + + # Verify content was passed through as-is + # The adapter is now a thin bridge and doesn't truncate + # Truncation should be done by the extractor if needed + call_args = mock_span.add_event.call_args_list[0] + content = call_args[1]["attributes"]["content"] + self.assertEqual(len(content), 2000) # Full content passed through + self.assertEqual(content, "x" * 2000) + + def test_unique_span_timestamps_regression_fix(self): + """Test that each span gets unique timestamps - regression test for timestamp bug. + + This test would FAIL with the old buggy logic where all end_time_ns were identical. + It PASSES with the correct logic where each span has unique timestamps. + """ + created_spans = [] + + def track_span(*args, **kwargs): + span = MagicMock() + created_spans.append(span) + return span + + self.mock_tracer.start_span.side_effect = track_span + + # Create multiple V2 spans with different timings + from nemoguardrails.tracing.spans import ActionSpan, RailSpan + + spans = [ + InteractionSpan( + name="span_1", + span_id="1", + start_time=0.0, # Starts at trace beginning + end_time=1.0, # Ends after 1 second + duration=1.0, + custom_attributes={"type": "first"}, + ), + RailSpan( + name="span_2", + span_id="2", + start_time=0.5, # Starts 0.5s after trace start + end_time=2.0, # Ends after 2 seconds + duration=1.5, + rail_type="input", + rail_name="test_rail", + custom_attributes={"type": "second"}, + ), + ActionSpan( + name="span_3", + span_id="3", + start_time=1.0, # Starts 1s after trace start + end_time=1.5, # Ends after 1.5 seconds + duration=0.5, + action_name="test_action", + custom_attributes={"type": "third"}, + ), + ] + + interaction_log = InteractionLog( + id="test_timestamps", + activated_rails=[], + events=[], + trace=spans, + ) + + # Use a fixed base time for predictable results + import time + + with unittest.mock.patch("time.time_ns", return_value=1700000000_000_000_000): + self.adapter.transform(interaction_log) + + # Verify that each span was created + self.assertEqual(len(created_spans), 3) + + # Extract the end times for each span + end_times = [] + for span_mock in created_spans: + end_call = span_mock.end.call_args + end_times.append(end_call[1]["end_time"]) + + # CRITICAL TEST: All end times should be DIFFERENT + # With the bug, all end_times would be identical (base_time_ns) + unique_end_times = set(end_times) + self.assertEqual( + len(unique_end_times), + 3, + f"End times should be unique but got: {end_times}. " + f"This indicates the timestamp calculation bug has regressed!", + ) + + # Verify correct absolute timestamps + base_ns = 1700000000_000_000_000 + expected_end_times = [ + base_ns + 1_000_000_000, # span_1 ends at 1s + base_ns + 2_000_000_000, # span_2 ends at 2s + base_ns + 1_500_000_000, # span_3 ends at 1.5s + ] + + self.assertEqual(end_times, expected_end_times) + + def test_multiple_interactions_different_base_times(self): + """Test that multiple interactions get different base times.""" + mock_span = MagicMock() + self.mock_tracer.start_span.return_value = mock_span + + span1 = InteractionSpan( + name="span1", + span_id="1", + start_time=0.0, + end_time=1.0, + duration=1.0, + custom_attributes={"interaction": "first"}, + ) + + span2 = InteractionSpan( + name="span2", + span_id="2", + start_time=0.0, + end_time=1.0, + duration=1.0, + custom_attributes={"interaction": "second"}, + ) + + log1 = InteractionLog(id="log1", activated_rails=[], events=[], trace=[span1]) + log2 = InteractionLog(id="log2", activated_rails=[], events=[], trace=[span2]) + + # First interaction + import time + + with unittest.mock.patch("time.time_ns", return_value=1000000000_000_000_000): + self.adapter.transform(log1) + + first_start = self.mock_tracer.start_span.call_args[1]["start_time"] + + # Reset mock + self.mock_tracer.start_span.reset_mock() + + # Second interaction (100ms later) + with unittest.mock.patch("time.time_ns", return_value=1000000100_000_000_000): + self.adapter.transform(log2) + + second_start = self.mock_tracer.start_span.call_args[1]["start_time"] + + # The two interactions should have different base times + self.assertNotEqual(first_start, second_start) + self.assertEqual( + second_start - first_start, 100_000_000_000 + ) # 100ms difference + + def test_uses_actual_interaction_start_time_from_rails(self): + """Test that adapter uses the actual start time from activated rails, not current time.""" + import time + + from nemoguardrails.rails.llm.options import ActivatedRail + + one_hour_ago = time.time() - 3600 + + rail = ActivatedRail( + type="input", + name="test_rail", + started_at=one_hour_ago, + finished_at=one_hour_ago + 2.0, + duration=2.0, + ) + + span = InteractionSpan( + name="test_span", + span_id="test_123", + start_time=0.0, + end_time=1.0, + duration=1.0, + operation_name="test", + service_name="test_service", + ) + + interaction_log = InteractionLog( + id="test_actual_time", activated_rails=[rail], events=[], trace=[span] + ) + + mock_span = MagicMock() + self.mock_tracer.start_span.return_value = mock_span + + self.adapter.transform(interaction_log) + + call_args = self.mock_tracer.start_span.call_args + actual_start_time_ns = call_args[1]["start_time"] + + expected_start_time_ns = int(one_hour_ago * 1_000_000_000) + self.assertEqual( + actual_start_time_ns, + expected_start_time_ns, + "Should use the actual interaction start time from rails, not current time", + ) + + end_call = mock_span.end.call_args + actual_end_time_ns = end_call[1]["end_time"] + expected_end_time_ns = expected_start_time_ns + 1_000_000_000 + + self.assertEqual( + actual_end_time_ns, + expected_end_time_ns, + "End time should be calculated relative to the actual interaction start", + ) + + def test_fallback_when_no_rail_timestamp(self): + """Test that adapter falls back to current time when rails have no timestamp.""" + span = InteractionSpan( + name="test_span", + span_id="test_no_rails", + start_time=0.0, + end_time=1.0, + duration=1.0, + operation_name="test", + service_name="test_service", + ) + + interaction_log = InteractionLog( + id="test_no_rails", activated_rails=[], events=[], trace=[span] + ) + + mock_span = MagicMock() + self.mock_tracer.start_span.return_value = mock_span + + with patch("time.time_ns", return_value=9999999999_000_000_000): + self.adapter.transform(interaction_log) + + call_args = self.mock_tracer.start_span.call_args + actual_start_time_ns = call_args[1]["start_time"] + + self.assertEqual( + actual_start_time_ns, + 9999999999_000_000_000, + "Should fall back to current time when no rail timestamps available", + ) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/tracing/spans/test_span_extractors.py b/tests/tracing/spans/test_span_extractors.py new file mode 100644 index 000000000..9c9c85c05 --- /dev/null +++ b/tests/tracing/spans/test_span_extractors.py @@ -0,0 +1,239 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import time + +import pytest + +from nemoguardrails.logging.explain import LLMCallInfo +from nemoguardrails.rails.llm.options import ActivatedRail, ExecutedAction +from nemoguardrails.tracing import ( + SpanExtractorV1, + SpanExtractorV2, + SpanLegacy, + create_span_extractor, +) +from nemoguardrails.tracing.spans import LLMSpan, is_opentelemetry_span + + +class TestSpanExtractors: + """Test span extraction for legacy and OpenTelemetry formats.""" + + @pytest.fixture + def test_data(self): + """Set up test data for span extraction.""" + llm_call = LLMCallInfo( + task="generate_user_intent", + prompt="What is the weather?", + completion="I cannot provide weather information.", + llm_model_name="gpt-4", + llm_provider_name="openai", + prompt_tokens=10, + completion_tokens=20, + total_tokens=30, + started_at=time.time(), + finished_at=time.time() + 1.0, + duration=1.0, + ) + + action = ExecutedAction( + action_name="generate_user_intent", + action_params={"temperature": 0.7}, + llm_calls=[llm_call], + started_at=time.time(), + finished_at=time.time() + 1.5, + duration=1.5, + ) + + rail = ActivatedRail( + type="input", + name="check_jailbreak", + decisions=["continue"], + executed_actions=[action], + stop=False, + started_at=time.time(), + finished_at=time.time() + 2.0, + duration=2.0, + ) + + return [rail] + + def test_span_extractor_legacy_format(self, test_data): + """Test legacy format span extractor produces legacy spans.""" + extractor = SpanExtractorV1() + spans = extractor.extract_spans(test_data) + + assert len(spans) > 0 + + # All spans should be legacy format + for span in spans: + assert isinstance(span, SpanLegacy) + assert not hasattr(span, "attributes") + + span_names = [s.name for s in spans] + assert "interaction" in span_names + assert "rail: check_jailbreak" in span_names + assert "action: generate_user_intent" in span_names + assert "LLM: gpt-4" in span_names + + def test_span_extractor_opentelemetry_attributes(self, test_data): + """Test OpenTelemetry span extractor adds semantic convention attributes.""" + extractor = SpanExtractorV2() + spans = extractor.extract_spans(test_data) + + # All spans should be typed spans + for span in spans: + assert is_opentelemetry_span(span) + + # LLM spans follow OpenTelemetry convention: "{operation} {model}" + llm_span = next(s for s in spans if s.name == "generate_user_intent gpt-4") + assert isinstance(llm_span, LLMSpan) + + assert llm_span.provider_name == "openai" + assert llm_span.request_model == "gpt-4" + assert llm_span.usage_input_tokens == 10 + + attributes = llm_span.to_otel_attributes() + assert "gen_ai.provider.name" in attributes + assert attributes["gen_ai.provider.name"] == "openai" + assert attributes["gen_ai.request.model"] == "gpt-4" + assert "gen_ai.usage.input_tokens" in attributes + assert attributes["gen_ai.usage.input_tokens"] == 10 + + def test_span_extractor_opentelemetry_events(self, test_data): + """Test OpenTelemetry span extractor adds events.""" + extractor = SpanExtractorV2(enable_content_capture=True) + spans = extractor.extract_spans(test_data) + + # LLM spans follow OpenTelemetry convention + llm_span = next(s for s in spans if s.name == "generate_user_intent gpt-4") + assert len(llm_span.events) > 0 + + event_names = [e.name for e in llm_span.events] + # Currently uses deprecated content events (TODO: update to newer format) + assert "gen_ai.content.prompt" in event_names + assert "gen_ai.content.completion" in event_names + + # Check event content (only present when content capture is enabled) + user_message_event = next( + e for e in llm_span.events if e.name == "gen_ai.content.prompt" + ) + assert user_message_event.body["content"] == "What is the weather?" + + def test_span_extractor_opentelemetry_metrics(self, test_data): + """Test OpenTelemetry span extractor adds metrics as attributes.""" + extractor = SpanExtractorV2() + spans = extractor.extract_spans(test_data) + + llm_span = next(s for s in spans if s.name == "generate_user_intent gpt-4") + assert isinstance(llm_span, LLMSpan) + + assert llm_span.usage_input_tokens == 10 + assert llm_span.usage_output_tokens == 20 + assert llm_span.usage_total_tokens == 30 + + attributes = llm_span.to_otel_attributes() + assert "gen_ai.usage.input_tokens" in attributes + assert "gen_ai.usage.output_tokens" in attributes + assert "gen_ai.usage.total_tokens" in attributes + + assert attributes["gen_ai.usage.input_tokens"] == 10 + assert attributes["gen_ai.usage.output_tokens"] == 20 + assert attributes["gen_ai.usage.total_tokens"] == 30 + assert attributes["gen_ai.provider.name"] == "openai" + + def test_span_extractor_conversation_events(self, test_data): + """Test OpenTelemetry span extractor extracts conversation events from internal events.""" + internal_events = [ + {"type": "UtteranceUserActionFinished", "final_transcript": "Hello bot"}, + {"type": "StartUtteranceBotAction", "script": "Hello! How can I help?"}, + {"type": "SystemMessage", "content": "You are a helpful assistant"}, + ] + + extractor = SpanExtractorV2(events=internal_events) + spans = extractor.extract_spans(test_data) + + interaction_span = next(s for s in spans if s.name == "guardrails.request") + assert len(interaction_span.events) > 0 + + event_names = [e.name for e in interaction_span.events] + assert "guardrails.utterance.user.finished" in event_names + assert "guardrails.utterance.bot.started" in event_names + + user_event = next( + e + for e in interaction_span.events + if e.name == "guardrails.utterance.user.finished" + ) + assert "type" in user_event.body + # Content not included by default (privacy) + assert "final_transcript" not in user_event.body + + +class TestSpanFormatConfiguration: + """Test span format configuration and factory.""" + + def test_create_span_extractor_legacy(self): + """Test creating legacy format span extractor.""" + extractor = create_span_extractor(span_format="legacy") + assert isinstance(extractor, SpanExtractorV1) + + def test_create_span_extractor_opentelemetry(self): + """Test creating OpenTelemetry format span extractor.""" + extractor = create_span_extractor(span_format="opentelemetry") + assert isinstance(extractor, SpanExtractorV2) + + def test_create_invalid_format_raises_error(self): + """Test invalid span format raises ValueError.""" + with pytest.raises(ValueError) as exc_info: + create_span_extractor(span_format="invalid") + assert "Invalid span format" in str(exc_info.value) + + def test_opentelemetry_extractor_with_events(self): + """Test OpenTelemetry extractor can be created with events.""" + events = [{"type": "UserMessage", "text": "test"}] + extractor = create_span_extractor( + span_format="opentelemetry", events=events, enable_content_capture=False + ) + + assert isinstance(extractor, SpanExtractorV2) + assert extractor.internal_events == events + + def test_legacy_extractor_ignores_extra_params(self): + """Test legacy extractor ignores OpenTelemetry-specific parameters.""" + # Legacy extractor should ignore events and enable_content_capture + extractor = create_span_extractor( + span_format="legacy", events=[{"type": "test"}], enable_content_capture=True + ) + + assert isinstance(extractor, SpanExtractorV1) + # V1 extractor doesn't have these attributes + assert not hasattr(extractor, "internal_events") + assert not hasattr(extractor, "enable_content_capture") + + @pytest.mark.parametrize( + "format_str,expected_class", + [ + ("legacy", SpanExtractorV1), + ("LEGACY", SpanExtractorV1), + ("opentelemetry", SpanExtractorV2), + ("OPENTELEMETRY", SpanExtractorV2), + ("OpenTelemetry", SpanExtractorV2), + ], + ) + def test_case_insensitive_format(self, format_str, expected_class): + """Test that span format is case-insensitive.""" + extractor = create_span_extractor(span_format=format_str) + assert isinstance(extractor, expected_class) diff --git a/tests/tracing/spans/test_span_format_enum.py b/tests/tracing/spans/test_span_format_enum.py new file mode 100644 index 000000000..174bbd9fb --- /dev/null +++ b/tests/tracing/spans/test_span_format_enum.py @@ -0,0 +1,209 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json +from typing import Any + +import pytest + +from nemoguardrails.tracing.span_format import ( + SpanFormat, + SpanFormatType, + validate_span_format, +) + + +class TestSpanFormat: + """Test cases for SpanFormat enum.""" + + def test_enum_values(self): + """Test that enum has expected values.""" + assert SpanFormat.LEGACY.value == "legacy" + assert SpanFormat.OPENTELEMETRY.value == "opentelemetry" + + def test_string_inheritance(self): + """Test that SpanFormat inherits from str.""" + assert isinstance(SpanFormat.LEGACY, str) + assert isinstance(SpanFormat.OPENTELEMETRY, str) + + def test_string_comparison(self): + """Test direct string comparison works.""" + assert SpanFormat.LEGACY == "legacy" + assert SpanFormat.OPENTELEMETRY == "opentelemetry" + assert SpanFormat.LEGACY != "opentelemetry" + + def test_json_serialization(self): + """Test that enum values can be JSON serialized.""" + data = {"format": SpanFormat.LEGACY} + json_str = json.dumps(data) + assert '"format": "legacy"' in json_str + + parsed = json.loads(json_str) + assert parsed["format"] == "legacy" + + def test_str_method(self): + """Test __str__ method returns value.""" + assert str(SpanFormat.LEGACY) == "legacy" + assert str(SpanFormat.OPENTELEMETRY) == "opentelemetry" + + def test_from_string_valid_values(self): + """Test from_string with valid values.""" + assert SpanFormat.from_string("legacy") == SpanFormat.LEGACY + assert SpanFormat.from_string("opentelemetry") == SpanFormat.OPENTELEMETRY + + assert SpanFormat.from_string("LEGACY") == SpanFormat.LEGACY + assert SpanFormat.from_string("OpenTelemetry") == SpanFormat.OPENTELEMETRY + assert SpanFormat.from_string("OPENTELEMETRY") == SpanFormat.OPENTELEMETRY + + def test_from_string_invalid_value(self): + """Test from_string with invalid value raises ValueError.""" + with pytest.raises(ValueError) as exc_info: + SpanFormat.from_string("invalid") + + error_msg = str(exc_info.value) + assert "Invalid span format: 'invalid'" in error_msg + assert "Valid formats are: legacy, opentelemetry" in error_msg + + def test_from_string_empty_value(self): + """Test from_string with empty string raises ValueError.""" + with pytest.raises(ValueError): + SpanFormat.from_string("") + + def test_from_string_none_value(self): + """Test from_string with None raises appropriate error.""" + with pytest.raises(AttributeError): + SpanFormat.from_string(None) + + +class TestValidateSpanFormat: + """Test cases for validate_span_format function.""" + + def test_validate_span_format_enum(self): + """Test validation with SpanFormat enum.""" + result = validate_span_format(SpanFormat.LEGACY) + assert result == SpanFormat.LEGACY + assert isinstance(result, SpanFormat) + + result = validate_span_format(SpanFormat.OPENTELEMETRY) + assert result == SpanFormat.OPENTELEMETRY + assert isinstance(result, SpanFormat) + + def test_validate_span_format_string(self): + """Test validation with string values.""" + result = validate_span_format("legacy") + assert result == SpanFormat.LEGACY + assert isinstance(result, SpanFormat) + + result = validate_span_format("opentelemetry") + assert result == SpanFormat.OPENTELEMETRY + assert isinstance(result, SpanFormat) + + result = validate_span_format("LEGACY") + assert result == SpanFormat.LEGACY + + def test_validate_span_format_invalid_string(self): + """Test validation with invalid string raises ValueError.""" + with pytest.raises(ValueError) as exc_info: + validate_span_format("invalid") + + error_msg = str(exc_info.value) + assert "Invalid span format: 'invalid'" in error_msg + + def test_validate_span_format_invalid_type(self): + """Test validation with invalid type raises TypeError.""" + with pytest.raises(TypeError) as exc_info: + validate_span_format(123) + + error_msg = str(exc_info.value) + assert "Span format must be a string or SpanFormat enum" in error_msg + assert "got " in error_msg + + def test_validate_span_format_none(self): + """Test validation with None raises TypeError.""" + with pytest.raises(TypeError): + validate_span_format(None) + + def test_validate_span_format_list(self): + """Test validation with list raises TypeError.""" + with pytest.raises(TypeError): + validate_span_format(["legacy"]) + + def test_validate_span_format_dict(self): + """Test validation with dict raises TypeError.""" + with pytest.raises(TypeError): + validate_span_format({"format": "legacy"}) + + +class TestSpanFormatType: + """Test cases for SpanFormatType type alias.""" + + def test_type_alias_accepts_enum(self): + """Test that type alias accepts SpanFormat enum.""" + + def test_function(format_type: SpanFormatType) -> SpanFormat: + return validate_span_format(format_type) + + result = test_function(SpanFormat.LEGACY) + assert result == SpanFormat.LEGACY + + def test_type_alias_accepts_string(self): + """Test that type alias accepts string values.""" + + def test_function(format_type: SpanFormatType) -> SpanFormat: + return validate_span_format(format_type) + + result = test_function("legacy") + assert result == SpanFormat.LEGACY + + result = test_function("opentelemetry") + assert result == SpanFormat.OPENTELEMETRY + + +class TestSpanFormatIntegration: + """Integration tests for span format functionality.""" + + def test_config_usage_pattern(self): + """Test typical configuration usage pattern.""" + config_value = "opentelemetry" + format_enum = validate_span_format(config_value) + + if format_enum == SpanFormat.OPENTELEMETRY: + assert True # Expected path + else: + pytest.fail("Unexpected format") + + def test_function_parameter_pattern(self): + """Test typical function parameter usage pattern.""" + + def process_spans(span_format: SpanFormatType = SpanFormat.LEGACY): + validated_format = validate_span_format(span_format) + return validated_format + + result = process_spans() + assert result == SpanFormat.LEGACY + + result = process_spans("opentelemetry") + assert result == SpanFormat.OPENTELEMETRY + + result = process_spans(SpanFormat.OPENTELEMETRY) + assert result == SpanFormat.OPENTELEMETRY + + def test_all_enum_values_have_tests(self): + """Ensure all enum values are tested.""" + tested_values = {"legacy", "opentelemetry"} + actual_values = {format_enum.value for format_enum in SpanFormat} + assert ( + tested_values == actual_values + ), f"Missing tests for: {actual_values - tested_values}" diff --git a/tests/tracing/spans/test_span_models_and_extractors.py b/tests/tracing/spans/test_span_models_and_extractors.py new file mode 100644 index 000000000..ed6bebec3 --- /dev/null +++ b/tests/tracing/spans/test_span_models_and_extractors.py @@ -0,0 +1,273 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import time + +import pytest + +from nemoguardrails.logging.explain import LLMCallInfo +from nemoguardrails.rails.llm.options import ActivatedRail, ExecutedAction +from nemoguardrails.tracing import ( + SpanEvent, + SpanExtractorV1, + SpanExtractorV2, + SpanLegacy, + SpanOpentelemetry, + create_span_extractor, +) +from nemoguardrails.tracing.spans import LLMSpan, is_opentelemetry_span + + +class TestSpanModels: + def test_span_v1_creation(self): + span = SpanLegacy( + span_id="test-123", + name="test span", + start_time=0.0, + end_time=1.0, + duration=1.0, + metrics={"test_metric": 42}, + ) + + assert span.span_id == "test-123" + assert span.name == "test span" + assert span.duration == 1.0 + assert span.metrics["test_metric"] == 42 + + assert not hasattr(span, "attributes") + assert not hasattr(span, "events") + assert not hasattr(span, "otel_metrics") + + def test_span_v2_creation(self): + """Test creating a v2 span - typed spans with explicit fields.""" + from nemoguardrails.tracing.spans import LLMSpan + + event = SpanEvent( + name="gen_ai.content.prompt", timestamp=0.5, body={"content": "test prompt"} + ) + + # V2 spans are typed with explicit fields + span = LLMSpan( + span_id="test-456", + name="generate_user_intent gpt-4", + start_time=0.0, + end_time=2.0, + duration=2.0, + provider_name="openai", + request_model="gpt-4", + response_model="gpt-4", + operation_name="chat.completions", + usage_input_tokens=10, + usage_output_tokens=20, + usage_total_tokens=30, + events=[event], + ) + + assert span.span_id == "test-456" + assert span.provider_name == "openai" + assert span.request_model == "gpt-4" + assert span.usage_input_tokens == 10 + assert len(span.events) == 1 + assert span.events[0].name == "gen_ai.content.prompt" + + # Check that to_otel_attributes works + attributes = span.to_otel_attributes() + assert attributes["gen_ai.provider.name"] == "openai" + assert attributes["gen_ai.request.model"] == "gpt-4" + + assert not isinstance(span, SpanLegacy) + # Python 3.9 compatibility: cannot use isinstance with Union types + # SpanOpentelemetry is TypedSpan which is a Union, so check the actual type + assert isinstance(span, LLMSpan) + + # Note: V1 and V2 spans are now fundamentally different types + # V1 is a simple span model, V2 is typed spans with explicit fields + # No conversion between them is needed or supported + + +class TestSpanExtractors: + @pytest.fixture + def test_data(self): + llm_call = LLMCallInfo( + task="generate_user_intent", + prompt="What is the weather?", + completion="I cannot provide weather information.", + llm_model_name="gpt-4", + llm_provider_name="openai", + prompt_tokens=10, + completion_tokens=20, + total_tokens=30, + started_at=time.time(), + finished_at=time.time() + 1.0, + duration=1.0, + ) + + action = ExecutedAction( + action_name="generate_user_intent", + action_params={"temperature": 0.7}, + llm_calls=[llm_call], + started_at=time.time(), + finished_at=time.time() + 1.5, + duration=1.5, + ) + + rail = ActivatedRail( + type="input", + name="check_jailbreak", + decisions=["continue"], + executed_actions=[action], + stop=False, + started_at=time.time(), + finished_at=time.time() + 2.0, + duration=2.0, + ) + + activated_rails = [rail] + return { + "activated_rails": activated_rails, + "llm_call": llm_call, + "action": action, + "rail": rail, + } + + def test_span_extractor_v1(self, test_data): + extractor = SpanExtractorV1() + spans = extractor.extract_spans(test_data["activated_rails"]) + + assert len(spans) > 0 + + for span in spans: + assert isinstance(span, SpanLegacy) + assert not hasattr(span, "attributes") + + span_names = [s.name for s in spans] + assert "interaction" in span_names + assert "rail: check_jailbreak" in span_names + assert "action: generate_user_intent" in span_names + assert "LLM: gpt-4" in span_names + + def test_span_extractor_v2_attributes(self, test_data): + extractor = SpanExtractorV2() + spans = extractor.extract_spans(test_data["activated_rails"]) + + for span in spans: + # Now we expect typed spans + assert is_opentelemetry_span(span) + + # In V2, LLM spans follow OpenTelemetry convention: "{operation} {model}" + llm_span = next(s for s in spans if s.name == "generate_user_intent gpt-4") + assert isinstance(llm_span, LLMSpan) + + # For typed spans, check the fields directly + assert llm_span.provider_name == "openai" + assert llm_span.request_model == "gpt-4" + assert llm_span.usage_input_tokens == 10 + + # Also verify attributes conversion works + attributes = llm_span.to_otel_attributes() + assert "gen_ai.provider.name" in attributes + assert attributes["gen_ai.provider.name"] == "openai" + assert attributes["gen_ai.request.model"] == "gpt-4" + assert "gen_ai.usage.input_tokens" in attributes + assert attributes["gen_ai.usage.input_tokens"] == 10 + + def test_span_extractor_v2_events(self, test_data): + extractor = SpanExtractorV2(enable_content_capture=True) + spans = extractor.extract_spans(test_data["activated_rails"]) + + # In V2, LLM spans follow OpenTelemetry convention: "{operation} {model}" + llm_span = next(s for s in spans if s.name == "generate_user_intent gpt-4") + assert len(llm_span.events) > 0 + + event_names = [e.name for e in llm_span.events] + # V2 currently uses deprecated content events for simplicity (TODO: update to newer format) + assert "gen_ai.content.prompt" in event_names + assert "gen_ai.content.completion" in event_names + + # Check user message event content (only present when content capture is enabled) + user_message_event = next( + e for e in llm_span.events if e.name == "gen_ai.content.prompt" + ) + assert user_message_event.body["content"] == "What is the weather?" + + def test_span_extractor_v2_metrics(self, test_data): + extractor = SpanExtractorV2() + spans = extractor.extract_spans(test_data["activated_rails"]) + + # In V2, LLM spans follow OpenTelemetry convention: "{operation} {model}" + llm_span = next(s for s in spans if s.name == "generate_user_intent gpt-4") + assert isinstance(llm_span, LLMSpan) + + # Check typed fields + assert llm_span.usage_input_tokens == 10 + assert llm_span.usage_output_tokens == 20 + assert llm_span.usage_total_tokens == 30 + assert llm_span.provider_name == "openai" + + # Verify attributes conversion + attributes = llm_span.to_otel_attributes() + assert attributes["gen_ai.usage.total_tokens"] == 30 + assert attributes["gen_ai.provider.name"] == "openai" + + def test_span_extractor_v2_conversation_events(self, test_data): + internal_events = [ + {"type": "UtteranceUserActionFinished", "final_transcript": "Hello bot"}, + {"type": "StartUtteranceBotAction", "script": "Hello! How can I help?"}, + {"type": "SystemMessage", "content": "You are a helpful assistant"}, + ] + + # Test with content excluded by default (privacy compliant) + extractor = SpanExtractorV2(events=internal_events) + spans = extractor.extract_spans(test_data["activated_rails"]) + + interaction_span = next(s for s in spans if s.name == "guardrails.request") + assert len(interaction_span.events) > 0 + + event_names = [e.name for e in interaction_span.events] + # These are guardrails internal events, not OTel GenAI events + assert "guardrails.utterance.user.finished" in event_names + assert "guardrails.utterance.bot.started" in event_names + + user_event = next( + e + for e in interaction_span.events + if e.name == "guardrails.utterance.user.finished" + ) + # By default, content is NOT included (privacy compliant) + assert "type" in user_event.body + assert "final_transcript" not in user_event.body + + +class TestSpanVersionConfiguration: + def test_create_span_extractor_legacy(self): + extractor = create_span_extractor(span_format="legacy") + assert isinstance(extractor, SpanExtractorV1) + + def test_create_span_extractor_opentelemetry(self): + extractor = create_span_extractor(span_format="opentelemetry") + assert isinstance(extractor, SpanExtractorV2) + + def test_create_invalid_format(self): + with pytest.raises(ValueError, match="Invalid span format"): + create_span_extractor(span_format="invalid") + + def test_opentelemetry_extractor_with_events(self): + events = [{"type": "UserMessage", "text": "test"}] + extractor = create_span_extractor( + span_format="opentelemetry", events=events, enable_content_capture=False + ) + + assert isinstance(extractor, SpanExtractorV2) + assert extractor.internal_events == events diff --git a/tests/tracing/spans/test_span_v2_integration.py b/tests/tracing/spans/test_span_v2_integration.py new file mode 100644 index 000000000..e82becc91 --- /dev/null +++ b/tests/tracing/spans/test_span_v2_integration.py @@ -0,0 +1,161 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest + +from nemoguardrails import LLMRails, RailsConfig +from nemoguardrails.rails.llm.options import GenerationOptions +from nemoguardrails.tracing import SpanOpentelemetry, create_span_extractor +from nemoguardrails.tracing.spans import LLMSpan, is_opentelemetry_span +from tests.utils import FakeLLM + + +@pytest.fixture +def v2_config(): + return RailsConfig.from_content( + yaml_content=""" +models: + - type: main + engine: openai + model: gpt-4 + +tracing: + enabled: true + span_format: opentelemetry + adapters: [] +""" + ) + + +@pytest.fixture +def v1_config(): + return RailsConfig.from_content( + yaml_content=""" +models: + - type: main + engine: openai + model: gpt-4 + +tracing: + enabled: true + span_format: legacy + adapters: [] +""" + ) + + +@pytest.fixture +def default_config(): + return RailsConfig.from_content( + yaml_content=""" +models: + - type: main + engine: openai + model: gpt-4 + +tracing: + enabled: true + adapters: [] +""" + ) + + +def test_span_v2_configuration(v2_config): + assert v2_config.tracing.span_format == "opentelemetry" + + llm = FakeLLM(responses=["Hello! I'm here to help."]) + _rails = LLMRails(config=v2_config, llm=llm) + + extractor = create_span_extractor(span_format="opentelemetry") + assert extractor.__class__.__name__ == "SpanExtractorV2" + + +@pytest.mark.asyncio +async def test_v2_spans_generated_with_events(v2_config): + llm = FakeLLM(responses=[" express greeting", "Hello! How can I help you today?"]) + + rails = LLMRails(config=v2_config, llm=llm) + + options = GenerationOptions( + log={"activated_rails": True, "internal_events": True, "llm_calls": True} + ) + + response = await rails.generate_async( + messages=[{"role": "user", "content": "Hello!"}], options=options + ) + + assert response.response is not None + assert response.log is not None + + from nemoguardrails.tracing.interaction_types import ( + InteractionOutput, + extract_interaction_log, + ) + + interaction_output = InteractionOutput( + id="test", input="Hello!", output=response.response + ) + + interaction_log = extract_interaction_log(interaction_output, response.log) + + assert len(interaction_log.trace) > 0 + + for span in interaction_log.trace: + assert is_opentelemetry_span(span) + + interaction_span = next( + (s for s in interaction_log.trace if s.name == "guardrails.request"), None + ) + assert interaction_span is not None + + llm_spans = [s for s in interaction_log.trace if isinstance(s, LLMSpan)] + assert len(llm_spans) > 0 + + for llm_span in llm_spans: + assert hasattr(llm_span, "provider_name") + assert hasattr(llm_span, "request_model") + + attrs = llm_span.to_otel_attributes() + assert "gen_ai.provider.name" in attrs + assert "gen_ai.request.model" in attrs + + assert hasattr(llm_span, "events") + assert len(llm_span.events) > 0 + + +def test_v1_backward_compatibility(v1_config): + assert v1_config.tracing.span_format == "legacy" + + llm = FakeLLM(responses=["Hello!"]) + _rails = LLMRails(config=v1_config, llm=llm) + + extractor = create_span_extractor(span_format="legacy") + assert extractor.__class__.__name__ == "SpanExtractorV1" + + +def test_default_span_format(default_config): + assert default_config.tracing.span_format == "opentelemetry" + + +def test_span_format_configuration_direct(): + extractor_legacy = create_span_extractor(span_format="legacy") + assert extractor_legacy.__class__.__name__ == "SpanExtractorV1" + + extractor_otel = create_span_extractor(span_format="opentelemetry") + assert extractor_otel.__class__.__name__ == "SpanExtractorV2" + + with pytest.raises(ValueError) as exc_info: + create_span_extractor(span_format="invalid") + assert "Invalid span format" in str(exc_info.value) diff --git a/tests/tracing/spans/test_span_v2_otel_semantics.py b/tests/tracing/spans/test_span_v2_otel_semantics.py new file mode 100644 index 000000000..41a1fb781 --- /dev/null +++ b/tests/tracing/spans/test_span_v2_otel_semantics.py @@ -0,0 +1,604 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for SpanOpentelemetry with complete OpenTelemetry semantic convention attributes.""" + +from unittest.mock import MagicMock, patch + +import pytest + +from nemoguardrails.rails.llm.options import ActivatedRail, ExecutedAction, LLMCallInfo +from nemoguardrails.tracing.constants import ( + CommonAttributes, + EventNames, + GenAIAttributes, + GuardrailsAttributes, + OperationNames, + SpanKind, + SpanNames, +) +from nemoguardrails.tracing.span_extractors import SpanExtractorV2 +from nemoguardrails.tracing.spans import ActionSpan, InteractionSpan, LLMSpan, RailSpan + + +class TestSpanOpentelemetryOTelAttributes: + """Test that SpanV2 contains complete OTel semantic convention attributes.""" + + def test_interaction_span_has_complete_attributes(self): + """Test that interaction span has all required OTel attributes.""" + rail = ActivatedRail( + type="input", + name="check_jailbreak", + started_at=1.0, + finished_at=2.0, + duration=1.0, + executed_actions=[], + ) + + extractor = SpanExtractorV2() + spans = extractor.extract_spans([rail]) + + interaction_span = next(s for s in spans if s.parent_id is None) + assert isinstance(interaction_span, InteractionSpan) + + attrs = interaction_span.to_otel_attributes() + assert attrs[CommonAttributes.SPAN_KIND] == SpanKind.SERVER + assert attrs[GenAIAttributes.GEN_AI_OPERATION_NAME] == OperationNames.GUARDRAILS + assert "service.name" in attrs + assert interaction_span.name == SpanNames.GUARDRAILS_REQUEST + + assert GenAIAttributes.GEN_AI_PROVIDER_NAME not in attrs + assert GenAIAttributes.GEN_AI_SYSTEM not in attrs + + def test_rail_span_has_complete_attributes(self): + """Test that rail spans have all required attributes.""" + rail = ActivatedRail( + type="input", + name="check_jailbreak", + started_at=1.0, + finished_at=2.0, + duration=1.0, + stop=True, + decisions=["blocked"], + executed_actions=[], + ) + + extractor = SpanExtractorV2() + spans = extractor.extract_spans([rail]) + + rail_span = next(s for s in spans if s.name == SpanNames.GUARDRAILS_RAIL) + assert isinstance(rail_span, RailSpan) + + attrs = rail_span.to_otel_attributes() + assert attrs[CommonAttributes.SPAN_KIND] == SpanKind.INTERNAL + assert attrs[GuardrailsAttributes.RAIL_TYPE] == "input" + assert attrs[GuardrailsAttributes.RAIL_NAME] == "check_jailbreak" + assert attrs[GuardrailsAttributes.RAIL_STOP] is True + assert attrs[GuardrailsAttributes.RAIL_DECISIONS] == ["blocked"] + + def test_llm_span_has_complete_attributes(self): + """Test that LLM spans have all required OTel GenAI attributes.""" + llm_call = LLMCallInfo( + task="generate", + llm_model_name="gpt-4", + llm_provider_name="openai", + prompt="Hello, world!", + completion="Hi there!", + prompt_tokens=10, + completion_tokens=5, + total_tokens=15, + started_at=1.5, + finished_at=1.8, + duration=0.3, + raw_response={ + "id": "chatcmpl-123", + "choices": [{"finish_reason": "stop"}], + "temperature": 0.7, + "max_tokens": 100, + "top_p": 0.9, + }, + ) + + action = ExecutedAction( + action_name="generate_user_intent", + started_at=1.0, + finished_at=2.0, + duration=1.0, + llm_calls=[llm_call], + ) + + rail = ActivatedRail( + type="dialog", + name="generate_next_step", + started_at=1.0, + finished_at=2.0, + duration=1.0, + executed_actions=[action], + ) + + extractor = SpanExtractorV2() + spans = extractor.extract_spans([rail]) + + llm_span = next(s for s in spans if "gpt-4" in s.name) + assert isinstance(llm_span, LLMSpan) + + attrs = llm_span.to_otel_attributes() + assert attrs[CommonAttributes.SPAN_KIND] == SpanKind.CLIENT + assert attrs[GenAIAttributes.GEN_AI_PROVIDER_NAME] == "openai" + assert attrs[GenAIAttributes.GEN_AI_REQUEST_MODEL] == "gpt-4" + assert attrs[GenAIAttributes.GEN_AI_RESPONSE_MODEL] == "gpt-4" + assert attrs[GenAIAttributes.GEN_AI_OPERATION_NAME] == "generate" + assert attrs[GenAIAttributes.GEN_AI_USAGE_INPUT_TOKENS] == 10 + assert attrs[GenAIAttributes.GEN_AI_USAGE_OUTPUT_TOKENS] == 5 + assert attrs[GenAIAttributes.GEN_AI_USAGE_TOTAL_TOKENS] == 15 + assert attrs[GenAIAttributes.GEN_AI_RESPONSE_ID] == "chatcmpl-123" + assert attrs[GenAIAttributes.GEN_AI_RESPONSE_FINISH_REASONS] == ["stop"] + assert attrs[GenAIAttributes.GEN_AI_REQUEST_TEMPERATURE] == 0.7 + assert attrs[GenAIAttributes.GEN_AI_REQUEST_MAX_TOKENS] == 100 + assert attrs[GenAIAttributes.GEN_AI_REQUEST_TOP_P] == 0.9 + + assert GenAIAttributes.GEN_AI_SYSTEM not in attrs + + def test_llm_span_events_are_complete(self): + """Test that LLM span events follow OTel GenAI conventions.""" + llm_call = LLMCallInfo( + task="chat", + llm_model_name="claude-3", + prompt="What is the weather?", + completion="I cannot access real-time weather data.", + started_at=1.5, + finished_at=1.8, + duration=0.3, + raw_response={"stop_reason": "end_turn"}, + ) + + action = ExecutedAction( + action_name="llm_generate", + started_at=1.0, + finished_at=2.0, + duration=1.0, + llm_calls=[llm_call], + ) + + rail = ActivatedRail( + type="dialog", + name="chat", + started_at=1.0, + finished_at=2.0, + duration=1.0, + executed_actions=[action], + ) + + extractor = SpanExtractorV2(enable_content_capture=True) + spans = extractor.extract_spans([rail]) + + llm_span = next(s for s in spans if "claude" in s.name) + assert isinstance(llm_span, LLMSpan) + + assert len(llm_span.events) >= 2 # at least user and assistant messages + + user_event = next( + e for e in llm_span.events if e.name == EventNames.GEN_AI_CONTENT_PROMPT + ) + assert user_event.body["content"] == "What is the weather?" + + assistant_event = next( + e for e in llm_span.events if e.name == EventNames.GEN_AI_CONTENT_COMPLETION + ) + assert ( + assistant_event.body["content"] == "I cannot access real-time weather data." + ) + + finish_events = [e for e in llm_span.events if e.name == "gen_ai.choice.finish"] + if finish_events: + finish_event = finish_events[0] + assert "finish_reason" in finish_event.body + assert "index" in finish_event.body + + def test_action_span_with_error_attributes(self): + """Test that action spans include error information when present.""" + # TODO: Figure out how errors are properly attached to actions + action = ExecutedAction( + action_name="failed_action", + started_at=1.0, + finished_at=2.0, + duration=1.0, + llm_calls=[], + ) + # skip setting error for now since ExecutedAction doesn't have that field + # action.error = ValueError("Something went wrong") + + rail = ActivatedRail( + type="input", + name="check_input", + started_at=1.0, + finished_at=2.0, + duration=1.0, + executed_actions=[action], + ) + + extractor = SpanExtractorV2() + spans = extractor.extract_spans([rail]) + + action_span = next(s for s in spans if s.name == SpanNames.GUARDRAILS_ACTION) + assert isinstance(action_span, ActionSpan) + + attrs = action_span.to_otel_attributes() + # since we didn't set an error, these shouldn't be present + assert "error" not in attrs or attrs["error"] is None + assert "error.type" not in attrs + assert "error.message" not in attrs + + def test_span_names_are_low_cardinality(self): + """Test that span names follow low-cardinality convention.""" + rails = [ + ActivatedRail( + type="input", + name=f"rail_{i}", + started_at=float(i), + finished_at=float(i + 1), + duration=1.0, + executed_actions=[ + ExecutedAction( + action_name=f"action_{i}", + started_at=float(i), + finished_at=float(i + 1), + duration=1.0, + llm_calls=[ + LLMCallInfo( + task=f"task_{i}", + llm_model_name=f"model_{i}", + started_at=float(i), + finished_at=float(i + 1), + duration=1.0, + ) + ], + ) + ], + ) + for i in range(3) + ] + + extractor = SpanExtractorV2() + all_spans = [] + for rail in rails: + spans = extractor.extract_spans([rail]) + all_spans.extend(spans) + + expected_patterns = { + SpanNames.GUARDRAILS_REQUEST, + SpanNames.GUARDRAILS_RAIL, + SpanNames.GUARDRAILS_ACTION, + } + + for span in all_spans: + if not any(f"model_{i}" in span.name for i in range(3)): + assert span.name in expected_patterns + + rail_spans = [s for s in all_spans if s.name == SpanNames.GUARDRAILS_RAIL] + rail_names = { + s.to_otel_attributes()[GuardrailsAttributes.RAIL_NAME] for s in rail_spans + } + assert len(rail_names) == 3 + + def test_no_semantic_logic_in_adapter(self): + """Verify adapter is just an API bridge by checking it doesn't modify attributes.""" + from opentelemetry import trace + from opentelemetry.sdk.trace import TracerProvider + from opentelemetry.sdk.trace.export import SimpleSpanProcessor + + from nemoguardrails.tracing import InteractionLog + from nemoguardrails.tracing.adapters.opentelemetry import OpenTelemetryAdapter + + # create a mock exporter to capture spans + class MockExporter: + def __init__(self): + self.spans = [] + + def export(self, spans): + self.spans.extend(spans) + return 0 + + def shutdown(self): + pass + + # setup OTel + exporter = MockExporter() + provider = TracerProvider() + provider.add_span_processor(SimpleSpanProcessor(exporter)) + trace.set_tracer_provider(provider) + + # create adapter + adapter = OpenTelemetryAdapter() + + # create a simple rail + rail = ActivatedRail( + type="input", + name="test_rail", + started_at=1.0, + finished_at=2.0, + duration=1.0, + executed_actions=[], + ) + + # extract spans with V2 extractor + extractor = SpanExtractorV2() + spans = extractor.extract_spans([rail]) + + # create interaction log + interaction_log = InteractionLog( + id="test-trace-123", + activated_rails=[rail], + trace=spans, + ) + + # transform through adapter + adapter.transform(interaction_log) + + assert len(exporter.spans) > 0 + + for otel_span in exporter.spans: + attrs = dict(otel_span.attributes) + + if otel_span.name == SpanNames.GUARDRAILS_REQUEST: + assert GenAIAttributes.GEN_AI_OPERATION_NAME in attrs + assert GenAIAttributes.GEN_AI_PROVIDER_NAME not in attrs + assert GenAIAttributes.GEN_AI_SYSTEM not in attrs + + +class TestOpenTelemetryAdapterAsTheBridge: + """Test that OpenTelemetryAdapter is a pure API bridge.""" + + def test_adapter_handles_span_kind_mapping(self): + """Test that adapter correctly maps span.kind string to OTel enum.""" + from opentelemetry.trace import SpanKind as OTelSpanKind + + from nemoguardrails.tracing import InteractionLog + from nemoguardrails.tracing.adapters.opentelemetry import OpenTelemetryAdapter + + # mock provider to capture span creation + created_spans = [] + + class MockTracer: + def start_span(self, name, context=None, start_time=None, kind=None): + created_spans.append({"name": name, "kind": kind}) + return MagicMock() + + provider = MagicMock() + provider.get_tracer = MagicMock(return_value=MockTracer()) + + with patch("opentelemetry.trace.get_tracer_provider", return_value=provider): + adapter = OpenTelemetryAdapter() + + spans = [ + InteractionSpan( + span_id="1", + name="server_span", + start_time=0.0, + end_time=1.0, + duration=1.0, + ), + LLMSpan( + span_id="2", + name="client_span", + start_time=0.0, + end_time=1.0, + duration=1.0, + provider_name="openai", + request_model="gpt-4", + response_model="gpt-4", + operation_name="chat.completions", + ), + RailSpan( + span_id="3", + name="internal_span", + start_time=0.0, + end_time=1.0, + duration=1.0, + rail_type="input", + rail_name="test_rail", + ), + ] + + interaction_log = InteractionLog( + id="test-123", + activated_rails=[], + trace=spans, + ) + + adapter.transform(interaction_log) + + assert created_spans[0]["kind"] == OTelSpanKind.SERVER + assert created_spans[1]["kind"] == OTelSpanKind.CLIENT + assert created_spans[2]["kind"] == OTelSpanKind.INTERNAL + + +class TestContentPrivacy: + """Test that sensitive content is handled according to OTel GenAI conventions.""" + + def test_content_not_included_by_default(self): + """Test that content is NOT included by default per OTel spec.""" + events = [ + {"type": "UserMessage", "text": "My SSN is 123-45-6789"}, + { + "type": "UtteranceBotActionFinished", + "final_script": "I cannot process SSN", + }, + ] + extractor = SpanExtractorV2(events=events, enable_content_capture=False) + + activated_rail = ActivatedRail( + type="action", + name="generate", + started_at=0.0, + finished_at=1.0, + duration=1.0, + executed_actions=[ + ExecutedAction( + action_name="generate", + started_at=0.0, + finished_at=1.0, + duration=1.0, + llm_calls=[ + LLMCallInfo( + task="general", + prompt="User sensitive prompt", + completion="Bot response with PII", + duration=0.5, + total_tokens=100, + prompt_tokens=50, + completion_tokens=50, + raw_response={"model": "gpt-3.5-turbo"}, + ) + ], + ) + ], + ) + + spans = extractor.extract_spans([activated_rail]) + + llm_span = next((s for s in spans if isinstance(s, LLMSpan)), None) + assert llm_span is not None + + for event in llm_span.events: + if event.name in ["gen_ai.content.prompt", "gen_ai.content.completion"]: + assert event.body == {} + assert "content" not in event.body + + def test_content_included_when_explicitly_enabled(self): + """Test that content IS included when explicitly enabled.""" + # Create extractor with enable_content_capture=True + events = [ + {"type": "UserMessage", "text": "Hello bot"}, + {"type": "UtteranceBotActionFinished", "final_script": "Hello user"}, + ] + extractor = SpanExtractorV2(events=events, enable_content_capture=True) + + activated_rail = ActivatedRail( + type="action", + name="generate", + started_at=0.0, + finished_at=1.0, + duration=1.0, + executed_actions=[ + ExecutedAction( + action_name="generate", + started_at=0.0, + finished_at=1.0, + duration=1.0, + llm_calls=[ + LLMCallInfo( + task="general", + prompt="Test prompt", + completion="Test response", + duration=0.5, + total_tokens=100, + prompt_tokens=50, + completion_tokens=50, + raw_response={"model": "gpt-3.5-turbo"}, + ) + ], + ) + ], + ) + + spans = extractor.extract_spans([activated_rail]) + + llm_span = next((s for s in spans if isinstance(s, LLMSpan)), None) + assert llm_span is not None + + prompt_event = next( + (e for e in llm_span.events if e.name == "gen_ai.content.prompt"), None + ) + assert prompt_event is not None + assert prompt_event.body.get("content") == "Test prompt" + + completion_event = next( + (e for e in llm_span.events if e.name == "gen_ai.content.completion"), None + ) + assert completion_event is not None + assert completion_event.body.get("content") == "Test response" + + def test_conversation_events_respect_privacy_setting(self): + """Test that guardrails internal events respect the privacy setting.""" + events = [ + {"type": "UserMessage", "text": "Private message"}, + { + "type": "UtteranceBotActionFinished", + "final_script": "Private response", + "is_success": True, + }, + ] + + extractor_no_content = SpanExtractorV2( + events=events, enable_content_capture=False + ) + activated_rail = ActivatedRail( + type="dialog", name="main", started_at=0.0, finished_at=1.0, duration=1.0 + ) + + spans = extractor_no_content.extract_spans([activated_rail]) + interaction_span = spans[0] # First span is the interaction span + + user_event = next( + (e for e in interaction_span.events if e.name == "guardrails.user_message"), + None, + ) + assert user_event is not None + assert user_event.body["type"] == "UserMessage" + assert "content" not in user_event.body + + bot_event = next( + ( + e + for e in interaction_span.events + if e.name == "guardrails.utterance.bot.finished" + ), + None, + ) + assert bot_event is not None + assert bot_event.body["type"] == "UtteranceBotActionFinished" + assert bot_event.body["is_success"] == True + assert "content" not in bot_event.body # Content excluded + + extractor_with_content = SpanExtractorV2( + events=events, enable_content_capture=True + ) + spans = extractor_with_content.extract_spans([activated_rail]) + interaction_span = spans[0] + + user_event = next( + (e for e in interaction_span.events if e.name == "guardrails.user_message"), + None, + ) + assert user_event is not None + assert user_event.body.get("content") == "Private message" + + bot_event = next( + ( + e + for e in interaction_span.events + if e.name == "guardrails.utterance.bot.finished" + ), + None, + ) + assert bot_event is not None + assert bot_event.body.get("content") == "Private response" + assert bot_event.body.get("type") == "UtteranceBotActionFinished" + assert bot_event.body.get("is_success") == True + + +if __name__ == "__main__": + pytest.main([__file__, "-v"]) diff --git a/tests/tracing/spans/test_spans.py b/tests/tracing/spans/test_spans.py new file mode 100644 index 000000000..2cf218bc0 --- /dev/null +++ b/tests/tracing/spans/test_spans.py @@ -0,0 +1,98 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import pytest + +from nemoguardrails.tracing import SpanEvent, SpanLegacy +from nemoguardrails.tracing.spans import LLMSpan, is_opentelemetry_span + + +class TestSpanModels: + """Test the span models for legacy and OpenTelemetry formats.""" + + def test_span_legacy_creation(self): + """Test creating a legacy format span.""" + span = SpanLegacy( + span_id="test-123", + name="test span", + start_time=0.0, + end_time=1.0, + duration=1.0, + metrics={"test_metric": 42}, + ) + + assert span.span_id == "test-123" + assert span.name == "test span" + assert span.duration == 1.0 + assert span.metrics["test_metric"] == 42 + + # Legacy spans don't have OpenTelemetry attributes + assert not hasattr(span, "attributes") + assert not hasattr(span, "events") + assert not hasattr(span, "otel_metrics") + + def test_span_opentelemetry_creation(self): + """Test creating an OpenTelemetry format span - typed spans with explicit fields.""" + event = SpanEvent( + name="gen_ai.content.prompt", timestamp=0.5, body={"content": "test prompt"} + ) + + # OpenTelemetry spans are typed with explicit fields + span = LLMSpan( + span_id="test-456", + name="generate_user_intent gpt-4", + start_time=0.0, + end_time=2.0, + duration=2.0, + provider_name="openai", + request_model="gpt-4", + response_model="gpt-4", + operation_name="chat.completions", + usage_input_tokens=10, + usage_output_tokens=20, + usage_total_tokens=30, + events=[event], + ) + + assert span.span_id == "test-456" + assert span.provider_name == "openai" + assert span.request_model == "gpt-4" + assert span.usage_input_tokens == 10 + assert len(span.events) == 1 + assert span.events[0].name == "gen_ai.content.prompt" + + attributes = span.to_otel_attributes() + assert attributes["gen_ai.provider.name"] == "openai" + assert attributes["gen_ai.request.model"] == "gpt-4" + + def test_span_legacy_model_is_simple(self): + """Test that Legacy span model is a simple span without OpenTelemetry features.""" + legacy_span = SpanLegacy( + span_id="legacy-123", + name="test", + start_time=0.0, + end_time=1.0, + duration=1.0, + metrics={"metric": 1}, + ) + + assert isinstance(legacy_span, SpanLegacy) + assert legacy_span.span_id == "legacy-123" + assert legacy_span.metrics["metric"] == 1 + + # Legacy spans don't have OpenTelemetry attributes or events + assert not hasattr(legacy_span, "attributes") + assert not hasattr(legacy_span, "events") diff --git a/tests/tracing/test_span_formatting.py b/tests/tracing/test_span_formatting.py new file mode 100644 index 000000000..2e8cbff1d --- /dev/null +++ b/tests/tracing/test_span_formatting.py @@ -0,0 +1,276 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest + +from nemoguardrails.tracing.span_formatting import ( + extract_span_attributes, + format_span_for_filesystem, +) +from nemoguardrails.tracing.spans import ( + ActionSpan, + InteractionSpan, + LLMSpan, + RailSpan, + SpanEvent, + SpanLegacy, +) + + +class TestFormatSpanForFilesystem: + def test_format_legacy_span_with_metrics(self): + span = SpanLegacy( + name="llm_call", + span_id="span_1", + parent_id="parent_1", + start_time=0.5, + end_time=1.5, + duration=1.0, + metrics={"input_tokens": 10, "output_tokens": 20}, + ) + + result = format_span_for_filesystem(span) + + assert result["name"] == "llm_call" + assert result["span_id"] == "span_1" + assert result["parent_id"] == "parent_1" + assert result["start_time"] == 0.5 + assert result["end_time"] == 1.5 + assert result["duration"] == 1.0 + assert result["span_type"] == "SpanLegacy" + assert result["metrics"] == {"input_tokens": 10, "output_tokens": 20} + assert "span_kind" not in result + assert "attributes" not in result + + def test_format_legacy_span_without_metrics(self): + span = SpanLegacy( + name="test", + span_id="span_1", + parent_id=None, + start_time=0.0, + end_time=1.0, + duration=1.0, + metrics={}, + ) + + result = format_span_for_filesystem(span) + + assert result["span_type"] == "SpanLegacy" + assert "metrics" not in result + + def test_format_interaction_span(self): + span = InteractionSpan( + name="interaction", + span_id="span_1", + parent_id=None, + start_time=0.0, + end_time=2.0, + duration=2.0, + span_kind="server", + request_model="gpt-4", + ) + + result = format_span_for_filesystem(span) + + assert result["span_type"] == "InteractionSpan" + assert result["span_kind"] == "server" + assert "attributes" in result + assert result["attributes"]["gen_ai.operation.name"] == "guardrails" + + def test_format_span_with_events(self): + events = [ + SpanEvent( + name="test_event", + timestamp=0.5, + attributes={"key": "value"}, + ) + ] + span = InteractionSpan( + name="interaction", + span_id="span_1", + parent_id=None, + start_time=0.0, + end_time=1.0, + duration=1.0, + span_kind="server", + events=events, + ) + + result = format_span_for_filesystem(span) + + assert "events" in result + assert len(result["events"]) == 1 + assert result["events"][0]["name"] == "test_event" + assert result["events"][0]["timestamp"] == 0.5 + assert result["events"][0]["attributes"] == {"key": "value"} + + def test_format_span_with_error(self): + span = ActionSpan( + name="action", + span_id="span_1", + parent_id=None, + start_time=0.0, + end_time=1.0, + duration=1.0, + span_kind="internal", + action_name="fetch", + error=True, + error_type="ConnectionError", + error_message="Failed", + ) + + result = format_span_for_filesystem(span) + + assert "error" in result + assert result["error"]["occurred"] is True + assert result["error"]["type"] == "ConnectionError" + assert result["error"]["message"] == "Failed" + + def test_format_span_with_custom_attributes(self): + span = LLMSpan( + name="llm", + span_id="span_1", + parent_id=None, + start_time=0.0, + end_time=1.0, + duration=1.0, + span_kind="client", + provider_name="openai", + operation_name="chat.completions", + request_model="gpt-4", + response_model="gpt-4", + custom_attributes={"custom": "value"}, + ) + + result = format_span_for_filesystem(span) + + assert "custom_attributes" in result + assert result["custom_attributes"] == {"custom": "value"} + + def test_format_unknown_span_type_raises(self): + class UnknownSpan: + def __init__(self): + self.name = "unknown" + + with pytest.raises(ValueError) as exc_info: + format_span_for_filesystem(UnknownSpan()) + + assert "Unknown span type: UnknownSpan" in str(exc_info.value) + assert "Only SpanLegacy and typed spans are supported" in str(exc_info.value) + + +class TestExtractSpanAttributes: + def test_extract_from_legacy_span_with_metrics(self): + span = SpanLegacy( + name="test", + span_id="1", + parent_id=None, + start_time=0.0, + end_time=1.0, + duration=1.0, + metrics={"tokens": 100, "latency": 0.5}, + ) + + attrs = extract_span_attributes(span) + + assert attrs == {"tokens": 100, "latency": 0.5} + assert attrs is not span.metrics + + def test_extract_from_legacy_span_without_metrics(self): + span = SpanLegacy( + name="test", + span_id="1", + parent_id=None, + start_time=0.0, + end_time=1.0, + duration=1.0, + metrics={}, + ) + + attrs = extract_span_attributes(span) + + assert attrs == {} + + def test_extract_from_interaction_span(self): + span = InteractionSpan( + name="interaction", + span_id="1", + parent_id=None, + start_time=0.0, + end_time=1.0, + duration=1.0, + span_kind="server", + request_model="gpt-4", + ) + + attrs = extract_span_attributes(span) + + assert "span.kind" in attrs + assert attrs["span.kind"] == "server" + assert "gen_ai.operation.name" in attrs + + def test_extract_from_rail_span(self): + span = RailSpan( + name="check", + span_id="1", + parent_id=None, + start_time=0.0, + end_time=1.0, + duration=1.0, + span_kind="internal", + rail_type="input", + rail_name="check_jailbreak", + rail_stop=False, + ) + + attrs = extract_span_attributes(span) + + assert attrs["rail.type"] == "input" + assert attrs["rail.name"] == "check_jailbreak" + assert attrs["rail.stop"] is False + + def test_extract_from_llm_span(self): + span = LLMSpan( + name="llm", + span_id="1", + parent_id=None, + start_time=0.0, + end_time=1.0, + duration=1.0, + span_kind="client", + provider_name="openai", + operation_name="chat.completions", + request_model="gpt-4", + response_model="gpt-4", + temperature=0.7, + usage_input_tokens=50, + usage_output_tokens=100, + ) + + attrs = extract_span_attributes(span) + + assert attrs["gen_ai.request.model"] == "gpt-4" + assert attrs["gen_ai.request.temperature"] == 0.7 + assert attrs["gen_ai.usage.input_tokens"] == 50 + assert attrs["gen_ai.usage.output_tokens"] == 100 + + def test_extract_unknown_span_type_raises(self): + class UnknownSpan: + pass + + with pytest.raises(ValueError) as exc_info: + extract_span_attributes(UnknownSpan()) + + assert "Unknown span type: UnknownSpan" in str(exc_info.value) diff --git a/tests/test_tracing.py b/tests/tracing/test_tracing.py similarity index 100% rename from tests/test_tracing.py rename to tests/tracing/test_tracing.py