diff --git a/README.md b/README.md index 311c689..293135d 100644 --- a/README.md +++ b/README.md @@ -17,12 +17,12 @@ [badge-zenodo]: https://zenodo.org/badge/899554552.svg -🧬 CellAnnotator is an [scverse ecosystem package](https://scverse.org/packages/#ecosystem), designed to annotate cell types in scRNA-seq data based on marker genes using large language models (LLMs). It supports OpenAI, Google Gemini, and Anthropic Claude models out of the box, with more providers planned for the future. +🧬 CellAnnotator is an [scverse ecosystem package](https://scverse.org/packages/#ecosystem), designed to annotate cell types in scRNA-seq data based on marker genes using large language models (LLMs). It supports OpenAI, Google Gemini, Anthropic Claude, and OpenRouter models out of the box. ## ✨ Key Features -- 🤖 **LLM-agnostic backend**: Seamlessly use models from OpenAI, Anthropic (Claude), and Gemini (Google) — just set your provider and API key. +- 🤖 **LLM-agnostic backend**: Seamlessly use models from OpenAI, Anthropic (Claude), Gemini (Google), or OpenRouter — just set your provider and API key. - 🧬 **Automatically annotate cells** including type, state, and confidence fields. - 🔄 **Consistent annotations** across all samples in your study. - 🧠 **Infuse prior knowledge** by providing information about your biological system. @@ -60,6 +60,7 @@ After installation, head over to the LLM provider of your choice to generate an - OpenAI: [API key](https://help.openai.com/en/articles/4936850-where-do-i-find-my-openai-api-key) - Google (Gemini): [API key](https://ai.google.dev/gemini-api/docs/api-key) - Anthropic (Claude): [API key](https://docs.anthropic.com/en/docs/get-started) +- OpenRouter: [API key](https://openrouter.ai/settings/keys) 🔒 Keep this key private and don't share it with anyone. `CellAnnotator` will try to read the key as an environmental variable - either expose it to the environment yourself, or store it as an `.env` file anywhere within the repository where you conduct your analysis and plan to run `CellAnnotator`. The package will then use [dotenv](https://pypi.org/project/python-dotenv/) to export the key from the `env` file as an environmental variable. @@ -78,6 +79,31 @@ cell_ann = CellAnnotator( By default, this will store annotations in `adata.obs['cell_type_predicted']`. Head over to our 📚 [tutorials](https://cell-annotator.readthedocs.io/en/latest/notebooks/tutorials/index.html) to see more advanced use cases, and learn how to adapt this to your own data. You can run `CellAnnotator` for just a single sample of data, or across multiple samples. In the latter case, it will attempt to harmonize annotations across samples. +### Advanced provider options + +`CellAnnotator` can also be used in single-sample mode by setting `sample_key=None`. + +Example: + +```python +from cell_annotator import CellAnnotator + +cell_ann = CellAnnotator( + adata=adata, + species="human", + tissue="pancreas", + cluster_key="leiden_1", + sample_key=None, # single-sample mode + provider="openrouter", + model="openai/gpt-4o-mini", + api_key="YOUR_OPENROUTER_API_KEY", +) + +cell_ann.get_expected_cell_type_markers(n_markers=3) +cell_ann.get_cluster_markers() +cell_ann.annotate_clusters(key_added="cell_type_predicted") +``` + ## 💸 Costs and models @@ -89,15 +115,19 @@ CellAnnotator is LLM-agnostic and works with multiple providers: - **Anthropic Claude:** Claude models are supported. See the [Anthropic pricing page](https://docs.anthropic.com/claude/docs/pricing) for details. +- **OpenRouter:** OpenRouter routes requests to many model families (including OpenAI, Anthropic, and others) behind a single API key. Use `provider="openrouter"` and pass a model slug such as `openai/gpt-4o-mini` or `anthropic/claude-3.5-sonnet`. + You can select your provider and model by setting the appropriate parameters. More providers may be supported in the future as the LLM ecosystem evolves. ## 🔐 Data privacy -This package sends cluster marker genes, and the `species` and `tissue` you define, to the selected LLM provider (e.g., OpenAI, Google, or Anthropic). **No actual gene expression values are sent.** +This package sends cluster marker genes, and the `species` and `tissue` you define, to the selected LLM provider (e.g., OpenAI, Google, Anthropic, or OpenRouter routes). **No actual gene expression values are sent.** Please ensure your usage of this package aligns with your institution's guidelines on data privacy and the use of external AI models. Each provider has its own privacy policy and terms of service. Review these carefully before using CellAnnotator with sensitive or regulated data. +When using OpenRouter, requests are forwarded to the upstream provider implied by your model slug (e.g. `openai/...`, `anthropic/...`). Review both [OpenRouter's privacy policy](https://openrouter.ai/privacy) and the upstream provider's. Some OpenRouter model tiers may log prompts by default; users who need privacy guarantees should configure this via their OpenRouter account settings. + ## 🙏 Credits This tool was inspired by [Hou et al., Nature Methods 2024](https://www.nature.com/articles/s41592-024-02235-4) and [https://github.com/VPetukhov/GPTCellAnnotator](https://github.com/VPetukhov/GPTCellAnnotator). diff --git a/src/cell_annotator/_constants.py b/src/cell_annotator/_constants.py index df5091c..d67d654 100644 --- a/src/cell_annotator/_constants.py +++ b/src/cell_annotator/_constants.py @@ -19,9 +19,10 @@ class PackageConstants: "openai": "gpt-4o-mini", "gemini": "gemini-2.5-flash-lite", "anthropic": "claude-haiku-4-5", + "openrouter": "openai/gpt-4o-mini", } # Supported LLM providers - supported_providers: list[str] = ["openai", "gemini", "anthropic"] + supported_providers: list[str] = ["openai", "gemini", "anthropic", "openrouter"] default_cluster_key: str = "leiden" cell_type_key: str = "cell_type_harmonized" diff --git a/src/cell_annotator/model/_api_keys.py b/src/cell_annotator/model/_api_keys.py index 4e82a83..c801ce3 100644 --- a/src/cell_annotator/model/_api_keys.py +++ b/src/cell_annotator/model/_api_keys.py @@ -15,22 +15,32 @@ class APIKeyManager: on setup for different providers. """ - # Provider configurations + # Provider configurations. ``model_keywords`` feed ``detect_provider_from_model``; + # OpenRouter is detected via the ``provider/model`` slash heuristic instead. PROVIDER_CONFIG = { "openai": { "env_var": "OPENAI_API_KEY", "setup_url": "https://platform.openai.com/api-keys", "description": "OpenAI models (GPT, o1, etc.)", + "model_keywords": ("gpt", "o1", "davinci", "curie", "babbage", "ada"), }, "gemini": { "env_var": "GEMINI_API_KEY", "setup_url": "https://aistudio.google.com/apikey", "description": "Google Gemini models", + "model_keywords": ("gemini", "bison"), }, "anthropic": { "env_var": "ANTHROPIC_API_KEY", "setup_url": "https://console.anthropic.com/settings/keys", "description": "Anthropic Claude models", + "model_keywords": ("claude", "anthropic", "sonnet", "haiku", "opus"), + }, + "openrouter": { + "env_var": "OPENROUTER_API_KEY", + "setup_url": "https://openrouter.ai/settings/keys", + "description": "OpenRouter models (aggregated providers)", + "model_keywords": (), }, } @@ -168,8 +178,6 @@ def validate_model_access(self, model: str) -> tuple[bool, str | None]: """ Check if a specific model is accessible by detecting its provider. - Uses heuristics to detect provider from model name. - Parameters ---------- model @@ -179,23 +187,8 @@ def validate_model_access(self, model: str) -> tuple[bool, str | None]: ------- Tuple of (is_accessible, provider_name) """ - # Detect provider from model name using heuristics - model_lower = model.lower() - - if any(gemini_name in model_lower for gemini_name in ["gemini", "bison"]): - provider = "gemini" - elif any(claude_name in model_lower for claude_name in ["claude", "anthropic"]): - provider = "anthropic" - elif any(openai_name in model_lower for openai_name in ["gpt", "o1", "davinci", "curie", "babbage", "ada"]): - provider = "openai" - else: - # Default to OpenAI for unknown models (most common) - provider = "openai" - - if self.validate_provider(provider): - return True, provider - else: - return False, provider + provider = detect_provider_from_model(model) + return self.validate_provider(provider), provider def check_and_warn(self, provider: str | None = None, model: str | None = None) -> bool: """ @@ -249,6 +242,42 @@ def check_and_warn(self, provider: str | None = None, model: str | None = None) return False +def detect_provider_from_model(model: str) -> str: + """ + Auto-detect the LLM provider from a model name string. + + OpenRouter slugs follow ``/`` (e.g. ``openai/gpt-4o-mini``); + the ``models/`` prefix that Gemini IDs sometimes carry is excluded so it + does not false-match. Otherwise, match keywords from + ``APIKeyManager.PROVIDER_CONFIG[*].model_keywords`` in priority order + (gemini, anthropic, openai). Defaults to ``"openai"`` if nothing matches. + + Parameters + ---------- + model + Model name or slug. + + Returns + ------- + Provider name. + """ + model_lower = model.lower() + + # OpenRouter uses '/' slugs (e.g. 'openai/gpt-4o-mini'). + # The 'models/' guard avoids false-matching Gemini IDs like 'models/gemini-1.5-flash'. + if "/" in model and not model_lower.startswith("models/"): + return "openrouter" + + # Priority order matters: a model name like "ada-claude-experiment" should + # route to anthropic, not openai (anthropic-specific keywords win). + for provider in ("gemini", "anthropic", "openai"): + keywords = APIKeyManager.PROVIDER_CONFIG[provider].get("model_keywords", ()) + if any(keyword in model_lower for keyword in keywords): + return provider + + return "openai" + + class APIKeyMixin: """Mixin class to add API key management capabilities to other classes.""" diff --git a/src/cell_annotator/model/_providers.py b/src/cell_annotator/model/_providers.py index 97a6977..6232869 100644 --- a/src/cell_annotator/model/_providers.py +++ b/src/cell_annotator/model/_providers.py @@ -1,5 +1,7 @@ """LLM provider abstraction layer.""" +import json +import os from abc import ABC, abstractmethod from dotenv import load_dotenv @@ -72,7 +74,7 @@ def _list_models_impl(self) -> list[str]: class OpenAIProvider(LLMProvider): """OpenAI provider implementation.""" - def __init__(self, api_key: str | None = None) -> None: + def __init__(self, api_key: str | None = None, enable_text_repair: bool = False) -> None: """ Initialize OpenAI provider with dependency check. @@ -80,10 +82,19 @@ def __init__(self, api_key: str | None = None) -> None: ---------- api_key Optional API key. If None, uses environment variable. + enable_text_repair + If True, the JSON fallback may ask the model to rewrite its own + free-form output into schema-valid JSON as a last resort. Off by + default to preserve the structured-outputs invariant; subclasses + targeting routers with weak structured-output support (e.g. + OpenRouter) may opt in. """ check_deps("openai") self._client = None self._api_key = api_key + self._base_url = None + self._default_headers = None + self._enable_text_repair = enable_text_repair @property def client(self): @@ -95,7 +106,14 @@ def client(self): from openai import OpenAI # Use manual API key if provided, otherwise use environment/default - self._client = OpenAI(api_key=self._api_key) if self._api_key else OpenAI() + client_kwargs = {} + if self._api_key: + client_kwargs["api_key"] = self._api_key + if self._base_url: + client_kwargs["base_url"] = self._base_url + if self._default_headers: + client_kwargs["default_headers"] = self._default_headers + self._client = OpenAI(**client_kwargs) return self._client def __repr__(self) -> str: @@ -115,10 +133,12 @@ def __repr__(self) -> str: def _list_models_impl(self) -> list[str]: """List available OpenAI models.""" models = self.client.models.list() + return self._filter_chat_model_ids(models.data) - # Filter to only chat models (exclude embeddings, TTS, etc.) + def _filter_chat_model_ids(self, model_data: list) -> list[str]: + """Filter a model list to chat-capable OpenAI models.""" chat_models = [] - for model in models.data: + for model in model_data: model_id = model.id.lower() if ( any(prefix in model_id for prefix in ["gpt", "o1"]) @@ -146,11 +166,13 @@ def query( if other_messages is None: other_messages = [] - try: - messages = [{"role": "user", "content": instruction}] - if other_messages: - messages.extend(other_messages) + messages = self._build_messages( + agent_description=agent_description, + instruction=instruction, + other_messages=other_messages, + ) + try: completion = self.client.chat.completions.parse( model=model, messages=messages, # type: ignore[arg-type] @@ -174,7 +196,277 @@ def query( logger.warning(failure_reason) return response_format.default_failure(failure_reason=failure_reason) except openai.OpenAIError as e: - raise e + logger.warning( + "Structured parse failed for model '%s'. Falling back to JSON-mode query. Error: %s", model, str(e) + ) + return self._query_with_json_fallback( + model=model, + response_format=response_format, + messages=messages, + max_completion_tokens=max_completion_tokens, + fallback_error=str(e), + ) + except (ValueError, TypeError) as e: + if not self._enable_text_repair: + raise + logger.warning( + "Non-OpenAI parse failure for model '%s'. Falling back to JSON-mode query. Error: %s", model, str(e) + ) + return self._query_with_json_fallback( + model=model, + response_format=response_format, + messages=messages, + max_completion_tokens=max_completion_tokens, + fallback_error=str(e), + ) + + def _build_messages(self, agent_description: str, instruction: str, other_messages: list) -> list[dict[str, str]]: + """Build chat messages with system prompt and optional history.""" + messages: list[dict[str, str]] = [] + if agent_description: + messages.append({"role": "system", "content": agent_description}) + if other_messages: + messages.extend(other_messages) + messages.append({"role": "user", "content": instruction}) + return messages + + def _query_with_json_fallback( + self, + model: str, + response_format: type[BaseOutput], + messages: list[dict[str, str]], + max_completion_tokens: int | None, + fallback_error: str, + ) -> BaseOutput: + """ + Fallback for providers/models that do not support `.parse(...)`. + + Tries successive structured-output strategies, from strongest signal + (json_schema via ``extra_body``) to weakest (free-form text repaired + into JSON, gated on ``self._enable_text_repair``). + """ + schema = response_format.model_json_schema() + + # Tier 1: json_schema via extra_body. Canonical OpenRouter structured-output + # path (per OpenRouter docs and LiteLLM); a strong "this must be JSON + # matching the schema" signal that many upstream models honour even when + # the SDK's `.parse()` helper does not work end-to-end. + try: + tier1_kwargs: dict = { + "model": model, + "messages": messages, # type: ignore[arg-type] + "extra_body": { + "response_format": { + "type": "json_schema", + "json_schema": { + "name": response_format.__name__, + "strict": True, + "schema": schema, + }, + } + }, + } + if max_completion_tokens is not None: + tier1_kwargs["max_tokens"] = max_completion_tokens + completion = self.client.chat.completions.create(**tier1_kwargs) + text = self._coerce_text_content(completion.choices[0].message.content) + if text: + return response_format.model_validate_json(text) + except Exception: # noqa: BLE001 + pass + + # Tier 2: plain json_object mode with the schema interpolated into the prompt. + schema_json = json.dumps(schema, ensure_ascii=True) + fallback_instruction = ( + "Return only valid JSON that matches this schema exactly. " + "Do not include markdown fences or extra text.\n" + f"JSON schema: {schema_json}" + ) + fallback_messages = [*messages, {"role": "user", "content": fallback_instruction}] + tier2_kwargs: dict = { + "model": model, + "messages": fallback_messages, # type: ignore[arg-type] + "response_format": {"type": "json_object"}, + } + if max_completion_tokens is not None: + # `max_tokens` is the most widely supported field across OpenAI-compatible APIs. + tier2_kwargs["max_tokens"] = max_completion_tokens + + try: + completion = self.client.chat.completions.create(**tier2_kwargs) + raw_content = completion.choices[0].message.content + text = self._coerce_text_content(raw_content) + if not text: + return response_format.default_failure( + failure_reason=( + f"Model returned empty content during JSON fallback. Original parse error: {fallback_error}" + ) + ) + + # Strict JSON parsing first. + try: + return response_format.model_validate_json(text) + except Exception: # noqa: BLE001 + pass + + # TODO: _extract_json_candidate uses find("{") / rfind("}"), which is + # brittle when the model emits prose with embedded braces or multiple + # JSON blocks. Revisit if this becomes a real failure mode. + json_candidate = self._extract_json_candidate(text) + if json_candidate is not None: + try: + return response_format.model_validate_json(json_candidate) + except Exception: # noqa: BLE001 + pass + + # Tier 3: ask the model to repair its own output. Off by default + # (project invariant: never parse free-form LLM text in production paths); + # opted into by OpenRouterProvider where upstream-model variability + # justifies a last-resort recovery. + if self._enable_text_repair: + repaired_text = self._repair_text_to_json( + model=model, + raw_text=text, + schema_json=schema_json, + max_completion_tokens=max_completion_tokens, + ) + if repaired_text: + try: + return response_format.model_validate_json(repaired_text) + except Exception: # noqa: BLE001 + repaired_candidate = self._extract_json_candidate(repaired_text) + if repaired_candidate is not None: + try: + return response_format.model_validate_json(repaired_candidate) + except Exception: # noqa: BLE001 + pass + + return response_format.default_failure( + failure_reason=( + "Could not parse structured JSON response from model output. " + f"Original parse error: {fallback_error}" + ) + ) + except Exception as fallback_exception: # noqa: BLE001 + return response_format.default_failure( + failure_reason=( + "Fallback JSON query failed. " + f"Original parse error: {fallback_error}. " + f"Fallback error: {str(fallback_exception)}" + ) + ) + + def _coerce_text_content(self, content) -> str: + """Coerce OpenAI-compatible response content into plain text.""" + if content is None: + return "" + if isinstance(content, str): + return content.strip() + if isinstance(content, list): + chunks = [] + for item in content: + if isinstance(item, dict) and "text" in item: + chunks.append(str(item["text"])) + elif hasattr(item, "text"): + chunks.append(str(item.text)) + return "\n".join(chunks).strip() + return str(content).strip() + + def _extract_json_candidate(self, text: str) -> str | None: + """Extract the first JSON object-like substring from text.""" + start = text.find("{") + end = text.rfind("}") + if start == -1 or end == -1 or end <= start: + return None + return text[start : end + 1] + + def _repair_text_to_json( + self, + model: str, + raw_text: str, + schema_json: str, + max_completion_tokens: int | None, + ) -> str: + """Ask the model to convert plain text into schema-valid JSON.""" + logger.warning( + "Last-resort text-to-JSON repair engaged for model '%s'. " + "This bypasses the structured-output contract; verify the response.", + model, + ) + repair_instruction = ( + "Convert the following assistant output into valid JSON matching this schema exactly. " + "Return JSON only, with no markdown or explanation.\n" + f"Schema: {schema_json}\n" + f"Assistant output: {raw_text}" + ) + repair_kwargs: dict = { + "model": model, + "messages": [{"role": "user", "content": repair_instruction}], + "response_format": {"type": "json_object"}, + } + if max_completion_tokens is not None: + repair_kwargs["max_tokens"] = max_completion_tokens + repair_completion = self.client.chat.completions.create(**repair_kwargs) + repair_content = repair_completion.choices[0].message.content + return self._coerce_text_content(repair_content) + + +class OpenRouterProvider(OpenAIProvider): + """OpenRouter provider implementation (OpenAI-compatible API).""" + + def __init__(self, api_key: str | None = None) -> None: + # When no manual key is supplied, resolve the OpenRouter key from the + # environment explicitly. The underlying OpenAI client otherwise picks + # up ``OPENAI_API_KEY`` (since both providers share the SDK), causing + # 401s against ``https://openrouter.ai/api/v1`` whenever both keys + # are configured side-by-side. + if api_key is None: + load_dotenv() + api_key = os.getenv("OPENROUTER_API_KEY") + super().__init__(api_key=api_key, enable_text_repair=True) + self._base_url = "https://openrouter.ai/api/v1" + + # Optional headers recommended by OpenRouter for request attribution. + referer = os.getenv("OPENROUTER_SITE_URL") + title = os.getenv("OPENROUTER_APP_NAME") + headers = {} + if referer: + headers["HTTP-Referer"] = referer + if title: + headers["X-Title"] = title + self._default_headers = headers if headers else None + + def __repr__(self) -> str: + """Return a string representation of the OpenRouter provider.""" + try: + models = self.list_available_models()[:5] # Show first 5 models + if models: + model_preview = ", ".join(models) + if len(self.list_available_models()) > 5: + model_preview += ", ..." + return f"OpenRouterProvider(models: {model_preview}). Call .list_available_models() for complete list." + else: + return "OpenRouterProvider(models: none available). Call .list_available_models() for complete list." + except Exception: # noqa: BLE001 + return "OpenRouterProvider(models: unavailable). Call .list_available_models() for details." + + def _list_models_impl(self) -> list[str]: + """ + List available OpenRouter models. + + OpenRouter exposes models from many upstream providers, so model IDs + are not restricted to OpenAI prefixes like "gpt" and "o1". + """ + models = self.client.models.list() + + filtered_models = [] + for model in models.data: + model_id = model.id.lower() + if any(keyword in model_id for keyword in ["embedding", "tts", "whisper", "dall", "moderation", "rerank"]): + continue + filtered_models.append(model.id) + + return sorted(filtered_models) class GeminiProvider(LLMProvider): @@ -412,7 +704,7 @@ def get_provider(provider_name: str, api_key: str | None = None) -> LLMProvider: Parameters ---------- provider_name - Name of the provider ('openai', 'gemini', or 'anthropic'). + Name of the provider ('openai', 'gemini', 'anthropic', or 'openrouter'). api_key Optional API key. If provided, creates a new provider instance with this key. If None, uses cached provider instance with environment variables. @@ -429,8 +721,10 @@ def get_provider(provider_name: str, api_key: str | None = None) -> LLMProvider: return GeminiProvider(api_key=api_key) elif provider_name == "anthropic": return AnthropicProvider(api_key=api_key) + elif provider_name == "openrouter": + return OpenRouterProvider(api_key=api_key) else: - available = ["openai", "gemini", "anthropic"] + available = ["openai", "gemini", "anthropic", "openrouter"] raise ValueError(f"Unknown provider '{provider_name}'. Available: {', '.join(available)}") # Use cached provider instance for environment-based keys @@ -441,8 +735,10 @@ def get_provider(provider_name: str, api_key: str | None = None) -> LLMProvider: _PROVIDERS[provider_name] = GeminiProvider() elif provider_name == "anthropic": _PROVIDERS[provider_name] = AnthropicProvider() + elif provider_name == "openrouter": + _PROVIDERS[provider_name] = OpenRouterProvider() else: - available = ["openai", "gemini", "anthropic"] + available = ["openai", "gemini", "anthropic", "openrouter"] raise ValueError(f"Unknown provider '{provider_name}'. Available: {', '.join(available)}") return _PROVIDERS[provider_name] diff --git a/src/cell_annotator/model/base_annotator.py b/src/cell_annotator/model/base_annotator.py index 10c2a84..7989ecf 100644 --- a/src/cell_annotator/model/base_annotator.py +++ b/src/cell_annotator/model/base_annotator.py @@ -88,7 +88,11 @@ def __repr__(self) -> str: @d.dedent def query_llm( - self, instruction: str, response_format: type[BaseOutput], other_messages: list | None = None + self, + instruction: str, + response_format: type[BaseOutput], + agent_description: str | None = None, + other_messages: list | None = None, ) -> BaseOutput: """ Query the LLM with a given instruction. @@ -97,11 +101,15 @@ def query_llm( ---------- %(instruction)s %(response_format)s + agent_description + Optional system prompt override. If None, uses the default + cell-annotation prompt from `self.prompts`. %(other_messages)s %(returns_parsed_response)s """ - agent_description = self.prompts.get_cell_type_prompt() + if agent_description is None: + agent_description = self.prompts.get_cell_type_prompt() response = self._provider.query( agent_description=agent_description, diff --git a/src/cell_annotator/model/cell_annotator.py b/src/cell_annotator/model/cell_annotator.py index b950e14..226a955 100644 --- a/src/cell_annotator/model/cell_annotator.py +++ b/src/cell_annotator/model/cell_annotator.py @@ -49,7 +49,16 @@ def __init__( provider: str | None = None, api_key: str | None = None, ): - super().__init__(species, tissue, stage, cluster_key, model, max_completion_tokens, provider, api_key) + super().__init__( + species, + tissue, + stage, + cluster_key, + model, + max_completion_tokens, + provider, + api_key, + ) self.adata = adata self.sample_key = sample_key self._api_key = api_key # Store API key for passing to SampleAnnotators diff --git a/src/cell_annotator/model/llm_interface.py b/src/cell_annotator/model/llm_interface.py index 7e14743..125d7a6 100644 --- a/src/cell_annotator/model/llm_interface.py +++ b/src/cell_annotator/model/llm_interface.py @@ -4,7 +4,7 @@ from cell_annotator._docs import d from cell_annotator._logging import logger from cell_annotator._response_formats import BaseOutput -from cell_annotator.model._api_keys import APIKeyMixin +from cell_annotator.model._api_keys import APIKeyMixin, detect_provider_from_model from cell_annotator.model._providers import get_provider @@ -42,9 +42,11 @@ def __init__( # Determine provider and model if provider is None and model is None: - # Auto-select the first available provider and its default model - if api_key is None: - # Only check environment keys if no manual key provided + # Auto-select the first available provider and its default model. + # When ``_skip_validation`` is set or a manual API key is supplied, + # we must not raise on missing env keys — fork-PR CI runs and + # ``LLMInterface(_skip_validation=True)`` callers both rely on this. + if api_key is None and not _skip_validation: available_providers = self.api_keys.get_available_providers() if not available_providers: raise ValueError( @@ -53,7 +55,6 @@ def __init__( ) provider = available_providers[0] else: - # If manual API key provided but no provider specified, default to OpenAI provider = "openai" model = PackageConstants.default_models[provider] elif provider is None and model is not None: @@ -107,27 +108,8 @@ def __repr__(self) -> str: return "\n".join(lines) def _detect_provider_from_model(self, model: str) -> str: - """ - Auto-detect provider from model name. - - Parameters - ---------- - model - Model name. - - Returns - ------- - Provider name. - """ - if any(keyword in model.lower() for keyword in ["gpt", "o1"]): - return "openai" - elif any(keyword in model.lower() for keyword in ["gemini", "bison"]): - return "gemini" - elif any(keyword in model.lower() for keyword in ["claude", "sonnet", "haiku", "opus"]): - return "anthropic" - else: - # Default to OpenAI for unknown models - return "openai" + """Auto-detect provider from model name. Thin wrapper around the shared helper.""" + return detect_provider_from_model(model) @d.dedent def query_llm( @@ -174,8 +156,11 @@ def test_query(self, return_details: bool = False) -> bool | tuple[bool, str]: """ Test if the LLM setup is working correctly. - Performs a simple query to verify that the API key is valid - and the model can be accessed successfully. + Performs a simple structured-output query against the configured model. + For OpenRouter slugs whose upstream model does not implement OpenAI's + ``.parse()`` endpoint, the provider's fallback chain + (``extra_body`` json_schema → plain ``json_object`` → optional text-repair) + carries the request, so the same code path works for every provider. Parameters ---------- diff --git a/src/cell_annotator/model/sample_annotator.py b/src/cell_annotator/model/sample_annotator.py index 099e532..dd70e0c 100644 --- a/src/cell_annotator/model/sample_annotator.py +++ b/src/cell_annotator/model/sample_annotator.py @@ -54,7 +54,15 @@ def __init__( _skip_validation: bool = False, ): super().__init__( - species, tissue, stage, cluster_key, model, max_completion_tokens, provider, api_key, _skip_validation + species, + tissue, + stage, + cluster_key, + model, + max_completion_tokens, + provider, + api_key, + _skip_validation, ) self.adata = adata self.sample_name = sample_name diff --git a/tests/model/test_api_keys.py b/tests/model/test_api_keys.py index b1f421f..79d91f8 100644 --- a/tests/model/test_api_keys.py +++ b/tests/model/test_api_keys.py @@ -21,7 +21,7 @@ def test_initialization(self): def test_supported_providers(self): """Test that all expected providers are supported.""" manager = APIKeyManager() - expected_providers = {"openai", "gemini", "anthropic"} + expected_providers = {"openai", "gemini", "anthropic", "openrouter"} assert set(manager.PROVIDER_CONFIG.keys()) == expected_providers @patch.dict(os.environ, {"OPENAI_API_KEY": "test_openai_key"}, clear=False) @@ -45,6 +45,13 @@ def test_anthropic_key_detection(self): availability = manager.check_key_availability() assert availability["anthropic"] is True + @patch.dict(os.environ, {"OPENROUTER_API_KEY": "test_openrouter_key"}, clear=False) + def test_openrouter_key_detection(self): + """Test OpenRouter key detection from environment.""" + manager = APIKeyManager() + availability = manager.check_key_availability() + assert availability["openrouter"] is True + @patch.dict(os.environ, {}, clear=True) @patch("cell_annotator.model._api_keys.load_dotenv") def test_no_keys_available(self, _mock_load_dotenv): @@ -54,6 +61,7 @@ def test_no_keys_available(self, _mock_load_dotenv): assert availability["openai"] is False assert availability["gemini"] is False assert availability["anthropic"] is False + assert availability["openrouter"] is False assert manager.get_available_providers() == [] @patch.dict(os.environ, {"OPENAI_API_KEY": "test_openai", "ANTHROPIC_API_KEY": "test_anthropic"}, clear=True) @@ -69,7 +77,12 @@ def test_partial_keys_available(self, _mock_load_dotenv): @patch.dict( os.environ, - {"OPENAI_API_KEY": "test_openai", "GEMINI_API_KEY": "test_gemini", "ANTHROPIC_API_KEY": "test_anthropic"}, + { + "OPENAI_API_KEY": "test_openai", + "GEMINI_API_KEY": "test_gemini", + "ANTHROPIC_API_KEY": "test_anthropic", + "OPENROUTER_API_KEY": "test_openrouter", + }, clear=True, ) @patch("cell_annotator.model._api_keys.load_dotenv") @@ -77,7 +90,7 @@ def test_all_keys_available(self, _mock_load_dotenv): """Test behavior when all API keys are available.""" manager = APIKeyManager() available = manager.get_available_providers() - expected = {"openai", "gemini", "anthropic"} + expected = {"openai", "gemini", "anthropic", "openrouter"} assert set(available) == expected def test_validate_provider(self): @@ -154,7 +167,7 @@ def test_mixin_check_api_access(self): mixin = APIKeyMixin() # Test with valid providers - for provider in ["openai", "gemini", "anthropic"]: + for provider in ["openai", "gemini", "anthropic", "openrouter"]: result = mixin.check_api_access(provider) assert isinstance(result, bool) diff --git a/tests/model/test_base_annotator.py b/tests/model/test_base_annotator.py index 36353aa..6d329b5 100644 --- a/tests/model/test_base_annotator.py +++ b/tests/model/test_base_annotator.py @@ -1,3 +1,4 @@ +import os from unittest.mock import patch import pytest @@ -37,6 +38,12 @@ def test_query_llm_real(self, base_annotator): print(f"✅ {base_annotator._provider_name} provider test passed with model: {base_annotator.model}") print(f"Response: {response.parsed_response}") + @pytest.mark.skipif( + not any( + os.getenv(key) for key in ["OPENAI_API_KEY", "GEMINI_API_KEY", "ANTHROPIC_API_KEY", "OPENROUTER_API_KEY"] + ), + reason="Auto-detection requires at least one provider key in env (fork-PR CI runs don't have access to secrets).", + ) def test_provider_auto_detection(self): """Test automatic provider detection when none specified.""" annotator = BaseAnnotator( @@ -52,7 +59,7 @@ def test_provider_auto_detection(self): def test_explicit_provider_selection(self): """Test explicit provider selection.""" - for provider_name in ["openai", "gemini", "anthropic"]: + for provider_name in ["openai", "gemini", "anthropic", "openrouter"]: try: annotator = BaseAnnotator( species="human", @@ -74,6 +81,7 @@ def test_provider_model_combination(self): ("openai", "gpt-4o-mini"), ("gemini", "gemini-1.5-flash"), ("anthropic", "claude-3-haiku-20240307"), + ("openrouter", "openai/gpt-4o-mini"), ] for provider_name, model in test_cases: diff --git a/tests/model/test_llm_interface.py b/tests/model/test_llm_interface.py index 2a873c1..6fcf007 100644 --- a/tests/model/test_llm_interface.py +++ b/tests/model/test_llm_interface.py @@ -1,5 +1,6 @@ """Tests for LLMInterface class.""" +import os from unittest.mock import patch import pytest @@ -38,6 +39,9 @@ def test_repr(self, provider_name): ("claude-3-5-sonnet-20241022", "anthropic"), ("claude-3-haiku-20240307", "anthropic"), ("claude-3-opus-20240229", "anthropic"), + ("openai/gpt-4o-mini", "openrouter"), + ("anthropic/claude-3.5-sonnet", "openrouter"), + ("models/gemini-1.5-flash", "gemini"), # 'models/' prefix must NOT trigger OpenRouter ("unknown-model", "openai"), # Should default to openai ], ) @@ -47,6 +51,22 @@ def test_detect_provider_from_model(self, model_name, expected_provider): detected_provider = interface._detect_provider_from_model(model_name) assert detected_provider == expected_provider + def test_skip_validation_with_no_env_keys(self): + """``LLMInterface(_skip_validation=True)`` constructs cleanly even when no env API keys are set. + + Regression test: previously the auto-select branch raised ``ValueError("No API keys found")`` + even when ``_skip_validation=True``, breaking fork-PR CI runs and any caller that relied on + the validation skip. + """ + with patch.dict( + os.environ, + {"OPENAI_API_KEY": "", "GEMINI_API_KEY": "", "ANTHROPIC_API_KEY": "", "OPENROUTER_API_KEY": ""}, + clear=False, + ): + interface = LLMInterface(_skip_validation=True) + assert interface._provider_name == "openai" + assert interface.model is not None + @patch("cell_annotator.model.llm_interface.LLMInterface.query_llm") def test_query_success(self, mock_query_llm, provider_name): """Test successful test_query scenario.""" diff --git a/tests/model/test_providers.py b/tests/model/test_providers.py index 3a53968..8aef61e 100644 --- a/tests/model/test_providers.py +++ b/tests/model/test_providers.py @@ -7,7 +7,7 @@ from cell_annotator._constants import PackageConstants from cell_annotator._response_formats import BaseOutput -from cell_annotator.model._providers import AnthropicProvider, GeminiProvider, OpenAIProvider +from cell_annotator.model._providers import AnthropicProvider, GeminiProvider, OpenAIProvider, OpenRouterProvider class SimpleOutput(BaseOutput): @@ -52,16 +52,29 @@ def test_anthropic_provider_initialization(self): provider_with_key = AnthropicProvider(api_key="test-key") assert provider_with_key._api_key == "test-key" + def test_openrouter_provider_initialization(self): + """Test OpenRouter provider initialization.""" + # Test with no API key — constructor resolves OPENROUTER_API_KEY from env + # (so the SDK doesn't fall through to OPENAI_API_KEY when both are set). + provider = OpenRouterProvider() + assert provider is not None + + # Test with manual API key — explicit override wins over env. + provider_with_key = OpenRouterProvider(api_key="test-key") + assert provider_with_key._api_key == "test-key" + def test_provider_repr(self): """Test string representation of providers.""" openai_provider = OpenAIProvider() gemini_provider = GeminiProvider() anthropic_provider = AnthropicProvider() + openrouter_provider = OpenRouterProvider() # Should contain provider name assert "OpenAIProvider" in repr(openai_provider) assert "GeminiProvider" in repr(gemini_provider) assert "AnthropicProvider" in repr(anthropic_provider) + assert "OpenRouterProvider" in repr(openrouter_provider) class TestOpenAIProvider: @@ -213,6 +226,51 @@ def test_anthropic_initialization(self): assert provider_with_key._api_key == "test-key" +class TestOpenRouterProvider: + """Isolated tests for OpenRouter provider.""" + + @flaky + @pytest.mark.skipif(not os.getenv("OPENROUTER_API_KEY"), reason="OPENROUTER_API_KEY not available") + @pytest.mark.real_llm_query() + def test_openrouter_list_models_real(self): + """Test OpenRouter model listing with real API.""" + provider = OpenRouterProvider() + models = provider.list_available_models() + + assert isinstance(models, list) + assert len(models) > 0 + model_names = [model.lower() for model in models] + assert any("/" in model for model in model_names) + + @flaky + @pytest.mark.skipif(not os.getenv("OPENROUTER_API_KEY"), reason="OPENROUTER_API_KEY not available") + @pytest.mark.real_llm_query() + def test_openrouter_query_real(self): + """Test OpenRouter query with real API call.""" + provider = OpenRouterProvider() + + response = provider.query( + agent_description="You are a helpful assistant.", + instruction="Say hello in exactly one word.", + model=PackageConstants.default_models["openrouter"], + response_format=SimpleOutput, + max_completion_tokens=50, + ) + + assert isinstance(response, SimpleOutput) + assert hasattr(response, "text") + assert len(response.text.strip()) > 0 + + def test_openrouter_initialization(self): + """Test OpenRouter provider initialization and basic properties.""" + provider = OpenRouterProvider() + assert provider is not None + + # Test with manual API key — explicit override wins over env. + provider_with_key = OpenRouterProvider(api_key="test-key") + assert provider_with_key._api_key == "test-key" + + class TestProviderIntegration: """Integration tests for provider functionality.""" @@ -222,6 +280,7 @@ def test_provider_factory_pattern(self): "openai": OpenAIProvider, "gemini": GeminiProvider, "anthropic": AnthropicProvider, + "openrouter": OpenRouterProvider, } for name, provider_class in providers.items(): @@ -231,7 +290,9 @@ def test_provider_factory_pattern(self): @flaky @pytest.mark.skipif( - not any(os.getenv(key) for key in ["OPENAI_API_KEY", "GEMINI_API_KEY", "ANTHROPIC_API_KEY"]), + not any( + os.getenv(key) for key in ["OPENAI_API_KEY", "GEMINI_API_KEY", "ANTHROPIC_API_KEY", "OPENROUTER_API_KEY"] + ), reason="No API keys available for testing", ) @pytest.mark.real_llm_query() @@ -245,6 +306,8 @@ def test_available_provider_models(self): providers_to_test.append(("gemini", GeminiProvider())) if os.getenv("ANTHROPIC_API_KEY"): providers_to_test.append(("anthropic", AnthropicProvider())) + if os.getenv("OPENROUTER_API_KEY"): + providers_to_test.append(("openrouter", OpenRouterProvider())) for provider_name, provider in providers_to_test: print(f"Testing {provider_name} models...") @@ -259,10 +322,11 @@ def test_provider_error_handling(self): OpenAIProvider(api_key="invalid-openai-key"), GeminiProvider(api_key="invalid-gemini-key"), AnthropicProvider(api_key="invalid-anthropic-key"), + OpenRouterProvider(api_key="invalid-openrouter-key"), ] # Just verify they were created successfully - assert len(invalid_providers) == 3 + assert len(invalid_providers) == 4 for provider in invalid_providers: assert provider is not None @@ -272,6 +336,7 @@ def test_provider_response_format_consistency(self): OpenAIProvider(api_key="fake-key"), GeminiProvider(api_key="fake-key"), AnthropicProvider(api_key="fake-key"), + OpenRouterProvider(api_key="fake-key"), ] # Just test that providers can be created and have the expected interface @@ -281,3 +346,144 @@ def test_provider_response_format_consistency(self): assert hasattr(provider, "list_available_models") assert callable(provider.query) assert callable(provider.list_available_models) + + +class _FallbackOutput(BaseOutput): + """Test response format whose required field has a default, so ``default_failure`` can construct it.""" + + text: str = "" + + +class TestJSONFallback: + """Tests for the tiered JSON-fallback chain in OpenAIProvider/OpenRouterProvider.""" + + @pytest.mark.parametrize( + ("provider_cls", "expected"), + [(OpenAIProvider, False), (OpenRouterProvider, True)], + ) + def test_text_repair_default(self, provider_cls, expected): + """``enable_text_repair`` is off for OpenAI, on for OpenRouter — the structured-outputs invariant.""" + assert provider_cls()._enable_text_repair is expected + + @staticmethod + def _build_mock_client(*, tier_responses): + """Build a mock OpenAI client whose .parse() raises and whose .create() is programmable. + + Each entry in ``tier_responses`` is either a string (returned as ``message.content``) + or ``None`` (raises ``RuntimeError`` to fall through to the next tier). + """ + from unittest.mock import MagicMock + + import openai + + client = MagicMock() + client.chat.completions.parse.side_effect = openai.OpenAIError("simulated parse failure") + + call_idx = {"i": 0} + + def fake_create(**kwargs): + idx = call_idx["i"] + call_idx["i"] += 1 + response = tier_responses[idx] if idx < len(tier_responses) else None + if response is None: + raise RuntimeError(f"tier {idx + 1} simulated failure") + mock_completion = MagicMock() + mock_completion.choices = [MagicMock()] + mock_completion.choices[0].message.content = response + return mock_completion + + client.chat.completions.create.side_effect = fake_create + return client + + def test_chain_tries_extra_body_first_then_json_object(self, monkeypatch): + """Tier 1 uses ``extra_body`` json_schema; on failure tier 2 uses plain ``json_object``.""" + provider = OpenAIProvider(api_key="fake", enable_text_repair=False) + # Tier 1 fails, tier 2 returns valid JSON. + client = self._build_mock_client(tier_responses=[None, '{"text": "ok"}']) + monkeypatch.setattr(provider, "_client", client) + + result = provider.query( + agent_description="x", + instruction="y", + model="m", + response_format=_FallbackOutput, + ) + + assert isinstance(result, _FallbackOutput) + assert result.text == "ok" + calls = client.chat.completions.create.call_args_list + assert len(calls) == 2 + # Tier 1: extra_body json_schema + assert "extra_body" in calls[0].kwargs + assert calls[0].kwargs["extra_body"]["response_format"]["type"] == "json_schema" + # Tier 2: plain json_object + assert calls[1].kwargs.get("response_format") == {"type": "json_object"} + + @pytest.mark.parametrize("enable_text_repair", [False, True]) + def test_text_repair_flag_gates_tier_three(self, monkeypatch, enable_text_repair): + """Tier 3 (text-repair) only runs when ``enable_text_repair`` is True.""" + provider = OpenAIProvider(api_key="fake", enable_text_repair=enable_text_repair) + # Tier 1 fails, tier 2 returns prose without JSON braces, tier 3 returns valid JSON. + client = self._build_mock_client( + tier_responses=[None, "no json here at all", '{"text": "repaired"}'], + ) + monkeypatch.setattr(provider, "_client", client) + + result = provider.query( + agent_description="x", + instruction="y", + model="m", + response_format=_FallbackOutput, + ) + + if enable_text_repair: + assert isinstance(result, _FallbackOutput) + assert result.text == "repaired" + assert client.chat.completions.create.call_count == 3 + else: + assert result.reason_for_failure is not None + assert client.chat.completions.create.call_count == 2 + + def test_query_reraises_local_error_when_text_repair_disabled(self): + """ValueError raised before any API call must propagate when text-repair is off.""" + from unittest.mock import MagicMock + + provider = OpenAIProvider(api_key="fake", enable_text_repair=False) + client = MagicMock() + client.chat.completions.parse.side_effect = ValueError("local validation issue") + provider._client = client + + with pytest.raises(ValueError, match="local validation issue"): + provider.query( + agent_description="x", + instruction="y", + model="m", + response_format=_FallbackOutput, + ) + + +class TestOpenRouterFallback: + """Live tests exercising the fallback chain against a real OpenRouter slug.""" + + @flaky + @pytest.mark.skipif(not os.getenv("OPENROUTER_API_KEY"), reason="OPENROUTER_API_KEY not available") + @pytest.mark.real_llm_query() + def test_openrouter_query_real_with_fallback(self): + """Pick a slug whose upstream model commonly fails ``.parse()`` to exercise the new ``extra_body`` tier.""" + provider = OpenRouterProvider() + # anthropic-via-OpenRouter consistently triggered .parse() failure during PR #71 development; + # haiku is the cheapest known-to-need-fallback option. + # If model behavior on OpenRouter ever drifts and this slug succeeds via .parse(), the + # underlying chain still returns a valid response so the assertion holds — mark xfail + # only if a different failure mode emerges. + response = provider.query( + agent_description="You are a helpful assistant.", + instruction="Say hello in exactly one word.", + model="anthropic/claude-haiku-4-5", + response_format=SimpleOutput, + max_completion_tokens=50, + ) + + assert isinstance(response, SimpleOutput) + assert hasattr(response, "text") + assert len(response.text.strip()) > 0