Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 3 additions & 6 deletions llm-service/app/ai/indexing/summary_indexer.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,10 +73,9 @@
from .base import BaseTextIndexer
from .readers.base_reader import ReaderConfig, ChunksResult
from ..vector_stores.vector_store_factory import VectorStoreFactory
from ...config import settings
from ...config import settings, ModelSource
from ...services.metadata_apis import data_sources_metadata_api
from ...services.models.providers import ModelProvider
from ...services.models import ModelSource
from ...services.models.providers import get_provider_class

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -133,9 +132,7 @@ def __index_configuration(
embed_summaries: bool = True,
) -> Dict[str, Any]:
prompt_helper: Optional[PromptHelper] = None
model_source: ModelSource = (
ModelProvider.get_provider_class().get_model_source()
)
model_source: ModelSource = get_provider_class().get_model_source()
if model_source == "CAII":
# if we're using CAII, let's be conservative and use a small context window to account for mistral's small context
prompt_helper = PromptHelper(context_window=3000)
Expand Down
18 changes: 13 additions & 5 deletions llm-service/app/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,14 +46,21 @@

import logging
import os.path
from enum import Enum
from typing import cast, Optional, Literal


SummaryStorageProviderType = Literal["Local", "S3"]
ChatStoreProviderType = Literal["Local", "S3"]
VectorDbProviderType = Literal["QDRANT", "OPENSEARCH"]
MetadataDbProviderType = Literal["H2", "PostgreSQL"]
ModelProviderType = Literal["Azure", "CAII", "OpenAI", "Bedrock"]


class ModelSource(str, Enum):
AZURE = "Azure"
OPENAI = "OpenAI"
BEDROCK = "Bedrock"
CAII = "CAII"
Comment on lines -56 to +63
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This probably isn't the best place for ModelSource (I moved it here from llm-service/app/services/models/ to avoid a circular import), but I wanted to replace and remove the redundant ModelProviderType.



class _Settings:
Expand Down Expand Up @@ -185,14 +192,15 @@ def openai_api_base(self) -> Optional[str]:
return os.environ.get("OPENAI_API_BASE")

@property
def model_provider(self) -> Optional[ModelProviderType]:
def model_provider(self) -> Optional[ModelSource]:
"""The preferred model provider to use.
Options: 'AZURE', 'CAII', 'OPENAI', 'BEDROCK'
If not set, will use the first available provider in priority order."""
provider = os.environ.get("MODEL_PROVIDER")
if provider and provider in ["Azure", "CAII", "OpenAI", "Bedrock"]:
return cast(ModelProviderType, provider)
return None
try:
return ModelSource(provider)
except ValueError:
return None


settings = _Settings()
4 changes: 2 additions & 2 deletions llm-service/app/routers/index/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@
from fastapi import APIRouter

import app.services.models
import app.services.models._model_source
from app.config import ModelSource
from .... import exceptions
from ....services import models
from ....services.caii.caii import describe_endpoint, build_model_response
Expand Down Expand Up @@ -71,7 +71,7 @@ def get_reranking_models() -> List[ModelResponse]:
"/model_source", summary="Model source enabled - Bedrock, CAII, OpenAI or Azure"
)
@exceptions.propagates
def get_model() -> app.services.models._model_source.ModelSource:
def get_model() -> ModelSource:
return app.services.models.get_model_source()


Expand Down
16 changes: 8 additions & 8 deletions llm-service/app/services/amp_metadata/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@
ChatStoreProviderType,
VectorDbProviderType,
MetadataDbProviderType,
ModelProviderType,
ModelSource,
)
from app.services.models.providers import (
CAIIModelProvider,
Expand Down Expand Up @@ -136,7 +136,7 @@ class ProjectConfig(BaseModel):
chat_store_provider: ChatStoreProviderType
vector_db_provider: VectorDbProviderType
metadata_db_provider: MetadataDbProviderType
model_provider: Optional[ModelProviderType] = None
model_provider: Optional[ModelSource] = None
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This still comes back as a string when the frontend calls /llm-service/amp/config 👍

aws_config: AwsConfig
azure_config: AzureConfig
caii_config: CaiiConfig
Expand Down Expand Up @@ -216,13 +216,13 @@ def validate_model_config(environ: dict[str, str]) -> ValidationResult:
f"Preferred provider {preferred_provider} is properly configured. \n"
)
if preferred_provider == "Bedrock":
valid_model_config_exists = BedrockModelProvider.is_enabled()
valid_model_config_exists = BedrockModelProvider.env_vars_are_set()
elif preferred_provider == "Azure":
valid_model_config_exists = AzureModelProvider.is_enabled()
valid_model_config_exists = AzureModelProvider.env_vars_are_set()
elif preferred_provider == "OpenAI":
valid_model_config_exists = OpenAiModelProvider.is_enabled()
valid_model_config_exists = OpenAiModelProvider.env_vars_are_set()
elif preferred_provider == "CAII":
valid_model_config_exists = CAIIModelProvider.is_enabled()
valid_model_config_exists = CAIIModelProvider.env_vars_are_set()
return ValidationResult(
valid=valid_model_config_exists,
message=valid_message if valid_model_config_exists else message,
Expand Down Expand Up @@ -276,7 +276,7 @@ def validate_model_config(environ: dict[str, str]) -> ValidationResult:

if message == "":
# check to see if CAII models are available via discovery
if CAIIModelProvider.is_enabled():
if CAIIModelProvider.env_vars_are_set():
message = "CAII models are available."
valid_model_config_exists = True
else:
Expand Down Expand Up @@ -388,7 +388,7 @@ def build_configuration(
validate_config = validate(frozenset(env.items()))

model_provider = (
TypeAdapter(ModelProviderType).validate_python(env.get("MODEL_PROVIDER"))
TypeAdapter(ModelSource).validate_python(env.get("MODEL_PROVIDER"))
if env.get("MODEL_PROVIDER")
else None
)
Expand Down
8 changes: 4 additions & 4 deletions llm-service/app/services/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,12 +37,12 @@
#
from .embedding import Embedding
from .llm import LLM
from .providers import ModelProvider
from .providers import get_provider_class
from .reranking import Reranking
from ._model_source import ModelSource
from ...config import ModelSource

__all__ = ["Embedding", "LLM", "Reranking", "ModelSource"]
__all__ = ["Embedding", "LLM", "Reranking", "get_model_source"]


def get_model_source() -> ModelSource:
return ModelProvider.get_provider_class().get_model_source()
return get_provider_class().get_model_source()
48 changes: 0 additions & 48 deletions llm-service/app/services/models/_model_source.py

This file was deleted.

6 changes: 3 additions & 3 deletions llm-service/app/services/models/embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@
from llama_index.core.base.embeddings.base import BaseEmbedding

from . import _model_type, _noop
from .providers._model_provider import ModelProvider
from .providers import get_provider_class
from ..caii.types import ModelResponse


Expand All @@ -51,15 +51,15 @@ def get(cls, model_name: Optional[str] = None) -> BaseEmbedding:
if model_name is None:
model_name = cls.list_available()[0].model_id

return ModelProvider.get_provider_class().get_embedding_model(model_name)
return get_provider_class().get_embedding_model(model_name)

@staticmethod
def get_noop() -> BaseEmbedding:
return _noop.DummyEmbeddingModel()

@staticmethod
def list_available() -> list[ModelResponse]:
return ModelProvider.get_provider_class().list_embedding_models()
return get_provider_class().list_embedding_models()

@classmethod
def test(cls, model_name: str) -> str:
Expand Down
6 changes: 3 additions & 3 deletions llm-service/app/services/models/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@
from llama_index.core.base.llms.types import ChatMessage, MessageRole

from . import _model_type, _noop
from .providers._model_provider import ModelProvider
from .providers import get_provider_class
from ..caii.types import ModelResponse


Expand All @@ -51,15 +51,15 @@ def get(cls, model_name: Optional[str] = None) -> llms.LLM:
if not model_name:
model_name = cls.list_available()[0].model_id

return ModelProvider.get_provider_class().get_llm_model(model_name)
return get_provider_class().get_llm_model(model_name)

@staticmethod
def get_noop() -> llms.LLM:
return _noop.DummyLlm()

@staticmethod
def list_available() -> list[ModelResponse]:
return ModelProvider.get_provider_class().list_llm_models()
return get_provider_class().list_llm_models()

@classmethod
def test(cls, model_name: str) -> Literal["ok"]:
Expand Down
39 changes: 37 additions & 2 deletions llm-service/app/services/models/providers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,16 +35,51 @@
# BUSINESS ADVANTAGE OR UNAVAILABILITY, OR LOSS OR CORRUPTION OF
# DATA.
#
import logging

from app.config import settings
from .azure import AzureModelProvider
from .bedrock import BedrockModelProvider
from .caii import CAIIModelProvider
from .openai import OpenAiModelProvider
from ._model_provider import ModelProvider
from ._model_provider import _ModelProvider

logger = logging.getLogger(__name__)

__all__ = [
"AzureModelProvider",
"BedrockModelProvider",
"CAIIModelProvider",
"OpenAiModelProvider",
"ModelProvider",
"get_provider_class",
]


def get_provider_class() -> type[_ModelProvider]:
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Revamped to loop through ModelProvider's subclasses rather than have them hard-coded.

"""Return the ModelProvider subclass for the given provider name."""
model_providers: list[type[_ModelProvider]] = sorted(
_ModelProvider.__subclasses__(),
key=lambda ModelProviderSubcls: ModelProviderSubcls.get_priority(),
)

model_provider = settings.model_provider
for ModelProviderSubcls in model_providers:
if model_provider == ModelProviderSubcls.get_model_source():
logger.info(
'using model provider "%s" based on `MODEL_PROVIDER` env var',
ModelProviderSubcls.get_model_source().value,
)
return ModelProviderSubcls

# Fallback if no specific provider is set
for ModelProviderSubcls in model_providers:
if ModelProviderSubcls.env_vars_are_set():
logger.info(
'falling back to model provider "%s" based on env vars %s',
ModelProviderSubcls.get_model_source().value,
ModelProviderSubcls.get_env_var_names(),
)
return ModelProviderSubcls

logger.info('falling back to model provider "CAII"')
return CAIIModelProvider
Comment on lines +84 to +85
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we want still want this last-resort fallback, or should we error out?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should keep it for backwards compatibility.

57 changes: 19 additions & 38 deletions llm-service/app/services/models/providers/_model_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,50 +42,36 @@
from llama_index.core.llms import LLM
from llama_index.core.postprocessor.types import BaseNodePostprocessor

from app.config import settings
from .._model_source import ModelSource
from app.config import ModelSource
from ...caii.types import ModelResponse


class ModelProvider(abc.ABC):
class _ModelProvider(abc.ABC):
Comment on lines -50 to +49
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

With get_provider_class() moved out, this class should no longer need to be invoked directly except for subclassing.

@classmethod
def is_enabled(cls) -> bool:
"""Return whether this model provider is enabled, based on the presence of required env vars."""
def env_vars_are_set(cls) -> bool:
"""Return whether this model provider's env vars have set values."""
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

"is enabled" is a bit ambiguous between "is available for use" and "is the active model provider", since we now have the MODEL_PROVIDER env var that the method doesn't even account for.

return all(map(os.environ.get, cls.get_env_var_names()))

@staticmethod
def get_provider_class() -> type["ModelProvider"]:
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Moved to llm-service/app/services/models/providers/__init__.py since it doesn't make sense for it to be inherited.

"""Return the ModelProvider subclass for the given provider name."""
from . import (
AzureModelProvider,
CAIIModelProvider,
OpenAiModelProvider,
BedrockModelProvider,
)

model_provider = settings.model_provider
if model_provider == "Azure":
return AzureModelProvider
elif model_provider == "CAII":
return CAIIModelProvider
elif model_provider == "OpenAI":
return OpenAiModelProvider
elif model_provider == "Bedrock":
return BedrockModelProvider
@abc.abstractmethod
def get_env_var_names() -> set[str]:
"""Return the names of the env vars required by this model provider."""
raise NotImplementedError

# Fallback to priority order if no specific provider is set
if AzureModelProvider.is_enabled():
return AzureModelProvider
elif OpenAiModelProvider.is_enabled():
return OpenAiModelProvider
elif BedrockModelProvider.is_enabled():
return BedrockModelProvider
return CAIIModelProvider
@staticmethod
@abc.abstractmethod
def get_model_source() -> ModelSource:
"""Return the name of this model provider"""
raise NotImplementedError
Comment on lines +61 to +65
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Moved here from the bottom of the class definition to group it with the other "metadata" methods (as opposed to the "get model" methods).


@staticmethod
@abc.abstractmethod
def get_env_var_names() -> set[str]:
"""Return the names of the env vars required by this model provider."""
def get_priority() -> int:
"""Return the priority of this model provider relative to the others.

1 is the highest priority.

"""
Comment on lines +69 to +74
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I set the values for this based on the order we had in get_provider_class():

# Fallback to priority order if no specific provider is set
if AzureModelProvider.is_enabled():
    return AzureModelProvider
elif OpenAiModelProvider.is_enabled():
    return OpenAiModelProvider
elif BedrockModelProvider.is_enabled():
    return BedrockModelProvider
return CAIIModelProvider

raise NotImplementedError

@staticmethod
Expand Down Expand Up @@ -123,8 +109,3 @@ def get_embedding_model(name: str) -> BaseEmbedding:
def get_reranking_model(name: str, top_n: int) -> BaseNodePostprocessor:
"""Return reranking model with `name`."""
raise NotImplementedError

@staticmethod
@abc.abstractmethod
def get_model_source() -> ModelSource:
raise NotImplementedError
Loading
Loading