Skip to content

Commit 19123ca

Browse files
refactor: standardize InferenceRouter model handling (#2965)
1 parent 8031141 commit 19123ca

File tree

4 files changed

+28
-38
lines changed

4 files changed

+28
-38
lines changed

llama_stack/apis/common/errors.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,3 +62,13 @@ class SessionNotFoundError(ValueError):
6262
def __init__(self, session_name: str) -> None:
6363
message = f"Session '{session_name}' not found or access denied."
6464
super().__init__(message)
65+
66+
67+
class ModelTypeError(TypeError):
68+
"""raised when a model is present but not the correct type"""
69+
70+
def __init__(self, model_name: str, model_type: str, expected_model_type: str) -> None:
71+
message = (
72+
f"Model '{model_name}' is of type '{model_type}' rather than the expected type '{expected_model_type}'"
73+
)
74+
super().__init__(message)

llama_stack/core/routers/inference.py

Lines changed: 16 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
InterleavedContent,
1919
InterleavedContentItem,
2020
)
21-
from llama_stack.apis.common.errors import ModelNotFoundError
21+
from llama_stack.apis.common.errors import ModelNotFoundError, ModelTypeError
2222
from llama_stack.apis.inference import (
2323
BatchChatCompletionResponse,
2424
BatchCompletionResponse,
@@ -177,6 +177,15 @@ async def _count_tokens(
177177
encoded = self.formatter.encode_content(messages)
178178
return len(encoded.tokens) if encoded and encoded.tokens else 0
179179

180+
async def _get_model(self, model_id: str, expected_model_type: str) -> Model:
181+
"""takes a model id and gets model after ensuring that it is accessible and of the correct type"""
182+
model = await self.routing_table.get_model(model_id)
183+
if model is None:
184+
raise ModelNotFoundError(model_id)
185+
if model.model_type != expected_model_type:
186+
raise ModelTypeError(model_id, model.model_type, expected_model_type)
187+
return model
188+
180189
async def chat_completion(
181190
self,
182191
model_id: str,
@@ -195,11 +204,7 @@ async def chat_completion(
195204
)
196205
if sampling_params is None:
197206
sampling_params = SamplingParams()
198-
model = await self.routing_table.get_model(model_id)
199-
if model is None:
200-
raise ModelNotFoundError(model_id)
201-
if model.model_type == ModelType.embedding:
202-
raise ValueError(f"Model '{model_id}' is an embedding model and does not support chat completions")
207+
model = await self._get_model(model_id, ModelType.llm)
203208
if tool_config:
204209
if tool_choice and tool_choice != tool_config.tool_choice:
205210
raise ValueError("tool_choice and tool_config.tool_choice must match")
@@ -301,11 +306,7 @@ async def completion(
301306
logger.debug(
302307
f"InferenceRouter.completion: {model_id=}, {stream=}, {content=}, {sampling_params=}, {response_format=}",
303308
)
304-
model = await self.routing_table.get_model(model_id)
305-
if model is None:
306-
raise ModelNotFoundError(model_id)
307-
if model.model_type == ModelType.embedding:
308-
raise ValueError(f"Model '{model_id}' is an embedding model and does not support chat completions")
309+
model = await self._get_model(model_id, ModelType.llm)
309310
provider = await self.routing_table.get_provider_impl(model_id)
310311
params = dict(
311312
model_id=model_id,
@@ -355,11 +356,7 @@ async def embeddings(
355356
task_type: EmbeddingTaskType | None = None,
356357
) -> EmbeddingsResponse:
357358
logger.debug(f"InferenceRouter.embeddings: {model_id}")
358-
model = await self.routing_table.get_model(model_id)
359-
if model is None:
360-
raise ModelNotFoundError(model_id)
361-
if model.model_type == ModelType.llm:
362-
raise ValueError(f"Model '{model_id}' is an LLM model and does not support embeddings")
359+
await self._get_model(model_id, ModelType.embedding)
363360
provider = await self.routing_table.get_provider_impl(model_id)
364361
return await provider.embeddings(
365362
model_id=model_id,
@@ -395,12 +392,7 @@ async def openai_completion(
395392
logger.debug(
396393
f"InferenceRouter.openai_completion: {model=}, {stream=}, {prompt=}",
397394
)
398-
model_obj = await self.routing_table.get_model(model)
399-
if model_obj is None:
400-
raise ModelNotFoundError(model)
401-
if model_obj.model_type == ModelType.embedding:
402-
raise ValueError(f"Model '{model}' is an embedding model and does not support completions")
403-
395+
model_obj = await self._get_model(model, ModelType.llm)
404396
params = dict(
405397
model=model_obj.identifier,
406398
prompt=prompt,
@@ -476,11 +468,7 @@ async def openai_chat_completion(
476468
logger.debug(
477469
f"InferenceRouter.openai_chat_completion: {model=}, {stream=}, {messages=}",
478470
)
479-
model_obj = await self.routing_table.get_model(model)
480-
if model_obj is None:
481-
raise ModelNotFoundError(model)
482-
if model_obj.model_type == ModelType.embedding:
483-
raise ValueError(f"Model '{model}' is an embedding model and does not support chat completions")
471+
model_obj = await self._get_model(model, ModelType.llm)
484472

485473
# Use the OpenAI client for a bit of extra input validation without
486474
# exposing the OpenAI client itself as part of our API surface
@@ -567,12 +555,7 @@ async def openai_embeddings(
567555
logger.debug(
568556
f"InferenceRouter.openai_embeddings: {model=}, input_type={type(input)}, {encoding_format=}, {dimensions=}",
569557
)
570-
model_obj = await self.routing_table.get_model(model)
571-
if model_obj is None:
572-
raise ModelNotFoundError(model)
573-
if model_obj.model_type != ModelType.embedding:
574-
raise ValueError(f"Model '{model}' is not an embedding model")
575-
558+
model_obj = await self._get_model(model, ModelType.embedding)
576559
params = dict(
577560
model=model_obj.identifier,
578561
input=input,

llama_stack/core/routing_tables/vector_dbs.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88

99
from pydantic import TypeAdapter
1010

11-
from llama_stack.apis.common.errors import ModelNotFoundError, VectorStoreNotFoundError
11+
from llama_stack.apis.common.errors import ModelNotFoundError, ModelTypeError, VectorStoreNotFoundError
1212
from llama_stack.apis.models import ModelType
1313
from llama_stack.apis.resource import ResourceType
1414
from llama_stack.apis.vector_dbs import ListVectorDBsResponse, VectorDB, VectorDBs
@@ -66,7 +66,7 @@ async def register_vector_db(
6666
if model is None:
6767
raise ModelNotFoundError(embedding_model)
6868
if model.model_type != ModelType.embedding:
69-
raise ValueError(f"Model {embedding_model} is not an embedding model")
69+
raise ModelTypeError(embedding_model, model.model_type, ModelType.embedding)
7070
if "embedding_dimension" not in model.metadata:
7171
raise ValueError(f"Model {embedding_model} does not have an embedding dimension")
7272
vector_db_data = {

llama_stack/providers/remote/inference/ollama/ollama.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -457,9 +457,6 @@ async def openai_embeddings(
457457
user: str | None = None,
458458
) -> OpenAIEmbeddingsResponse:
459459
model_obj = await self._get_model(model)
460-
if model_obj.model_type != ModelType.embedding:
461-
raise ValueError(f"Model {model} is not an embedding model")
462-
463460
if model_obj.provider_resource_id is None:
464461
raise ValueError(f"Model {model} has no provider_resource_id set")
465462

0 commit comments

Comments
 (0)