38
38
Usage ,
39
39
UserMessage ,
40
40
)
41
- from api .setting import AWS_REGION , DEBUG , DEFAULT_MODEL , ENABLE_CROSS_REGION_INFERENCE
41
+ from api .setting import (
42
+ AWS_REGION ,
43
+ DEBUG ,
44
+ DEFAULT_MODEL ,
45
+ ENABLE_CROSS_REGION_INFERENCE ,
46
+ ENABLE_APPLICATION_INFERENCE_PROFILES ,
47
+ )
42
48
43
49
logger = logging .getLogger (__name__ )
44
50
@@ -83,15 +89,40 @@ def list_bedrock_models() -> dict:
83
89
Returns a model list combines:
84
90
- ON_DEMAND models.
85
91
- Cross-Region Inference Profiles (if enabled via Env)
92
+ - Application Inference Profiles (if enabled via Env)
86
93
"""
87
94
model_list = {}
88
95
try :
89
96
profile_list = []
97
+ app_profile_dict = {}
98
+
90
99
if ENABLE_CROSS_REGION_INFERENCE :
91
100
# List system defined inference profile IDs
92
101
response = bedrock_client .list_inference_profiles (maxResults = 1000 , typeEquals = "SYSTEM_DEFINED" )
93
102
profile_list = [p ["inferenceProfileId" ] for p in response ["inferenceProfileSummaries" ]]
94
103
104
+ if ENABLE_APPLICATION_INFERENCE_PROFILES :
105
+ # List application defined inference profile IDs and create mapping
106
+ response = bedrock_client .list_inference_profiles (maxResults = 1000 , typeEquals = "APPLICATION" )
107
+
108
+ for profile in response ["inferenceProfileSummaries" ]:
109
+ try :
110
+ profile_arn = profile .get ("inferenceProfileArn" )
111
+ if not profile_arn :
112
+ continue
113
+
114
+ # Process all models in the profile
115
+ models = profile .get ("models" , [])
116
+ for model in models :
117
+ model_arn = model .get ("modelArn" , "" )
118
+ if model_arn :
119
+ model_id = model_arn .split ('/' )[- 1 ] if '/' in model_arn else model_arn
120
+ if model_id :
121
+ app_profile_dict [model_id ] = profile_arn
122
+ except Exception as e :
123
+ logger .warning (f"Error processing application profile: { e } " )
124
+ continue
125
+
95
126
# List foundation models, only cares about text outputs here.
96
127
response = bedrock_client .list_foundation_models (byOutputModality = "TEXT" )
97
128
@@ -115,6 +146,10 @@ def list_bedrock_models() -> dict:
115
146
if profile_id in profile_list :
116
147
model_list [profile_id ] = {"modalities" : input_modalities }
117
148
149
+ # Add application inference profiles
150
+ if model_id in app_profile_dict :
151
+ model_list [app_profile_dict [model_id ]] = {"modalities" : input_modalities }
152
+
118
153
except Exception as e :
119
154
logger .error (f"Unable to list models: { str (e )} " )
120
155
@@ -162,7 +197,9 @@ async def _invoke_bedrock(self, chat_request: ChatRequest, stream=False):
162
197
try :
163
198
if stream :
164
199
# Run the blocking boto3 call in a thread pool
165
- response = await run_in_threadpool (bedrock_runtime .converse_stream , ** args )
200
+ response = await run_in_threadpool (
201
+ bedrock_runtime .converse_stream , ** args
202
+ )
166
203
else :
167
204
# Run the blocking boto3 call in a thread pool
168
205
response = await run_in_threadpool (bedrock_runtime .converse , ** args )
@@ -274,7 +311,9 @@ def _parse_messages(self, chat_request: ChatRequest) -> list[dict]:
274
311
messages .append (
275
312
{
276
313
"role" : message .role ,
277
- "content" : self ._parse_content_parts (message , chat_request .model ),
314
+ "content" : self ._parse_content_parts (
315
+ message , chat_request .model
316
+ ),
278
317
}
279
318
)
280
319
elif isinstance (message , AssistantMessage ):
@@ -283,7 +322,9 @@ def _parse_messages(self, chat_request: ChatRequest) -> list[dict]:
283
322
messages .append (
284
323
{
285
324
"role" : message .role ,
286
- "content" : self ._parse_content_parts (message , chat_request .model ),
325
+ "content" : self ._parse_content_parts (
326
+ message , chat_request .model
327
+ ),
287
328
}
288
329
)
289
330
if message .tool_calls :
@@ -363,7 +404,9 @@ def _reframe_multi_payloard(self, messages: list) -> list:
363
404
# If the next role is different from the previous message, add the previous role's messages to the list
364
405
if next_role != current_role :
365
406
if current_content :
366
- reformatted_messages .append ({"role" : current_role , "content" : current_content })
407
+ reformatted_messages .append (
408
+ {"role" : current_role , "content" : current_content }
409
+ )
367
410
# Switch to the new role
368
411
current_role = next_role
369
412
current_content = []
@@ -376,7 +419,9 @@ def _reframe_multi_payloard(self, messages: list) -> list:
376
419
377
420
# Add the last role's messages to the list
378
421
if current_content :
379
- reformatted_messages .append ({"role" : current_role , "content" : current_content })
422
+ reformatted_messages .append (
423
+ {"role" : current_role , "content" : current_content }
424
+ )
380
425
381
426
return reformatted_messages
382
427
@@ -414,9 +459,13 @@ def _parse_request(self, chat_request: ChatRequest) -> dict:
414
459
# Use max_completion_tokens if provided.
415
460
416
461
max_tokens = (
417
- chat_request .max_completion_tokens if chat_request .max_completion_tokens else chat_request .max_tokens
462
+ chat_request .max_completion_tokens
463
+ if chat_request .max_completion_tokens
464
+ else chat_request .max_tokens
465
+ )
466
+ budget_tokens = self ._calc_budget_tokens (
467
+ max_tokens , chat_request .reasoning_effort
418
468
)
419
- budget_tokens = self ._calc_budget_tokens (max_tokens , chat_request .reasoning_effort )
420
469
inference_config ["maxTokens" ] = max_tokens
421
470
# unset topP - Not supported
422
471
inference_config .pop ("topP" )
@@ -428,7 +477,9 @@ def _parse_request(self, chat_request: ChatRequest) -> dict:
428
477
if chat_request .tools :
429
478
tool_config = {"tools" : [self ._convert_tool_spec (t .function ) for t in chat_request .tools ]}
430
479
431
- if chat_request .tool_choice and not chat_request .model .startswith ("meta.llama3-1-" ):
480
+ if chat_request .tool_choice and not chat_request .model .startswith (
481
+ "meta.llama3-1-"
482
+ ):
432
483
if isinstance (chat_request .tool_choice , str ):
433
484
# auto (default) is mapped to {"auto" : {}}
434
485
# required is mapped to {"any" : {}}
@@ -477,11 +528,15 @@ def _create_response(
477
528
message .content = ""
478
529
for c in content :
479
530
if "reasoningContent" in c :
480
- message .reasoning_content = c ["reasoningContent" ]["reasoningText" ].get ("text" , "" )
531
+ message .reasoning_content = c ["reasoningContent" ][
532
+ "reasoningText"
533
+ ].get ("text" , "" )
481
534
elif "text" in c :
482
535
message .content = c ["text" ]
483
536
else :
484
- logger .warning ("Unknown tag in message content " + "," .join (c .keys ()))
537
+ logger .warning (
538
+ "Unknown tag in message content " + "," .join (c .keys ())
539
+ )
485
540
486
541
response = ChatResponse (
487
542
id = message_id ,
@@ -505,7 +560,9 @@ def _create_response(
505
560
response .created = int (time .time ())
506
561
return response
507
562
508
- def _create_response_stream (self , model_id : str , message_id : str , chunk : dict ) -> ChatStreamResponse | None :
563
+ def _create_response_stream (
564
+ self , model_id : str , message_id : str , chunk : dict
565
+ ) -> ChatStreamResponse | None :
509
566
"""Parsing the Bedrock stream response chunk.
510
567
511
568
Ref: https://docs.aws.amazon.com/bedrock/latest/userguide/conversation-inference.html#message-inference-examples
@@ -627,7 +684,9 @@ def _parse_image(self, image_url: str) -> tuple[bytes, str]:
627
684
image_content = response .content
628
685
return image_content , content_type
629
686
else :
630
- raise HTTPException (status_code = 500 , detail = "Unable to access the image url" )
687
+ raise HTTPException (
688
+ status_code = 500 , detail = "Unable to access the image url"
689
+ )
631
690
632
691
def _parse_content_parts (
633
692
self ,
@@ -687,7 +746,9 @@ def _convert_tool_spec(self, func: Function) -> dict:
687
746
}
688
747
}
689
748
690
- def _calc_budget_tokens (self , max_tokens : int , reasoning_effort : Literal ["low" , "medium" , "high" ]) -> int :
749
+ def _calc_budget_tokens (
750
+ self , max_tokens : int , reasoning_effort : Literal ["low" , "medium" , "high" ]
751
+ ) -> int :
691
752
# Helper function to calculate budget_tokens based on the max_tokens.
692
753
# Ratio for efforts: Low - 30%, medium - 60%, High: Max token - 1
693
754
# Note that The minimum budget_tokens is 1,024 tokens so far.
@@ -718,7 +779,9 @@ def _convert_finish_reason(self, finish_reason: str | None) -> str | None:
718
779
"complete" : "stop" ,
719
780
"content_filtered" : "content_filter" ,
720
781
}
721
- return finish_reason_mapping .get (finish_reason .lower (), finish_reason .lower ())
782
+ return finish_reason_mapping .get (
783
+ finish_reason .lower (), finish_reason .lower ()
784
+ )
722
785
return None
723
786
724
787
@@ -809,7 +872,9 @@ def _parse_args(self, embeddings_request: EmbeddingsRequest) -> dict:
809
872
return args
810
873
811
874
def embed (self , embeddings_request : EmbeddingsRequest ) -> EmbeddingsResponse :
812
- response = self ._invoke_model (args = self ._parse_args (embeddings_request ), model_id = embeddings_request .model )
875
+ response = self ._invoke_model (
876
+ args = self ._parse_args (embeddings_request ), model_id = embeddings_request .model
877
+ )
813
878
response_body = json .loads (response .get ("body" ).read ())
814
879
if DEBUG :
815
880
logger .info ("Bedrock response body: " + str (response_body ))
@@ -825,10 +890,15 @@ class TitanEmbeddingsModel(BedrockEmbeddingsModel):
825
890
def _parse_args (self , embeddings_request : EmbeddingsRequest ) -> dict :
826
891
if isinstance (embeddings_request .input , str ):
827
892
input_text = embeddings_request .input
828
- elif isinstance (embeddings_request .input , list ) and len (embeddings_request .input ) == 1 :
893
+ elif (
894
+ isinstance (embeddings_request .input , list )
895
+ and len (embeddings_request .input ) == 1
896
+ ):
829
897
input_text = embeddings_request .input [0 ]
830
898
else :
831
- raise ValueError ("Amazon Titan Embeddings models support only single strings as input." )
899
+ raise ValueError (
900
+ "Amazon Titan Embeddings models support only single strings as input."
901
+ )
832
902
args = {
833
903
"inputText" : input_text ,
834
904
# Note: inputImage is not supported!
@@ -842,7 +912,9 @@ def _parse_args(self, embeddings_request: EmbeddingsRequest) -> dict:
842
912
return args
843
913
844
914
def embed (self , embeddings_request : EmbeddingsRequest ) -> EmbeddingsResponse :
845
- response = self ._invoke_model (args = self ._parse_args (embeddings_request ), model_id = embeddings_request .model )
915
+ response = self ._invoke_model (
916
+ args = self ._parse_args (embeddings_request ), model_id = embeddings_request .model
917
+ )
846
918
response_body = json .loads (response .get ("body" ).read ())
847
919
if DEBUG :
848
920
logger .info ("Bedrock response body: " + str (response_body ))
0 commit comments