1
- from typing import Any , AsyncGenerator , Dict , Iterator , List , Optional
1
+ from typing import Any , Dict , List , Optional
2
2
from langchain_core .callbacks .manager import (
3
3
AsyncCallbackManagerForLLMRun ,
4
4
CallbackManagerForLLMRun
5
5
)
6
- from langchain_core . outputs import GenerationChunk
6
+ import logging
7
7
from langchain_core .language_models import LLM
8
- from pydantic import ConfigDict
8
+ from pydantic import ConfigDict , model_validator
9
9
import json
10
10
import asyncio
11
11
import boto3
12
+ from typing_extensions import Self
12
13
13
14
class AmazonQ (LLM ):
14
15
"""Amazon Q LLM wrapper.
@@ -45,6 +46,13 @@ class AmazonQ(LLM):
45
46
chat_mode : str = "RETRIEVAL_MODE"
46
47
"""AWS region name. If not provided, will be extracted from environment."""
47
48
49
+ credentials_profile_name : Optional [str ] = None
50
+ """The name of the profile in the ~/.aws/credentials or ~/.aws/config files, which
51
+ has either access keys or role information specified.
52
+ If not specified, the default credential profile or, if on an EC2 instance,
53
+ credentials from IMDS will be used.
54
+ See: https://boto3.amazonaws.com/v1/documentation/api/latest/guide/credentials.html
55
+ """
48
56
model_config = ConfigDict (
49
57
extra = "forbid" ,
50
58
)
@@ -85,24 +93,15 @@ def _call(
85
93
86
94
# Prepare the request
87
95
request = {
88
- 'applicationId' : "130f4ea4-855f-4ddf-b2a5-1e40923692d4" ,
96
+ 'applicationId' : self . application_id ,
89
97
'userMessage' : prompt ,
90
98
'chatMode' :self .chat_mode ,
91
99
}
92
- if not self .conversation_id :
93
- request = {
94
- 'applicationId' : self .application_id ,
95
- 'userMessage' : prompt ,
96
- 'chatMode' :self .chat_mode ,
97
- }
98
- else :
99
- request = {
100
- 'applicationId' : self .application_id ,
101
- 'userMessage' : prompt ,
102
- 'chatMode' :self .chat_mode ,
103
- 'conversationId' :self .conversation_id ,
104
- 'parentMessageId' :self .parent_message_id ,
105
- }
100
+ if self .conversation_id :
101
+ request .update ({
102
+ 'conversationId' : self .conversation_id ,
103
+ 'parentMessageId' : self .parent_message_id ,
104
+ })
106
105
107
106
# Call Amazon Q
108
107
response = self .client .chat_sync (** request )
@@ -115,6 +114,12 @@ def _call(
115
114
raise ValueError ("Unexpected response format from Amazon Q" )
116
115
117
116
except Exception as e :
117
+ if "Prompt Length" in str (e ):
118
+ logging .info (f"Prompt Length: { len (prompt )} " )
119
+ print (f"""Prompt:
120
+ { prompt } """ )
121
+ raise ValueError (f"Error raised by Amazon Q service: { e } " )
122
+
118
123
raise ValueError (f"Error raised by Amazon Q service: { e } " )
119
124
120
125
def get_last_response (self ) -> Dict :
@@ -143,4 +148,38 @@ def _identifying_params(self) -> Dict[str, Any]:
143
148
"""Get the identifying parameters."""
144
149
return {
145
150
"region_name" : self .region_name ,
146
- }
151
+ }
152
+ @model_validator (mode = "after" )
153
+ def validate_environment (self ) -> Self :
154
+ """Dont do anything if client provided externally"""
155
+ if self .client is not None :
156
+ return self
157
+
158
+ """Validate that AWS credentials to and python package exists in environment."""
159
+ try :
160
+ import boto3
161
+
162
+ try :
163
+ if self .credentials_profile_name is not None :
164
+ session = boto3 .Session (profile_name = self .credentials_profile_name )
165
+ else :
166
+ # use default credentials
167
+ session = boto3 .Session ()
168
+
169
+ self .client = session .client (
170
+ "qbusiness" , region_name = self .region_name
171
+ )
172
+
173
+ except Exception as e :
174
+ raise ValueError (
175
+ "Could not load credentials to authenticate with AWS client. "
176
+ "Please check that credentials in the specified "
177
+ "profile name are valid."
178
+ ) from e
179
+
180
+ except ImportError :
181
+ raise ImportError (
182
+ "Could not import boto3 python package. "
183
+ "Please install it with `pip install boto3`."
184
+ )
185
+ return self
0 commit comments