Skip to content

Fix timeout and output parsing issues for local models #2045

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 30 additions & 1 deletion src/ragas/prompt/pydantic_prompt.py
Original file line number Diff line number Diff line change
Expand Up @@ -399,15 +399,19 @@ async def parse_output_string(
) -> OutputModel:
callbacks = callbacks or []
try:
# First attempt to extract JSON from the output
jsonstr = extract_json(output_string)
result = super().parse(jsonstr)
except OutputParserException:
# If JSON extraction fails, try more aggressive parsing
if retries_left != 0:
retry_rm, retry_cb = new_group(
name="fix_output_format",
inputs={"output_string": output_string},
callbacks=callbacks,
)

# Add more explicit instructions for the fix_output_format prompt
fixed_output_string = await fix_output_format_prompt.generate(
llm=llm,
data=OutputStringAndPrompt(
Expand All @@ -418,7 +422,32 @@ async def parse_output_string(
retries_left=retries_left - 1,
)
retry_rm.on_chain_end({"fixed_output_string": fixed_output_string})
result = super().parse(fixed_output_string.text)

try:
# Try to parse the fixed output
fixed_jsonstr = extract_json(fixed_output_string.text)
result = super().parse(fixed_jsonstr)
except OutputParserException:
# If still failing, try one more time with a more lenient approach
# This is especially helpful for local models that might not format JSON perfectly
try:
# Try to find anything that looks like JSON
import re
json_pattern = r'\{(?:[^{}]|(?R))*\}'
json_matches = re.findall(json_pattern, fixed_output_string.text)

if json_matches:
for potential_json in json_matches:
try:
result = super().parse(potential_json)
return result
except:
continue

# If we got here, all parsing attempts failed
raise RagasOutputParserException()
except:
raise RagasOutputParserException()
else:
raise RagasOutputParserException()
return result
Expand Down
78 changes: 59 additions & 19 deletions src/ragas/prompt/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,39 +68,79 @@ def replace_string(s: str) -> str:

def extract_json(text: str) -> str:
"""Identify json from a text blob by matching '[]' or '{}'.

Warning: This will identify the first json structure!"""

# check for markdown indicator; if present, start there
md_json_idx = text.find("```json")
if md_json_idx != -1:
text = text[md_json_idx:]

# search for json delimiter pairs

This function attempts to extract valid JSON from text, handling various formats
including markdown code blocks and malformed JSON that might be produced by local models.

Warning: This will identify the first json structure!
"""
import json
import re

# Remove any leading/trailing whitespace
text = text.strip()

# Check for markdown code blocks (```json or ```), and extract content if present
md_json_pattern = r"```(?:json)?\s*([\s\S]*?)\s*```"
md_matches = re.findall(md_json_pattern, text)

if md_matches:
# Try each markdown code block
for md_content in md_matches:
try:
# Validate if it's proper JSON
json.loads(md_content.strip())
return md_content.strip()
except json.JSONDecodeError:
# If not valid JSON, continue with next match or fallback to other methods
pass

# Search for json delimiter pairs
left_bracket_idx = text.find("[")
left_brace_idx = text.find("{")

indices = [idx for idx in (left_bracket_idx, left_brace_idx) if idx != -1]
start_idx = min(indices) if indices else None


# If no delimiter found, return the original text
if start_idx is None:
if not indices:
return text


start_idx = min(indices)

# Identify the exterior delimiters defining JSON
open_char = text[start_idx]
close_char = "]" if open_char == "[" else "}"

# Initialize a count to keep track of delimiter pairs
count = 0
for i, char in enumerate(text[start_idx:], start=start_idx):
if char == open_char:
count += 1
elif char == close_char:
count -= 1

# When count returns to zero, we've found a complete structure
if count == 0:
return text[start_idx : i + 1]

return text # In case of unbalanced JSON, return the original text
potential_json = text[start_idx : i + 1]
try:
# Validate if it's proper JSON
json.loads(potential_json)
return potential_json
except json.JSONDecodeError:
# If not valid JSON, try to fix common issues
fixed_json = potential_json
# Replace single quotes with double quotes (common issue with local models)
fixed_json = re.sub(r"(?<![\\])\'", "\"", fixed_json)
# Fix trailing commas in arrays and objects
fixed_json = re.sub(r",\s*}", "}", fixed_json)
fixed_json = re.sub(r",\s*\]", "]", fixed_json)

try:
json.loads(fixed_json)
return fixed_json
except json.JSONDecodeError:
# If still not valid, continue searching
pass

# If we couldn't find a valid JSON structure, return the original text
return text
7 changes: 7 additions & 0 deletions src/ragas/run_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,9 @@ class RunConfig:
Whether to log retry attempts using tenacity, by default False.
seed : int, optional
Random seed for reproducibility, by default 42.
is_local_model : bool, optional
Whether the LLM is a local model (e.g., via Ollama), by default False.
Local models may need longer timeouts.

Attributes
----------
Expand All @@ -58,9 +61,13 @@ class RunConfig:
] = (Exception,)
log_tenacity: bool = False
seed: int = 42
is_local_model: bool = False

def __post_init__(self):
self.rng = np.random.default_rng(seed=self.seed)
# Increase timeout for local models which may need more time
if self.is_local_model:
self.timeout = 600 # 10 minutes for local models


def add_retry(fn: WrappedFn, run_config: RunConfig) -> WrappedFn:
Expand Down
41 changes: 41 additions & 0 deletions tests/test_local_model_parsing.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
import pytest
import json
from ragas.prompt.utils import extract_json
from ragas.run_config import RunConfig


def test_extract_json_with_markdown():
"""Test extracting JSON from markdown code blocks."""
text = """
Here's the JSON:
```json
{"key": "value", "array": [1, 2, 3]}
```
"""
result = extract_json(text)
assert json.loads(result) == {"key": "value", "array": [1, 2, 3]}


def test_extract_json_with_single_quotes():
"""Test extracting JSON with single quotes (common in local model outputs)."""
text = "{'key': 'value', 'array': [1, 2, 3]}"
result = extract_json(text)
assert json.loads(result) == {"key": "value", "array": [1, 2, 3]}


def test_extract_json_with_trailing_commas():
"""Test extracting JSON with trailing commas (common in local model outputs)."""
text = '{"key": "value", "array": [1, 2, 3,],}'
result = extract_json(text)
assert json.loads(result) == {"key": "value", "array": [1, 2, 3]}


def test_run_config_local_model():
"""Test that local model flag increases timeout."""
# Default config
config = RunConfig()
assert config.timeout == 180

# Local model config
local_config = RunConfig(is_local_model=True)
assert local_config.timeout == 600 # Should be increased to 10 minutes
Loading