-
Notifications
You must be signed in to change notification settings - Fork 11
Clean up logic for determining the active model provider #309
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
c8254b7
9b0f0a1
e2f5240
a95f051
077f855
53b3d8b
d46955c
4af55d7
cc81c1a
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -50,7 +50,7 @@ | |
ChatStoreProviderType, | ||
VectorDbProviderType, | ||
MetadataDbProviderType, | ||
ModelProviderType, | ||
ModelSource, | ||
) | ||
from app.services.models.providers import ( | ||
CAIIModelProvider, | ||
|
@@ -136,7 +136,7 @@ class ProjectConfig(BaseModel): | |
chat_store_provider: ChatStoreProviderType | ||
vector_db_provider: VectorDbProviderType | ||
metadata_db_provider: MetadataDbProviderType | ||
model_provider: Optional[ModelProviderType] = None | ||
model_provider: Optional[ModelSource] = None | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This still comes back as a string when the frontend calls |
||
aws_config: AwsConfig | ||
azure_config: AzureConfig | ||
caii_config: CaiiConfig | ||
|
@@ -216,13 +216,13 @@ def validate_model_config(environ: dict[str, str]) -> ValidationResult: | |
f"Preferred provider {preferred_provider} is properly configured. \n" | ||
) | ||
if preferred_provider == "Bedrock": | ||
valid_model_config_exists = BedrockModelProvider.is_enabled() | ||
valid_model_config_exists = BedrockModelProvider.env_vars_are_set() | ||
elif preferred_provider == "Azure": | ||
valid_model_config_exists = AzureModelProvider.is_enabled() | ||
valid_model_config_exists = AzureModelProvider.env_vars_are_set() | ||
elif preferred_provider == "OpenAI": | ||
valid_model_config_exists = OpenAiModelProvider.is_enabled() | ||
valid_model_config_exists = OpenAiModelProvider.env_vars_are_set() | ||
elif preferred_provider == "CAII": | ||
valid_model_config_exists = CAIIModelProvider.is_enabled() | ||
valid_model_config_exists = CAIIModelProvider.env_vars_are_set() | ||
return ValidationResult( | ||
valid=valid_model_config_exists, | ||
message=valid_message if valid_model_config_exists else message, | ||
|
@@ -276,7 +276,7 @@ def validate_model_config(environ: dict[str, str]) -> ValidationResult: | |
|
||
if message == "": | ||
# check to see if CAII models are available via discovery | ||
if CAIIModelProvider.is_enabled(): | ||
if CAIIModelProvider.env_vars_are_set(): | ||
message = "CAII models are available." | ||
valid_model_config_exists = True | ||
else: | ||
|
@@ -388,7 +388,7 @@ def build_configuration( | |
validate_config = validate(frozenset(env.items())) | ||
|
||
model_provider = ( | ||
TypeAdapter(ModelProviderType).validate_python(env.get("MODEL_PROVIDER")) | ||
TypeAdapter(ModelSource).validate_python(env.get("MODEL_PROVIDER")) | ||
if env.get("MODEL_PROVIDER") | ||
else None | ||
) | ||
|
This file was deleted.
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -35,16 +35,51 @@ | |
# BUSINESS ADVANTAGE OR UNAVAILABILITY, OR LOSS OR CORRUPTION OF | ||
# DATA. | ||
# | ||
import logging | ||
|
||
from app.config import settings | ||
from .azure import AzureModelProvider | ||
from .bedrock import BedrockModelProvider | ||
from .caii import CAIIModelProvider | ||
from .openai import OpenAiModelProvider | ||
from ._model_provider import ModelProvider | ||
from ._model_provider import _ModelProvider | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
__all__ = [ | ||
"AzureModelProvider", | ||
"BedrockModelProvider", | ||
"CAIIModelProvider", | ||
"OpenAiModelProvider", | ||
"ModelProvider", | ||
"get_provider_class", | ||
] | ||
|
||
|
||
def get_provider_class() -> type[_ModelProvider]: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Revamped to loop through |
||
"""Return the ModelProvider subclass for the given provider name.""" | ||
model_providers: list[type[_ModelProvider]] = sorted( | ||
_ModelProvider.__subclasses__(), | ||
key=lambda ModelProviderSubcls: ModelProviderSubcls.get_priority(), | ||
) | ||
|
||
model_provider = settings.model_provider | ||
for ModelProviderSubcls in model_providers: | ||
if model_provider == ModelProviderSubcls.get_model_source(): | ||
logger.info( | ||
'using model provider "%s" based on `MODEL_PROVIDER` env var', | ||
ModelProviderSubcls.get_model_source().value, | ||
) | ||
return ModelProviderSubcls | ||
|
||
# Fallback if no specific provider is set | ||
for ModelProviderSubcls in model_providers: | ||
if ModelProviderSubcls.env_vars_are_set(): | ||
logger.info( | ||
'falling back to model provider "%s" based on env vars %s', | ||
ModelProviderSubcls.get_model_source().value, | ||
ModelProviderSubcls.get_env_var_names(), | ||
) | ||
return ModelProviderSubcls | ||
|
||
logger.info('falling back to model provider "CAII"') | ||
return CAIIModelProvider | ||
Comment on lines
+84
to
+85
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Do we want still want this last-resort fallback, or should we error out? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think we should keep it for backwards compatibility. |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -42,50 +42,36 @@ | |
from llama_index.core.llms import LLM | ||
from llama_index.core.postprocessor.types import BaseNodePostprocessor | ||
|
||
from app.config import settings | ||
from .._model_source import ModelSource | ||
from app.config import ModelSource | ||
from ...caii.types import ModelResponse | ||
|
||
|
||
class ModelProvider(abc.ABC): | ||
class _ModelProvider(abc.ABC): | ||
Comment on lines
-50
to
+49
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. With |
||
@classmethod | ||
def is_enabled(cls) -> bool: | ||
"""Return whether this model provider is enabled, based on the presence of required env vars.""" | ||
def env_vars_are_set(cls) -> bool: | ||
"""Return whether this model provider's env vars have set values.""" | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. "is enabled" is a bit ambiguous between "is available for use" and "is the active model provider", since we now have the |
||
return all(map(os.environ.get, cls.get_env_var_names())) | ||
|
||
@staticmethod | ||
def get_provider_class() -> type["ModelProvider"]: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Moved to |
||
"""Return the ModelProvider subclass for the given provider name.""" | ||
from . import ( | ||
AzureModelProvider, | ||
CAIIModelProvider, | ||
OpenAiModelProvider, | ||
BedrockModelProvider, | ||
) | ||
|
||
model_provider = settings.model_provider | ||
if model_provider == "Azure": | ||
return AzureModelProvider | ||
elif model_provider == "CAII": | ||
return CAIIModelProvider | ||
elif model_provider == "OpenAI": | ||
return OpenAiModelProvider | ||
elif model_provider == "Bedrock": | ||
return BedrockModelProvider | ||
@abc.abstractmethod | ||
def get_env_var_names() -> set[str]: | ||
"""Return the names of the env vars required by this model provider.""" | ||
raise NotImplementedError | ||
|
||
# Fallback to priority order if no specific provider is set | ||
if AzureModelProvider.is_enabled(): | ||
return AzureModelProvider | ||
elif OpenAiModelProvider.is_enabled(): | ||
return OpenAiModelProvider | ||
elif BedrockModelProvider.is_enabled(): | ||
return BedrockModelProvider | ||
return CAIIModelProvider | ||
@staticmethod | ||
@abc.abstractmethod | ||
def get_model_source() -> ModelSource: | ||
"""Return the name of this model provider""" | ||
raise NotImplementedError | ||
Comment on lines
+61
to
+65
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Moved here from the bottom of the class definition to group it with the other "metadata" methods (as opposed to the "get model" methods). |
||
|
||
@staticmethod | ||
@abc.abstractmethod | ||
def get_env_var_names() -> set[str]: | ||
"""Return the names of the env vars required by this model provider.""" | ||
def get_priority() -> int: | ||
"""Return the priority of this model provider relative to the others. | ||
|
||
1 is the highest priority. | ||
|
||
""" | ||
Comment on lines
+69
to
+74
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I set the values for this based on the order we had in # Fallback to priority order if no specific provider is set
if AzureModelProvider.is_enabled():
return AzureModelProvider
elif OpenAiModelProvider.is_enabled():
return OpenAiModelProvider
elif BedrockModelProvider.is_enabled():
return BedrockModelProvider
return CAIIModelProvider |
||
raise NotImplementedError | ||
|
||
@staticmethod | ||
|
@@ -123,8 +109,3 @@ def get_embedding_model(name: str) -> BaseEmbedding: | |
def get_reranking_model(name: str, top_n: int) -> BaseNodePostprocessor: | ||
"""Return reranking model with `name`.""" | ||
raise NotImplementedError | ||
|
||
@staticmethod | ||
@abc.abstractmethod | ||
def get_model_source() -> ModelSource: | ||
raise NotImplementedError |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This probably isn't the best place for
ModelSource
(I moved it here fromllm-service/app/services/models/
to avoid a circular import), but I wanted to replace and remove the redundantModelProviderType
.