Skip to content

Commit 40c8306

Browse files
committed
Made changes in response to comments on langchain-ai#30
1 parent c9785c1 commit 40c8306

File tree

1 file changed

+58
-19
lines changed

1 file changed

+58
-19
lines changed

libs/aws/langchain_aws/llms/q_business.py

Lines changed: 58 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,15 @@
1-
from typing import Any, AsyncGenerator, Dict, Iterator, List, Optional
1+
from typing import Any, Dict, List, Optional
22
from langchain_core.callbacks.manager import (
33
AsyncCallbackManagerForLLMRun,
44
CallbackManagerForLLMRun
55
)
6-
from langchain_core.outputs import GenerationChunk
6+
import logging
77
from langchain_core.language_models import LLM
8-
from pydantic import ConfigDict
8+
from pydantic import ConfigDict, model_validator
99
import json
1010
import asyncio
1111
import boto3
12+
from typing_extensions import Self
1213

1314
class AmazonQ(LLM):
1415
"""Amazon Q LLM wrapper.
@@ -45,6 +46,13 @@ class AmazonQ(LLM):
4546
chat_mode: str = "RETRIEVAL_MODE"
4647
"""AWS region name. If not provided, will be extracted from environment."""
4748

49+
credentials_profile_name: Optional[str] = None
50+
"""The name of the profile in the ~/.aws/credentials or ~/.aws/config files, which
51+
has either access keys or role information specified.
52+
If not specified, the default credential profile or, if on an EC2 instance,
53+
credentials from IMDS will be used.
54+
See: https://boto3.amazonaws.com/v1/documentation/api/latest/guide/credentials.html
55+
"""
4856
model_config = ConfigDict(
4957
extra="forbid",
5058
)
@@ -85,24 +93,15 @@ def _call(
8593

8694
# Prepare the request
8795
request = {
88-
'applicationId': "130f4ea4-855f-4ddf-b2a5-1e40923692d4",
96+
'applicationId': self.application_id,
8997
'userMessage': prompt,
9098
'chatMode':self.chat_mode,
9199
}
92-
if not self.conversation_id:
93-
request = {
94-
'applicationId': self.application_id,
95-
'userMessage': prompt,
96-
'chatMode':self.chat_mode,
97-
}
98-
else:
99-
request = {
100-
'applicationId': self.application_id,
101-
'userMessage': prompt,
102-
'chatMode':self.chat_mode,
103-
'conversationId':self.conversation_id,
104-
'parentMessageId':self.parent_message_id,
105-
}
100+
if self.conversation_id:
101+
request.update({
102+
'conversationId': self.conversation_id,
103+
'parentMessageId': self.parent_message_id,
104+
})
106105

107106
# Call Amazon Q
108107
response = self.client.chat_sync(**request)
@@ -115,6 +114,12 @@ def _call(
115114
raise ValueError("Unexpected response format from Amazon Q")
116115

117116
except Exception as e:
117+
if "Prompt Length" in str(e):
118+
logging.info(f"Prompt Length: {len(prompt)}")
119+
print(f"""Prompt:
120+
{prompt}""")
121+
raise ValueError(f"Error raised by Amazon Q service: {e}")
122+
118123
raise ValueError(f"Error raised by Amazon Q service: {e}")
119124

120125
def get_last_response(self) -> Dict:
@@ -143,4 +148,38 @@ def _identifying_params(self) -> Dict[str, Any]:
143148
"""Get the identifying parameters."""
144149
return {
145150
"region_name": self.region_name,
146-
}
151+
}
152+
@model_validator(mode="after")
153+
def validate_environment(self) -> Self:
154+
"""Dont do anything if client provided externally"""
155+
if self.client is not None:
156+
return self
157+
158+
"""Validate that AWS credentials to and python package exists in environment."""
159+
try:
160+
import boto3
161+
162+
try:
163+
if self.credentials_profile_name is not None:
164+
session = boto3.Session(profile_name=self.credentials_profile_name)
165+
else:
166+
# use default credentials
167+
session = boto3.Session()
168+
169+
self.client = session.client(
170+
"qbusiness", region_name=self.region_name
171+
)
172+
173+
except Exception as e:
174+
raise ValueError(
175+
"Could not load credentials to authenticate with AWS client. "
176+
"Please check that credentials in the specified "
177+
"profile name are valid."
178+
) from e
179+
180+
except ImportError:
181+
raise ImportError(
182+
"Could not import boto3 python package. "
183+
"Please install it with `pip install boto3`."
184+
)
185+
return self

0 commit comments

Comments
 (0)