Skip to content

Commit 5df03ee

Browse files
anistarksdivye92
andauthored
Fix/tool call accuracy (#2300)
## Issue Link / Problem Description <!-- Link to related issue or describe the problem this PR solves --> - contd #2092 --------- Co-authored-by: sdivye92 <[email protected]>
1 parent a07e16e commit 5df03ee

File tree

3 files changed

+393
-59
lines changed

3 files changed

+393
-59
lines changed

src/ragas/metrics/_tool_call_accuracy.py

Lines changed: 63 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,27 @@
1515

1616
@dataclass
1717
class ToolCallAccuracy(MultiTurnMetric):
18+
"""
19+
Tool Call Accuracy metric measures how accurately an LLM agent makes tool calls
20+
compared to reference tool calls.
21+
22+
The metric evaluates two aspects:
23+
1. Sequence alignment: Whether predicted and reference tool calls match exactly in order
24+
2. Argument accuracy: How well tool call arguments match between predicted and reference
25+
26+
Score calculation:
27+
- If sequences don't align exactly: score = 0
28+
- If sequences align: score = (average argument accuracy) * sequence_alignment_factor
29+
- Length mismatches result in warnings and proportional penalty
30+
31+
Edge cases:
32+
- No predicted tool calls: returns 0.0
33+
- Length mismatch: compares only the overlapping portion and applies coverage penalty
34+
- Missing arguments: contributes 0 to the argument score for that tool call
35+
36+
The final score is always between 0.0 and 1.0.
37+
"""
38+
1839
name: str = "tool_call_accuracy"
1940
_required_columns: t.Dict[MetricType, t.Set[str]] = field(
2041
default_factory=lambda: {
@@ -55,15 +76,7 @@ async def _get_arg_score(
5576
def is_sequence_aligned(
5677
self, pred_sequence: t.List[str], ref_sequence: t.List[str]
5778
) -> bool:
58-
if len(pred_sequence) != len(ref_sequence):
59-
return False
60-
ref_index = 0 # Index to track position in reference sequence
61-
for pred in pred_sequence:
62-
if ref_index < len(ref_sequence) and pred == ref_sequence[ref_index]:
63-
ref_index += 1
64-
if ref_index == len(ref_sequence):
65-
return True
66-
return False
79+
return pred_sequence == ref_sequence
6780

6881
async def _multi_turn_ascore(
6982
self, sample: MultiTurnSample, callbacks: Callbacks
@@ -77,30 +90,53 @@ async def _multi_turn_ascore(
7790
if isinstance(item, AIMessage) and item.tool_calls is not None:
7891
pred_tool_calls.extend(item.tool_calls)
7992

93+
reference_tool_calls = sample.reference_tool_calls
94+
95+
# Handle edge cases
96+
if not pred_tool_calls and not reference_tool_calls:
97+
# Both empty - perfect match
98+
return 1.0
99+
elif not pred_tool_calls:
100+
warnings.warn("No tool calls found in the user input")
101+
return 0.0
102+
elif not reference_tool_calls:
103+
# Reference is empty but we have predictions - this is typically an error in test data
104+
warnings.warn("Reference tool calls are empty but predictions exist")
105+
return 0.0
106+
107+
# Check for length mismatch and warn user
108+
if len(pred_tool_calls) != len(reference_tool_calls):
109+
warnings.warn(
110+
f"Length mismatch: predicted tool calls ({len(pred_tool_calls)}) "
111+
f"vs reference tool calls ({len(reference_tool_calls)}). "
112+
f"Only the first {min(len(pred_tool_calls), len(reference_tool_calls))} "
113+
f"tool calls will be compared."
114+
)
115+
80116
tool_call_pred_sequence = [tool_call.name for tool_call in pred_tool_calls]
81-
tool_call_ref_sequence = [
82-
tool_call.name for tool_call in sample.reference_tool_calls
83-
]
117+
tool_call_ref_sequence = [tool_call.name for tool_call in reference_tool_calls]
84118

85119
sequence_aligned = int(
86120
self.is_sequence_aligned(tool_call_pred_sequence, tool_call_ref_sequence)
87121
)
88122

89-
if pred_tool_calls:
90-
score = 0.0
91-
reference_tool_calls = sample.reference_tool_calls
92-
for ref_tool_call in reference_tool_calls:
93-
for pred_tool_call in pred_tool_calls:
94-
if ref_tool_call.name == pred_tool_call.name:
95-
arg_score = await self._get_arg_score(
96-
pred_tool_call.args, ref_tool_call.args, callbacks
97-
)
98-
score += arg_score
99-
100-
score /= len(reference_tool_calls)
101-
else:
102-
warnings.warn("No tool calls found in the user input")
103-
return 0.0
123+
# Calculate score based on paired tool calls
124+
score = 0.0
125+
compared_count = 0
126+
127+
for ref_tool_call, pred_tool_call in zip(reference_tool_calls, pred_tool_calls):
128+
compared_count += 1
129+
if ref_tool_call.name == pred_tool_call.name:
130+
arg_score = await self._get_arg_score(
131+
pred_tool_call.args, ref_tool_call.args, callbacks
132+
)
133+
score += arg_score
134+
135+
score /= len(reference_tool_calls)
136+
137+
if compared_count < len(reference_tool_calls):
138+
coverage_penalty = compared_count / len(reference_tool_calls)
139+
score *= coverage_penalty
104140

105141
return score * sequence_aligned
106142

tests/unit/test_async_evaluation.py

Lines changed: 40 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -140,40 +140,48 @@ class TestNestAsyncioNotAppliedInAevaluate:
140140
@pytest.mark.asyncio
141141
async def test_aevaluate_no_nest_asyncio_applied(self):
142142
"""Test that aevaluate doesn't call apply_nest_asyncio."""
143-
# Mock all the dependencies to avoid actual API calls
144-
with patch("ragas.evaluation.EvaluationDataset"):
145-
with patch("ragas.evaluation.validate_required_columns"):
146-
with patch("ragas.evaluation.validate_supported_metrics"):
147-
with patch("ragas.evaluation.Executor") as mock_executor_class:
148-
with patch("ragas.evaluation.new_group"):
149-
with patch(
150-
"ragas.async_utils.apply_nest_asyncio"
151-
) as mock_apply:
152-
# Mock executor
153-
mock_executor = MagicMock()
154-
mock_executor.aresults = AsyncMock(return_value=[0.8])
155-
mock_executor_class.return_value = mock_executor
156-
157-
# Mock dataset
158-
mock_dataset_instance = MagicMock()
159-
mock_dataset_instance.get_sample_type.return_value = (
160-
MagicMock()
161-
)
162-
mock_dataset_instance.__iter__ = lambda x: iter([])
163-
164-
from ragas import aevaluate
143+
with warnings.catch_warnings():
144+
# Suppress RuntimeWarning about unawaited coroutines in tests
145+
warnings.filterwarnings(
146+
"ignore",
147+
category=RuntimeWarning,
148+
message=".*coroutine.*was never awaited",
149+
)
165150

166-
try:
167-
await aevaluate(
168-
dataset=mock_dataset_instance,
169-
metrics=[],
170-
show_progress=False,
151+
# Mock all the dependencies to avoid actual API calls
152+
with patch("ragas.evaluation.EvaluationDataset"):
153+
with patch("ragas.evaluation.validate_required_columns"):
154+
with patch("ragas.evaluation.validate_supported_metrics"):
155+
with patch("ragas.evaluation.Executor") as mock_executor_class:
156+
with patch("ragas.evaluation.new_group"):
157+
with patch(
158+
"ragas.async_utils.apply_nest_asyncio"
159+
) as mock_apply:
160+
# Mock executor
161+
mock_executor = MagicMock()
162+
mock_executor.aresults = AsyncMock(
163+
return_value=[0.8]
171164
)
172-
except Exception:
173-
pass
174-
175-
# aevaluate should never call apply_nest_asyncio
176-
mock_apply.assert_not_called()
165+
mock_executor_class.return_value = mock_executor
166+
167+
# Mock dataset
168+
mock_dataset_instance = MagicMock()
169+
mock_dataset_instance.get_sample_type.return_value = MagicMock()
170+
mock_dataset_instance.__iter__ = lambda x: iter([])
171+
172+
from ragas import aevaluate
173+
174+
try:
175+
await aevaluate(
176+
dataset=mock_dataset_instance,
177+
metrics=[],
178+
show_progress=False,
179+
)
180+
except Exception:
181+
pass
182+
183+
# aevaluate should never call apply_nest_asyncio
184+
mock_apply.assert_not_called()
177185

178186

179187
class TestAsyncIntegration:

0 commit comments

Comments
 (0)