Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
148 changes: 140 additions & 8 deletions docetl/operations/utils/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,20 @@

BASIC_MODELS = ["gpt-4o-mini", "gpt-4o"]

# Outlines + Ollama structured generation integration
HAVE_OUTLINES = False
try:
# Outlines integration layer provides a wrapper over python-ollama client
from outlines.integrations.ollama import from_ollama as outlines_from_ollama
from outlines.types import JsonSchema as OutlinesJsonSchema
from outlines.inputs import Chat as OutlinesChat
HAVE_OUTLINES = True
except Exception:
HAVE_OUTLINES = False
try:
from ollama import Client as OllamaClient
except Exception:
OllamaClient = None

class OutputMode(Enum):
"""Enumeration of output modes for LLM calls."""
Expand Down Expand Up @@ -331,24 +345,41 @@ def _cached_call_llm(
total_cost += completion_cost(validator_response)

# Parse the validator response
suggestion = json.loads(
validator_response.choices[0]
.message.tool_calls[0]
.function.arguments
)
if not suggestion["should_refine"]:
message0 = validator_response.choices[0].message
suggestion = None
try:
tool_calls = (
message0.tool_calls
if hasattr(message0, "tool_calls") and message0.tool_calls
else []
)
if tool_calls:
suggestion = json.loads(tool_calls[0].function.arguments)
else:
# Some providers (e.g., Ollama) may not support tool calls. Try parsing content as JSON.
suggestion = json.loads(message0.content)
except Exception:
# Fallback: assume no refinement if we can't parse
suggestion = {"should_refine": False, "improvements": ""}
# Safely read fields with defaults
if not isinstance(suggestion, dict):
suggestion = {}
should_refine = bool(suggestion.get("should_refine", False))
improvements = str(suggestion.get("improvements", ""))

if not should_refine:
break

if verbose:
self.runner.console.log(
f"Validator improvements (gleaning round {rnd + 1}): {suggestion['improvements']}"
f"Validator improvements (gleaning round {rnd + 1}): {improvements}"
)

# Prompt for improvement
improvement_prompt = f"""Based on the validation feedback:

```
{suggestion['improvements']}
{improvements}
```

Please improve your previous response. Ensure that the output adheres to the required schema and addresses any issues raised in the validation."""
Expand Down Expand Up @@ -697,6 +728,107 @@ def _call_llm_with_cache(
},
}

# If using Ollama + structured_output, optionally route via Outlines for stricter schema
use_outlines_for_ollama = (
model.startswith("ollama/")
and (self.runner.config.get("use_outlines_for_ollama", False))
and HAVE_OUTLINES
and OllamaClient is not None
)
if use_outlines_for_ollama:
# Build Ollama client with configured base URL if provided
client_kwargs: dict[str, Any] = {}
if self.default_lm_api_base:
# python-ollama uses 'host' parameter
client_kwargs["host"] = self.default_lm_api_base

try:
ollama_client = OllamaClient(**client_kwargs)
except Exception as e:
raise RuntimeError(
f"Failed to initialize Ollama client for Outlines: {e}"
)

native_model = model.split("/", 1)[1] if "/" in model else model

# Convert our JSON schema dict into an Outlines JsonSchema object
outlines_schema = OutlinesJsonSchema(schema=json.dumps(schema))

persona = self.runner.config.get("system_prompt", {}).get(
"persona", "a helpful assistant"
)
dataset_description = self.runner.config.get(
"system_prompt", {}
).get("dataset_description", "a collection of unstructured documents")
parethetical_op_instructions = (
"many inputs:one output" if op_type == "reduce" else "one input:one output"
)
base_prompt = (
f"You are a {persona}, helping the user make sense of their data. "
f"The dataset description is: {dataset_description}. "
f"You will be performing a {op_type} operation ({parethetical_op_instructions}). "
"You will perform the specified task on the provided data, as precisely and "
"exhaustively (i.e., high recall) as possible."
)
system_prompt = base_prompt + " Respond with a JSON object that follows the required schema."
messages_with_system_prompt = truncate_messages(
[
{"role": "system", "content": system_prompt},
]
+ messages,
model,
)

# Map LiteLLM-style kwargs to Ollama options
def _map_litellm_to_ollama_options(src: dict[str, Any]) -> dict[str, Any]:
options: dict[str, Any] = {}
if "temperature" in src:
options["temperature"] = src["temperature"]
if "top_p" in src:
options["top_p"] = src["top_p"]
if "top_k" in src:
options["top_k"] = src["top_k"]
if "max_tokens" in src:
options["num_predict"] = src["max_tokens"]
if "stop" in src:
options["stop"] = src["stop"]
if "seed" in src:
options["seed"] = src["seed"]
if "frequency_penalty" in src:
options["repeat_penalty"] = src["frequency_penalty"]
return options

options = _map_litellm_to_ollama_options(litellm_completion_kwargs)

outlines_model = outlines_from_ollama(ollama_client, model_name=native_model)

# Outlines expects an Outlines Chat object
outlines_chat = OutlinesChat(messages=messages_with_system_prompt)

try:
generated = outlines_model.generate(
outlines_chat,
output_type=outlines_schema,
options=options,
)
except Exception as e:
raise RuntimeError(
f"Outlines/Ollama structured generation failed: {e}"
)

# Parse JSON into dict; return directly (non-ModelResponse path)
try:
return json.loads(generated)
except Exception:
# If not valid JSON, raise to surface error like other structured paths
raise InvalidOutputError(
"Could not decode structured output JSON response",
generated,
schema,
[],
[],
)

persona = self.runner.config.get("system_prompt", {}).get(
"persona", "a helpful assistant"
)
Expand Down
Loading