@@ -1662,84 +1662,3 @@ async def mock_async_task(inputs: TaskInput) -> TaskOutput:
1662
1662
),
1663
1663
]
1664
1664
)
1665
-
1666
-
1667
- @pytest .mark .skipif (not tenacity_import_successful (), reason = 'tenacity not installed' )
1668
- def test_evaluate_sync_with_retried_task_and_evaluator (
1669
- example_dataset : Dataset [TaskInput , TaskOutput , TaskMetadata ],
1670
- ):
1671
- task_attempt = 0
1672
-
1673
- def mock_sync_task (inputs : TaskInput ) -> TaskOutput :
1674
- nonlocal task_attempt
1675
- if task_attempt < 3 :
1676
- task_attempt += 1
1677
- raise RuntimeError (f'task failure { task_attempt } ' )
1678
- if inputs .query == 'What is 2+2?' :
1679
- return TaskOutput (answer = '4' )
1680
- elif inputs .query == 'What is the capital of France?' :
1681
- return TaskOutput (answer = 'Paris' )
1682
- return TaskOutput (answer = 'Unknown' ) # pragma: no cover
1683
-
1684
- evaluator_attempt = 0
1685
-
1686
- @dataclass
1687
- class RetryEvaluator (Evaluator [TaskInput , TaskOutput , TaskMetadata ]):
1688
- def evaluate (self , ctx : EvaluatorContext [TaskInput , TaskOutput , TaskMetadata ]):
1689
- nonlocal evaluator_attempt
1690
- if evaluator_attempt < 3 :
1691
- evaluator_attempt += 1
1692
- raise RuntimeError (f'evaluator failure { evaluator_attempt } ' )
1693
- if ctx .expected_output is None : # pragma: no cover
1694
- return {'result' : 'no_expected_output' }
1695
- return {
1696
- 'correct' : ctx .output .answer == ctx .expected_output .answer ,
1697
- 'confidence' : ctx .output .confidence ,
1698
- }
1699
-
1700
- example_dataset .add_evaluator (RetryEvaluator ())
1701
-
1702
- report = example_dataset .evaluate_sync (
1703
- mock_sync_task ,
1704
- retry_task = RetryConfig (stop = stop_after_attempt (3 )),
1705
- retry_evaluators = RetryConfig (stop = stop_after_attempt (3 )),
1706
- )
1707
-
1708
- assert task_attempt == 3
1709
- assert evaluator_attempt == 3
1710
-
1711
- assert report is not None
1712
- assert len (report .cases ) == 2
1713
- assert ReportCaseAdapter .dump_python (report .cases [0 ]) == snapshot (
1714
- {
1715
- 'assertions' : {
1716
- 'correct' : {
1717
- 'name' : 'correct' ,
1718
- 'reason' : None ,
1719
- 'source' : {'name' : 'RetryEvaluator' , 'arguments' : None },
1720
- 'value' : True ,
1721
- }
1722
- },
1723
- 'attributes' : {},
1724
- 'evaluator_failures' : [],
1725
- 'expected_output' : {'answer' : '4' , 'confidence' : 1.0 },
1726
- 'inputs' : {'query' : 'What is 2+2?' },
1727
- 'labels' : {},
1728
- 'metadata' : {'category' : 'general' , 'difficulty' : 'easy' },
1729
- 'metrics' : {},
1730
- 'name' : 'case1' ,
1731
- 'output' : {'answer' : '4' , 'confidence' : 1.0 },
1732
- 'scores' : {
1733
- 'confidence' : {
1734
- 'name' : 'confidence' ,
1735
- 'reason' : None ,
1736
- 'source' : {'name' : 'RetryEvaluator' , 'arguments' : None },
1737
- 'value' : 1.0 ,
1738
- }
1739
- },
1740
- 'span_id' : '0000000000000003' ,
1741
- 'task_duration' : IsNumber (),
1742
- 'total_duration' : IsNumber (),
1743
- 'trace_id' : '00000000000000000000000000000001' ,
1744
- }
1745
- )
0 commit comments