Skip to content

Commit c1ff680

Browse files
Add support for Bedrock ARNs for regional support (#149)
* fix issue with async call deep within the mistral library * simplify the async stuff to encapsulate in the evaluators * get model arn by suffix * merge main and resolve conflicts * fix pytest * fix mypy --------- Co-authored-by: jwatson <[email protected]>
1 parent fb08ae1 commit c1ff680

File tree

3 files changed

+67
-13
lines changed

3 files changed

+67
-13
lines changed

llm-service/app/services/evaluators.py

Lines changed: 24 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -35,36 +35,53 @@
3535
# BUSINESS ADVANTAGE OR UNAVAILABILITY, OR LOSS OR CORRUPTION OF
3636
# DATA.
3737
# ##############################################################################
38+
import asyncio
3839

3940
from llama_index.core.base.response.schema import Response
4041
from llama_index.core.chat_engine.types import AgentChatResponse
41-
from llama_index.core.evaluation import FaithfulnessEvaluator, RelevancyEvaluator
42+
from llama_index.core.evaluation import (
43+
FaithfulnessEvaluator,
44+
RelevancyEvaluator,
45+
EvaluationResult,
46+
)
47+
from llama_index.core.llms import LLM
4248

4349
from ..services import models
4450

4551

4652
def evaluate_response(
47-
query: str, chat_response: AgentChatResponse, model_name: str
53+
query: str, chat_response: AgentChatResponse, model_name: str
4854
) -> tuple[float, float]:
4955
# todo: pass in the correct llm model and use it, rather than requiring querying for it like this.
5056
evaluator_llm = models.LLM.get(model_name)
57+
return asyncio.run(_async_evaluate_response(query, chat_response, evaluator_llm))
5158

52-
relevancy_evaluator = RelevancyEvaluator(llm=evaluator_llm)
53-
relevance = relevancy_evaluator.evaluate_response(
59+
60+
async def _async_evaluate_response(query: str, chat_response: AgentChatResponse, evaluator_llm: LLM) -> tuple[float, float] :
61+
relevance = await _evaluate_relevancy(chat_response, evaluator_llm, query)
62+
faithfulness = await _evaluate_faithfulness(chat_response, evaluator_llm, query)
63+
return relevance.score or 0, faithfulness.score or 0
64+
65+
66+
async def _evaluate_faithfulness(chat_response: AgentChatResponse, evaluator_llm: LLM, query: str) -> EvaluationResult:
67+
faithfulness_evaluator = FaithfulnessEvaluator(llm=evaluator_llm)
68+
return await faithfulness_evaluator.aevaluate_response(
5469
query=query,
5570
response=Response(
5671
response=chat_response.response,
5772
source_nodes=chat_response.source_nodes,
5873
metadata=chat_response.metadata,
5974
),
6075
)
61-
faithfulness_evaluator = FaithfulnessEvaluator(llm=evaluator_llm)
62-
faithfulness = faithfulness_evaluator.evaluate_response(
76+
77+
78+
async def _evaluate_relevancy(chat_response: AgentChatResponse, evaluator_llm: LLM, query: str) -> EvaluationResult:
79+
relevancy_evaluator = RelevancyEvaluator(llm=evaluator_llm)
80+
return await relevancy_evaluator.aevaluate_response(
6381
query=query,
6482
response=Response(
6583
response=chat_response.response,
6684
source_nodes=chat_response.source_nodes,
6785
metadata=chat_response.metadata,
6886
),
6987
)
70-
return relevance.score or 0, faithfulness.score or 0

llm-service/app/services/models/_bedrock.py

Lines changed: 32 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,10 @@
3535
# BUSINESS ADVANTAGE OR UNAVAILABILITY, OR LOSS OR CORRUPTION OF
3636
# DATA.
3737
#
38-
from typing import List
38+
import os
39+
from typing import List, Optional
40+
41+
import boto3
3942

4043
from app.services.caii.types import ModelResponse
4144
from ._model_provider import ModelProvider
@@ -51,20 +54,43 @@ def get_env_var_names() -> set[str]:
5154

5255
@staticmethod
5356
def get_llm_models() -> List[ModelResponse]:
54-
return [
57+
models = [
5558
ModelResponse(
56-
model_id=DEFAULT_BEDROCK_LLM_MODEL,
57-
name="Llama3.1 8B Instruct v1",
59+
model_id=DEFAULT_BEDROCK_LLM_MODEL, name="Llama3.1 8B Instruct v1"
5860
),
5961
ModelResponse(
6062
model_id="meta.llama3-1-70b-instruct-v1:0",
6163
name="Llama3.1 70B Instruct v1",
6264
),
6365
ModelResponse(
64-
model_id="cohere.command-r-plus-v1:0",
65-
name="Cohere Command R Plus v1",
66+
model_id="cohere.command-r-plus-v1:0", name="Cohere Command R Plus v1"
6667
),
6768
]
69+
llama323b = BedrockModelProvider._get_model_arn_by_suffix(
70+
"meta.llama3-2-3b-instruct-v1:0"
71+
)
72+
if llama323b:
73+
models.append(llama323b)
74+
llama321b = BedrockModelProvider._get_model_arn_by_suffix(
75+
"meta.llama3-2-1b-instruct-v1:0"
76+
)
77+
if llama321b:
78+
models.append(llama321b)
79+
80+
return models
81+
82+
@staticmethod
83+
def _get_model_arn_by_suffix(suffix: str) -> Optional[ModelResponse]:
84+
default_region = os.environ.get("AWS_DEFAULT_REGION") or None
85+
bedrock_client = boto3.client("bedrock", region_name=default_region)
86+
profiles = bedrock_client.list_inference_profiles()["inferenceProfileSummaries"]
87+
for profile in profiles:
88+
if profile["inferenceProfileId"].endswith(suffix):
89+
return ModelResponse(
90+
model_id=profile["inferenceProfileId"],
91+
name=profile["inferenceProfileName"],
92+
)
93+
return None
6894

6995
@staticmethod
7096
def get_embedding_models() -> List[ModelResponse]:

llm-service/app/tests/conftest.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,9 +52,11 @@
5252

5353
from app.ai.vector_stores.qdrant import QdrantVectorStore
5454
from app.main import app
55+
from app.services.caii.types import ModelResponse
5556
from app.services.metadata_apis import data_sources_metadata_api
5657
from app.services import models
5758
from app.services.metadata_apis.data_sources_metadata_api import RagDataSource
59+
from app.services.models._bedrock import BedrockModelProvider
5860

5961

6062
@dataclass
@@ -215,3 +217,12 @@ def client() -> Iterator[TestClient]:
215217
"""
216218
with TestClient(app) as test_client:
217219
yield test_client
220+
221+
222+
@pytest.fixture(autouse=True)
223+
def _get_model_arn_by_suffix(monkeypatch: pytest.MonkeyPatch) -> None:
224+
monkeypatch.setattr(
225+
BedrockModelProvider,
226+
"_get_model_arn_by_suffix",
227+
lambda name: ModelResponse(model_id=f"us.{name}", name=name),
228+
)

0 commit comments

Comments
 (0)