diff --git a/nemoguardrails/tracing/adapters/opentelemetry.py b/nemoguardrails/tracing/adapters/opentelemetry.py index 6044b3cfe..f30db6410 100644 --- a/nemoguardrails/tracing/adapters/opentelemetry.py +++ b/nemoguardrails/tracing/adapters/opentelemetry.py @@ -55,7 +55,7 @@ import warnings from importlib.metadata import version -from typing import TYPE_CHECKING, Optional, Type +from typing import TYPE_CHECKING, Type if TYPE_CHECKING: from nemoguardrails.tracing import InteractionLog @@ -203,17 +203,22 @@ def _create_span( spans, trace_id, ): - with self.tracer.start_as_current_span( + start_time_ns = int(span_data.start_time * 1_000_000_000) + end_time_ns = int(span_data.end_time * 1_000_000_000) + + span = self.tracer.start_span( span_data.name, context=parent_context, - ) as span: - for key, value in span_data.metrics.items(): - span.set_attribute(key, value) + start_time=start_time_ns, + ) + + for key, value in span_data.metrics.items(): + span.set_attribute(key, value) + + span.set_attribute("span_id", span_data.span_id) + span.set_attribute("trace_id", trace_id) + span.set_attribute("duration", span_data.duration) - span.set_attribute("span_id", span_data.span_id) - span.set_attribute("trace_id", trace_id) - span.set_attribute("start_time", span_data.start_time) - span.set_attribute("end_time", span_data.end_time) - span.set_attribute("duration", span_data.duration) + spans[span_data.span_id] = span - spans[span_data.span_id] = span + span.end(end_time=end_time_ns) diff --git a/tests/test_opentelemetry_timing_behavior.py b/tests/test_opentelemetry_timing_behavior.py new file mode 100644 index 000000000..eb958578a --- /dev/null +++ b/tests/test_opentelemetry_timing_behavior.py @@ -0,0 +1,137 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import time +from typing import List + +import pytest +from opentelemetry import trace +from opentelemetry.sdk.trace import TracerProvider +from opentelemetry.sdk.trace.export import SimpleSpanProcessor + +from nemoguardrails.eval.models import Span +from nemoguardrails.tracing import InteractionLog +from nemoguardrails.tracing.adapters.opentelemetry import OpenTelemetryAdapter + + +class InMemorySpanExporter: + """Simple in-memory span exporter to capture spans for testing.""" + + def __init__(self): + self.spans: List = [] + + def export(self, spans): + self.spans.extend(spans) + return 0 # Success + + def shutdown(self): + pass + + +class TestOpenTelemetryTimingBehavior: + """ + Test that verifies OpenTelemetry spans are created with correct timestamps. + + This test focuses on the ACTUAL timing behavior, not implementation details. + It will fail with the old broken code (retrospective timing) and pass with + the fixed code (historical timing). + """ + + def setup_method(self): + self.exporter = InMemorySpanExporter() + self.tracer_provider = TracerProvider() + self.tracer_provider.add_span_processor(SimpleSpanProcessor(self.exporter)) + + trace._TRACER_PROVIDER = None + trace.set_tracer_provider(self.tracer_provider) + + self.adapter = OpenTelemetryAdapter() + + def teardown_method(self): + # Clean up - reset to no-op + trace._TRACER_PROVIDER = None + trace.set_tracer_provider(trace.NoOpTracerProvider()) + + def test_spans_use_historical_timestamps_not_current_time(self): + """ + Test that spans are created with historical timestamps from span_data, + not with the current time when transform() is called. + + This test will: + - FAIL with old broken code (uses current time) + - PASS with fixed code (uses historical time) + """ + historical_start = 1234567890.5 # January 1, 2009 + historical_end = 1234567892.0 # 1.5 seconds later + + interaction_log = InteractionLog( + id="timing_test", + activated_rails=[], + events=[], + trace=[ + Span( + name="historical_operation", + span_id="span_1", + parent_id=None, + start_time=historical_start, + end_time=historical_end, + duration=1.5, + metrics={"test_metric": 42}, + ) + ], + ) + + current_time_before = time.time() + + self.adapter.transform(interaction_log) + + current_time_after = time.time() + + assert len(self.exporter.spans) == 1 + captured_span = self.exporter.spans[0] + + actual_start_time = captured_span.start_time / 1_000_000_000 + actual_end_time = captured_span.end_time / 1_000_000_000 + + assert ( + abs(actual_start_time - historical_start) < 0.001 + ), f"Span start time ({actual_start_time}) should match historical time ({historical_start})" + + assert ( + abs(actual_end_time - historical_end) < 0.001 + ), f"Span end time ({actual_end_time}) should match historical time ({historical_end})" + + time_diff_start = abs(actual_start_time - current_time_before) + time_diff_end = abs(actual_end_time - current_time_after) + + assert time_diff_start > 1000000, ( + f"Span start time should be very different from current time. " + f"Difference: {time_diff_start} seconds. This suggests the old bug is present." + ) + + assert time_diff_end > 1000000, ( + f"Span end time should be very different from current time. " + f"Difference: {time_diff_end} seconds. This suggests the old bug is present." + ) + actual_duration = actual_end_time - actual_start_time + expected_duration = historical_end - historical_start + assert ( + abs(actual_duration - expected_duration) < 0.001 + ), f"Span duration should be {expected_duration}s, got {actual_duration}s" + + assert captured_span.name == "historical_operation" + assert captured_span.attributes.get("test_metric") == 42 + assert captured_span.attributes.get("span_id") == "span_1" + assert captured_span.attributes.get("trace_id") == "timing_test" diff --git a/tests/test_tracing_adapters_opentelemetry.py b/tests/test_tracing_adapters_opentelemetry.py index ee1a5a667..9a9a4e68b 100644 --- a/tests/test_tracing_adapters_opentelemetry.py +++ b/tests/test_tracing_adapters_opentelemetry.py @@ -21,21 +21,14 @@ # TODO: check to see if we can add it as a dependency # but now we try to import opentelemetry and set a flag if it's not available -try: - from opentelemetry.sdk.trace import TracerProvider - from opentelemetry.trace import NoOpTracerProvider - - from nemoguardrails.tracing.adapters.opentelemetry import OpenTelemetryAdapter - - OPENTELEMETRY_AVAILABLE = True -except ImportError: - OPENTELEMETRY_AVAILABLE = False +from opentelemetry.sdk.trace import TracerProvider +from opentelemetry.trace import NoOpTracerProvider from nemoguardrails.eval.models import Span from nemoguardrails.tracing import InteractionLog +from nemoguardrails.tracing.adapters.opentelemetry import OpenTelemetryAdapter -@unittest.skipIf(not OPENTELEMETRY_AVAILABLE, "opentelemetry is not available") class TestOpenTelemetryAdapter(unittest.TestCase): def setUp(self): # Set up a mock tracer provider for testing @@ -73,7 +66,10 @@ def test_initialization(self): self.assertEqual(self.adapter.tracer, self.mock_tracer) def test_transform(self): - """Test that transform creates spans correctly.""" + """Test that transform creates spans correctly with proper timing.""" + mock_span = MagicMock() + self.mock_tracer.start_span.return_value = mock_span + interaction_log = InteractionLog( id="test_id", activated_rails=[], @@ -83,8 +79,8 @@ def test_transform(self): name="test_span", span_id="span_1", parent_id=None, - start_time=0.0, - end_time=1.0, + start_time=1234567890.5, # historical timestamp + end_time=1234567891.5, # historical timestamp duration=1.0, metrics={"key": 123}, ) @@ -93,27 +89,28 @@ def test_transform(self): self.adapter.transform(interaction_log) - # Verify that start_as_current_span was called - self.mock_tracer.start_as_current_span.assert_called_once_with( + # Verify that start_span was called with proper timing (not start_as_current_span) + self.mock_tracer.start_span.assert_called_once_with( "test_span", context=None, + start_time=1234567890500000000, # Converted to nanoseconds ) - # We retrieve the mock span instance here - span_instance = ( - self.mock_tracer.start_as_current_span.return_value.__enter__.return_value - ) + mock_span.set_attribute.assert_any_call("key", 123) + mock_span.set_attribute.assert_any_call("span_id", "span_1") + mock_span.set_attribute.assert_any_call("trace_id", "test_id") + mock_span.set_attribute.assert_any_call("duration", 1.0) - # Verify span attributes were set - span_instance.set_attribute.assert_any_call("key", 123) - span_instance.set_attribute.assert_any_call("span_id", "span_1") - span_instance.set_attribute.assert_any_call("trace_id", "test_id") - span_instance.set_attribute.assert_any_call("start_time", 0.0) - span_instance.set_attribute.assert_any_call("end_time", 1.0) - span_instance.set_attribute.assert_any_call("duration", 1.0) + # Verify span was ended with correct end time + mock_span.end.assert_called_once_with( + end_time=1234567891500000000 + ) # Converted to nanoseconds def test_transform_span_attributes_various_types(self): """Test that different attribute types are handled correctly.""" + mock_span = MagicMock() + self.mock_tracer.start_span.return_value = mock_span + interaction_log = InteractionLog( id="test_id", activated_rails=[], @@ -123,8 +120,8 @@ def test_transform_span_attributes_various_types(self): name="test_span", span_id="span_1", parent_id=None, - start_time=0.0, - end_time=1.0, + start_time=1234567890.0, + end_time=1234567891.0, duration=1.0, metrics={ "int_key": 42, @@ -138,19 +135,14 @@ def test_transform_span_attributes_various_types(self): self.adapter.transform(interaction_log) - span_instance = ( - self.mock_tracer.start_as_current_span.return_value.__enter__.return_value - ) - - span_instance.set_attribute.assert_any_call("int_key", 42) - span_instance.set_attribute.assert_any_call("float_key", 3.14) - span_instance.set_attribute.assert_any_call("str_key", 123) - span_instance.set_attribute.assert_any_call("bool_key", 1) - span_instance.set_attribute.assert_any_call("span_id", "span_1") - span_instance.set_attribute.assert_any_call("trace_id", "test_id") - span_instance.set_attribute.assert_any_call("start_time", 0.0) - span_instance.set_attribute.assert_any_call("end_time", 1.0) - span_instance.set_attribute.assert_any_call("duration", 1.0) + mock_span.set_attribute.assert_any_call("int_key", 42) + mock_span.set_attribute.assert_any_call("float_key", 3.14) + mock_span.set_attribute.assert_any_call("str_key", 123) + mock_span.set_attribute.assert_any_call("bool_key", 1) + mock_span.set_attribute.assert_any_call("span_id", "span_1") + mock_span.set_attribute.assert_any_call("trace_id", "test_id") + mock_span.set_attribute.assert_any_call("duration", 1.0) + mock_span.end.assert_called_once_with(end_time=1234567891000000000) def test_transform_with_empty_trace(self): """Test transform with empty trace.""" @@ -163,11 +155,11 @@ def test_transform_with_empty_trace(self): self.adapter.transform(interaction_log) - self.mock_tracer.start_as_current_span.assert_not_called() + self.mock_tracer.start_span.assert_not_called() def test_transform_with_tracer_failure(self): """Test transform when tracer fails.""" - self.mock_tracer.start_as_current_span.side_effect = Exception("Tracer failure") + self.mock_tracer.start_span.side_effect = Exception("Tracer failure") interaction_log = InteractionLog( id="test_id", @@ -178,8 +170,8 @@ def test_transform_with_tracer_failure(self): name="test_span", span_id="span_1", parent_id=None, - start_time=0.0, - end_time=1.0, + start_time=1234567890.0, + end_time=1234567891.0, duration=1.0, metrics={"key": 123}, ) @@ -191,10 +183,78 @@ def test_transform_with_tracer_failure(self): self.assertIn("Tracer failure", str(context.exception)) + def test_transform_with_parent_child_relationships(self): + """Test that parent-child relationships are preserved with correct timing.""" + parent_mock_span = MagicMock() + child_mock_span = MagicMock() + self.mock_tracer.start_span.side_effect = [parent_mock_span, child_mock_span] + + interaction_log = InteractionLog( + id="test_id", + activated_rails=[], + events=[], + trace=[ + Span( + name="parent_span", + span_id="span_1", + parent_id=None, + start_time=1234567890.0, + end_time=1234567892.0, + duration=2.0, + metrics={"parent_key": 1}, + ), + Span( + name="child_span", + span_id="span_2", + parent_id="span_1", + start_time=1234567890.5, # child starts after parent + end_time=1234567891.5, # child ends before parent + duration=1.0, + metrics={"child_key": 2}, + ), + ], + ) + + with patch( + "opentelemetry.trace.set_span_in_context" + ) as mock_set_span_in_context: + mock_set_span_in_context.return_value = "parent_context" + + self.adapter.transform(interaction_log) + + # verify parent span created first with no context + self.assertEqual(self.mock_tracer.start_span.call_count, 2) + first_call = self.mock_tracer.start_span.call_args_list[0] + self.assertEqual(first_call[0][0], "parent_span") # name + self.assertEqual(first_call[1]["context"], None) # no parent context + self.assertEqual( + first_call[1]["start_time"], 1234567890000000000 + ) # nanoseconds + + # verify child span created with parent context + second_call = self.mock_tracer.start_span.call_args_list[1] + self.assertEqual(second_call[0][0], "child_span") # name + self.assertEqual( + second_call[1]["context"], "parent_context" + ) # parent context + self.assertEqual( + second_call[1]["start_time"], 1234567890500000000 + ) # nanoseconds + + # verify parent context was set correctly + mock_set_span_in_context.assert_called_once_with(parent_mock_span) + + # verify both spans ended with correct times + parent_mock_span.end.assert_called_once_with(end_time=1234567892000000000) + child_mock_span.end.assert_called_once_with(end_time=1234567891500000000) + def test_transform_async(self): """Test async transform functionality.""" async def run_test(): + mock_span = MagicMock() + self.mock_tracer.start_span.return_value = mock_span + interaction_log = InteractionLog( id="test_id", activated_rails=[], @@ -204,8 +264,8 @@ async def run_test(): name="test_span", span_id="span_1", parent_id=None, - start_time=0.0, - end_time=1.0, + start_time=1234567890.5, + end_time=1234567891.5, duration=1.0, metrics={"key": 123}, ) @@ -214,22 +274,17 @@ async def run_test(): await self.adapter.transform_async(interaction_log) - self.mock_tracer.start_as_current_span.assert_called_once_with( + self.mock_tracer.start_span.assert_called_once_with( "test_span", context=None, + start_time=1234567890500000000, ) - # We retrieve the mock span instance here - span_instance = ( - self.mock_tracer.start_as_current_span.return_value.__enter__.return_value - ) - - span_instance.set_attribute.assert_any_call("key", 123) - span_instance.set_attribute.assert_any_call("span_id", "span_1") - span_instance.set_attribute.assert_any_call("trace_id", "test_id") - span_instance.set_attribute.assert_any_call("start_time", 0.0) - span_instance.set_attribute.assert_any_call("end_time", 1.0) - span_instance.set_attribute.assert_any_call("duration", 1.0) + mock_span.set_attribute.assert_any_call("key", 123) + mock_span.set_attribute.assert_any_call("span_id", "span_1") + mock_span.set_attribute.assert_any_call("trace_id", "test_id") + mock_span.set_attribute.assert_any_call("duration", 1.0) + mock_span.end.assert_called_once_with(end_time=1234567891500000000) asyncio.run(run_test()) @@ -246,13 +301,13 @@ async def run_test(): await self.adapter.transform_async(interaction_log) - self.mock_tracer.start_as_current_span.assert_not_called() + self.mock_tracer.start_span.assert_not_called() asyncio.run(run_test()) def test_transform_async_with_tracer_failure(self): """Test async transform when tracer fails.""" - self.mock_tracer.start_as_current_span.side_effect = Exception("Tracer failure") + self.mock_tracer.start_span.side_effect = Exception("Tracer failure") async def run_test(): interaction_log = InteractionLog( @@ -264,8 +319,8 @@ async def run_test(): name="test_span", span_id="span_1", parent_id=None, - start_time=0.0, - end_time=1.0, + start_time=1234567890.0, + end_time=1234567891.0, duration=1.0, metrics={"key": 123}, )