-
Notifications
You must be signed in to change notification settings - Fork 8
Clean up logic for determining the active model provider #309
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull Request Overview
This PR refactors the test configuration by moving the parametrization from the test class decorator to the fixture definition itself, leveraging pytest's direct fixture parametrization feature.
- Removes the
@pytest.mark.parametrize
decorator from the test class - Adds
params=ModelProvider.__subclasses__()
parameter directly to the@pytest.fixture
decorator
Tip: Customize your code reviews with copilot-instructions.md. Create the file or learn how to get started.
def is_enabled(cls) -> bool: | ||
"""Return whether this model provider is enabled, based on the presence of required env vars.""" | ||
def env_vars_are_set(cls) -> bool: | ||
"""Return whether this model provider's env vars have set values.""" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
"is enabled" is a bit ambiguous between "is available for use" and "is the active model provider", since we now have the MODEL_PROVIDER
env var that the method doesn't even account for.
@staticmethod | ||
@abc.abstractmethod | ||
def get_model_source() -> ModelSource: | ||
"""Return the name of this model provider""" | ||
raise NotImplementedError |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Moved here from the bottom of the class definition to group it with the other "metadata" methods (as opposed to the "get model" methods).
def get_priority() -> int: | ||
"""Return the priority of this model provider relative to the others. | ||
|
||
1 is the highest priority. | ||
|
||
""" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I set the values for this based on the order we had in get_provider_class()
:
# Fallback to priority order if no specific provider is set
if AzureModelProvider.is_enabled():
return AzureModelProvider
elif OpenAiModelProvider.is_enabled():
return OpenAiModelProvider
elif BedrockModelProvider.is_enabled():
return BedrockModelProvider
return CAIIModelProvider
return all(map(os.environ.get, cls.get_env_var_names())) | ||
|
||
@staticmethod | ||
def get_provider_class() -> type["ModelProvider"]: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Moved to llm-service/app/services/models/providers/__init__.py
since it doesn't make sense for it to be inherited.
] | ||
|
||
|
||
def get_provider_class() -> type[_ModelProvider]: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Revamped to loop through ModelProvider
's subclasses rather than have them hard-coded.
class ModelProvider(abc.ABC): | ||
class _ModelProvider(abc.ABC): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
With get_provider_class()
moved out, this class should no longer need to be invoked directly except for subclassing.
) | ||
) | ||
|
||
|
||
@pytest.fixture() | ||
@pytest.fixture(params=_ModelProvider.__subclasses__()) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I've discovered that we can do this instead of the
@pytest.mark.parametrize(
"EnabledModelProvider",
ModelProvider.__subclasses__(),
indirect=True,
)
below! which is a lot simpler and easier on a test writer.
for name in get_all_env_var_names(): | ||
monkeypatch.delenv(name, raising=False) | ||
for name in ModelProviderSubcls.get_env_var_names(): | ||
monkeypatch.setenv(name, "test") | ||
for name in get_all_env_var_names() - ModelProviderSubcls.get_env_var_names(): | ||
monkeypatch.delenv(name, raising=False) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I've realized this is probably simpler?
ModelProviderType = Literal["Azure", "CAII", "OpenAI", "Bedrock"] | ||
|
||
|
||
class ModelSource(str, Enum): | ||
AZURE = "Azure" | ||
OPENAI = "OpenAI" | ||
BEDROCK = "Bedrock" | ||
CAII = "CAII" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This probably isn't the best place for ModelSource
(I moved it here from llm-service/app/services/models/
to avoid a circular import), but I wanted to replace and remove the redundant ModelProviderType
.
@@ -136,7 +136,7 @@ class ProjectConfig(BaseModel): | |||
chat_store_provider: ChatStoreProviderType | |||
vector_db_provider: VectorDbProviderType | |||
metadata_db_provider: MetadataDbProviderType | |||
model_provider: Optional[ModelProviderType] = None | |||
model_provider: Optional[ModelSource] = None |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This still comes back as a string when the frontend calls /llm-service/amp/config
👍
logger.info('falling back to model provider "CAII"') | ||
return CAIIModelProvider |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we should keep it for backwards compatibility.
Also helps lay groundwork for #307.