From c8254b71773bff8a3f5b11046af1f54602ae8610 Mon Sep 17 00:00:00 2001 From: Michael Liu Date: Fri, 22 Aug 2025 14:11:06 -0700 Subject: [PATCH 1/9] Parametrize ModelProvider fixture directly --- llm-service/app/tests/services/test_models.py | 7 +------ 1 file changed, 1 insertion(+), 6 deletions(-) diff --git a/llm-service/app/tests/services/test_models.py b/llm-service/app/tests/services/test_models.py index 1649c8eba..891674234 100644 --- a/llm-service/app/tests/services/test_models.py +++ b/llm-service/app/tests/services/test_models.py @@ -55,7 +55,7 @@ def get_all_env_var_names() -> set[str]: ) -@pytest.fixture() +@pytest.fixture(params=ModelProvider.__subclasses__()) def EnabledModelProvider( request: pytest.FixtureRequest, monkeypatch: pytest.MonkeyPatch, @@ -71,11 +71,6 @@ def EnabledModelProvider( return ModelProviderSubcls -@pytest.mark.parametrize( - "EnabledModelProvider", - ModelProvider.__subclasses__(), - indirect=True, -) class TestListAvailableModels: @pytest.fixture(autouse=True) def caii_get_models(self, monkeypatch: pytest.MonkeyPatch) -> None: From 9b0f0a188e7797b08eafd0e3bbf556888b7ef7f8 Mon Sep 17 00:00:00 2001 From: Michael Liu Date: Fri, 22 Aug 2025 14:39:57 -0700 Subject: [PATCH 2/9] Simplify clearing env vars --- llm-service/app/tests/services/test_models.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/llm-service/app/tests/services/test_models.py b/llm-service/app/tests/services/test_models.py index 891674234..3b9fcfe5e 100644 --- a/llm-service/app/tests/services/test_models.py +++ b/llm-service/app/tests/services/test_models.py @@ -63,10 +63,10 @@ def EnabledModelProvider( """Sets and unsets environment variables for the given model provider.""" ModelProviderSubcls: type[ModelProvider] = request.param + for name in get_all_env_var_names(): + monkeypatch.delenv(name, raising=False) for name in ModelProviderSubcls.get_env_var_names(): monkeypatch.setenv(name, "test") - for name in get_all_env_var_names() - ModelProviderSubcls.get_env_var_names(): - monkeypatch.delenv(name, raising=False) return ModelProviderSubcls From e2f5240af31581cee4027abdaf599d4f8bb21856 Mon Sep 17 00:00:00 2001 From: Michael Liu Date: Fri, 22 Aug 2025 14:46:38 -0700 Subject: [PATCH 3/9] Move get_model_source() to the tops of classes --- llm-service/app/services/models/_model_source.py | 2 -- .../app/services/models/providers/_model_provider.py | 10 +++++----- llm-service/app/services/models/providers/azure.py | 8 ++++---- llm-service/app/services/models/providers/bedrock.py | 8 ++++---- llm-service/app/services/models/providers/caii.py | 11 ++++++----- llm-service/app/services/models/providers/openai.py | 8 ++++---- 6 files changed, 23 insertions(+), 24 deletions(-) diff --git a/llm-service/app/services/models/_model_source.py b/llm-service/app/services/models/_model_source.py index 625662ea5..618b2646e 100644 --- a/llm-service/app/services/models/_model_source.py +++ b/llm-service/app/services/models/_model_source.py @@ -44,5 +44,3 @@ class ModelSource(str, Enum): CAII = "CAII" AZURE = "Azure" OPENAI = "OpenAI" - - diff --git a/llm-service/app/services/models/providers/_model_provider.py b/llm-service/app/services/models/providers/_model_provider.py index 954d93b2f..3ee5a6964 100644 --- a/llm-service/app/services/models/providers/_model_provider.py +++ b/llm-service/app/services/models/providers/_model_provider.py @@ -82,6 +82,11 @@ def get_provider_class() -> type["ModelProvider"]: return BedrockModelProvider return CAIIModelProvider + @staticmethod + @abc.abstractmethod + def get_model_source() -> ModelSource: + raise NotImplementedError + @staticmethod @abc.abstractmethod def get_env_var_names() -> set[str]: @@ -123,8 +128,3 @@ def get_embedding_model(name: str) -> BaseEmbedding: def get_reranking_model(name: str, top_n: int) -> BaseNodePostprocessor: """Return reranking model with `name`.""" raise NotImplementedError - - @staticmethod - @abc.abstractmethod - def get_model_source() -> ModelSource: - raise NotImplementedError diff --git a/llm-service/app/services/models/providers/azure.py b/llm-service/app/services/models/providers/azure.py index c079e1aa3..f59d55203 100644 --- a/llm-service/app/services/models/providers/azure.py +++ b/llm-service/app/services/models/providers/azure.py @@ -47,6 +47,10 @@ class AzureModelProvider(ModelProvider): + @staticmethod + def get_model_source() -> ModelSource: + return ModelSource.AZURE + @staticmethod def get_env_var_names() -> set[str]: return {"AZURE_OPENAI_API_KEY", "AZURE_OPENAI_ENDPOINT", "OPENAI_API_VERSION"} @@ -105,10 +109,6 @@ def get_embedding_model(name: str) -> AzureOpenAIEmbedding: def get_reranking_model(name: str, top_n: int) -> SimpleReranker: return SimpleReranker(top_n=top_n) - @staticmethod - def get_model_source() -> ModelSource: - return ModelSource.AZURE - # ensure interface is implemented _ = AzureModelProvider() diff --git a/llm-service/app/services/models/providers/bedrock.py b/llm-service/app/services/models/providers/bedrock.py index f74ab7f70..fe5026aa4 100644 --- a/llm-service/app/services/models/providers/bedrock.py +++ b/llm-service/app/services/models/providers/bedrock.py @@ -74,6 +74,10 @@ class BedrockModelProvider(ModelProvider): + @staticmethod + def get_model_source() -> ModelSource: + return ModelSource.BEDROCK + @staticmethod def get_env_var_names() -> set[str]: return {"AWS_ACCESS_KEY_ID", "AWS_SECRET_ACCESS_KEY", "AWS_DEFAULT_REGION"} @@ -309,10 +313,6 @@ def get_embedding_model(name: str) -> BedrockEmbedding: def get_reranking_model(name: str, top_n: int) -> AWSBedrockRerank: return AWSBedrockRerank(rerank_model_name=name, top_n=top_n) - @staticmethod - def get_model_source() -> ModelSource: - return ModelSource.BEDROCK - # ensure interface is implemented _ = BedrockModelProvider() diff --git a/llm-service/app/services/models/providers/caii.py b/llm-service/app/services/models/providers/caii.py index c5d37c5f6..750f07c69 100644 --- a/llm-service/app/services/models/providers/caii.py +++ b/llm-service/app/services/models/providers/caii.py @@ -51,7 +51,8 @@ get_llm as get_caii_llm_model, get_embedding_model as get_caii_embedding_model, get_reranking_model as get_caii_reranking_model, - describe_endpoint, get_models_with_task, + describe_endpoint, + get_models_with_task, ) from ...caii.types import ModelResponse from ...caii.utils import get_cml_version_from_sense_bootstrap @@ -60,6 +61,10 @@ class CAIIModelProvider(ModelProvider): + @staticmethod + def get_model_source() -> ModelSource: + return ModelSource.CAII + @staticmethod def get_env_var_names() -> set[str]: return {"CAII_DOMAIN"} @@ -112,10 +117,6 @@ def is_enabled(cls) -> bool: return super().is_enabled() - @staticmethod - def get_model_source() -> ModelSource: - return ModelSource.CAII - # ensure interface is implemented _ = CAIIModelProvider() diff --git a/llm-service/app/services/models/providers/openai.py b/llm-service/app/services/models/providers/openai.py index 62d258790..24801d0bb 100644 --- a/llm-service/app/services/models/providers/openai.py +++ b/llm-service/app/services/models/providers/openai.py @@ -51,6 +51,10 @@ class OpenAiModelProvider(ModelProvider): + @staticmethod + def get_model_source() -> ModelSource: + return ModelSource.OPENAI + @staticmethod def get_env_var_names() -> set[str]: return {"OPENAI_API_KEY"} @@ -119,10 +123,6 @@ def get_embedding_model(name: str) -> OpenAIEmbedding: def get_reranking_model(name: str, top_n: int) -> BaseNodePostprocessor: raise NotImplementedError("No reranking models available") - @staticmethod - def get_model_source() -> ModelSource: - return ModelSource.OPENAI - # ensure interface is implemented _ = OpenAiModelProvider() From a95f051a3d628c4048c0751cbe19a59f38058610 Mon Sep 17 00:00:00 2001 From: Michael Liu Date: Fri, 22 Aug 2025 14:58:06 -0700 Subject: [PATCH 4/9] Factor out get_provider_class() --- .../app/ai/indexing/summary_indexer.py | 6 ++-- llm-service/app/services/models/__init__.py | 4 +-- llm-service/app/services/models/embedding.py | 6 ++-- llm-service/app/services/models/llm.py | 6 ++-- .../app/services/models/providers/__init__.py | 25 +++++++++++++++- .../models/providers/_model_provider.py | 30 ------------------- llm-service/app/services/models/reranking.py | 8 ++--- 7 files changed, 37 insertions(+), 48 deletions(-) diff --git a/llm-service/app/ai/indexing/summary_indexer.py b/llm-service/app/ai/indexing/summary_indexer.py index d89b1d51c..bfacda142 100644 --- a/llm-service/app/ai/indexing/summary_indexer.py +++ b/llm-service/app/ai/indexing/summary_indexer.py @@ -75,7 +75,7 @@ from ..vector_stores.vector_store_factory import VectorStoreFactory from ...config import settings from ...services.metadata_apis import data_sources_metadata_api -from ...services.models.providers import ModelProvider +from ...services.models.providers import get_provider_class from ...services.models import ModelSource logger = logging.getLogger(__name__) @@ -133,9 +133,7 @@ def __index_configuration( embed_summaries: bool = True, ) -> Dict[str, Any]: prompt_helper: Optional[PromptHelper] = None - model_source: ModelSource = ( - ModelProvider.get_provider_class().get_model_source() - ) + model_source: ModelSource = get_provider_class().get_model_source() if model_source == "CAII": # if we're using CAII, let's be conservative and use a small context window to account for mistral's small context prompt_helper = PromptHelper(context_window=3000) diff --git a/llm-service/app/services/models/__init__.py b/llm-service/app/services/models/__init__.py index da248cc22..3dcfc48af 100644 --- a/llm-service/app/services/models/__init__.py +++ b/llm-service/app/services/models/__init__.py @@ -37,7 +37,7 @@ # from .embedding import Embedding from .llm import LLM -from .providers import ModelProvider +from .providers import get_provider_class from .reranking import Reranking from ._model_source import ModelSource @@ -45,4 +45,4 @@ def get_model_source() -> ModelSource: - return ModelProvider.get_provider_class().get_model_source() + return get_provider_class().get_model_source() diff --git a/llm-service/app/services/models/embedding.py b/llm-service/app/services/models/embedding.py index 9dd8f6b94..551308699 100644 --- a/llm-service/app/services/models/embedding.py +++ b/llm-service/app/services/models/embedding.py @@ -41,7 +41,7 @@ from llama_index.core.base.embeddings.base import BaseEmbedding from . import _model_type, _noop -from .providers._model_provider import ModelProvider +from .providers import get_provider_class from ..caii.types import ModelResponse @@ -51,7 +51,7 @@ def get(cls, model_name: Optional[str] = None) -> BaseEmbedding: if model_name is None: model_name = cls.list_available()[0].model_id - return ModelProvider.get_provider_class().get_embedding_model(model_name) + return get_provider_class().get_embedding_model(model_name) @staticmethod def get_noop() -> BaseEmbedding: @@ -59,7 +59,7 @@ def get_noop() -> BaseEmbedding: @staticmethod def list_available() -> list[ModelResponse]: - return ModelProvider.get_provider_class().list_embedding_models() + return get_provider_class().list_embedding_models() @classmethod def test(cls, model_name: str) -> str: diff --git a/llm-service/app/services/models/llm.py b/llm-service/app/services/models/llm.py index e8283ac32..15f126a63 100644 --- a/llm-service/app/services/models/llm.py +++ b/llm-service/app/services/models/llm.py @@ -41,7 +41,7 @@ from llama_index.core.base.llms.types import ChatMessage, MessageRole from . import _model_type, _noop -from .providers._model_provider import ModelProvider +from .providers import get_provider_class from ..caii.types import ModelResponse @@ -51,7 +51,7 @@ def get(cls, model_name: Optional[str] = None) -> llms.LLM: if not model_name: model_name = cls.list_available()[0].model_id - return ModelProvider.get_provider_class().get_llm_model(model_name) + return get_provider_class().get_llm_model(model_name) @staticmethod def get_noop() -> llms.LLM: @@ -59,7 +59,7 @@ def get_noop() -> llms.LLM: @staticmethod def list_available() -> list[ModelResponse]: - return ModelProvider.get_provider_class().list_llm_models() + return get_provider_class().list_llm_models() @classmethod def test(cls, model_name: str) -> Literal["ok"]: diff --git a/llm-service/app/services/models/providers/__init__.py b/llm-service/app/services/models/providers/__init__.py index 47ecc8c78..3d04b1049 100644 --- a/llm-service/app/services/models/providers/__init__.py +++ b/llm-service/app/services/models/providers/__init__.py @@ -35,6 +35,7 @@ # BUSINESS ADVANTAGE OR UNAVAILABILITY, OR LOSS OR CORRUPTION OF # DATA. # +from app.config import settings from .azure import AzureModelProvider from .bedrock import BedrockModelProvider from .caii import CAIIModelProvider @@ -46,5 +47,27 @@ "BedrockModelProvider", "CAIIModelProvider", "OpenAiModelProvider", - "ModelProvider", + "get_provider_class", ] + + +def get_provider_class() -> type[ModelProvider]: + """Return the ModelProvider subclass for the given provider name.""" + model_provider = settings.model_provider + if model_provider == "Azure": + return AzureModelProvider + elif model_provider == "CAII": + return CAIIModelProvider + elif model_provider == "OpenAI": + return OpenAiModelProvider + elif model_provider == "Bedrock": + return BedrockModelProvider + + # Fallback to priority order if no specific provider is set + if AzureModelProvider.is_enabled(): + return AzureModelProvider + elif OpenAiModelProvider.is_enabled(): + return OpenAiModelProvider + elif BedrockModelProvider.is_enabled(): + return BedrockModelProvider + return CAIIModelProvider diff --git a/llm-service/app/services/models/providers/_model_provider.py b/llm-service/app/services/models/providers/_model_provider.py index 3ee5a6964..77262ef0d 100644 --- a/llm-service/app/services/models/providers/_model_provider.py +++ b/llm-service/app/services/models/providers/_model_provider.py @@ -42,7 +42,6 @@ from llama_index.core.llms import LLM from llama_index.core.postprocessor.types import BaseNodePostprocessor -from app.config import settings from .._model_source import ModelSource from ...caii.types import ModelResponse @@ -53,35 +52,6 @@ def is_enabled(cls) -> bool: """Return whether this model provider is enabled, based on the presence of required env vars.""" return all(map(os.environ.get, cls.get_env_var_names())) - @staticmethod - def get_provider_class() -> type["ModelProvider"]: - """Return the ModelProvider subclass for the given provider name.""" - from . import ( - AzureModelProvider, - CAIIModelProvider, - OpenAiModelProvider, - BedrockModelProvider, - ) - - model_provider = settings.model_provider - if model_provider == "Azure": - return AzureModelProvider - elif model_provider == "CAII": - return CAIIModelProvider - elif model_provider == "OpenAI": - return OpenAiModelProvider - elif model_provider == "Bedrock": - return BedrockModelProvider - - # Fallback to priority order if no specific provider is set - if AzureModelProvider.is_enabled(): - return AzureModelProvider - elif OpenAiModelProvider.is_enabled(): - return OpenAiModelProvider - elif BedrockModelProvider.is_enabled(): - return BedrockModelProvider - return CAIIModelProvider - @staticmethod @abc.abstractmethod def get_model_source() -> ModelSource: diff --git a/llm-service/app/services/models/reranking.py b/llm-service/app/services/models/reranking.py index 95cc6f870..8aad6959d 100644 --- a/llm-service/app/services/models/reranking.py +++ b/llm-service/app/services/models/reranking.py @@ -42,7 +42,7 @@ from llama_index.core.schema import NodeWithScore, TextNode from . import _model_type -from .providers._model_provider import ModelProvider +from .providers import get_provider_class from ..caii.types import ModelResponse from ..query.simple_reranker import SimpleReranker @@ -57,9 +57,7 @@ def get( if not model_name: return SimpleReranker(top_n=top_n) - return ModelProvider.get_provider_class().get_reranking_model( - name=model_name, top_n=top_n - ) + return get_provider_class().get_reranking_model(name=model_name, top_n=top_n) @staticmethod def get_noop() -> BaseNodePostprocessor: @@ -67,7 +65,7 @@ def get_noop() -> BaseNodePostprocessor: @staticmethod def list_available() -> list[ModelResponse]: - return ModelProvider.get_provider_class().list_reranking_models() + return get_provider_class().list_reranking_models() @classmethod def test(cls, model_name: str) -> str: From 077f855db05eb096f5ce6803940c972cf6a3525d Mon Sep 17 00:00:00 2001 From: Michael Liu Date: Fri, 22 Aug 2025 15:01:53 -0700 Subject: [PATCH 5/9] Rename is_enabled() to env_vars_are_set() --- llm-service/app/services/amp_metadata/__init__.py | 10 +++++----- .../app/services/models/providers/__init__.py | 6 +++--- .../services/models/providers/_model_provider.py | 10 +++++----- llm-service/app/services/models/providers/azure.py | 8 ++++---- .../app/services/models/providers/bedrock.py | 8 ++++---- llm-service/app/services/models/providers/caii.py | 14 +++++++------- .../app/services/models/providers/openai.py | 8 ++++---- .../services/query/agents/tool_calling_querier.py | 2 +- 8 files changed, 33 insertions(+), 33 deletions(-) diff --git a/llm-service/app/services/amp_metadata/__init__.py b/llm-service/app/services/amp_metadata/__init__.py index 3baa8fb2b..da1328afa 100644 --- a/llm-service/app/services/amp_metadata/__init__.py +++ b/llm-service/app/services/amp_metadata/__init__.py @@ -216,13 +216,13 @@ def validate_model_config(environ: dict[str, str]) -> ValidationResult: f"Preferred provider {preferred_provider} is properly configured. \n" ) if preferred_provider == "Bedrock": - valid_model_config_exists = BedrockModelProvider.is_enabled() + valid_model_config_exists = BedrockModelProvider.env_vars_are_set() elif preferred_provider == "Azure": - valid_model_config_exists = AzureModelProvider.is_enabled() + valid_model_config_exists = AzureModelProvider.env_vars_are_set() elif preferred_provider == "OpenAI": - valid_model_config_exists = OpenAiModelProvider.is_enabled() + valid_model_config_exists = OpenAiModelProvider.env_vars_are_set() elif preferred_provider == "CAII": - valid_model_config_exists = CAIIModelProvider.is_enabled() + valid_model_config_exists = CAIIModelProvider.env_vars_are_set() return ValidationResult( valid=valid_model_config_exists, message=valid_message if valid_model_config_exists else message, @@ -276,7 +276,7 @@ def validate_model_config(environ: dict[str, str]) -> ValidationResult: if message == "": # check to see if CAII models are available via discovery - if CAIIModelProvider.is_enabled(): + if CAIIModelProvider.env_vars_are_set(): message = "CAII models are available." valid_model_config_exists = True else: diff --git a/llm-service/app/services/models/providers/__init__.py b/llm-service/app/services/models/providers/__init__.py index 3d04b1049..367ce04a3 100644 --- a/llm-service/app/services/models/providers/__init__.py +++ b/llm-service/app/services/models/providers/__init__.py @@ -64,10 +64,10 @@ def get_provider_class() -> type[ModelProvider]: return BedrockModelProvider # Fallback to priority order if no specific provider is set - if AzureModelProvider.is_enabled(): + if AzureModelProvider.env_vars_are_set(): return AzureModelProvider - elif OpenAiModelProvider.is_enabled(): + elif OpenAiModelProvider.env_vars_are_set(): return OpenAiModelProvider - elif BedrockModelProvider.is_enabled(): + elif BedrockModelProvider.env_vars_are_set(): return BedrockModelProvider return CAIIModelProvider diff --git a/llm-service/app/services/models/providers/_model_provider.py b/llm-service/app/services/models/providers/_model_provider.py index 77262ef0d..7c0002392 100644 --- a/llm-service/app/services/models/providers/_model_provider.py +++ b/llm-service/app/services/models/providers/_model_provider.py @@ -48,19 +48,19 @@ class ModelProvider(abc.ABC): @classmethod - def is_enabled(cls) -> bool: - """Return whether this model provider is enabled, based on the presence of required env vars.""" + def env_vars_are_set(cls) -> bool: + """Return whether this model provider's env vars have set values.""" return all(map(os.environ.get, cls.get_env_var_names())) @staticmethod @abc.abstractmethod - def get_model_source() -> ModelSource: + def get_env_var_names() -> set[str]: + """Return the names of the env vars required by this model provider.""" raise NotImplementedError @staticmethod @abc.abstractmethod - def get_env_var_names() -> set[str]: - """Return the names of the env vars required by this model provider.""" + def get_model_source() -> ModelSource: raise NotImplementedError @staticmethod diff --git a/llm-service/app/services/models/providers/azure.py b/llm-service/app/services/models/providers/azure.py index f59d55203..1798389c8 100644 --- a/llm-service/app/services/models/providers/azure.py +++ b/llm-service/app/services/models/providers/azure.py @@ -47,14 +47,14 @@ class AzureModelProvider(ModelProvider): - @staticmethod - def get_model_source() -> ModelSource: - return ModelSource.AZURE - @staticmethod def get_env_var_names() -> set[str]: return {"AZURE_OPENAI_API_KEY", "AZURE_OPENAI_ENDPOINT", "OPENAI_API_VERSION"} + @staticmethod + def get_model_source() -> ModelSource: + return ModelSource.AZURE + @staticmethod def list_llm_models() -> list[ModelResponse]: return [ diff --git a/llm-service/app/services/models/providers/bedrock.py b/llm-service/app/services/models/providers/bedrock.py index fe5026aa4..0ccc03f35 100644 --- a/llm-service/app/services/models/providers/bedrock.py +++ b/llm-service/app/services/models/providers/bedrock.py @@ -74,14 +74,14 @@ class BedrockModelProvider(ModelProvider): - @staticmethod - def get_model_source() -> ModelSource: - return ModelSource.BEDROCK - @staticmethod def get_env_var_names() -> set[str]: return {"AWS_ACCESS_KEY_ID", "AWS_SECRET_ACCESS_KEY", "AWS_DEFAULT_REGION"} + @staticmethod + def get_model_source() -> ModelSource: + return ModelSource.BEDROCK + @staticmethod def get_foundation_models( modality: Optional[BedrockModality] = None, diff --git a/llm-service/app/services/models/providers/caii.py b/llm-service/app/services/models/providers/caii.py index 750f07c69..aa6095641 100644 --- a/llm-service/app/services/models/providers/caii.py +++ b/llm-service/app/services/models/providers/caii.py @@ -61,14 +61,14 @@ class CAIIModelProvider(ModelProvider): - @staticmethod - def get_model_source() -> ModelSource: - return ModelSource.CAII - @staticmethod def get_env_var_names() -> set[str]: return {"CAII_DOMAIN"} + @staticmethod + def get_model_source() -> ModelSource: + return ModelSource.CAII + @staticmethod @timed_lru_cache(maxsize=1, seconds=300) def list_llm_models() -> list[ModelResponse]: @@ -105,17 +105,17 @@ def get_reranking_model(name: str, top_n: int) -> BaseNodePostprocessor: return get_caii_reranking_model(name, top_n) @classmethod - def is_enabled(cls) -> bool: + def env_vars_are_set(cls) -> bool: version: Optional[str] = get_cml_version_from_sense_bootstrap() if not version: - return super().is_enabled() + return super().env_vars_are_set() cml_version = Version(version) if cml_version >= Version("2.0.50-b68"): available_models = get_models_with_task("TEXT_GENERATION") if available_models: return True - return super().is_enabled() + return super().env_vars_are_set() # ensure interface is implemented diff --git a/llm-service/app/services/models/providers/openai.py b/llm-service/app/services/models/providers/openai.py index 24801d0bb..541c55239 100644 --- a/llm-service/app/services/models/providers/openai.py +++ b/llm-service/app/services/models/providers/openai.py @@ -51,14 +51,14 @@ class OpenAiModelProvider(ModelProvider): - @staticmethod - def get_model_source() -> ModelSource: - return ModelSource.OPENAI - @staticmethod def get_env_var_names() -> set[str]: return {"OPENAI_API_KEY"} + @staticmethod + def get_model_source() -> ModelSource: + return ModelSource.OPENAI + @staticmethod def list_llm_models() -> list[ModelResponse]: return [ diff --git a/llm-service/app/services/query/agents/tool_calling_querier.py b/llm-service/app/services/query/agents/tool_calling_querier.py index 954f1877d..6ad49b08c 100644 --- a/llm-service/app/services/query/agents/tool_calling_querier.py +++ b/llm-service/app/services/query/agents/tool_calling_querier.py @@ -409,7 +409,7 @@ async def agen() -> AsyncGenerator[ChatResponse, None]: # if delta is empty and response is empty, # it is a start to a tool call stream - if BedrockModelProvider.is_enabled(): + if BedrockModelProvider.env_vars_are_set(): delta = event.delta or "" if ( isinstance(event.raw, dict) From 53b3d8bc5a56332c373a130279c5896d3023dfc1 Mon Sep 17 00:00:00 2001 From: Michael Liu Date: Fri, 22 Aug 2025 15:16:25 -0700 Subject: [PATCH 6/9] Privatize ModelProvider base class --- .../app/services/models/providers/__init__.py | 4 ++-- .../services/models/providers/_model_provider.py | 2 +- .../app/services/models/providers/azure.py | 4 ++-- .../app/services/models/providers/bedrock.py | 4 ++-- .../app/services/models/providers/caii.py | 4 ++-- .../app/services/models/providers/openai.py | 4 ++-- llm-service/app/tests/services/test_models.py | 16 ++++++++-------- 7 files changed, 19 insertions(+), 19 deletions(-) diff --git a/llm-service/app/services/models/providers/__init__.py b/llm-service/app/services/models/providers/__init__.py index 367ce04a3..c7235db63 100644 --- a/llm-service/app/services/models/providers/__init__.py +++ b/llm-service/app/services/models/providers/__init__.py @@ -40,7 +40,7 @@ from .bedrock import BedrockModelProvider from .caii import CAIIModelProvider from .openai import OpenAiModelProvider -from ._model_provider import ModelProvider +from ._model_provider import _ModelProvider __all__ = [ "AzureModelProvider", @@ -51,7 +51,7 @@ ] -def get_provider_class() -> type[ModelProvider]: +def get_provider_class() -> type[_ModelProvider]: """Return the ModelProvider subclass for the given provider name.""" model_provider = settings.model_provider if model_provider == "Azure": diff --git a/llm-service/app/services/models/providers/_model_provider.py b/llm-service/app/services/models/providers/_model_provider.py index 7c0002392..71b5400ac 100644 --- a/llm-service/app/services/models/providers/_model_provider.py +++ b/llm-service/app/services/models/providers/_model_provider.py @@ -46,7 +46,7 @@ from ...caii.types import ModelResponse -class ModelProvider(abc.ABC): +class _ModelProvider(abc.ABC): @classmethod def env_vars_are_set(cls) -> bool: """Return whether this model provider's env vars have set values.""" diff --git a/llm-service/app/services/models/providers/azure.py b/llm-service/app/services/models/providers/azure.py index 1798389c8..6e6ba0f4b 100644 --- a/llm-service/app/services/models/providers/azure.py +++ b/llm-service/app/services/models/providers/azure.py @@ -38,7 +38,7 @@ from llama_index.embeddings.azure_openai import AzureOpenAIEmbedding from llama_index.llms.azure_openai import AzureOpenAI -from ._model_provider import ModelProvider +from ._model_provider import _ModelProvider from .._model_source import ModelSource from ...caii.types import ModelResponse from ...llama_utils import completion_to_prompt, messages_to_prompt @@ -46,7 +46,7 @@ from ....config import settings -class AzureModelProvider(ModelProvider): +class AzureModelProvider(_ModelProvider): @staticmethod def get_env_var_names() -> set[str]: return {"AZURE_OPENAI_API_KEY", "AZURE_OPENAI_ENDPOINT", "OPENAI_API_VERSION"} diff --git a/llm-service/app/services/models/providers/bedrock.py b/llm-service/app/services/models/providers/bedrock.py index 0ccc03f35..b51a37621 100644 --- a/llm-service/app/services/models/providers/bedrock.py +++ b/llm-service/app/services/models/providers/bedrock.py @@ -52,7 +52,7 @@ from pydantic import TypeAdapter from app.config import settings -from ._model_provider import ModelProvider +from ._model_provider import _ModelProvider from .._model_source import ModelSource from ...caii.types import ModelResponse from ...llama_utils import completion_to_prompt, messages_to_prompt @@ -73,7 +73,7 @@ } -class BedrockModelProvider(ModelProvider): +class BedrockModelProvider(_ModelProvider): @staticmethod def get_env_var_names() -> set[str]: return {"AWS_ACCESS_KEY_ID", "AWS_SECRET_ACCESS_KEY", "AWS_DEFAULT_REGION"} diff --git a/llm-service/app/services/models/providers/caii.py b/llm-service/app/services/models/providers/caii.py index aa6095641..d90b6433e 100644 --- a/llm-service/app/services/models/providers/caii.py +++ b/llm-service/app/services/models/providers/caii.py @@ -42,7 +42,7 @@ from llama_index.core.postprocessor.types import BaseNodePostprocessor from packaging.version import Version -from ._model_provider import ModelProvider +from ._model_provider import _ModelProvider from .._model_source import ModelSource from ...caii.caii import ( get_caii_llm_models, @@ -60,7 +60,7 @@ from ...utils import timed_lru_cache -class CAIIModelProvider(ModelProvider): +class CAIIModelProvider(_ModelProvider): @staticmethod def get_env_var_names() -> set[str]: return {"CAII_DOMAIN"} diff --git a/llm-service/app/services/models/providers/openai.py b/llm-service/app/services/models/providers/openai.py index 541c55239..9c330994e 100644 --- a/llm-service/app/services/models/providers/openai.py +++ b/llm-service/app/services/models/providers/openai.py @@ -43,14 +43,14 @@ from llama_index.embeddings.openai import OpenAIEmbedding from llama_index.llms.openai import OpenAI -from ._model_provider import ModelProvider +from ._model_provider import _ModelProvider from .._model_source import ModelSource from ...caii.types import ModelResponse from ...llama_utils import completion_to_prompt, messages_to_prompt from ....config import settings -class OpenAiModelProvider(ModelProvider): +class OpenAiModelProvider(_ModelProvider): @staticmethod def get_env_var_names() -> set[str]: return {"OPENAI_API_KEY"} diff --git a/llm-service/app/tests/services/test_models.py b/llm-service/app/tests/services/test_models.py index 3b9fcfe5e..2397f18f5 100644 --- a/llm-service/app/tests/services/test_models.py +++ b/llm-service/app/tests/services/test_models.py @@ -43,25 +43,25 @@ from app.services.caii import caii from app.services.caii.types import ListEndpointEntry from app.services.models.providers import BedrockModelProvider -from app.services.models.providers._model_provider import ModelProvider +from app.services.models.providers._model_provider import _ModelProvider def get_all_env_var_names() -> set[str]: """Return the names of all the env vars required by all model providers.""" return set( itertools.chain.from_iterable( - subcls.get_env_var_names() for subcls in ModelProvider.__subclasses__() + subcls.get_env_var_names() for subcls in _ModelProvider.__subclasses__() ) ) -@pytest.fixture(params=ModelProvider.__subclasses__()) +@pytest.fixture(params=_ModelProvider.__subclasses__()) def EnabledModelProvider( request: pytest.FixtureRequest, monkeypatch: pytest.MonkeyPatch, -) -> type[ModelProvider]: +) -> type[_ModelProvider]: """Sets and unsets environment variables for the given model provider.""" - ModelProviderSubcls: type[ModelProvider] = request.param + ModelProviderSubcls: type[_ModelProvider] = request.param for name in get_all_env_var_names(): monkeypatch.delenv(name, raising=False) @@ -98,18 +98,18 @@ def get_foundation_models(self, monkeypatch: pytest.MonkeyPatch) -> None: BedrockModelProvider, "get_foundation_models", lambda modality: [] ) - def test_embedding(self, EnabledModelProvider: type[ModelProvider]) -> None: + def test_embedding(self, EnabledModelProvider: type[_ModelProvider]) -> None: """Verify models.Embedding.list_available() only returns models from the enabled model provider.""" assert ( models.Embedding.list_available() == EnabledModelProvider.list_embedding_models() ) - def test_llm(self, EnabledModelProvider: type[ModelProvider]) -> None: + def test_llm(self, EnabledModelProvider: type[_ModelProvider]) -> None: """Verify models.LLM.list_available() only returns models from the enabled model provider.""" assert models.LLM.list_available() == EnabledModelProvider.list_llm_models() - def test_reranking(self, EnabledModelProvider: type[ModelProvider]) -> None: + def test_reranking(self, EnabledModelProvider: type[_ModelProvider]) -> None: """Verify models.Reranking.list_available() only returns models from the enabled model provider.""" assert ( models.Reranking.list_available() From d46955ccd35943a3c46bc5309cb2e7ea9594fe28 Mon Sep 17 00:00:00 2001 From: Michael Liu Date: Fri, 22 Aug 2025 15:29:14 -0700 Subject: [PATCH 7/9] Rework get_provider_class() --- llm-service/app/services/models/__init__.py | 2 +- .../app/services/models/_model_source.py | 4 +- .../app/services/models/providers/__init__.py | 43 ++++++++++++------- .../models/providers/_model_provider.py | 11 +++++ .../app/services/models/providers/azure.py | 4 ++ .../app/services/models/providers/bedrock.py | 4 ++ .../app/services/models/providers/caii.py | 4 ++ .../app/services/models/providers/openai.py | 4 ++ 8 files changed, 58 insertions(+), 18 deletions(-) diff --git a/llm-service/app/services/models/__init__.py b/llm-service/app/services/models/__init__.py index 3dcfc48af..8ca22af2f 100644 --- a/llm-service/app/services/models/__init__.py +++ b/llm-service/app/services/models/__init__.py @@ -41,7 +41,7 @@ from .reranking import Reranking from ._model_source import ModelSource -__all__ = ["Embedding", "LLM", "Reranking", "ModelSource"] +__all__ = ["Embedding", "LLM", "Reranking", "ModelSource", "get_model_source"] def get_model_source() -> ModelSource: diff --git a/llm-service/app/services/models/_model_source.py b/llm-service/app/services/models/_model_source.py index 618b2646e..fec3d2b58 100644 --- a/llm-service/app/services/models/_model_source.py +++ b/llm-service/app/services/models/_model_source.py @@ -40,7 +40,7 @@ class ModelSource(str, Enum): - BEDROCK = "Bedrock" - CAII = "CAII" AZURE = "Azure" OPENAI = "OpenAI" + BEDROCK = "Bedrock" + CAII = "CAII" diff --git a/llm-service/app/services/models/providers/__init__.py b/llm-service/app/services/models/providers/__init__.py index c7235db63..d7d93311d 100644 --- a/llm-service/app/services/models/providers/__init__.py +++ b/llm-service/app/services/models/providers/__init__.py @@ -35,6 +35,8 @@ # BUSINESS ADVANTAGE OR UNAVAILABILITY, OR LOSS OR CORRUPTION OF # DATA. # +import logging + from app.config import settings from .azure import AzureModelProvider from .bedrock import BedrockModelProvider @@ -42,6 +44,8 @@ from .openai import OpenAiModelProvider from ._model_provider import _ModelProvider +logger = logging.getLogger(__name__) + __all__ = [ "AzureModelProvider", "BedrockModelProvider", @@ -53,21 +57,30 @@ def get_provider_class() -> type[_ModelProvider]: """Return the ModelProvider subclass for the given provider name.""" + model_providers: list[type[_ModelProvider]] = sorted( + _ModelProvider.__subclasses__(), + key=lambda ModelProviderSubcls: ModelProviderSubcls.get_priority(), + ) + model_provider = settings.model_provider - if model_provider == "Azure": - return AzureModelProvider - elif model_provider == "CAII": - return CAIIModelProvider - elif model_provider == "OpenAI": - return OpenAiModelProvider - elif model_provider == "Bedrock": - return BedrockModelProvider + for ModelProviderSubcls in model_providers: + if model_provider == ModelProviderSubcls.get_model_source(): + logger.debug( + "using model provider %s based on `MODEL_PROVIDER` env var: %s", + ModelProviderSubcls, + model_provider, + ) + return ModelProviderSubcls + + # Fallback if no specific provider is set + for ModelProviderSubcls in model_providers: + if ModelProviderSubcls.env_vars_are_set(): + logger.debug( + "falling back to model provider %s based on env vars: %s", + ModelProviderSubcls, + ModelProviderSubcls.get_env_var_names(), + ) + return ModelProviderSubcls - # Fallback to priority order if no specific provider is set - if AzureModelProvider.env_vars_are_set(): - return AzureModelProvider - elif OpenAiModelProvider.env_vars_are_set(): - return OpenAiModelProvider - elif BedrockModelProvider.env_vars_are_set(): - return BedrockModelProvider + logger.debug("falling back to model provider CAII") return CAIIModelProvider diff --git a/llm-service/app/services/models/providers/_model_provider.py b/llm-service/app/services/models/providers/_model_provider.py index 71b5400ac..f83106c07 100644 --- a/llm-service/app/services/models/providers/_model_provider.py +++ b/llm-service/app/services/models/providers/_model_provider.py @@ -61,6 +61,17 @@ def get_env_var_names() -> set[str]: @staticmethod @abc.abstractmethod def get_model_source() -> ModelSource: + """Return the name of this model provider""" + raise NotImplementedError + + @staticmethod + @abc.abstractmethod + def get_priority() -> int: + """Return the priority of this model provider relative to the others. + + 1 is the highest priority. + + """ raise NotImplementedError @staticmethod diff --git a/llm-service/app/services/models/providers/azure.py b/llm-service/app/services/models/providers/azure.py index 6e6ba0f4b..ca06d390f 100644 --- a/llm-service/app/services/models/providers/azure.py +++ b/llm-service/app/services/models/providers/azure.py @@ -55,6 +55,10 @@ def get_env_var_names() -> set[str]: def get_model_source() -> ModelSource: return ModelSource.AZURE + @staticmethod + def get_priority() -> int: + return 1 + @staticmethod def list_llm_models() -> list[ModelResponse]: return [ diff --git a/llm-service/app/services/models/providers/bedrock.py b/llm-service/app/services/models/providers/bedrock.py index b51a37621..7b45371be 100644 --- a/llm-service/app/services/models/providers/bedrock.py +++ b/llm-service/app/services/models/providers/bedrock.py @@ -82,6 +82,10 @@ def get_env_var_names() -> set[str]: def get_model_source() -> ModelSource: return ModelSource.BEDROCK + @staticmethod + def get_priority() -> int: + return 3 + @staticmethod def get_foundation_models( modality: Optional[BedrockModality] = None, diff --git a/llm-service/app/services/models/providers/caii.py b/llm-service/app/services/models/providers/caii.py index d90b6433e..85da16acf 100644 --- a/llm-service/app/services/models/providers/caii.py +++ b/llm-service/app/services/models/providers/caii.py @@ -69,6 +69,10 @@ def get_env_var_names() -> set[str]: def get_model_source() -> ModelSource: return ModelSource.CAII + @staticmethod + def get_priority() -> int: + return 4 + @staticmethod @timed_lru_cache(maxsize=1, seconds=300) def list_llm_models() -> list[ModelResponse]: diff --git a/llm-service/app/services/models/providers/openai.py b/llm-service/app/services/models/providers/openai.py index 9c330994e..1e4416df4 100644 --- a/llm-service/app/services/models/providers/openai.py +++ b/llm-service/app/services/models/providers/openai.py @@ -59,6 +59,10 @@ def get_env_var_names() -> set[str]: def get_model_source() -> ModelSource: return ModelSource.OPENAI + @staticmethod + def get_priority() -> int: + return 2 + @staticmethod def list_llm_models() -> list[ModelResponse]: return [ From 4af55d77fa9a4652e9fff99aab2879368cb19b94 Mon Sep 17 00:00:00 2001 From: Michael Liu Date: Fri, 22 Aug 2025 16:03:58 -0700 Subject: [PATCH 8/9] Remove redundant ModelProviderType --- .../app/ai/indexing/summary_indexer.py | 3 +- llm-service/app/config.py | 18 ++++++-- .../app/routers/index/models/__init__.py | 4 +- .../app/services/amp_metadata/__init__.py | 6 +-- llm-service/app/services/models/__init__.py | 4 +- .../app/services/models/_model_source.py | 46 ------------------- .../models/providers/_model_provider.py | 2 +- .../app/services/models/providers/azure.py | 3 +- .../app/services/models/providers/bedrock.py | 3 +- .../app/services/models/providers/caii.py | 2 +- .../app/services/models/providers/openai.py | 3 +- llm-service/app/services/query/querier.py | 2 +- 12 files changed, 27 insertions(+), 69 deletions(-) delete mode 100644 llm-service/app/services/models/_model_source.py diff --git a/llm-service/app/ai/indexing/summary_indexer.py b/llm-service/app/ai/indexing/summary_indexer.py index bfacda142..83a18d974 100644 --- a/llm-service/app/ai/indexing/summary_indexer.py +++ b/llm-service/app/ai/indexing/summary_indexer.py @@ -73,10 +73,9 @@ from .base import BaseTextIndexer from .readers.base_reader import ReaderConfig, ChunksResult from ..vector_stores.vector_store_factory import VectorStoreFactory -from ...config import settings +from ...config import settings, ModelSource from ...services.metadata_apis import data_sources_metadata_api from ...services.models.providers import get_provider_class -from ...services.models import ModelSource logger = logging.getLogger(__name__) diff --git a/llm-service/app/config.py b/llm-service/app/config.py index 85aa60430..ff882223b 100644 --- a/llm-service/app/config.py +++ b/llm-service/app/config.py @@ -46,6 +46,7 @@ import logging import os.path +from enum import Enum from typing import cast, Optional, Literal @@ -53,7 +54,13 @@ ChatStoreProviderType = Literal["Local", "S3"] VectorDbProviderType = Literal["QDRANT", "OPENSEARCH"] MetadataDbProviderType = Literal["H2", "PostgreSQL"] -ModelProviderType = Literal["Azure", "CAII", "OpenAI", "Bedrock"] + + +class ModelSource(str, Enum): + AZURE = "Azure" + OPENAI = "OpenAI" + BEDROCK = "Bedrock" + CAII = "CAII" class _Settings: @@ -185,14 +192,15 @@ def openai_api_base(self) -> Optional[str]: return os.environ.get("OPENAI_API_BASE") @property - def model_provider(self) -> Optional[ModelProviderType]: + def model_provider(self) -> Optional[ModelSource]: """The preferred model provider to use. Options: 'AZURE', 'CAII', 'OPENAI', 'BEDROCK' If not set, will use the first available provider in priority order.""" provider = os.environ.get("MODEL_PROVIDER") - if provider and provider in ["Azure", "CAII", "OpenAI", "Bedrock"]: - return cast(ModelProviderType, provider) - return None + try: + return ModelSource(provider) + except ValueError: + return None settings = _Settings() diff --git a/llm-service/app/routers/index/models/__init__.py b/llm-service/app/routers/index/models/__init__.py index ddddb2a0b..5e8185830 100644 --- a/llm-service/app/routers/index/models/__init__.py +++ b/llm-service/app/routers/index/models/__init__.py @@ -40,7 +40,7 @@ from fastapi import APIRouter import app.services.models -import app.services.models._model_source +from app.config import ModelSource from .... import exceptions from ....services import models from ....services.caii.caii import describe_endpoint, build_model_response @@ -71,7 +71,7 @@ def get_reranking_models() -> List[ModelResponse]: "/model_source", summary="Model source enabled - Bedrock, CAII, OpenAI or Azure" ) @exceptions.propagates -def get_model() -> app.services.models._model_source.ModelSource: +def get_model() -> ModelSource: return app.services.models.get_model_source() diff --git a/llm-service/app/services/amp_metadata/__init__.py b/llm-service/app/services/amp_metadata/__init__.py index da1328afa..0b8b86277 100644 --- a/llm-service/app/services/amp_metadata/__init__.py +++ b/llm-service/app/services/amp_metadata/__init__.py @@ -50,7 +50,7 @@ ChatStoreProviderType, VectorDbProviderType, MetadataDbProviderType, - ModelProviderType, + ModelSource, ) from app.services.models.providers import ( CAIIModelProvider, @@ -136,7 +136,7 @@ class ProjectConfig(BaseModel): chat_store_provider: ChatStoreProviderType vector_db_provider: VectorDbProviderType metadata_db_provider: MetadataDbProviderType - model_provider: Optional[ModelProviderType] = None + model_provider: Optional[ModelSource] = None aws_config: AwsConfig azure_config: AzureConfig caii_config: CaiiConfig @@ -388,7 +388,7 @@ def build_configuration( validate_config = validate(frozenset(env.items())) model_provider = ( - TypeAdapter(ModelProviderType).validate_python(env.get("MODEL_PROVIDER")) + TypeAdapter(ModelSource).validate_python(env.get("MODEL_PROVIDER")) if env.get("MODEL_PROVIDER") else None ) diff --git a/llm-service/app/services/models/__init__.py b/llm-service/app/services/models/__init__.py index 8ca22af2f..d6d6e0211 100644 --- a/llm-service/app/services/models/__init__.py +++ b/llm-service/app/services/models/__init__.py @@ -39,9 +39,9 @@ from .llm import LLM from .providers import get_provider_class from .reranking import Reranking -from ._model_source import ModelSource +from ...config import ModelSource -__all__ = ["Embedding", "LLM", "Reranking", "ModelSource", "get_model_source"] +__all__ = ["Embedding", "LLM", "Reranking", "get_model_source"] def get_model_source() -> ModelSource: diff --git a/llm-service/app/services/models/_model_source.py b/llm-service/app/services/models/_model_source.py deleted file mode 100644 index fec3d2b58..000000000 --- a/llm-service/app/services/models/_model_source.py +++ /dev/null @@ -1,46 +0,0 @@ -# -# CLOUDERA APPLIED MACHINE LEARNING PROTOTYPE (AMP) -# (C) Cloudera, Inc. 2025 -# All rights reserved. -# -# Applicable Open Source License: Apache 2.0 -# -# NOTE: Cloudera open source products are modular software products -# made up of hundreds of individual components, each of which was -# individually copyrighted. Each Cloudera open source product is a -# collective work under U.S. Copyright Law. Your license to use the -# collective work is as provided in your written agreement with -# Cloudera. Used apart from the collective work, this file is -# licensed for your use pursuant to the open source license -# identified above. -# -# This code is provided to you pursuant a written agreement with -# (i) Cloudera, Inc. or (ii) a third-party authorized to distribute -# this code. If you do not have a written agreement with Cloudera nor -# with an authorized and properly licensed third party, you do not -# have any rights to access nor to use this code. -# -# Absent a written agreement with Cloudera, Inc. ("Cloudera") to the -# contrary, A) CLOUDERA PROVIDES THIS CODE TO YOU WITHOUT WARRANTIES OF ANY -# KIND; (B) CLOUDERA DISCLAIMS ANY AND ALL EXPRESS AND IMPLIED -# WARRANTIES WITH RESPECT TO THIS CODE, INCLUDING BUT NOT LIMITED TO -# IMPLIED WARRANTIES OF TITLE, NON-INFRINGEMENT, MERCHANTABILITY AND -# FITNESS FOR A PARTICULAR PURPOSE; (C) CLOUDERA IS NOT LIABLE TO YOU, -# AND WILL NOT DEFEND, INDEMNIFY, NOR HOLD YOU HARMLESS FOR ANY CLAIMS -# ARISING FROM OR RELATED TO THE CODE; AND (D)WITH RESPECT TO YOUR EXERCISE -# OF ANY RIGHTS GRANTED TO YOU FOR THE CODE, CLOUDERA IS NOT LIABLE FOR ANY -# DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, PUNITIVE OR -# CONSEQUENTIAL DAMAGES INCLUDING, BUT NOT LIMITED TO, DAMAGES -# RELATED TO LOST REVENUE, LOST PROFITS, LOSS OF INCOME, LOSS OF -# BUSINESS ADVANTAGE OR UNAVAILABILITY, OR LOSS OR CORRUPTION OF -# DATA. -# - -from enum import Enum - - -class ModelSource(str, Enum): - AZURE = "Azure" - OPENAI = "OpenAI" - BEDROCK = "Bedrock" - CAII = "CAII" diff --git a/llm-service/app/services/models/providers/_model_provider.py b/llm-service/app/services/models/providers/_model_provider.py index f83106c07..e591c5d40 100644 --- a/llm-service/app/services/models/providers/_model_provider.py +++ b/llm-service/app/services/models/providers/_model_provider.py @@ -42,7 +42,7 @@ from llama_index.core.llms import LLM from llama_index.core.postprocessor.types import BaseNodePostprocessor -from .._model_source import ModelSource +from app.config import ModelSource from ...caii.types import ModelResponse diff --git a/llm-service/app/services/models/providers/azure.py b/llm-service/app/services/models/providers/azure.py index ca06d390f..da222b5f6 100644 --- a/llm-service/app/services/models/providers/azure.py +++ b/llm-service/app/services/models/providers/azure.py @@ -39,11 +39,10 @@ from llama_index.llms.azure_openai import AzureOpenAI from ._model_provider import _ModelProvider -from .._model_source import ModelSource from ...caii.types import ModelResponse from ...llama_utils import completion_to_prompt, messages_to_prompt from ...query.simple_reranker import SimpleReranker -from ....config import settings +from ....config import settings, ModelSource class AzureModelProvider(_ModelProvider): diff --git a/llm-service/app/services/models/providers/bedrock.py b/llm-service/app/services/models/providers/bedrock.py index 7b45371be..241137cfc 100644 --- a/llm-service/app/services/models/providers/bedrock.py +++ b/llm-service/app/services/models/providers/bedrock.py @@ -51,9 +51,8 @@ from llama_index.postprocessor.bedrock_rerank import AWSBedrockRerank from pydantic import TypeAdapter -from app.config import settings +from app.config import settings, ModelSource from ._model_provider import _ModelProvider -from .._model_source import ModelSource from ...caii.types import ModelResponse from ...llama_utils import completion_to_prompt, messages_to_prompt from ...utils import raise_for_http_error, timed_lru_cache diff --git a/llm-service/app/services/models/providers/caii.py b/llm-service/app/services/models/providers/caii.py index 85da16acf..1f0de10e0 100644 --- a/llm-service/app/services/models/providers/caii.py +++ b/llm-service/app/services/models/providers/caii.py @@ -43,7 +43,7 @@ from packaging.version import Version from ._model_provider import _ModelProvider -from .._model_source import ModelSource +from app.config import ModelSource from ...caii.caii import ( get_caii_llm_models, get_caii_embedding_models, diff --git a/llm-service/app/services/models/providers/openai.py b/llm-service/app/services/models/providers/openai.py index 1e4416df4..0f67099fe 100644 --- a/llm-service/app/services/models/providers/openai.py +++ b/llm-service/app/services/models/providers/openai.py @@ -44,10 +44,9 @@ from llama_index.llms.openai import OpenAI from ._model_provider import _ModelProvider -from .._model_source import ModelSource from ...caii.types import ModelResponse from ...llama_utils import completion_to_prompt, messages_to_prompt -from ....config import settings +from ....config import settings, ModelSource class OpenAiModelProvider(_ModelProvider): diff --git a/llm-service/app/services/query/querier.py b/llm-service/app/services/query/querier.py index 299e67a77..bf706341c 100644 --- a/llm-service/app/services/query/querier.py +++ b/llm-service/app/services/query/querier.py @@ -53,7 +53,7 @@ from .flexible_retriever import FlexibleRetriever from .multi_retriever import MultiSourceRetriever from ..metadata_apis.session_metadata_api import Session -from ..models._model_source import ModelSource +from ...config import ModelSource from ..models import get_model_source if TYPE_CHECKING: From cc81c1a36f8486c1a69026939d64da9d9c46f3f7 Mon Sep 17 00:00:00 2001 From: Michael Liu Date: Fri, 22 Aug 2025 16:12:30 -0700 Subject: [PATCH 9/9] Clean up logs --- .../app/services/models/providers/__init__.py | 15 +++++++-------- 1 file changed, 7 insertions(+), 8 deletions(-) diff --git a/llm-service/app/services/models/providers/__init__.py b/llm-service/app/services/models/providers/__init__.py index d7d93311d..7dfb6d34c 100644 --- a/llm-service/app/services/models/providers/__init__.py +++ b/llm-service/app/services/models/providers/__init__.py @@ -65,22 +65,21 @@ def get_provider_class() -> type[_ModelProvider]: model_provider = settings.model_provider for ModelProviderSubcls in model_providers: if model_provider == ModelProviderSubcls.get_model_source(): - logger.debug( - "using model provider %s based on `MODEL_PROVIDER` env var: %s", - ModelProviderSubcls, - model_provider, + logger.info( + 'using model provider "%s" based on `MODEL_PROVIDER` env var', + ModelProviderSubcls.get_model_source().value, ) return ModelProviderSubcls # Fallback if no specific provider is set for ModelProviderSubcls in model_providers: if ModelProviderSubcls.env_vars_are_set(): - logger.debug( - "falling back to model provider %s based on env vars: %s", - ModelProviderSubcls, + logger.info( + 'falling back to model provider "%s" based on env vars %s', + ModelProviderSubcls.get_model_source().value, ModelProviderSubcls.get_env_var_names(), ) return ModelProviderSubcls - logger.debug("falling back to model provider CAII") + logger.info('falling back to model provider "CAII"') return CAIIModelProvider