|
75 | 75 | from ..vector_stores.vector_store_factory import VectorStoreFactory
|
76 | 76 | from ...config import settings
|
77 | 77 | from ...services.metadata_apis import data_sources_metadata_api
|
78 |
| -from ...services.models.providers import CAIIModelProvider, AzureModelProvider, OpenAiModelProvider |
| 78 | +from ...services.models.providers import ModelProvider |
| 79 | +from ...services.models import ModelSource |
79 | 80 |
|
80 | 81 | logger = logging.getLogger(__name__)
|
81 | 82 |
|
@@ -132,11 +133,16 @@ def __index_configuration(
|
132 | 133 | embed_summaries: bool = True,
|
133 | 134 | ) -> Dict[str, Any]:
|
134 | 135 | prompt_helper: Optional[PromptHelper] = None
|
135 |
| - # if we're using CAII, let's be conservative and use a small context window to account for mistral's small context |
136 |
| - if CAIIModelProvider.is_enabled(): |
| 136 | + model_source: ModelSource = ( |
| 137 | + ModelProvider.get_provider_class().get_model_source() |
| 138 | + ) |
| 139 | + if model_source == "CAII": |
| 140 | + # if we're using CAII, let's be conservative and use a small context window to account for mistral's small context |
137 | 141 | prompt_helper = PromptHelper(context_window=3000)
|
138 |
| - if AzureModelProvider.is_enabled() or OpenAiModelProvider.is_enabled(): |
139 |
| - prompt_helper = PromptHelper(context_window=min(llm.metadata.context_window, 10000)) |
| 142 | + if model_source == "Azure" or model_source == "OpenAI": |
| 143 | + prompt_helper = PromptHelper( |
| 144 | + context_window=min(llm.metadata.context_window, 10000) |
| 145 | + ) |
140 | 146 | return {
|
141 | 147 | "llm": llm,
|
142 | 148 | "response_synthesizer": get_response_synthesizer(
|
|
0 commit comments