diff --git a/docetl/operations/utils/api.py b/docetl/operations/utils/api.py index 85e8358c..7ad31db2 100644 --- a/docetl/operations/utils/api.py +++ b/docetl/operations/utils/api.py @@ -42,6 +42,20 @@ BASIC_MODELS = ["gpt-4o-mini", "gpt-4o"] +# Outlines + Ollama structured generation integration +HAVE_OUTLINES = False +try: + # Outlines integration layer provides a wrapper over python-ollama client + from outlines.integrations.ollama import from_ollama as outlines_from_ollama + from outlines.types import JsonSchema as OutlinesJsonSchema + from outlines.inputs import Chat as OutlinesChat + HAVE_OUTLINES = True +except Exception: + HAVE_OUTLINES = False +try: + from ollama import Client as OllamaClient +except Exception: + OllamaClient = None class OutputMode(Enum): """Enumeration of output modes for LLM calls.""" @@ -331,24 +345,41 @@ def _cached_call_llm( total_cost += completion_cost(validator_response) # Parse the validator response - suggestion = json.loads( - validator_response.choices[0] - .message.tool_calls[0] - .function.arguments - ) - if not suggestion["should_refine"]: + message0 = validator_response.choices[0].message + suggestion = None + try: + tool_calls = ( + message0.tool_calls + if hasattr(message0, "tool_calls") and message0.tool_calls + else [] + ) + if tool_calls: + suggestion = json.loads(tool_calls[0].function.arguments) + else: + # Some providers (e.g., Ollama) may not support tool calls. Try parsing content as JSON. + suggestion = json.loads(message0.content) + except Exception: + # Fallback: assume no refinement if we can't parse + suggestion = {"should_refine": False, "improvements": ""} + # Safely read fields with defaults + if not isinstance(suggestion, dict): + suggestion = {} + should_refine = bool(suggestion.get("should_refine", False)) + improvements = str(suggestion.get("improvements", "")) + + if not should_refine: break if verbose: self.runner.console.log( - f"Validator improvements (gleaning round {rnd + 1}): {suggestion['improvements']}" + f"Validator improvements (gleaning round {rnd + 1}): {improvements}" ) # Prompt for improvement improvement_prompt = f"""Based on the validation feedback: ``` - {suggestion['improvements']} + {improvements} ``` Please improve your previous response. Ensure that the output adheres to the required schema and addresses any issues raised in the validation.""" @@ -697,6 +728,107 @@ def _call_llm_with_cache( }, } + # If using Ollama + structured_output, optionally route via Outlines for stricter schema + use_outlines_for_ollama = ( + model.startswith("ollama/") + and (self.runner.config.get("use_outlines_for_ollama", False)) + and HAVE_OUTLINES + and OllamaClient is not None + ) + if use_outlines_for_ollama: + # Build Ollama client with configured base URL if provided + client_kwargs: dict[str, Any] = {} + if self.default_lm_api_base: + # python-ollama uses 'host' parameter + client_kwargs["host"] = self.default_lm_api_base + + try: + ollama_client = OllamaClient(**client_kwargs) + except Exception as e: + raise RuntimeError( + f"Failed to initialize Ollama client for Outlines: {e}" + ) + + native_model = model.split("/", 1)[1] if "/" in model else model + + # Convert our JSON schema dict into an Outlines JsonSchema object + outlines_schema = OutlinesJsonSchema(schema=json.dumps(schema)) + + persona = self.runner.config.get("system_prompt", {}).get( + "persona", "a helpful assistant" + ) + dataset_description = self.runner.config.get( + "system_prompt", {} + ).get("dataset_description", "a collection of unstructured documents") + parethetical_op_instructions = ( + "many inputs:one output" if op_type == "reduce" else "one input:one output" + ) + base_prompt = ( + f"You are a {persona}, helping the user make sense of their data. " + f"The dataset description is: {dataset_description}. " + f"You will be performing a {op_type} operation ({parethetical_op_instructions}). " + "You will perform the specified task on the provided data, as precisely and " + "exhaustively (i.e., high recall) as possible." + ) + system_prompt = base_prompt + " Respond with a JSON object that follows the required schema." + messages_with_system_prompt = truncate_messages( + [ + {"role": "system", "content": system_prompt}, + ] + + messages, + model, + ) + + # Map LiteLLM-style kwargs to Ollama options + def _map_litellm_to_ollama_options(src: dict[str, Any]) -> dict[str, Any]: + options: dict[str, Any] = {} + if "temperature" in src: + options["temperature"] = src["temperature"] + if "top_p" in src: + options["top_p"] = src["top_p"] + if "top_k" in src: + options["top_k"] = src["top_k"] + if "max_tokens" in src: + options["num_predict"] = src["max_tokens"] + if "stop" in src: + options["stop"] = src["stop"] + if "seed" in src: + options["seed"] = src["seed"] + if "frequency_penalty" in src: + options["repeat_penalty"] = src["frequency_penalty"] + return options + + options = _map_litellm_to_ollama_options(litellm_completion_kwargs) + + outlines_model = outlines_from_ollama(ollama_client, model_name=native_model) + + # Outlines expects an Outlines Chat object + outlines_chat = OutlinesChat(messages=messages_with_system_prompt) + + try: + generated = outlines_model.generate( + outlines_chat, + output_type=outlines_schema, + options=options, + ) + except Exception as e: + raise RuntimeError( + f"Outlines/Ollama structured generation failed: {e}" + ) + + # Parse JSON into dict; return directly (non-ModelResponse path) + try: + return json.loads(generated) + except Exception: + # If not valid JSON, raise to surface error like other structured paths + raise InvalidOutputError( + "Could not decode structured output JSON response", + generated, + schema, + [], + [], + ) + persona = self.runner.config.get("system_prompt", {}).get( "persona", "a helpful assistant" )