Skip to content

Commit c3acd6f

Browse files
add async evaluate
1 parent 75defa7 commit c3acd6f

File tree

2 files changed

+135
-39
lines changed

2 files changed

+135
-39
lines changed

dspy/evaluate/evaluate.py

Lines changed: 103 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@
66
if TYPE_CHECKING:
77
import pandas as pd
88

9+
import asyncio
10+
911
import tqdm
1012

1113
import dspy
@@ -51,6 +53,7 @@ class EvaluationResult(Prediction):
5153
- score: An float value (e.g., 67.30) representing the overall performance
5254
- results: a list of (example, prediction, score) tuples for each example in devset
5355
"""
56+
5457
def __init__(self, score: float, results: list[Tuple["dspy.Example", "dspy.Example", Any]]):
5558
super().__init__(score=score, results=results)
5659

@@ -126,71 +129,132 @@ def __call__(
126129
127130
Returns:
128131
The evaluation results are returned as a dspy.EvaluationResult object containing the following attributes:
129-
132+
130133
- score: A float percentage score (e.g., 67.30) representing overall performance
131-
134+
132135
- results: a list of (example, prediction, score) tuples for each example in devset
133136
"""
134-
metric = metric if metric is not None else self.metric
135-
devset = devset if devset is not None else self.devset
136-
num_threads = num_threads if num_threads is not None else self.num_threads
137-
display_progress = display_progress if display_progress is not None else self.display_progress
138-
display_table = display_table if display_table is not None else self.display_table
139-
137+
metric, devset, num_threads, display_progress, display_table = self._resolve_call_args(
138+
metric, devset, num_threads, display_progress, display_table
139+
)
140140
if callback_metadata:
141141
logger.debug(f"Evaluate is called with callback metadata: {callback_metadata}")
142-
143142
tqdm.tqdm._instances.clear()
143+
results = self._execute_with_multithreading(program, metric, devset, num_threads, display_progress)
144+
return self._process_evaluate_result(devset, results, metric, display_table)
144145

145-
executor = ParallelExecutor(
146-
num_threads=num_threads,
147-
disable_progress_bar=not display_progress,
148-
max_errors=(
149-
self.max_errors
150-
if self.max_errors is not None
151-
else dspy.settings.max_errors
152-
),
153-
provide_traceback=self.provide_traceback,
154-
compare_results=True,
146+
@with_callbacks
147+
async def acall(
148+
self,
149+
program: "dspy.Module",
150+
metric: Callable | None = None,
151+
devset: List["dspy.Example"] | None = None,
152+
num_threads: int | None = None,
153+
display_progress: bool | None = None,
154+
display_table: bool | int | None = None,
155+
callback_metadata: dict[str, Any] | None = None,
156+
) -> EvaluationResult:
157+
"""Async version of `Evaluate.__call__`."""
158+
metric, devset, num_threads, display_progress, display_table = self._resolve_call_args(
159+
metric, devset, num_threads, display_progress, display_table
160+
)
161+
if callback_metadata:
162+
logger.debug(f"Evaluate.acall is called with callback metadata: {callback_metadata}")
163+
tqdm.tqdm._instances.clear()
164+
results = await self._execute_with_event_loop(program, metric, devset, num_threads)
165+
return self._process_evaluate_result(devset, results, metric, display_table)
166+
167+
def _resolve_call_args(self, metric, devset, num_threads, display_progress, display_table):
168+
return (
169+
metric or self.metric,
170+
devset or self.devset,
171+
num_threads or self.num_threads,
172+
display_progress or self.display_progress,
173+
display_table or self.display_table,
155174
)
156175

157-
def process_item(example):
158-
prediction = program(**example.inputs())
159-
score = metric(example, prediction)
160-
161-
# Increment assert and suggest failures to program's attributes
162-
if hasattr(program, "_assert_failures"):
163-
program._assert_failures += dspy.settings.get("assert_failures")
164-
if hasattr(program, "_suggest_failures"):
165-
program._suggest_failures += dspy.settings.get("suggest_failures")
166-
167-
return prediction, score
168-
169-
results = executor.execute(process_item, devset)
176+
def _process_evaluate_result(self, devset, results, metric, display_table):
170177
assert len(devset) == len(results)
171-
172178
results = [((dspy.Prediction(), self.failure_score) if r is None else r) for r in results]
173179
results = [(example, prediction, score) for example, (prediction, score) in zip(devset, results, strict=False)]
174180
ncorrect, ntotal = sum(score for *_, score in results), len(devset)
175-
176181
logger.info(f"Average Metric: {ncorrect} / {ntotal} ({round(100 * ncorrect / ntotal, 1)}%)")
177-
178182
if display_table:
179183
if importlib.util.find_spec("pandas") is not None:
180-
# Rename the 'correct' column to the name of the metric object
181184
metric_name = metric.__name__ if isinstance(metric, types.FunctionType) else metric.__class__.__name__
182-
# Construct a pandas DataFrame from the results
183185
result_df = self._construct_result_table(results, metric_name)
184-
185186
self._display_result_table(result_df, display_table, metric_name)
186187
else:
187188
logger.warning("Skipping table display since `pandas` is not installed.")
188-
189189
return EvaluationResult(
190190
score=round(100 * ncorrect / ntotal, 2),
191191
results=results,
192192
)
193193

194+
def _execute_with_multithreading(
195+
self,
196+
program: "dspy.Module",
197+
metric: Callable,
198+
devset: List["dspy.Example"],
199+
num_threads: int,
200+
disable_progress_bar: bool,
201+
):
202+
executor = ParallelExecutor(
203+
num_threads=num_threads,
204+
disable_progress_bar=disable_progress_bar,
205+
max_errors=(self.max_errors or dspy.settings.max_errors),
206+
provide_traceback=self.provide_traceback,
207+
compare_results=True,
208+
)
209+
210+
def process_item(example):
211+
prediction = program(**example.inputs())
212+
score = metric(example, prediction)
213+
return prediction, score
214+
215+
return executor.execute(process_item, devset)
216+
217+
async def _execute_with_event_loop(
218+
self,
219+
program: "dspy.Module",
220+
metric: Callable,
221+
devset: List["dspy.Example"],
222+
num_threads: int,
223+
):
224+
queue = asyncio.Queue()
225+
results = [None for _ in range(len(devset))]
226+
for i, example in enumerate(devset):
227+
await queue.put((i, example))
228+
229+
for _ in range(num_threads):
230+
# Add a sentinel value to indicate that the worker should exit
231+
await queue.put((-1, None))
232+
233+
# Create tqdm progress bar
234+
pbar = tqdm.tqdm(total=len(devset), dynamic_ncols=True)
235+
236+
async def worker():
237+
while True:
238+
index, example = await queue.get()
239+
if index == -1:
240+
break
241+
prediction = await program.acall(**example.inputs())
242+
score = metric(example, prediction)
243+
results[index] = (prediction, score)
244+
245+
vals = [r[-1] for r in results if r is not None]
246+
nresults = sum(vals)
247+
ntotal = len(vals)
248+
pct = round(100 * nresults / ntotal, 1) if ntotal else 0
249+
pbar.set_description(f"Average Metric: {nresults:.2f} / {ntotal} ({pct}%)")
250+
pbar.update(1)
251+
queue.task_done()
252+
253+
workers = [asyncio.create_task(worker()) for _ in range(num_threads)]
254+
await asyncio.gather(*workers)
255+
pbar.close()
256+
257+
return results
194258

195259
def _construct_result_table(
196260
self, results: list[Tuple["dspy.Example", "dspy.Example", Any]], metric_name: str

tests/evaluate/test_evaluate.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import asyncio
12
import signal
23
import threading
34
from unittest.mock import patch
@@ -57,6 +58,7 @@ def test_evaluate_call():
5758
@pytest.mark.extra
5859
def test_construct_result_df():
5960
import pandas as pd
61+
6062
devset = [new_example("What is 1+1?", "2"), new_example("What is 2+2?", "4")]
6163
ev = Evaluate(
6264
devset=devset,
@@ -253,6 +255,36 @@ def on_evaluate_end(
253255
assert callback.end_call_outputs.score == 100.0
254256
assert callback.end_call_count == 1
255257

258+
256259
def test_evaluation_result_repr():
257260
result = EvaluationResult(score=100.0, results=[(new_example("What is 1+1?", "2"), {"answer": "2"}, 100.0)])
258261
assert repr(result) == "EvaluationResult(score=100.0, results=<list of 1 results>)"
262+
263+
264+
@pytest.mark.asyncio
265+
async def test_async_evaluate_call():
266+
class MyProgram(dspy.Module):
267+
async def acall(self, question):
268+
# Simulate network delay
269+
await asyncio.sleep(0.05)
270+
# Just echo the input as the prediction
271+
return dspy.Prediction(answer=question)
272+
273+
def dummy_metric(gold, pred, traces=None):
274+
return gold.answer == pred.answer
275+
276+
devset = [new_example(f"placeholder_{i}", f"placeholder_{i}") for i in range(20)]
277+
278+
evaluator = Evaluate(devset=devset, metric=dummy_metric, num_threads=10, display_progress=False)
279+
280+
result = await evaluator.acall(MyProgram())
281+
assert isinstance(result.score, float)
282+
assert result.score == 100.0 # Both answers should match
283+
assert len(result.results) == 20
284+
285+
# Check the results are in the correct order
286+
for i, (example, prediction, score) in enumerate(result.results):
287+
assert isinstance(prediction, dspy.Prediction)
288+
assert score == 1
289+
assert example.question == f"placeholder_{i}"
290+
assert prediction.answer == f"placeholder_{i}"

0 commit comments

Comments
 (0)