diff --git a/src/ragas/metrics/_tool_call_accuracy.py b/src/ragas/metrics/_tool_call_accuracy.py index d04cf5c9d..1cc9c4daf 100644 --- a/src/ragas/metrics/_tool_call_accuracy.py +++ b/src/ragas/metrics/_tool_call_accuracy.py @@ -15,6 +15,27 @@ @dataclass class ToolCallAccuracy(MultiTurnMetric): + """ + Tool Call Accuracy metric measures how accurately an LLM agent makes tool calls + compared to reference tool calls. + + The metric evaluates two aspects: + 1. Sequence alignment: Whether predicted and reference tool calls match exactly in order + 2. Argument accuracy: How well tool call arguments match between predicted and reference + + Score calculation: + - If sequences don't align exactly: score = 0 + - If sequences align: score = (average argument accuracy) * sequence_alignment_factor + - Length mismatches result in warnings and proportional penalty + + Edge cases: + - No predicted tool calls: returns 0.0 + - Length mismatch: compares only the overlapping portion and applies coverage penalty + - Missing arguments: contributes 0 to the argument score for that tool call + + The final score is always between 0.0 and 1.0. + """ + name: str = "tool_call_accuracy" _required_columns: t.Dict[MetricType, t.Set[str]] = field( default_factory=lambda: { @@ -55,15 +76,7 @@ async def _get_arg_score( def is_sequence_aligned( self, pred_sequence: t.List[str], ref_sequence: t.List[str] ) -> bool: - if len(pred_sequence) != len(ref_sequence): - return False - ref_index = 0 # Index to track position in reference sequence - for pred in pred_sequence: - if ref_index < len(ref_sequence) and pred == ref_sequence[ref_index]: - ref_index += 1 - if ref_index == len(ref_sequence): - return True - return False + return pred_sequence == ref_sequence async def _multi_turn_ascore( self, sample: MultiTurnSample, callbacks: Callbacks @@ -77,30 +90,53 @@ async def _multi_turn_ascore( if isinstance(item, AIMessage) and item.tool_calls is not None: pred_tool_calls.extend(item.tool_calls) + reference_tool_calls = sample.reference_tool_calls + + # Handle edge cases + if not pred_tool_calls and not reference_tool_calls: + # Both empty - perfect match + return 1.0 + elif not pred_tool_calls: + warnings.warn("No tool calls found in the user input") + return 0.0 + elif not reference_tool_calls: + # Reference is empty but we have predictions - this is typically an error in test data + warnings.warn("Reference tool calls are empty but predictions exist") + return 0.0 + + # Check for length mismatch and warn user + if len(pred_tool_calls) != len(reference_tool_calls): + warnings.warn( + f"Length mismatch: predicted tool calls ({len(pred_tool_calls)}) " + f"vs reference tool calls ({len(reference_tool_calls)}). " + f"Only the first {min(len(pred_tool_calls), len(reference_tool_calls))} " + f"tool calls will be compared." + ) + tool_call_pred_sequence = [tool_call.name for tool_call in pred_tool_calls] - tool_call_ref_sequence = [ - tool_call.name for tool_call in sample.reference_tool_calls - ] + tool_call_ref_sequence = [tool_call.name for tool_call in reference_tool_calls] sequence_aligned = int( self.is_sequence_aligned(tool_call_pred_sequence, tool_call_ref_sequence) ) - if pred_tool_calls: - score = 0.0 - reference_tool_calls = sample.reference_tool_calls - for ref_tool_call in reference_tool_calls: - for pred_tool_call in pred_tool_calls: - if ref_tool_call.name == pred_tool_call.name: - arg_score = await self._get_arg_score( - pred_tool_call.args, ref_tool_call.args, callbacks - ) - score += arg_score - - score /= len(reference_tool_calls) - else: - warnings.warn("No tool calls found in the user input") - return 0.0 + # Calculate score based on paired tool calls + score = 0.0 + compared_count = 0 + + for ref_tool_call, pred_tool_call in zip(reference_tool_calls, pred_tool_calls): + compared_count += 1 + if ref_tool_call.name == pred_tool_call.name: + arg_score = await self._get_arg_score( + pred_tool_call.args, ref_tool_call.args, callbacks + ) + score += arg_score + + score /= len(reference_tool_calls) + + if compared_count < len(reference_tool_calls): + coverage_penalty = compared_count / len(reference_tool_calls) + score *= coverage_penalty return score * sequence_aligned diff --git a/tests/unit/test_async_evaluation.py b/tests/unit/test_async_evaluation.py index 20ed71bc9..f8b40bd4b 100644 --- a/tests/unit/test_async_evaluation.py +++ b/tests/unit/test_async_evaluation.py @@ -140,40 +140,48 @@ class TestNestAsyncioNotAppliedInAevaluate: @pytest.mark.asyncio async def test_aevaluate_no_nest_asyncio_applied(self): """Test that aevaluate doesn't call apply_nest_asyncio.""" - # Mock all the dependencies to avoid actual API calls - with patch("ragas.evaluation.EvaluationDataset"): - with patch("ragas.evaluation.validate_required_columns"): - with patch("ragas.evaluation.validate_supported_metrics"): - with patch("ragas.evaluation.Executor") as mock_executor_class: - with patch("ragas.evaluation.new_group"): - with patch( - "ragas.async_utils.apply_nest_asyncio" - ) as mock_apply: - # Mock executor - mock_executor = MagicMock() - mock_executor.aresults = AsyncMock(return_value=[0.8]) - mock_executor_class.return_value = mock_executor - - # Mock dataset - mock_dataset_instance = MagicMock() - mock_dataset_instance.get_sample_type.return_value = ( - MagicMock() - ) - mock_dataset_instance.__iter__ = lambda x: iter([]) - - from ragas import aevaluate + with warnings.catch_warnings(): + # Suppress RuntimeWarning about unawaited coroutines in tests + warnings.filterwarnings( + "ignore", + category=RuntimeWarning, + message=".*coroutine.*was never awaited", + ) - try: - await aevaluate( - dataset=mock_dataset_instance, - metrics=[], - show_progress=False, + # Mock all the dependencies to avoid actual API calls + with patch("ragas.evaluation.EvaluationDataset"): + with patch("ragas.evaluation.validate_required_columns"): + with patch("ragas.evaluation.validate_supported_metrics"): + with patch("ragas.evaluation.Executor") as mock_executor_class: + with patch("ragas.evaluation.new_group"): + with patch( + "ragas.async_utils.apply_nest_asyncio" + ) as mock_apply: + # Mock executor + mock_executor = MagicMock() + mock_executor.aresults = AsyncMock( + return_value=[0.8] ) - except Exception: - pass - - # aevaluate should never call apply_nest_asyncio - mock_apply.assert_not_called() + mock_executor_class.return_value = mock_executor + + # Mock dataset + mock_dataset_instance = MagicMock() + mock_dataset_instance.get_sample_type.return_value = MagicMock() + mock_dataset_instance.__iter__ = lambda x: iter([]) + + from ragas import aevaluate + + try: + await aevaluate( + dataset=mock_dataset_instance, + metrics=[], + show_progress=False, + ) + except Exception: + pass + + # aevaluate should never call apply_nest_asyncio + mock_apply.assert_not_called() class TestAsyncIntegration: diff --git a/tests/unit/test_tool_call_accuracy.py b/tests/unit/test_tool_call_accuracy.py new file mode 100644 index 000000000..64c8be643 --- /dev/null +++ b/tests/unit/test_tool_call_accuracy.py @@ -0,0 +1,290 @@ +"""Tests for ToolCallAccuracy metric.""" + +from unittest.mock import AsyncMock + +import pytest + +from ragas.dataset_schema import MultiTurnSample +from ragas.messages import AIMessage, ToolCall +from ragas.metrics import ToolCallAccuracy + + +@pytest.fixture +def tool_call_accuracy(): + """Fixture providing ToolCallAccuracy instance.""" + return ToolCallAccuracy() + + +@pytest.fixture +def mock_callbacks(): + """Fixture providing mock callbacks.""" + return AsyncMock() + + +class TestToolCallAccuracy: + """Test cases for ToolCallAccuracy metric.""" + + def test_is_sequence_aligned_perfect_match(self, tool_call_accuracy): + """Test sequence alignment with perfect match.""" + pred_seq = ["func1", "func2", "func3"] + ref_seq = ["func1", "func2", "func3"] + assert tool_call_accuracy.is_sequence_aligned(pred_seq, ref_seq) is True + + def test_is_sequence_aligned_different_order(self, tool_call_accuracy): + """Test sequence alignment with different order.""" + pred_seq = ["func1", "func3", "func2"] + ref_seq = ["func1", "func2", "func3"] + assert tool_call_accuracy.is_sequence_aligned(pred_seq, ref_seq) is False + + def test_is_sequence_aligned_different_length(self, tool_call_accuracy): + """Test sequence alignment with different lengths.""" + pred_seq = ["func1", "func2"] + ref_seq = ["func1", "func2", "func3"] + assert tool_call_accuracy.is_sequence_aligned(pred_seq, ref_seq) is False + + def test_is_sequence_aligned_empty_sequences(self, tool_call_accuracy): + """Test sequence alignment with empty sequences.""" + assert tool_call_accuracy.is_sequence_aligned([], []) is True + + @pytest.mark.asyncio + async def test_perfect_match_scenario(self, tool_call_accuracy, mock_callbacks): + """Test perfect match scenario with identical tool calls.""" + # Create reference tool calls + ref_tool_calls = [ + ToolCall(name="search", args={"query": "python"}), + ToolCall(name="filter", args={"type": "recent"}), + ] + + # Create predicted tool calls + pred_tool_calls = [ + ToolCall(name="search", args={"query": "python"}), + ToolCall(name="filter", args={"type": "recent"}), + ] + + # Create sample + sample = MultiTurnSample( + user_input=[ + AIMessage(content="I'll search for you", tool_calls=pred_tool_calls) + ], + reference_tool_calls=ref_tool_calls, + ) + + # Mock the arg comparison to return 1.0 for perfect matches + tool_call_accuracy.arg_comparison_metric.single_turn_ascore = AsyncMock( + return_value=1.0 + ) + + score = await tool_call_accuracy._multi_turn_ascore(sample, mock_callbacks) + assert score == 1.0 + + @pytest.mark.asyncio + async def test_no_predicted_tool_calls(self, tool_call_accuracy, mock_callbacks): + """Test case with no predicted tool calls.""" + ref_tool_calls = [ToolCall(name="search", args={"query": "python"})] + + sample = MultiTurnSample( + user_input=[AIMessage(content="No tool calls here")], + reference_tool_calls=ref_tool_calls, + ) + + with pytest.warns(UserWarning, match="No tool calls found"): + score = await tool_call_accuracy._multi_turn_ascore(sample, mock_callbacks) + assert score == 0.0 + + @pytest.mark.asyncio + async def test_sequence_misalignment(self, tool_call_accuracy, mock_callbacks): + """Test case where sequences don't align.""" + ref_tool_calls = [ + ToolCall(name="search", args={"query": "python"}), + ToolCall(name="filter", args={"type": "recent"}), + ] + + # Different order - should result in score 0 due to sequence misalignment + pred_tool_calls = [ + ToolCall(name="filter", args={"type": "recent"}), + ToolCall(name="search", args={"query": "python"}), + ] + + sample = MultiTurnSample( + user_input=[AIMessage(content="Searching...", tool_calls=pred_tool_calls)], + reference_tool_calls=ref_tool_calls, + ) + + tool_call_accuracy.arg_comparison_metric.single_turn_ascore = AsyncMock( + return_value=1.0 + ) + + score = await tool_call_accuracy._multi_turn_ascore(sample, mock_callbacks) + assert score == 0.0 + + @pytest.mark.asyncio + async def test_length_mismatch_more_predicted( + self, tool_call_accuracy, mock_callbacks + ): + """Test case with more predicted tool calls than reference.""" + ref_tool_calls = [ToolCall(name="search", args={"query": "python"})] + + pred_tool_calls = [ + ToolCall(name="search", args={"query": "python"}), + ToolCall(name="filter", args={"type": "recent"}), + ] + + sample = MultiTurnSample( + user_input=[AIMessage(content="Searching...", tool_calls=pred_tool_calls)], + reference_tool_calls=ref_tool_calls, + ) + + tool_call_accuracy.arg_comparison_metric.single_turn_ascore = AsyncMock( + return_value=1.0 + ) + + with pytest.warns(UserWarning, match="Length mismatch"): + score = await tool_call_accuracy._multi_turn_ascore(sample, mock_callbacks) + + # Should be 0 because sequences don't align (different lengths) + assert score == 0.0 + + @pytest.mark.asyncio + async def test_length_mismatch_fewer_predicted( + self, tool_call_accuracy, mock_callbacks + ): + """Test case with fewer predicted tool calls than reference.""" + ref_tool_calls = [ + ToolCall(name="search", args={"query": "python"}), + ToolCall(name="filter", args={"type": "recent"}), + ] + + pred_tool_calls = [ToolCall(name="search", args={"query": "python"})] + + sample = MultiTurnSample( + user_input=[AIMessage(content="Searching...", tool_calls=pred_tool_calls)], + reference_tool_calls=ref_tool_calls, + ) + + tool_call_accuracy.arg_comparison_metric.single_turn_ascore = AsyncMock( + return_value=1.0 + ) + + with pytest.warns(UserWarning, match="Length mismatch"): + score = await tool_call_accuracy._multi_turn_ascore(sample, mock_callbacks) + + # Should be 0 because sequences don't align (different lengths) + assert score == 0.0 + + @pytest.mark.asyncio + async def test_partial_argument_match(self, tool_call_accuracy, mock_callbacks): + """Test case with partial argument matches.""" + ref_tool_calls = [ + ToolCall(name="search", args={"query": "python", "limit": 10}), + ToolCall(name="filter", args={"type": "recent"}), + ] + + pred_tool_calls = [ + ToolCall( + name="search", args={"query": "python", "limit": 5} + ), # Wrong limit + ToolCall(name="filter", args={"type": "recent"}), # Perfect match + ] + + sample = MultiTurnSample( + user_input=[AIMessage(content="Searching...", tool_calls=pred_tool_calls)], + reference_tool_calls=ref_tool_calls, + ) + + # Mock to return scores based on the argument comparison + # For the "search" tool call: we need to call for each argument + # For "python" vs "python": 1.0, for 5 vs 10: 0.0 -> average = 0.5 + # For the "filter" tool call: "recent" vs "recent": 1.0 -> average = 1.0 + tool_call_accuracy.arg_comparison_metric.single_turn_ascore = AsyncMock( + side_effect=[1.0, 0.0, 1.0] # query match, limit mismatch, type match + ) + + score = await tool_call_accuracy._multi_turn_ascore(sample, mock_callbacks) + assert score == 0.75 # (0.5 + 1.0) / 2 + + @pytest.mark.asyncio + async def test_wrong_tool_names(self, tool_call_accuracy, mock_callbacks): + """Test case with wrong tool names.""" + ref_tool_calls = [ToolCall(name="search", args={"query": "python"})] + + pred_tool_calls = [ToolCall(name="wrong_tool", args={"query": "python"})] + + sample = MultiTurnSample( + user_input=[AIMessage(content="Searching...", tool_calls=pred_tool_calls)], + reference_tool_calls=ref_tool_calls, + ) + + score = await tool_call_accuracy._multi_turn_ascore(sample, mock_callbacks) + assert score == 0.0 # Wrong tool name should result in 0 + + @pytest.mark.asyncio + async def test_multiple_ai_messages(self, tool_call_accuracy, mock_callbacks): + """Test case with multiple AI messages containing tool calls.""" + ref_tool_calls = [ + ToolCall(name="search", args={"query": "python"}), + ToolCall(name="filter", args={"type": "recent"}), + ] + + # Tool calls spread across multiple messages + sample = MultiTurnSample( + user_input=[ + AIMessage( + content="First", + tool_calls=[ToolCall(name="search", args={"query": "python"})], + ), + AIMessage( + content="Second", + tool_calls=[ToolCall(name="filter", args={"type": "recent"})], + ), + ], + reference_tool_calls=ref_tool_calls, + ) + + tool_call_accuracy.arg_comparison_metric.single_turn_ascore = AsyncMock( + return_value=1.0 + ) + + score = await tool_call_accuracy._multi_turn_ascore(sample, mock_callbacks) + assert score == 1.0 + + @pytest.mark.asyncio + async def test_empty_reference_tool_calls(self, tool_call_accuracy, mock_callbacks): + """Test case with empty reference tool calls and no predictions.""" + sample = MultiTurnSample( + user_input=[AIMessage(content="No tools needed")], + reference_tool_calls=[], + ) + + score = await tool_call_accuracy._multi_turn_ascore(sample, mock_callbacks) + assert score == 1.0 # Both empty should be perfect match + + @pytest.mark.asyncio + async def test_empty_reference_with_predictions( + self, tool_call_accuracy, mock_callbacks + ): + """Test case with empty reference but predictions exist.""" + sample = MultiTurnSample( + user_input=[ + AIMessage( + content="Calling tool", + tool_calls=[ToolCall(name="unexpected", args={})], + ) + ], + reference_tool_calls=[], + ) + + with pytest.warns(UserWarning, match="Reference tool calls are empty"): + score = await tool_call_accuracy._multi_turn_ascore(sample, mock_callbacks) + assert score == 0.0 + + def test_metric_name(self, tool_call_accuracy): + """Test that metric has correct name.""" + assert tool_call_accuracy.name == "tool_call_accuracy" + + def test_required_columns(self, tool_call_accuracy): + """Test that metric has correct required columns.""" + from ragas.metrics.base import MetricType + + required = tool_call_accuracy._required_columns[MetricType.MULTI_TURN] + assert "user_input" in required + assert "reference_tool_calls" in required