40
40
Part ,
41
41
ResponseValidationError ,
42
42
Tool as VertexAITool ,
43
+ ToolConfig ,
43
44
)
44
45
except ImportError :
45
46
GenerativeModel = None
@@ -137,20 +138,17 @@ def invoke(
137
138
Returns:
138
139
LLMResponse: The response from the LLM.
139
140
"""
140
- system_message = [system_instruction ] if system_instruction is not None else []
141
- self .model = GenerativeModel (
142
- model_name = self .model_name ,
143
- system_instruction = system_message ,
144
- ** self .options ,
141
+ model = self ._get_model (
142
+ system_instruction = system_instruction ,
145
143
)
146
144
try :
147
145
if isinstance (message_history , MessageHistory ):
148
146
message_history = message_history .messages
149
- messages = self .get_messages (input , message_history )
150
- response = self . model .generate_content (messages , ** self . model_params )
151
- return LLMResponse ( content = response . text )
147
+ options = self ._get_call_params (input , message_history , tools = None )
148
+ response = model .generate_content (** options )
149
+ return self . _parse_content_response ( response )
152
150
except ResponseValidationError as e :
153
- raise LLMGenerationError (e )
151
+ raise LLMGenerationError ("Error calling VertexAILLM" ) from e
154
152
155
153
async def ainvoke (
156
154
self ,
@@ -172,65 +170,81 @@ async def ainvoke(
172
170
try :
173
171
if isinstance (message_history , MessageHistory ):
174
172
message_history = message_history .messages
175
- system_message = (
176
- [system_instruction ] if system_instruction is not None else []
177
- )
178
- self .model = GenerativeModel (
179
- model_name = self .model_name ,
180
- system_instruction = system_message ,
181
- ** self .options ,
173
+ model = self ._get_model (
174
+ system_instruction = system_instruction ,
182
175
)
183
- messages = self .get_messages (input , message_history )
184
- response = await self .model .generate_content_async (
185
- messages , ** self .model_params
186
- )
187
- return LLMResponse (content = response .text )
176
+ options = self ._get_call_params (input , message_history , tools = None )
177
+ response = await model .generate_content_async (** options )
178
+ return self ._parse_content_response (response )
188
179
except ResponseValidationError as e :
189
- raise LLMGenerationError (e )
190
-
191
- def _to_vertexai_tool (self , tool : Tool ) -> VertexAITool :
192
- return VertexAITool (
193
- function_declarations = [
194
- FunctionDeclaration (
195
- name = tool .get_name (),
196
- description = tool .get_description (),
197
- parameters = tool .get_parameters (exclude = ["additional_properties" ]),
198
- )
199
- ]
180
+ raise LLMGenerationError ("Error calling VertexAILLM" ) from e
181
+
182
+ def _to_vertexai_function_declaration (self , tool : Tool ) -> FunctionDeclaration :
183
+ return FunctionDeclaration (
184
+ name = tool .get_name (),
185
+ description = tool .get_description (),
186
+ parameters = tool .get_parameters (exclude = ["additional_properties" ]),
200
187
)
201
188
202
189
def _get_llm_tools (
203
190
self , tools : Optional [Sequence [Tool ]]
204
191
) -> Optional [list [VertexAITool ]]:
205
192
if not tools :
206
193
return None
207
- return [self ._to_vertexai_tool (tool ) for tool in tools ]
194
+ return [
195
+ VertexAITool (
196
+ function_declarations = [
197
+ self ._to_vertexai_function_declaration (tool ) for tool in tools
198
+ ]
199
+ )
200
+ ]
208
201
209
202
def _get_model (
210
203
self ,
211
204
system_instruction : Optional [str ] = None ,
212
- tools : Optional [Sequence [Tool ]] = None ,
213
205
) -> GenerativeModel :
214
206
system_message = [system_instruction ] if system_instruction is not None else []
215
- vertex_ai_tools = self ._get_llm_tools (tools )
216
207
model = GenerativeModel (
217
208
model_name = self .model_name ,
218
209
system_instruction = system_message ,
219
- tools = vertex_ai_tools ,
220
- ** self .options ,
221
210
)
222
211
return model
223
212
213
+ def _get_call_params (
214
+ self ,
215
+ input : str ,
216
+ message_history : Optional [Union [List [LLMMessage ], MessageHistory ]],
217
+ tools : Optional [Sequence [Tool ]],
218
+ ) -> dict [str , Any ]:
219
+ options = dict (self .options )
220
+ if tools :
221
+ # we want a tool back, remove generation_config if defined
222
+ options .pop ("generation_config" , None )
223
+ options ["tools" ] = self ._get_llm_tools (tools )
224
+ if "tool_config" not in options :
225
+ options ["tool_config" ] = ToolConfig (
226
+ function_calling_config = ToolConfig .FunctionCallingConfig (
227
+ mode = ToolConfig .FunctionCallingConfig .Mode .ANY ,
228
+ )
229
+ )
230
+ else :
231
+ # no tools, remove tool_config if defined
232
+ options .pop ("tool_config" , None )
233
+
234
+ messages = self .get_messages (input , message_history )
235
+ options ["contents" ] = messages
236
+ return options
237
+
224
238
async def _acall_llm (
225
239
self ,
226
240
input : str ,
227
241
message_history : Optional [Union [List [LLMMessage ], MessageHistory ]] = None ,
228
242
system_instruction : Optional [str ] = None ,
229
243
tools : Optional [Sequence [Tool ]] = None ,
230
244
) -> GenerationResponse :
231
- model = self ._get_model (system_instruction = system_instruction , tools = tools )
232
- messages = self .get_messages (input , message_history )
233
- response = await model .generate_content_async (messages , ** self . model_params )
245
+ model = self ._get_model (system_instruction = system_instruction )
246
+ options = self ._get_call_params (input , message_history , tools )
247
+ response = await model .generate_content_async (** options )
234
248
return response
235
249
236
250
def _call_llm (
@@ -240,9 +254,9 @@ def _call_llm(
240
254
system_instruction : Optional [str ] = None ,
241
255
tools : Optional [Sequence [Tool ]] = None ,
242
256
) -> GenerationResponse :
243
- model = self ._get_model (system_instruction = system_instruction , tools = tools )
244
- messages = self .get_messages (input , message_history )
245
- response = model .generate_content (messages , ** self . model_params )
257
+ model = self ._get_model (system_instruction = system_instruction )
258
+ options = self ._get_call_params (input , message_history , tools )
259
+ response = model .generate_content (** options )
246
260
return response
247
261
248
262
def _to_tool_call (self , function_call : FunctionCall ) -> ToolCall :
0 commit comments