|
45 | 45 | from llama_index.core.base.llms.types import ChatMessage
|
46 | 46 | from llama_index.core.llms import LLM
|
47 | 47 | from llama_index.core.postprocessor.types import BaseNodePostprocessor
|
| 48 | +from llama_index.llms.nvidia import NVIDIA |
48 | 49 |
|
49 | 50 | from .CaiiEmbeddingModel import CaiiEmbeddingModel
|
50 |
| -from .CaiiModel import CaiiModel, CaiiModelMistral, DeepseekModel |
| 51 | +from .CaiiModel import DeepseekModel |
51 | 52 | from .caii_reranking import CaiiRerankingModel
|
52 | 53 | from .types import Endpoint, ListEndpointEntry, ModelResponse
|
53 | 54 | from .utils import build_auth_headers, get_caii_access_token
|
@@ -103,43 +104,29 @@ def get_llm(
|
103 | 104 | ) -> LLM:
|
104 | 105 | endpoint = describe_endpoint(endpoint_name=endpoint_name)
|
105 | 106 | api_base = endpoint.url.removesuffix("/chat/completions")
|
106 |
| - headers = build_auth_headers() |
107 | 107 |
|
108 | 108 | model = endpoint.model_name
|
| 109 | + # todo: test if the NVIDIA impl works with deepseek, too |
109 | 110 | if "deepseek" in endpoint_name.lower():
|
110 | 111 | return DeepseekModel(
|
111 | 112 | model=model,
|
112 | 113 | context=128000,
|
113 | 114 | messages_to_prompt=messages_to_prompt,
|
114 | 115 | completion_to_prompt=completion_to_prompt,
|
115 | 116 | api_base=api_base,
|
116 |
| - default_headers=headers, |
117 |
| - ) |
118 |
| - |
119 |
| - if "mistral" in endpoint_name.lower(): |
120 |
| - return CaiiModelMistral( |
121 |
| - model=model, |
122 |
| - messages_to_prompt=messages_to_prompt, |
123 |
| - completion_to_prompt=completion_to_prompt, |
124 |
| - api_base=api_base, |
125 |
| - context=128000, |
126 |
| - default_headers=headers, |
127 |
| - ) |
128 |
| - |
129 |
| - else: |
130 |
| - return CaiiModel( |
131 |
| - model=model, |
132 |
| - context=128000, |
133 |
| - messages_to_prompt=messages_to_prompt, |
134 |
| - completion_to_prompt=completion_to_prompt, |
135 |
| - api_base=api_base, |
136 |
| - default_headers=headers, |
| 117 | + default_headers=(build_auth_headers()), |
137 | 118 | )
|
| 119 | + return NVIDIA( |
| 120 | + api_key=get_caii_access_token(), |
| 121 | + base_url=api_base, |
| 122 | + model=model |
| 123 | + ) |
138 | 124 |
|
139 | 125 |
|
140 | 126 | def get_embedding_model(model_name: str) -> BaseEmbedding:
|
141 | 127 | endpoint_name = model_name
|
142 | 128 | endpoint = describe_endpoint(endpoint_name=endpoint_name)
|
| 129 | + # todo: figure out if the Nvidia library can be made to work for embeddings as well. |
143 | 130 | return CaiiEmbeddingModel(endpoint=endpoint)
|
144 | 131 |
|
145 | 132 |
|
|
0 commit comments