Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
90 changes: 63 additions & 27 deletions src/ragas/metrics/_tool_call_accuracy.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,27 @@

@dataclass
class ToolCallAccuracy(MultiTurnMetric):
"""
Tool Call Accuracy metric measures how accurately an LLM agent makes tool calls
compared to reference tool calls.

The metric evaluates two aspects:
1. Sequence alignment: Whether predicted and reference tool calls match exactly in order
2. Argument accuracy: How well tool call arguments match between predicted and reference

Score calculation:
- If sequences don't align exactly: score = 0
- If sequences align: score = (average argument accuracy) * sequence_alignment_factor
- Length mismatches result in warnings and proportional penalty

Edge cases:
- No predicted tool calls: returns 0.0
- Length mismatch: compares only the overlapping portion and applies coverage penalty
- Missing arguments: contributes 0 to the argument score for that tool call

The final score is always between 0.0 and 1.0.
"""

name: str = "tool_call_accuracy"
_required_columns: t.Dict[MetricType, t.Set[str]] = field(
default_factory=lambda: {
Expand Down Expand Up @@ -55,15 +76,7 @@ async def _get_arg_score(
def is_sequence_aligned(
self, pred_sequence: t.List[str], ref_sequence: t.List[str]
) -> bool:
if len(pred_sequence) != len(ref_sequence):
return False
ref_index = 0 # Index to track position in reference sequence
for pred in pred_sequence:
if ref_index < len(ref_sequence) and pred == ref_sequence[ref_index]:
ref_index += 1
if ref_index == len(ref_sequence):
return True
return False
return pred_sequence == ref_sequence

async def _multi_turn_ascore(
self, sample: MultiTurnSample, callbacks: Callbacks
Expand All @@ -77,30 +90,53 @@ async def _multi_turn_ascore(
if isinstance(item, AIMessage) and item.tool_calls is not None:
pred_tool_calls.extend(item.tool_calls)

reference_tool_calls = sample.reference_tool_calls

# Handle edge cases
if not pred_tool_calls and not reference_tool_calls:
# Both empty - perfect match
return 1.0
elif not pred_tool_calls:
warnings.warn("No tool calls found in the user input")
return 0.0
elif not reference_tool_calls:
# Reference is empty but we have predictions - this is typically an error in test data
warnings.warn("Reference tool calls are empty but predictions exist")
return 0.0

# Check for length mismatch and warn user
if len(pred_tool_calls) != len(reference_tool_calls):
warnings.warn(
f"Length mismatch: predicted tool calls ({len(pred_tool_calls)}) "
f"vs reference tool calls ({len(reference_tool_calls)}). "
f"Only the first {min(len(pred_tool_calls), len(reference_tool_calls))} "
f"tool calls will be compared."
)

tool_call_pred_sequence = [tool_call.name for tool_call in pred_tool_calls]
tool_call_ref_sequence = [
tool_call.name for tool_call in sample.reference_tool_calls
]
tool_call_ref_sequence = [tool_call.name for tool_call in reference_tool_calls]

sequence_aligned = int(
self.is_sequence_aligned(tool_call_pred_sequence, tool_call_ref_sequence)
)

if pred_tool_calls:
score = 0.0
reference_tool_calls = sample.reference_tool_calls
for ref_tool_call in reference_tool_calls:
for pred_tool_call in pred_tool_calls:
if ref_tool_call.name == pred_tool_call.name:
arg_score = await self._get_arg_score(
pred_tool_call.args, ref_tool_call.args, callbacks
)
score += arg_score

score /= len(reference_tool_calls)
else:
warnings.warn("No tool calls found in the user input")
return 0.0
# Calculate score based on paired tool calls
score = 0.0
compared_count = 0

for ref_tool_call, pred_tool_call in zip(reference_tool_calls, pred_tool_calls):
compared_count += 1
if ref_tool_call.name == pred_tool_call.name:
arg_score = await self._get_arg_score(
pred_tool_call.args, ref_tool_call.args, callbacks
)
score += arg_score

score /= len(reference_tool_calls)

if compared_count < len(reference_tool_calls):
coverage_penalty = compared_count / len(reference_tool_calls)
score *= coverage_penalty

return score * sequence_aligned

Expand Down
72 changes: 40 additions & 32 deletions tests/unit/test_async_evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,40 +140,48 @@ class TestNestAsyncioNotAppliedInAevaluate:
@pytest.mark.asyncio
async def test_aevaluate_no_nest_asyncio_applied(self):
"""Test that aevaluate doesn't call apply_nest_asyncio."""
# Mock all the dependencies to avoid actual API calls
with patch("ragas.evaluation.EvaluationDataset"):
with patch("ragas.evaluation.validate_required_columns"):
with patch("ragas.evaluation.validate_supported_metrics"):
with patch("ragas.evaluation.Executor") as mock_executor_class:
with patch("ragas.evaluation.new_group"):
with patch(
"ragas.async_utils.apply_nest_asyncio"
) as mock_apply:
# Mock executor
mock_executor = MagicMock()
mock_executor.aresults = AsyncMock(return_value=[0.8])
mock_executor_class.return_value = mock_executor

# Mock dataset
mock_dataset_instance = MagicMock()
mock_dataset_instance.get_sample_type.return_value = (
MagicMock()
)
mock_dataset_instance.__iter__ = lambda x: iter([])

from ragas import aevaluate
with warnings.catch_warnings():
# Suppress RuntimeWarning about unawaited coroutines in tests
warnings.filterwarnings(
"ignore",
category=RuntimeWarning,
message=".*coroutine.*was never awaited",
)

try:
await aevaluate(
dataset=mock_dataset_instance,
metrics=[],
show_progress=False,
# Mock all the dependencies to avoid actual API calls
with patch("ragas.evaluation.EvaluationDataset"):
with patch("ragas.evaluation.validate_required_columns"):
with patch("ragas.evaluation.validate_supported_metrics"):
with patch("ragas.evaluation.Executor") as mock_executor_class:
with patch("ragas.evaluation.new_group"):
with patch(
"ragas.async_utils.apply_nest_asyncio"
) as mock_apply:
# Mock executor
mock_executor = MagicMock()
mock_executor.aresults = AsyncMock(
return_value=[0.8]
)
except Exception:
pass

# aevaluate should never call apply_nest_asyncio
mock_apply.assert_not_called()
mock_executor_class.return_value = mock_executor

# Mock dataset
mock_dataset_instance = MagicMock()
mock_dataset_instance.get_sample_type.return_value = MagicMock()
mock_dataset_instance.__iter__ = lambda x: iter([])

from ragas import aevaluate

try:
await aevaluate(
dataset=mock_dataset_instance,
metrics=[],
show_progress=False,
)
except Exception:
pass

# aevaluate should never call apply_nest_asyncio
mock_apply.assert_not_called()


class TestAsyncIntegration:
Expand Down
Loading
Loading