Skip to content

Commit 0183608

Browse files
theTechiezxkane
andauthored
feat: add support to include application inference profiles as models (#131)
--------- Co-authored-by: Mengxin Zhu <[email protected]>
1 parent dd191d7 commit 0183608

File tree

5 files changed

+139
-19
lines changed

5 files changed

+139
-19
lines changed

README.md

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ If you find this GitHub repository useful, please consider giving it a free star
2626
- [x] Support Embedding API
2727
- [x] Support Multimodal API
2828
- [x] Support Cross-Region Inference
29+
- [x] Support Application Inference Profiles (**new**)
2930
- [x] Support Reasoning (**new**)
3031

3132
Please check [Usage Guide](./docs/Usage.md) for more details about how to use the new APIs.
@@ -148,7 +149,48 @@ print(completion.choices[0].message.content)
148149

149150
Please check [Usage Guide](./docs/Usage.md) for more details about how to use embedding API, multimodal API and tool call.
150151

152+
### Application Inference Profiles
151153

154+
This proxy now supports **Application Inference Profiles**, which allow you to track usage and costs for your model invocations. You can use application inference profiles created in your AWS account for cost tracking and monitoring purposes.
155+
156+
**Using Application Inference Profiles:**
157+
158+
```bash
159+
# Use an application inference profile ARN as the model ID
160+
curl $OPENAI_BASE_URL/chat/completions \
161+
-H "Content-Type: application/json" \
162+
-H "Authorization: Bearer $OPENAI_API_KEY" \
163+
-d '{
164+
"model": "arn:aws:bedrock:us-west-2:123456789012:application-inference-profile/your-profile-id",
165+
"messages": [
166+
{
167+
"role": "user",
168+
"content": "Hello!"
169+
}
170+
]
171+
}'
172+
```
173+
174+
**SDK Usage with Application Inference Profiles:**
175+
176+
```python
177+
from openai import OpenAI
178+
179+
client = OpenAI()
180+
completion = client.chat.completions.create(
181+
model="arn:aws:bedrock:us-west-2:123456789012:application-inference-profile/your-profile-id",
182+
messages=[{"role": "user", "content": "Hello!"}],
183+
)
184+
185+
print(completion.choices[0].message.content)
186+
```
187+
188+
**Benefits of Application Inference Profiles:**
189+
- **Cost Tracking**: Track usage and costs for specific applications or use cases
190+
- **Usage Monitoring**: Monitor model invocation metrics through CloudWatch
191+
- **Tag-based Cost Allocation**: Use AWS cost allocation tags for detailed billing analysis
192+
193+
For more information about creating and managing application inference profiles, see the [Amazon Bedrock User Guide](https://docs.aws.amazon.com/bedrock/latest/userguide/inference-profiles-create.html).
152194

153195
## Other Examples
154196

deployment/BedrockProxy.template

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -151,6 +151,7 @@ Resources:
151151
Resource:
152152
- arn:aws:bedrock:*::foundation-model/*
153153
- arn:aws:bedrock:*:*:inference-profile/*
154+
- arn:aws:bedrock:*:*:application-inference-profile/*
154155
- Action:
155156
- secretsmanager:GetSecretValue
156157
- secretsmanager:DescribeSecret
@@ -185,6 +186,7 @@ Resources:
185186
Ref: DefaultModelId
186187
DEFAULT_EMBEDDING_MODEL: cohere.embed-multilingual-v3
187188
ENABLE_CROSS_REGION_INFERENCE: "true"
189+
ENABLE_APPLICATION_INFERENCE_PROFILES: "true"
188190
MemorySize: 1024
189191
PackageType: Image
190192
Role:

deployment/BedrockProxyFargate.template

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -193,6 +193,7 @@ Resources:
193193
Resource:
194194
- arn:aws:bedrock:*::foundation-model/*
195195
- arn:aws:bedrock:*:*:inference-profile/*
196+
- arn:aws:bedrock:*:*:application-inference-profile/*
196197
Version: "2012-10-17"
197198
PolicyName: ProxyTaskRoleDefaultPolicy933321B8
198199
Roles:
@@ -222,6 +223,8 @@ Resources:
222223
Value: cohere.embed-multilingual-v3
223224
- Name: ENABLE_CROSS_REGION_INFERENCE
224225
Value: "true"
226+
- Name: ENABLE_APPLICATION_INFERENCE_PROFILES
227+
Value: "true"
225228
Essential: true
226229
Image:
227230
Fn::Join:

src/api/models/bedrock.py

Lines changed: 91 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,13 @@
3838
Usage,
3939
UserMessage,
4040
)
41-
from api.setting import AWS_REGION, DEBUG, DEFAULT_MODEL, ENABLE_CROSS_REGION_INFERENCE
41+
from api.setting import (
42+
AWS_REGION,
43+
DEBUG,
44+
DEFAULT_MODEL,
45+
ENABLE_CROSS_REGION_INFERENCE,
46+
ENABLE_APPLICATION_INFERENCE_PROFILES,
47+
)
4248

4349
logger = logging.getLogger(__name__)
4450

@@ -83,15 +89,40 @@ def list_bedrock_models() -> dict:
8389
Returns a model list combines:
8490
- ON_DEMAND models.
8591
- Cross-Region Inference Profiles (if enabled via Env)
92+
- Application Inference Profiles (if enabled via Env)
8693
"""
8794
model_list = {}
8895
try:
8996
profile_list = []
97+
app_profile_dict = {}
98+
9099
if ENABLE_CROSS_REGION_INFERENCE:
91100
# List system defined inference profile IDs
92101
response = bedrock_client.list_inference_profiles(maxResults=1000, typeEquals="SYSTEM_DEFINED")
93102
profile_list = [p["inferenceProfileId"] for p in response["inferenceProfileSummaries"]]
94103

104+
if ENABLE_APPLICATION_INFERENCE_PROFILES:
105+
# List application defined inference profile IDs and create mapping
106+
response = bedrock_client.list_inference_profiles(maxResults=1000, typeEquals="APPLICATION")
107+
108+
for profile in response["inferenceProfileSummaries"]:
109+
try:
110+
profile_arn = profile.get("inferenceProfileArn")
111+
if not profile_arn:
112+
continue
113+
114+
# Process all models in the profile
115+
models = profile.get("models", [])
116+
for model in models:
117+
model_arn = model.get("modelArn", "")
118+
if model_arn:
119+
model_id = model_arn.split('/')[-1] if '/' in model_arn else model_arn
120+
if model_id:
121+
app_profile_dict[model_id] = profile_arn
122+
except Exception as e:
123+
logger.warning(f"Error processing application profile: {e}")
124+
continue
125+
95126
# List foundation models, only cares about text outputs here.
96127
response = bedrock_client.list_foundation_models(byOutputModality="TEXT")
97128

@@ -115,6 +146,10 @@ def list_bedrock_models() -> dict:
115146
if profile_id in profile_list:
116147
model_list[profile_id] = {"modalities": input_modalities}
117148

149+
# Add application inference profiles
150+
if model_id in app_profile_dict:
151+
model_list[app_profile_dict[model_id]] = {"modalities": input_modalities}
152+
118153
except Exception as e:
119154
logger.error(f"Unable to list models: {str(e)}")
120155

@@ -162,7 +197,9 @@ async def _invoke_bedrock(self, chat_request: ChatRequest, stream=False):
162197
try:
163198
if stream:
164199
# Run the blocking boto3 call in a thread pool
165-
response = await run_in_threadpool(bedrock_runtime.converse_stream, **args)
200+
response = await run_in_threadpool(
201+
bedrock_runtime.converse_stream, **args
202+
)
166203
else:
167204
# Run the blocking boto3 call in a thread pool
168205
response = await run_in_threadpool(bedrock_runtime.converse, **args)
@@ -274,7 +311,9 @@ def _parse_messages(self, chat_request: ChatRequest) -> list[dict]:
274311
messages.append(
275312
{
276313
"role": message.role,
277-
"content": self._parse_content_parts(message, chat_request.model),
314+
"content": self._parse_content_parts(
315+
message, chat_request.model
316+
),
278317
}
279318
)
280319
elif isinstance(message, AssistantMessage):
@@ -283,7 +322,9 @@ def _parse_messages(self, chat_request: ChatRequest) -> list[dict]:
283322
messages.append(
284323
{
285324
"role": message.role,
286-
"content": self._parse_content_parts(message, chat_request.model),
325+
"content": self._parse_content_parts(
326+
message, chat_request.model
327+
),
287328
}
288329
)
289330
if message.tool_calls:
@@ -363,7 +404,9 @@ def _reframe_multi_payloard(self, messages: list) -> list:
363404
# If the next role is different from the previous message, add the previous role's messages to the list
364405
if next_role != current_role:
365406
if current_content:
366-
reformatted_messages.append({"role": current_role, "content": current_content})
407+
reformatted_messages.append(
408+
{"role": current_role, "content": current_content}
409+
)
367410
# Switch to the new role
368411
current_role = next_role
369412
current_content = []
@@ -376,7 +419,9 @@ def _reframe_multi_payloard(self, messages: list) -> list:
376419

377420
# Add the last role's messages to the list
378421
if current_content:
379-
reformatted_messages.append({"role": current_role, "content": current_content})
422+
reformatted_messages.append(
423+
{"role": current_role, "content": current_content}
424+
)
380425

381426
return reformatted_messages
382427

@@ -414,9 +459,13 @@ def _parse_request(self, chat_request: ChatRequest) -> dict:
414459
# Use max_completion_tokens if provided.
415460

416461
max_tokens = (
417-
chat_request.max_completion_tokens if chat_request.max_completion_tokens else chat_request.max_tokens
462+
chat_request.max_completion_tokens
463+
if chat_request.max_completion_tokens
464+
else chat_request.max_tokens
465+
)
466+
budget_tokens = self._calc_budget_tokens(
467+
max_tokens, chat_request.reasoning_effort
418468
)
419-
budget_tokens = self._calc_budget_tokens(max_tokens, chat_request.reasoning_effort)
420469
inference_config["maxTokens"] = max_tokens
421470
# unset topP - Not supported
422471
inference_config.pop("topP")
@@ -428,7 +477,9 @@ def _parse_request(self, chat_request: ChatRequest) -> dict:
428477
if chat_request.tools:
429478
tool_config = {"tools": [self._convert_tool_spec(t.function) for t in chat_request.tools]}
430479

431-
if chat_request.tool_choice and not chat_request.model.startswith("meta.llama3-1-"):
480+
if chat_request.tool_choice and not chat_request.model.startswith(
481+
"meta.llama3-1-"
482+
):
432483
if isinstance(chat_request.tool_choice, str):
433484
# auto (default) is mapped to {"auto" : {}}
434485
# required is mapped to {"any" : {}}
@@ -477,11 +528,15 @@ def _create_response(
477528
message.content = ""
478529
for c in content:
479530
if "reasoningContent" in c:
480-
message.reasoning_content = c["reasoningContent"]["reasoningText"].get("text", "")
531+
message.reasoning_content = c["reasoningContent"][
532+
"reasoningText"
533+
].get("text", "")
481534
elif "text" in c:
482535
message.content = c["text"]
483536
else:
484-
logger.warning("Unknown tag in message content " + ",".join(c.keys()))
537+
logger.warning(
538+
"Unknown tag in message content " + ",".join(c.keys())
539+
)
485540

486541
response = ChatResponse(
487542
id=message_id,
@@ -505,7 +560,9 @@ def _create_response(
505560
response.created = int(time.time())
506561
return response
507562

508-
def _create_response_stream(self, model_id: str, message_id: str, chunk: dict) -> ChatStreamResponse | None:
563+
def _create_response_stream(
564+
self, model_id: str, message_id: str, chunk: dict
565+
) -> ChatStreamResponse | None:
509566
"""Parsing the Bedrock stream response chunk.
510567
511568
Ref: https://docs.aws.amazon.com/bedrock/latest/userguide/conversation-inference.html#message-inference-examples
@@ -627,7 +684,9 @@ def _parse_image(self, image_url: str) -> tuple[bytes, str]:
627684
image_content = response.content
628685
return image_content, content_type
629686
else:
630-
raise HTTPException(status_code=500, detail="Unable to access the image url")
687+
raise HTTPException(
688+
status_code=500, detail="Unable to access the image url"
689+
)
631690

632691
def _parse_content_parts(
633692
self,
@@ -687,7 +746,9 @@ def _convert_tool_spec(self, func: Function) -> dict:
687746
}
688747
}
689748

690-
def _calc_budget_tokens(self, max_tokens: int, reasoning_effort: Literal["low", "medium", "high"]) -> int:
749+
def _calc_budget_tokens(
750+
self, max_tokens: int, reasoning_effort: Literal["low", "medium", "high"]
751+
) -> int:
691752
# Helper function to calculate budget_tokens based on the max_tokens.
692753
# Ratio for efforts: Low - 30%, medium - 60%, High: Max token - 1
693754
# Note that The minimum budget_tokens is 1,024 tokens so far.
@@ -718,7 +779,9 @@ def _convert_finish_reason(self, finish_reason: str | None) -> str | None:
718779
"complete": "stop",
719780
"content_filtered": "content_filter",
720781
}
721-
return finish_reason_mapping.get(finish_reason.lower(), finish_reason.lower())
782+
return finish_reason_mapping.get(
783+
finish_reason.lower(), finish_reason.lower()
784+
)
722785
return None
723786

724787

@@ -809,7 +872,9 @@ def _parse_args(self, embeddings_request: EmbeddingsRequest) -> dict:
809872
return args
810873

811874
def embed(self, embeddings_request: EmbeddingsRequest) -> EmbeddingsResponse:
812-
response = self._invoke_model(args=self._parse_args(embeddings_request), model_id=embeddings_request.model)
875+
response = self._invoke_model(
876+
args=self._parse_args(embeddings_request), model_id=embeddings_request.model
877+
)
813878
response_body = json.loads(response.get("body").read())
814879
if DEBUG:
815880
logger.info("Bedrock response body: " + str(response_body))
@@ -825,10 +890,15 @@ class TitanEmbeddingsModel(BedrockEmbeddingsModel):
825890
def _parse_args(self, embeddings_request: EmbeddingsRequest) -> dict:
826891
if isinstance(embeddings_request.input, str):
827892
input_text = embeddings_request.input
828-
elif isinstance(embeddings_request.input, list) and len(embeddings_request.input) == 1:
893+
elif (
894+
isinstance(embeddings_request.input, list)
895+
and len(embeddings_request.input) == 1
896+
):
829897
input_text = embeddings_request.input[0]
830898
else:
831-
raise ValueError("Amazon Titan Embeddings models support only single strings as input.")
899+
raise ValueError(
900+
"Amazon Titan Embeddings models support only single strings as input."
901+
)
832902
args = {
833903
"inputText": input_text,
834904
# Note: inputImage is not supported!
@@ -842,7 +912,9 @@ def _parse_args(self, embeddings_request: EmbeddingsRequest) -> dict:
842912
return args
843913

844914
def embed(self, embeddings_request: EmbeddingsRequest) -> EmbeddingsResponse:
845-
response = self._invoke_model(args=self._parse_args(embeddings_request), model_id=embeddings_request.model)
915+
response = self._invoke_model(
916+
args=self._parse_args(embeddings_request), model_id=embeddings_request.model
917+
)
846918
response_body = json.loads(response.get("body").read())
847919
if DEBUG:
848920
logger.info("Bedrock response body: " + str(response_body))

src/api/setting.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,3 +16,4 @@
1616
DEFAULT_MODEL = os.environ.get("DEFAULT_MODEL", "anthropic.claude-3-sonnet-20240229-v1:0")
1717
DEFAULT_EMBEDDING_MODEL = os.environ.get("DEFAULT_EMBEDDING_MODEL", "cohere.embed-multilingual-v3")
1818
ENABLE_CROSS_REGION_INFERENCE = os.environ.get("ENABLE_CROSS_REGION_INFERENCE", "true").lower() != "false"
19+
ENABLE_APPLICATION_INFERENCE_PROFILES = os.environ.get("ENABLE_APPLICATION_INFERENCE_PROFILES", "true").lower() != "false"

0 commit comments

Comments
 (0)