diff --git a/docs/concepts/metrics/available_metrics/agents.md b/docs/concepts/metrics/available_metrics/agents.md index 156475a5a..3aee5c7ce 100644 --- a/docs/concepts/metrics/available_metrics/agents.md +++ b/docs/concepts/metrics/available_metrics/agents.md @@ -122,6 +122,107 @@ from ragas.metrics._tool_call_accuracy import ToolCallAccuracy metric = ToolCallAccuracy() metric.arg_comparison_metric = NonLLMStringSimilarity() ``` +## Tool Call F1 + +`ToolCallF1` is a metric that return F1-score based on precision and recall of tool calls made by an agent, comparing them to a set of expected calls (`reference_tool_calls`). While `ToolCallAccuracy` provides a binary score based on exact order and content match, `ToolCallF1` complements it by offering a softer evaluation useful for onboarding and iteration. It helps quantify how close the agent was to the expected behavior even if it over- or under-calls. + +### Formula + +ToolCallF1 is based on classic IR metrics. It uses unordered matching: the order in which the tools are called does not impact the result, only the presence and correctness of tool names and parameters are considered. + +$$ +\text{Precision} = \frac{\text{tool calls that match both name and parameters}}{\text{tool calls that match both name and parameters} + \text{extra tool calls that were not expected}} +$$ + +$$ +\text{Recall} = \frac{\text{tool calls that match both name and parameters}}{\text{tool calls that match both name and parameters} + \text{expected tool calls that were not made}} +$$ + +$$ +\text{F1} = \frac{2 \cdot \text{Precision} \cdot \text{Recall}}{\text{Precision} + \text{Recall}} +$$ + +### How is it different from Topic Adherence? + +While both `ToolCallF1` and `TopicAdherenceScore` uses precision, recall, and F1-score, they evaluate different aspects: + +| Metric | Evaluates | Based on | +| --------------------- | --------------------------------------- | ---------------------------- | +| `ToolCallF1` | Correctness of tool executions | Structured tool call objects | +| `TopicAdherenceScore` | Whether the conversation stays on-topic | Comparison of domain topics | + +Use `ToolCallF1` when you want to track whether the agent correctly **executed tools**. Use `TopicAdherenceScore` when evaluating whether the **content or intention** stays within allowed topics. + +### Example: Matching Expected Tool Calls + +```python +from ragas.metrics import ToolCallF1 +from ragas.dataset_schema import MultiTurnSample +from ragas.messages import HumanMessage, AIMessage, ToolMessage, ToolCall + +sample = [ + HumanMessage(content="What's the weather like in Paris today?"), + AIMessage(content="Let me check that for you.", tool_calls=[ + ToolCall(name="weather_check", args={"location": "Paris"}) + ]), + HumanMessage(content="And the UV index?"), + AIMessage(content="Sure, here's the UV index for Paris.", tool_calls=[ + ToolCall(name="uv_index_lookup", args={"location": "Paris"}) + ]) +] + +sample = MultiTurnSample( + user_input=sample, + reference_tool_calls=[ + ToolCall(name="weather_check", args={"location": "Paris"}), + ToolCall(name="uv_index_lookup", args={"location": "Paris"}) + ] +) + +scorer = ToolCallF1() +await scorer.multi_turn_ascore(sample) +``` + +Output + +``` +1.0 +``` + +### Example: Extra Tool Called + +```python +sample = [ + HumanMessage(content="What's the weather like in Paris today?"), + AIMessage(content="Let me check that for you.", tool_calls=[ + ToolCall(name="weather_check", args={"location": "Paris"}) + ]), + HumanMessage(content="And the UV index?"), + AIMessage(content="Sure, here's the UV index for Paris.", tool_calls=[ + ToolCall(name="uv_index_lookup", args={"location": "Paris"}), + ToolCall(name="air_quality", args={"location": "Paris"}) # extra call + ]) +] + +sample = MultiTurnSample( + user_input=sample, + reference_tool_calls=[ + ToolCall(name="uv_index_lookup", args={"location": "Paris"}), + ToolCall(name="weather_check", args={"location": "Paris"}) + ] +) + +await scorer.multi_turn_ascore(sample) +``` + +Output + +``` +0.67 +``` + +In this case, the agent calls both correct tools but adds an extra `air_quality` call. The F1-score reflects partial correctness instead of failing the example completely. + ## Agent Goal accuracy diff --git a/ragas/src/ragas/metrics/__init__.py b/ragas/src/ragas/metrics/__init__.py index 381203031..1e4c0692d 100644 --- a/ragas/src/ragas/metrics/__init__.py +++ b/ragas/src/ragas/metrics/__init__.py @@ -63,6 +63,7 @@ ) from ragas.metrics._summarization import SummarizationScore, summarization_score from ragas.metrics._tool_call_accuracy import ToolCallAccuracy +from ragas.metrics._tool_call_f1 import ToolCallF1 from ragas.metrics._topic_adherence import TopicAdherenceScore from ragas.metrics.base import ( Metric, @@ -126,6 +127,7 @@ "LLMSQLEquivalence", "AgentGoalAccuracyWithoutReference", "AgentGoalAccuracyWithReference", + "ToolCallF1", "ToolCallAccuracy", "ResponseRelevancy", "SemanticSimilarity", diff --git a/ragas/src/ragas/metrics/_tool_call_f1.py b/ragas/src/ragas/metrics/_tool_call_f1.py new file mode 100644 index 000000000..97cd21387 --- /dev/null +++ b/ragas/src/ragas/metrics/_tool_call_f1.py @@ -0,0 +1,60 @@ +from __future__ import annotations + +import typing as t +from dataclasses import dataclass, field + +from ragas.metrics.base import MultiTurnMetric, MetricType +from ragas.dataset_schema import MultiTurnSample +from ragas.messages import AIMessage + +if t.TYPE_CHECKING: + from langchain_core.callbacks.base import Callbacks + + +@dataclass +class ToolCallF1(MultiTurnMetric): + name: str = "tool_call_f1" + batch_size: int = 1 + is_multi_turn: bool = True + _required_columns: t.Dict[MetricType, t.Set[str]] = field( + default_factory=lambda: { + MetricType.MULTI_TURN: { + "reference_tool_calls", + "user_input", + } + } + ) + + def init(self, run_config): + pass + + async def _multi_turn_ascore( + self, sample: MultiTurnSample, callbacks: t.Optional[Callbacks] = None + ) -> float: + expected: set[tuple[str, frozenset]] = set() + if sample.reference_tool_calls: + for call in sample.reference_tool_calls: + expected.add((call.name, frozenset(call.args.items()))) + + actual: set[tuple[str, frozenset]] = set() + for msg in sample.user_input: + if isinstance(msg, AIMessage) and msg.tool_calls is not None: + for call in msg.tool_calls: + actual.add((call.name, frozenset(call.args.items()))) + + tp = len(actual & expected) + fp = len(actual - expected) + fn = len(expected - actual) + + precision = tp / (tp + fp) if (tp + fp) > 0 else 0.0 + recall = tp / (tp + fn) if (tp + fn) > 0 else 0.0 + f1 = ( + 2 * precision * recall / (precision + recall) + if (precision + recall) > 0 + else 0.0 + ) + + return round(f1, 4) + + async def _ascore(self, row: t.Dict, callbacks: Callbacks) -> float: + return await self._multi_turn_ascore(MultiTurnSample(**row), callbacks) diff --git a/ragas/tests/unit/test_tool_call_f1.py b/ragas/tests/unit/test_tool_call_f1.py new file mode 100644 index 000000000..e9575b7c0 --- /dev/null +++ b/ragas/tests/unit/test_tool_call_f1.py @@ -0,0 +1,61 @@ +import pytest +from ragas.messages import ToolCall, AIMessage, HumanMessage +from ragas import MultiTurnSample +from ragas.metrics import ToolCallF1 + +metric = ToolCallF1() + + +def make_sample(expected, predicted): + return MultiTurnSample( + user_input=[ + HumanMessage(content="What is the weather in Paris?"), + AIMessage( + content="Let me check the weather forecast", tool_calls=predicted + ), + ], + reference_tool_calls=expected, + reference="Expected correct weather tool call", + ) + + +@pytest.mark.asyncio +async def test_tool_call_f1_full_match(): + expected = [ToolCall(name="WeatherForecast", args={"location": "Paris"})] + predicted = [ToolCall(name="WeatherForecast", args={"location": "Paris"})] + sample = make_sample(expected, predicted) + score = await metric._multi_turn_ascore(sample) + assert score == 1.0 + + +@pytest.mark.asyncio +async def test_tool_call_f1_partial_match(): + expected = [ + ToolCall(name="WeatherForecast", args={"location": "Paris"}), + ToolCall(name="UVIndex", args={"location": "Paris"}), + ] + predicted = [ToolCall(name="WeatherForecast", args={"location": "Paris"})] + sample = make_sample(expected, predicted) + score = await metric._multi_turn_ascore(sample) + assert round(score, 2) == 0.67 + + +@pytest.mark.asyncio +async def test_tool_call_f1_no_match(): + expected = [ToolCall(name="WeatherForecast", args={"location": "Paris"})] + predicted = [ToolCall(name="AirQuality", args={"location": "Paris"})] + sample = make_sample(expected, predicted) + score = await metric._multi_turn_ascore(sample) + assert score == 0.0 + + +@pytest.mark.asyncio +async def test_tool_call_f1_extra_call(): + expected = [ToolCall(name="WeatherForecast", args={"location": "Paris"})] + predicted = [ + ToolCall(name="WeatherForecast", args={"location": "Paris"}), + ToolCall(name="AirQuality", args={"location": "Paris"}), + ] + sample = make_sample(expected, predicted) + score = await metric._multi_turn_ascore(sample) + assert round(score, 2) == 0.67