Skip to content

Commit 5bf0b04

Browse files
Merge pull request #152 from cloudera/mob/main
Use the NVidia library for CAII models, and tweak summarization config to account for Mistral's small context window
2 parents cc0bad3 + cf0d7a7 commit 5bf0b04

File tree

7 files changed

+49
-106
lines changed

7 files changed

+49
-106
lines changed

llm-service/app/ai/indexing/summary_indexer.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@
4747
DocumentSummaryIndex,
4848
StorageContext,
4949
get_response_synthesizer,
50-
load_index_from_storage,
50+
load_index_from_storage, PromptHelper,
5151
)
5252
from llama_index.core.base.base_query_engine import BaseQueryEngine
5353
from llama_index.core.base.embeddings.base import BaseEmbedding
@@ -66,6 +66,7 @@
6666
from .readers.base_reader import ReaderConfig, ChunksResult
6767
from ..vector_stores.qdrant import QdrantVectorStore
6868
from ...config import Settings
69+
from ...services.models import CAIIModelProvider
6970

7071
logger = logging.getLogger(__name__)
7172

@@ -117,13 +118,18 @@ def __index_configuration(
117118
data_source_id: int,
118119
embed_summaries: bool = True,
119120
) -> Dict[str, Any]:
121+
prompt_helper: Optional[PromptHelper] = None
122+
# if we're using CAII, let's be conservative, and use a small context window to account for mistral's small context
123+
if CAIIModelProvider.is_enabled():
124+
prompt_helper=PromptHelper(context_window=3000)
120125
return {
121126
"llm": llm,
122127
"response_synthesizer": get_response_synthesizer(
123128
response_mode=ResponseMode.TREE_SUMMARIZE,
124129
llm=llm,
125130
use_async=True,
126131
verbose=True,
132+
prompt_helper=prompt_helper
127133
),
128134
"show_progress": True,
129135
"embed_model": embedding_model,

llm-service/app/services/caii/CaiiModel.py

Lines changed: 0 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,6 @@
3838
from typing import Callable, Dict, Sequence, Any
3939

4040
from llama_index.core.base.llms.types import ChatMessage, LLMMetadata, ChatResponse, CompletionResponse
41-
from llama_index.llms.mistralai.base import MistralAI
4241
from llama_index.llms.openai import OpenAI
4342
from pydantic import Field
4443

@@ -108,41 +107,3 @@ def chat(self, messages: Sequence[ChatMessage], **kwargs: Any) -> ChatResponse:
108107
content: str = raw_response.message.content or ""
109108
raw_response.message.content = content.split("</think>")[-1]
110109
return raw_response
111-
112-
113-
class CaiiModelMistral(MistralAI):
114-
def __init__(
115-
self,
116-
model: str,
117-
context: int,
118-
api_base: str,
119-
messages_to_prompt: Callable[[Sequence[ChatMessage]], str],
120-
completion_to_prompt: Callable[[str], str],
121-
default_headers: Dict[str, str],
122-
):
123-
super().__init__(
124-
api_key=default_headers.get("Authorization"),
125-
model=model,
126-
endpoint=api_base.removesuffix(
127-
"/v1"
128-
), # mistral expects the base url without the /v1
129-
messages_to_prompt=messages_to_prompt,
130-
completion_to_prompt=completion_to_prompt,
131-
)
132-
133-
def _get_all_kwargs(self, **kwargs: Any) -> Dict[str, Any]:
134-
all_kwargs = super()._get_all_kwargs(**kwargs)
135-
# apparently, this key is no longer acceptable to the API that is implemented by the Nvidia NIMs for mistral.
136-
all_kwargs.pop('random_seed', None)
137-
return all_kwargs
138-
139-
@property
140-
def metadata(self) -> LLMMetadata:
141-
## todo: pull this info from somewhere
142-
return LLMMetadata(
143-
context_window=32000, ## this is the minimum mistral context window from utils.py
144-
num_output=self.max_tokens or -1,
145-
is_chat_model=False,
146-
is_function_calling_model=True,
147-
model_name=self.model,
148-
)

llm-service/app/services/caii/caii.py

Lines changed: 10 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -45,9 +45,10 @@
4545
from llama_index.core.base.llms.types import ChatMessage
4646
from llama_index.core.llms import LLM
4747
from llama_index.core.postprocessor.types import BaseNodePostprocessor
48+
from llama_index.llms.nvidia import NVIDIA
4849

4950
from .CaiiEmbeddingModel import CaiiEmbeddingModel
50-
from .CaiiModel import CaiiModel, CaiiModelMistral, DeepseekModel
51+
from .CaiiModel import DeepseekModel
5152
from .caii_reranking import CaiiRerankingModel
5253
from .types import Endpoint, ListEndpointEntry, ModelResponse
5354
from .utils import build_auth_headers, get_caii_access_token
@@ -103,43 +104,29 @@ def get_llm(
103104
) -> LLM:
104105
endpoint = describe_endpoint(endpoint_name=endpoint_name)
105106
api_base = endpoint.url.removesuffix("/chat/completions")
106-
headers = build_auth_headers()
107107

108108
model = endpoint.model_name
109+
# todo: test if the NVIDIA impl works with deepseek, too
109110
if "deepseek" in endpoint_name.lower():
110111
return DeepseekModel(
111112
model=model,
112113
context=128000,
113114
messages_to_prompt=messages_to_prompt,
114115
completion_to_prompt=completion_to_prompt,
115116
api_base=api_base,
116-
default_headers=headers,
117-
)
118-
119-
if "mistral" in endpoint_name.lower():
120-
return CaiiModelMistral(
121-
model=model,
122-
messages_to_prompt=messages_to_prompt,
123-
completion_to_prompt=completion_to_prompt,
124-
api_base=api_base,
125-
context=128000,
126-
default_headers=headers,
127-
)
128-
129-
else:
130-
return CaiiModel(
131-
model=model,
132-
context=128000,
133-
messages_to_prompt=messages_to_prompt,
134-
completion_to_prompt=completion_to_prompt,
135-
api_base=api_base,
136-
default_headers=headers,
117+
default_headers=(build_auth_headers()),
137118
)
119+
return NVIDIA(
120+
api_key=get_caii_access_token(),
121+
base_url=api_base,
122+
model=model
123+
)
138124

139125

140126
def get_embedding_model(model_name: str) -> BaseEmbedding:
141127
endpoint_name = model_name
142128
endpoint = describe_endpoint(endpoint_name=endpoint_name)
129+
# todo: figure out if the Nvidia library can be made to work for embeddings as well.
143130
return CaiiEmbeddingModel(endpoint=endpoint)
144131

145132

llm-service/app/services/models/__init__.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,15 @@
6565
from ..llama_utils import completion_to_prompt, messages_to_prompt
6666
from ..query.simple_reranker import SimpleReranker
6767

68+
__all__ = [
69+
'CAIIModelProvider',
70+
'ModelType',
71+
'Embedding',
72+
'LLM',
73+
'Reranking',
74+
'ModelSource',
75+
'BedrockModelProvider'
76+
]
6877

6978
T = TypeVar("T", bound=BaseComponent)
7079

llm-service/app/tests/conftest.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@
5656
from app.services.metadata_apis import data_sources_metadata_api
5757
from app.services import models
5858
from app.services.metadata_apis.data_sources_metadata_api import RagDataSource
59-
from app.services.models._bedrock import BedrockModelProvider
59+
from app.services.models import BedrockModelProvider
6060

6161

6262
@dataclass

llm-service/pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@ dependencies = [
1515
"llama-index-embeddings-bedrock>=0.2.1",
1616
"llama-index-llms-bedrock>=0.1.13",
1717
"llama-index-llms-openai>=0.1.31",
18-
"llama-index-llms-mistralai>=0.1.20",
1918
"llama-index-embeddings-openai>=0.1.11",
2019
"llama-index-vector-stores-qdrant>=0.2.17",
2120
"docx2txt>=0.8",
@@ -37,6 +36,7 @@ dependencies = [
3736
"mlflow>=2.20.1",
3837
"llama-index-llms-azure-openai>=0.3.0",
3938
"llama-index-embeddings-azure-openai>=0.3.0",
39+
"llama-index-llms-nvidia>=0.3.2",
4040
]
4141
requires-python = ">=3.10,<=3.12"
4242
readme = "README.md"

llm-service/uv.lock

Lines changed: 21 additions & 41 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)