18
18
InterleavedContent ,
19
19
InterleavedContentItem ,
20
20
)
21
- from llama_stack .apis .common .errors import ModelNotFoundError
21
+ from llama_stack .apis .common .errors import ModelNotFoundError , ModelTypeError
22
22
from llama_stack .apis .inference import (
23
23
BatchChatCompletionResponse ,
24
24
BatchCompletionResponse ,
@@ -177,6 +177,15 @@ async def _count_tokens(
177
177
encoded = self .formatter .encode_content (messages )
178
178
return len (encoded .tokens ) if encoded and encoded .tokens else 0
179
179
180
+ async def _get_model (self , model_id : str , expected_model_type : str ) -> Model :
181
+ """takes a model id and gets model after ensuring that it is accessible and of the correct type"""
182
+ model = await self .routing_table .get_model (model_id )
183
+ if model is None :
184
+ raise ModelNotFoundError (model_id )
185
+ if model .model_type != expected_model_type :
186
+ raise ModelTypeError (model_id , model .model_type , expected_model_type )
187
+ return model
188
+
180
189
async def chat_completion (
181
190
self ,
182
191
model_id : str ,
@@ -195,11 +204,7 @@ async def chat_completion(
195
204
)
196
205
if sampling_params is None :
197
206
sampling_params = SamplingParams ()
198
- model = await self .routing_table .get_model (model_id )
199
- if model is None :
200
- raise ModelNotFoundError (model_id )
201
- if model .model_type == ModelType .embedding :
202
- raise ValueError (f"Model '{ model_id } ' is an embedding model and does not support chat completions" )
207
+ model = await self ._get_model (model_id , ModelType .llm )
203
208
if tool_config :
204
209
if tool_choice and tool_choice != tool_config .tool_choice :
205
210
raise ValueError ("tool_choice and tool_config.tool_choice must match" )
@@ -301,11 +306,7 @@ async def completion(
301
306
logger .debug (
302
307
f"InferenceRouter.completion: { model_id = } , { stream = } , { content = } , { sampling_params = } , { response_format = } " ,
303
308
)
304
- model = await self .routing_table .get_model (model_id )
305
- if model is None :
306
- raise ModelNotFoundError (model_id )
307
- if model .model_type == ModelType .embedding :
308
- raise ValueError (f"Model '{ model_id } ' is an embedding model and does not support chat completions" )
309
+ model = await self ._get_model (model_id , ModelType .llm )
309
310
provider = await self .routing_table .get_provider_impl (model_id )
310
311
params = dict (
311
312
model_id = model_id ,
@@ -355,11 +356,7 @@ async def embeddings(
355
356
task_type : EmbeddingTaskType | None = None ,
356
357
) -> EmbeddingsResponse :
357
358
logger .debug (f"InferenceRouter.embeddings: { model_id } " )
358
- model = await self .routing_table .get_model (model_id )
359
- if model is None :
360
- raise ModelNotFoundError (model_id )
361
- if model .model_type == ModelType .llm :
362
- raise ValueError (f"Model '{ model_id } ' is an LLM model and does not support embeddings" )
359
+ await self ._get_model (model_id , ModelType .embedding )
363
360
provider = await self .routing_table .get_provider_impl (model_id )
364
361
return await provider .embeddings (
365
362
model_id = model_id ,
@@ -395,12 +392,7 @@ async def openai_completion(
395
392
logger .debug (
396
393
f"InferenceRouter.openai_completion: { model = } , { stream = } , { prompt = } " ,
397
394
)
398
- model_obj = await self .routing_table .get_model (model )
399
- if model_obj is None :
400
- raise ModelNotFoundError (model )
401
- if model_obj .model_type == ModelType .embedding :
402
- raise ValueError (f"Model '{ model } ' is an embedding model and does not support completions" )
403
-
395
+ model_obj = await self ._get_model (model , ModelType .llm )
404
396
params = dict (
405
397
model = model_obj .identifier ,
406
398
prompt = prompt ,
@@ -476,11 +468,7 @@ async def openai_chat_completion(
476
468
logger .debug (
477
469
f"InferenceRouter.openai_chat_completion: { model = } , { stream = } , { messages = } " ,
478
470
)
479
- model_obj = await self .routing_table .get_model (model )
480
- if model_obj is None :
481
- raise ModelNotFoundError (model )
482
- if model_obj .model_type == ModelType .embedding :
483
- raise ValueError (f"Model '{ model } ' is an embedding model and does not support chat completions" )
471
+ model_obj = await self ._get_model (model , ModelType .llm )
484
472
485
473
# Use the OpenAI client for a bit of extra input validation without
486
474
# exposing the OpenAI client itself as part of our API surface
@@ -567,12 +555,7 @@ async def openai_embeddings(
567
555
logger .debug (
568
556
f"InferenceRouter.openai_embeddings: { model = } , input_type={ type (input )} , { encoding_format = } , { dimensions = } " ,
569
557
)
570
- model_obj = await self .routing_table .get_model (model )
571
- if model_obj is None :
572
- raise ModelNotFoundError (model )
573
- if model_obj .model_type != ModelType .embedding :
574
- raise ValueError (f"Model '{ model } ' is not an embedding model" )
575
-
558
+ model_obj = await self ._get_model (model , ModelType .embedding )
576
559
params = dict (
577
560
model = model_obj .identifier ,
578
561
input = input ,
0 commit comments