15
15
16
16
@dataclass
17
17
class ToolCallAccuracy (MultiTurnMetric ):
18
+ """
19
+ Tool Call Accuracy metric measures how accurately an LLM agent makes tool calls
20
+ compared to reference tool calls.
21
+
22
+ The metric evaluates two aspects:
23
+ 1. Sequence alignment: Whether predicted and reference tool calls match exactly in order
24
+ 2. Argument accuracy: How well tool call arguments match between predicted and reference
25
+
26
+ Score calculation:
27
+ - If sequences don't align exactly: score = 0
28
+ - If sequences align: score = (average argument accuracy) * sequence_alignment_factor
29
+ - Length mismatches result in warnings and proportional penalty
30
+
31
+ Edge cases:
32
+ - No predicted tool calls: returns 0.0
33
+ - Length mismatch: compares only the overlapping portion and applies coverage penalty
34
+ - Missing arguments: contributes 0 to the argument score for that tool call
35
+
36
+ The final score is always between 0.0 and 1.0.
37
+ """
38
+
18
39
name : str = "tool_call_accuracy"
19
40
_required_columns : t .Dict [MetricType , t .Set [str ]] = field (
20
41
default_factory = lambda : {
@@ -55,15 +76,7 @@ async def _get_arg_score(
55
76
def is_sequence_aligned (
56
77
self , pred_sequence : t .List [str ], ref_sequence : t .List [str ]
57
78
) -> bool :
58
- if len (pred_sequence ) != len (ref_sequence ):
59
- return False
60
- ref_index = 0 # Index to track position in reference sequence
61
- for pred in pred_sequence :
62
- if ref_index < len (ref_sequence ) and pred == ref_sequence [ref_index ]:
63
- ref_index += 1
64
- if ref_index == len (ref_sequence ):
65
- return True
66
- return False
79
+ return pred_sequence == ref_sequence
67
80
68
81
async def _multi_turn_ascore (
69
82
self , sample : MultiTurnSample , callbacks : Callbacks
@@ -77,30 +90,53 @@ async def _multi_turn_ascore(
77
90
if isinstance (item , AIMessage ) and item .tool_calls is not None :
78
91
pred_tool_calls .extend (item .tool_calls )
79
92
93
+ reference_tool_calls = sample .reference_tool_calls
94
+
95
+ # Handle edge cases
96
+ if not pred_tool_calls and not reference_tool_calls :
97
+ # Both empty - perfect match
98
+ return 1.0
99
+ elif not pred_tool_calls :
100
+ warnings .warn ("No tool calls found in the user input" )
101
+ return 0.0
102
+ elif not reference_tool_calls :
103
+ # Reference is empty but we have predictions - this is typically an error in test data
104
+ warnings .warn ("Reference tool calls are empty but predictions exist" )
105
+ return 0.0
106
+
107
+ # Check for length mismatch and warn user
108
+ if len (pred_tool_calls ) != len (reference_tool_calls ):
109
+ warnings .warn (
110
+ f"Length mismatch: predicted tool calls ({ len (pred_tool_calls )} ) "
111
+ f"vs reference tool calls ({ len (reference_tool_calls )} ). "
112
+ f"Only the first { min (len (pred_tool_calls ), len (reference_tool_calls ))} "
113
+ f"tool calls will be compared."
114
+ )
115
+
80
116
tool_call_pred_sequence = [tool_call .name for tool_call in pred_tool_calls ]
81
- tool_call_ref_sequence = [
82
- tool_call .name for tool_call in sample .reference_tool_calls
83
- ]
117
+ tool_call_ref_sequence = [tool_call .name for tool_call in reference_tool_calls ]
84
118
85
119
sequence_aligned = int (
86
120
self .is_sequence_aligned (tool_call_pred_sequence , tool_call_ref_sequence )
87
121
)
88
122
89
- if pred_tool_calls :
90
- score = 0.0
91
- reference_tool_calls = sample .reference_tool_calls
92
- for ref_tool_call in reference_tool_calls :
93
- for pred_tool_call in pred_tool_calls :
94
- if ref_tool_call .name == pred_tool_call .name :
95
- arg_score = await self ._get_arg_score (
96
- pred_tool_call .args , ref_tool_call .args , callbacks
97
- )
98
- score += arg_score
99
-
100
- score /= len (reference_tool_calls )
101
- else :
102
- warnings .warn ("No tool calls found in the user input" )
103
- return 0.0
123
+ # Calculate score based on paired tool calls
124
+ score = 0.0
125
+ compared_count = 0
126
+
127
+ for ref_tool_call , pred_tool_call in zip (reference_tool_calls , pred_tool_calls ):
128
+ compared_count += 1
129
+ if ref_tool_call .name == pred_tool_call .name :
130
+ arg_score = await self ._get_arg_score (
131
+ pred_tool_call .args , ref_tool_call .args , callbacks
132
+ )
133
+ score += arg_score
134
+
135
+ score /= len (reference_tool_calls )
136
+
137
+ if compared_count < len (reference_tool_calls ):
138
+ coverage_penalty = compared_count / len (reference_tool_calls )
139
+ score *= coverage_penalty
104
140
105
141
return score * sequence_aligned
106
142
0 commit comments