From 5792662f251c403e56a374b0b1309f382d652652 Mon Sep 17 00:00:00 2001 From: openhands Date: Sat, 17 May 2025 17:32:33 +0000 Subject: [PATCH] Fix timeout and output parsing issues for local models This commit addresses issue #2044 by: 1. Adding a new is_local_model flag to RunConfig with increased timeout for local models 2. Improving the JSON extraction logic in extract_json to handle common issues with local model outputs 3. Enhancing the RagasOutputParser to be more robust with malformed JSON 4. Adding tests to verify the fixes --- src/ragas/prompt/pydantic_prompt.py | 31 +++++++++++- src/ragas/prompt/utils.py | 78 ++++++++++++++++++++++------- src/ragas/run_config.py | 7 +++ tests/test_local_model_parsing.py | 41 +++++++++++++++ 4 files changed, 137 insertions(+), 20 deletions(-) create mode 100644 tests/test_local_model_parsing.py diff --git a/src/ragas/prompt/pydantic_prompt.py b/src/ragas/prompt/pydantic_prompt.py index 7d424c8c2..22eabed82 100644 --- a/src/ragas/prompt/pydantic_prompt.py +++ b/src/ragas/prompt/pydantic_prompt.py @@ -399,15 +399,19 @@ async def parse_output_string( ) -> OutputModel: callbacks = callbacks or [] try: + # First attempt to extract JSON from the output jsonstr = extract_json(output_string) result = super().parse(jsonstr) except OutputParserException: + # If JSON extraction fails, try more aggressive parsing if retries_left != 0: retry_rm, retry_cb = new_group( name="fix_output_format", inputs={"output_string": output_string}, callbacks=callbacks, ) + + # Add more explicit instructions for the fix_output_format prompt fixed_output_string = await fix_output_format_prompt.generate( llm=llm, data=OutputStringAndPrompt( @@ -418,7 +422,32 @@ async def parse_output_string( retries_left=retries_left - 1, ) retry_rm.on_chain_end({"fixed_output_string": fixed_output_string}) - result = super().parse(fixed_output_string.text) + + try: + # Try to parse the fixed output + fixed_jsonstr = extract_json(fixed_output_string.text) + result = super().parse(fixed_jsonstr) + except OutputParserException: + # If still failing, try one more time with a more lenient approach + # This is especially helpful for local models that might not format JSON perfectly + try: + # Try to find anything that looks like JSON + import re + json_pattern = r'\{(?:[^{}]|(?R))*\}' + json_matches = re.findall(json_pattern, fixed_output_string.text) + + if json_matches: + for potential_json in json_matches: + try: + result = super().parse(potential_json) + return result + except: + continue + + # If we got here, all parsing attempts failed + raise RagasOutputParserException() + except: + raise RagasOutputParserException() else: raise RagasOutputParserException() return result diff --git a/src/ragas/prompt/utils.py b/src/ragas/prompt/utils.py index 4019940ab..8314c8281 100644 --- a/src/ragas/prompt/utils.py +++ b/src/ragas/prompt/utils.py @@ -68,29 +68,49 @@ def replace_string(s: str) -> str: def extract_json(text: str) -> str: """Identify json from a text blob by matching '[]' or '{}'. - - Warning: This will identify the first json structure!""" - - # check for markdown indicator; if present, start there - md_json_idx = text.find("```json") - if md_json_idx != -1: - text = text[md_json_idx:] - - # search for json delimiter pairs + + This function attempts to extract valid JSON from text, handling various formats + including markdown code blocks and malformed JSON that might be produced by local models. + + Warning: This will identify the first json structure! + """ + import json + import re + + # Remove any leading/trailing whitespace + text = text.strip() + + # Check for markdown code blocks (```json or ```), and extract content if present + md_json_pattern = r"```(?:json)?\s*([\s\S]*?)\s*```" + md_matches = re.findall(md_json_pattern, text) + + if md_matches: + # Try each markdown code block + for md_content in md_matches: + try: + # Validate if it's proper JSON + json.loads(md_content.strip()) + return md_content.strip() + except json.JSONDecodeError: + # If not valid JSON, continue with next match or fallback to other methods + pass + + # Search for json delimiter pairs left_bracket_idx = text.find("[") left_brace_idx = text.find("{") - + indices = [idx for idx in (left_bracket_idx, left_brace_idx) if idx != -1] - start_idx = min(indices) if indices else None - + # If no delimiter found, return the original text - if start_idx is None: + if not indices: return text - + + start_idx = min(indices) + # Identify the exterior delimiters defining JSON open_char = text[start_idx] close_char = "]" if open_char == "[" else "}" - + # Initialize a count to keep track of delimiter pairs count = 0 for i, char in enumerate(text[start_idx:], start=start_idx): @@ -98,9 +118,29 @@ def extract_json(text: str) -> str: count += 1 elif char == close_char: count -= 1 - + # When count returns to zero, we've found a complete structure if count == 0: - return text[start_idx : i + 1] - - return text # In case of unbalanced JSON, return the original text + potential_json = text[start_idx : i + 1] + try: + # Validate if it's proper JSON + json.loads(potential_json) + return potential_json + except json.JSONDecodeError: + # If not valid JSON, try to fix common issues + fixed_json = potential_json + # Replace single quotes with double quotes (common issue with local models) + fixed_json = re.sub(r"(? WrappedFn: diff --git a/tests/test_local_model_parsing.py b/tests/test_local_model_parsing.py new file mode 100644 index 000000000..6650d521d --- /dev/null +++ b/tests/test_local_model_parsing.py @@ -0,0 +1,41 @@ +import pytest +import json +from ragas.prompt.utils import extract_json +from ragas.run_config import RunConfig + + +def test_extract_json_with_markdown(): + """Test extracting JSON from markdown code blocks.""" + text = """ + Here's the JSON: + ```json + {"key": "value", "array": [1, 2, 3]} + ``` + """ + result = extract_json(text) + assert json.loads(result) == {"key": "value", "array": [1, 2, 3]} + + +def test_extract_json_with_single_quotes(): + """Test extracting JSON with single quotes (common in local model outputs).""" + text = "{'key': 'value', 'array': [1, 2, 3]}" + result = extract_json(text) + assert json.loads(result) == {"key": "value", "array": [1, 2, 3]} + + +def test_extract_json_with_trailing_commas(): + """Test extracting JSON with trailing commas (common in local model outputs).""" + text = '{"key": "value", "array": [1, 2, 3,],}' + result = extract_json(text) + assert json.loads(result) == {"key": "value", "array": [1, 2, 3]} + + +def test_run_config_local_model(): + """Test that local model flag increases timeout.""" + # Default config + config = RunConfig() + assert config.timeout == 180 + + # Local model config + local_config = RunConfig(is_local_model=True) + assert local_config.timeout == 600 # Should be increased to 10 minutes \ No newline at end of file