diff --git a/dspy/evaluate/evaluate.py b/dspy/evaluate/evaluate.py index 3203b7293..cd578d4e4 100644 --- a/dspy/evaluate/evaluate.py +++ b/dspy/evaluate/evaluate.py @@ -6,6 +6,7 @@ if TYPE_CHECKING: import pandas as pd + import tqdm import dspy @@ -51,6 +52,7 @@ class EvaluationResult(Prediction): - score: An float value (e.g., 67.30) representing the overall performance - results: a list of (example, prediction, score) tuples for each example in devset """ + def __init__(self, score: float, results: list[tuple["dspy.Example", "dspy.Example", Any]]): super().__init__(score=score, results=results) @@ -126,71 +128,113 @@ def __call__( Returns: The evaluation results are returned as a dspy.EvaluationResult object containing the following attributes: - + - score: A float percentage score (e.g., 67.30) representing overall performance - + - results: a list of (example, prediction, score) tuples for each example in devset """ - metric = metric if metric is not None else self.metric - devset = devset if devset is not None else self.devset - num_threads = num_threads if num_threads is not None else self.num_threads - display_progress = display_progress if display_progress is not None else self.display_progress - display_table = display_table if display_table is not None else self.display_table - + metric, devset, num_threads, display_progress, display_table = self._resolve_call_args( + metric, devset, num_threads, display_progress, display_table + ) if callback_metadata: logger.debug(f"Evaluate is called with callback metadata: {callback_metadata}") - tqdm.tqdm._instances.clear() + results = self._execute_with_multithreading(program, metric, devset, num_threads, display_progress) + return self._process_evaluate_result(devset, results, metric, display_table) - executor = ParallelExecutor( - num_threads=num_threads, - disable_progress_bar=not display_progress, - max_errors=( - self.max_errors - if self.max_errors is not None - else dspy.settings.max_errors - ), - provide_traceback=self.provide_traceback, - compare_results=True, + @with_callbacks + async def acall( + self, + program: "dspy.Module", + metric: Callable | None = None, + devset: list["dspy.Example"] | None = None, + num_threads: int | None = None, + display_progress: bool | None = None, + display_table: bool | int | None = None, + callback_metadata: dict[str, Any] | None = None, + ) -> EvaluationResult: + """Async version of `Evaluate.__call__`.""" + metric, devset, num_threads, display_progress, display_table = self._resolve_call_args( + metric, devset, num_threads, display_progress, display_table + ) + if callback_metadata: + logger.debug(f"Evaluate.acall is called with callback metadata: {callback_metadata}") + tqdm.tqdm._instances.clear() + results = await self._execute_with_event_loop(program, metric, devset, num_threads, display_progress) + return self._process_evaluate_result(devset, results, metric, display_table) + + def _resolve_call_args(self, metric, devset, num_threads, display_progress, display_table): + return ( + metric or self.metric, + devset or self.devset, + num_threads or self.num_threads, + display_progress or self.display_progress, + display_table or self.display_table, ) - def process_item(example): - prediction = program(**example.inputs()) - score = metric(example, prediction) - - # Increment assert and suggest failures to program's attributes - if hasattr(program, "_assert_failures"): - program._assert_failures += dspy.settings.get("assert_failures") - if hasattr(program, "_suggest_failures"): - program._suggest_failures += dspy.settings.get("suggest_failures") - - return prediction, score - - results = executor.execute(process_item, devset) + def _process_evaluate_result(self, devset, results, metric, display_table): assert len(devset) == len(results) - results = [((dspy.Prediction(), self.failure_score) if r is None else r) for r in results] results = [(example, prediction, score) for example, (prediction, score) in zip(devset, results, strict=False)] ncorrect, ntotal = sum(score for *_, score in results), len(devset) - logger.info(f"Average Metric: {ncorrect} / {ntotal} ({round(100 * ncorrect / ntotal, 1)}%)") - if display_table: if importlib.util.find_spec("pandas") is not None: - # Rename the 'correct' column to the name of the metric object metric_name = metric.__name__ if isinstance(metric, types.FunctionType) else metric.__class__.__name__ - # Construct a pandas DataFrame from the results result_df = self._construct_result_table(results, metric_name) - self._display_result_table(result_df, display_table, metric_name) else: logger.warning("Skipping table display since `pandas` is not installed.") - return EvaluationResult( score=round(100 * ncorrect / ntotal, 2), results=results, ) + def _execute_with_multithreading( + self, + program: "dspy.Module", + metric: Callable, + devset: list["dspy.Example"], + num_threads: int, + disable_progress_bar: bool, + ): + executor = ParallelExecutor( + num_threads=num_threads, + disable_progress_bar=disable_progress_bar, + max_errors=(self.max_errors or dspy.settings.max_errors), + provide_traceback=self.provide_traceback, + compare_results=True, + ) + + def process_item(example): + prediction = program(**example.inputs()) + score = metric(example, prediction) + return prediction, score + + return executor.execute(process_item, devset) + + async def _execute_with_event_loop( + self, + program: "dspy.Module", + metric: Callable, + devset: list["dspy.Example"], + num_threads: int, + disable_progress_bar: bool, + ): + executor = ParallelExecutor( + num_threads=num_threads, + disable_progress_bar=disable_progress_bar, + max_errors=(self.max_errors or dspy.settings.max_errors), + provide_traceback=self.provide_traceback, + compare_results=True, + ) + + async def process_item(example): + prediction = await program.acall(**example.inputs()) + score = metric(example, prediction) + return prediction, score + + return await executor.aexecute(process_item, devset) def _construct_result_table( self, results: list[tuple["dspy.Example", "dspy.Example", Any]], metric_name: str diff --git a/dspy/utils/parallelizer.py b/dspy/utils/parallelizer.py index d45df534e..92683195f 100644 --- a/dspy/utils/parallelizer.py +++ b/dspy/utils/parallelizer.py @@ -1,3 +1,4 @@ +import asyncio import contextlib import copy import logging @@ -42,29 +43,60 @@ def __init__( self.error_lock = threading.Lock() self.cancel_jobs = threading.Event() + self.error_lock_async = asyncio.Lock() + self.cancel_jobs_async = asyncio.Event() + def execute(self, function, data): tqdm.tqdm._instances.clear() - wrapped = self._wrap_function(function) + wrapped = self._wrap_function(function, async_mode=False) return self._execute_parallel(wrapped, data) - def _wrap_function(self, user_function): - def safe_func(item): + async def aexecute(self, function, data): + tqdm.tqdm._instances.clear() + wrapped = self._wrap_function(function, async_mode=True) + return await self._execute_parallel_async(wrapped, data) + + def _handle_error(self, item, e): + with self.error_lock: + self.error_count += 1 + if self.error_count >= self.max_errors: + self.cancel_jobs.set() + if self.provide_traceback: + logger.error(f"Error for {item}: {e}\n{traceback.format_exc()}") + else: + logger.error(f"Error for {item}: {e}. Set `provide_traceback=True` for traceback.") + + async def _handle_error_async(self, item, e): + async with self.error_lock_async: + self.error_count += 1 + if self.error_count >= self.max_errors: + self.cancel_jobs_async.set() + if self.provide_traceback: + logger.error(f"Error for {item}: {e}\n{traceback.format_exc()}") + + def _wrap_function(self, user_function, async_mode=False): + async def _async_safe_func(item): + if self.cancel_jobs.is_set(): + return None + try: + return await user_function(item) + except Exception as e: + await self._handle_error_async(item, e) + return None + + def _sync_safe_func(item): if self.cancel_jobs.is_set(): return None try: return user_function(item) except Exception as e: - with self.error_lock: - self.error_count += 1 - if self.error_count >= self.max_errors: - self.cancel_jobs.set() - if self.provide_traceback: - logger.error(f"Error for {item}: {e}\n{traceback.format_exc()}") - else: - logger.error(f"Error for {item}: {e}. Set `provide_traceback=True` for traceback.") + self._handle_error(item, e) return None - return safe_func + if async_mode: + return _async_safe_func + else: + return _sync_safe_func def _execute_parallel(self, function, data): results = [None] * len(data) @@ -204,6 +236,50 @@ def all_done(): return results + async def _execute_parallel_async(self, function, data): + queue = asyncio.Queue() + results = [None] * len(data) + for i, example in enumerate(data): + await queue.put((i, example)) + + for _ in range(self.num_threads): + # Add a sentinel value to indicate that the worker should exit + await queue.put((-1, None)) + + # Create tqdm progress bar + pbar = tqdm.tqdm(total=len(data), dynamic_ncols=True) + + async def worker(): + while True: + if self.cancel_jobs_async.is_set(): + break + index, example = await queue.get() + if index == -1: + break + function_outputs = await function(example) + results[index] = function_outputs + + if self.compare_results: + vals = [r[-1] for r in results if r is not None] + self._update_progress(pbar, sum(vals), len(vals)) + else: + self._update_progress( + pbar, + len([r for r in results if r is not None]), + len(data), + ) + + queue.task_done() + + workers = [asyncio.create_task(worker()) for _ in range(self.num_threads)] + await asyncio.gather(*workers) + pbar.close() + if self.cancel_jobs_async.is_set(): + logger.warning("Execution cancelled due to errors or interruption.") + raise Exception("Execution cancelled due to errors or interruption.") + + return results + def _update_progress(self, pbar, nresults, ntotal): if self.compare_results: pct = round(100 * nresults / ntotal, 1) if ntotal else 0 diff --git a/tests/evaluate/test_evaluate.py b/tests/evaluate/test_evaluate.py index 106396c15..9e845773e 100644 --- a/tests/evaluate/test_evaluate.py +++ b/tests/evaluate/test_evaluate.py @@ -1,3 +1,4 @@ +import asyncio import signal import threading from unittest.mock import patch @@ -57,6 +58,7 @@ def test_evaluate_call(): @pytest.mark.extra def test_construct_result_df(): import pandas as pd + devset = [new_example("What is 1+1?", "2"), new_example("What is 2+2?", "4")] ev = Evaluate( devset=devset, @@ -253,6 +255,36 @@ def on_evaluate_end( assert callback.end_call_outputs.score == 100.0 assert callback.end_call_count == 1 + def test_evaluation_result_repr(): result = EvaluationResult(score=100.0, results=[(new_example("What is 1+1?", "2"), {"answer": "2"}, 100.0)]) assert repr(result) == "EvaluationResult(score=100.0, results=)" + + +@pytest.mark.asyncio +async def test_async_evaluate_call(): + class MyProgram(dspy.Module): + async def acall(self, question): + # Simulate network delay + await asyncio.sleep(0.05) + # Just echo the input as the prediction + return dspy.Prediction(answer=question) + + def dummy_metric(gold, pred, traces=None): + return gold.answer == pred.answer + + devset = [new_example(f"placeholder_{i}", f"placeholder_{i}") for i in range(20)] + + evaluator = Evaluate(devset=devset, metric=dummy_metric, num_threads=10, display_progress=False) + + result = await evaluator.acall(MyProgram()) + assert isinstance(result.score, float) + assert result.score == 100.0 # Both answers should match + assert len(result.results) == 20 + + # Check the results are in the correct order + for i, (example, prediction, score) in enumerate(result.results): + assert isinstance(prediction, dspy.Prediction) + assert score == 1 + assert example.question == f"placeholder_{i}" + assert prediction.answer == f"placeholder_{i}" diff --git a/tests/utils/test_parallelizer.py b/tests/utils/test_parallelizer.py index 28307e4ea..883bddbd9 100644 --- a/tests/utils/test_parallelizer.py +++ b/tests/utils/test_parallelizer.py @@ -59,3 +59,63 @@ def task(item): # Verify that the results exclude the failed task assert results == [1, 2, None, 4, 5] + + +@pytest.mark.asyncio +async def test_worker_threads_independence_async(): + async def task(item): + # Each thread maintains its own state by appending to a thread-local list + return item * 2 + + data = [1, 2, 3, 4, 5] + executor = ParallelExecutor(num_threads=3) + results = await executor.aexecute(task, data) + + assert results == [2, 4, 6, 8, 10] + + +@pytest.mark.asyncio +async def test_parallel_execution_speed_async(): + async def task(item): + time.sleep(0.1) # Simulate a time-consuming task + return item + + data = [1, 2, 3, 4, 5] + executor = ParallelExecutor(num_threads=5) + + start_time = time.time() + await executor.aexecute(task, data) + end_time = time.time() + + assert end_time - start_time < len(data) + + +@pytest.mark.asyncio +async def test_max_errors_handling_async(): + async def task(item): + if item == 3: + raise ValueError("Intentional error") + return item + + data = [1, 2, 3, 4, 5] + executor = ParallelExecutor(num_threads=3, max_errors=1) + + with pytest.raises(Exception, match="Execution cancelled due to errors or interruption."): + await executor.aexecute(task, data) + + +@pytest.mark.asyncio +async def test_max_errors_not_met_async(): + async def task(item): + if item == 3: + raise ValueError("Intentional error") + return item + + data = [1, 2, 3, 4, 5] + executor = ParallelExecutor(num_threads=3, max_errors=2) + + # Ensure that the execution completes without crashing when max_errors is not met + results = await executor.aexecute(task, data) + + # Verify that the results exclude the failed task + assert results == [1, 2, None, 4, 5]