File tree Expand file tree Collapse file tree 2 files changed +4
-8
lines changed Expand file tree Collapse file tree 2 files changed +4
-8
lines changed Original file line number Diff line number Diff line change @@ -142,7 +142,7 @@ def run(self, server_url: str | None = None) -> tuple:
142
142
agg_score : float = 0.0
143
143
144
144
results = self ._run_mmlu (server_url )
145
- for task , result in results .items ():
145
+ for task , result in results [ 'results' ] .items ():
146
146
agg_score += float (result ["acc,none" ])
147
147
individual_scores [task ] = {
148
148
"score" : float (result ["acc,none" ]),
@@ -154,7 +154,7 @@ def run(self, server_url: str | None = None) -> tuple:
154
154
return overall_score , individual_scores
155
155
156
156
def _run_mmlu (
157
- self , server_url : str | None = None , return_all_results : bool = False
157
+ self , server_url : str | None = None
158
158
) -> dict :
159
159
if server_url is not None :
160
160
# Requires lm_eval >= 0.4.4
@@ -179,11 +179,7 @@ def _run_mmlu(
179
179
device = self .device ,
180
180
task_manager = tm ,
181
181
)
182
- if return_all_results :
183
- results = mmlu_output
184
- else :
185
- results = mmlu_output ["results" ]
186
- return results
182
+ return mmlu_output
187
183
188
184
# This method converts general errors from simple_evaluate
189
185
# into a more user-understandable error
Original file line number Diff line number Diff line change @@ -90,7 +90,7 @@ def run(self, server_url: str | None = None) -> tuple:
90
90
self .prepare_unitxt_files ()
91
91
logger .debug (locals ())
92
92
os .environ ["TOKENIZERS_PARALLELISM" ] = "true"
93
- results = self ._run_mmlu (server_url = server_url , return_all_results = True )
93
+ results = self ._run_mmlu (server_url = server_url )
94
94
taskname = self .tasks [0 ]
95
95
global_scores = results ["results" ][taskname ]
96
96
global_scores .pop ("alias" )
You can’t perform that action at this time.
0 commit comments