Skip to content

Commit dffd484

Browse files
authored
Fix VertexAILLM (#342)
* Fix? VertexAILLM * Ruff * Remove full generation config in case of tool calling * Ruff * Update example * Update example * Review comment part 1 * mypy * Better deal with call options * Ruff * mypy * Rm print
1 parent 9a2ee95 commit dffd484

File tree

4 files changed

+145
-78
lines changed

4 files changed

+145
-78
lines changed

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
### Fixed
1212

1313
- Fixed a bug where `spacy` and `rapidfuzz` needed to be installed even if not using the relevant entity resolvers.
14+
- Fixed a bug where `VertexAILLM.(a)invoke_with_tools` called with multiple tools would raise an error.
1415

1516
### Changed
1617

examples/customize/llms/vertexai_tool_calls.py

Lines changed: 52 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
"""
55

66
import asyncio
7+
from typing import Optional
78

89
from dotenv import load_dotenv
910
from vertexai.generative_models import GenerationConfig
@@ -17,7 +18,7 @@
1718

1819

1920
# Create a custom Tool implementation for person info extraction
20-
parameters = ObjectParameter(
21+
person_tool_parameters = ObjectParameter(
2122
description="Parameters for extracting person information",
2223
properties={
2324
"name": StringParameter(description="The person's full name"),
@@ -29,20 +30,50 @@
2930
)
3031

3132

32-
def run_tool(name: str, age: int, occupation: str) -> str:
33+
def run_person_tool(
34+
name: str, age: Optional[int] = None, occupation: Optional[str] = None
35+
) -> str:
3336
"""A simple function that summarizes person information from input parameters."""
3437
return f"Found person {name} with age {age} and occupation {occupation}"
3538

3639

3740
person_info_tool = Tool(
3841
name="extract_person_info",
3942
description="Extract information about a person from text",
40-
parameters=parameters,
41-
execute_func=run_tool,
43+
parameters=person_tool_parameters,
44+
execute_func=run_person_tool,
45+
)
46+
47+
company_tool_parameters = ObjectParameter(
48+
description="Parameters for extracting company information",
49+
properties={
50+
"name": StringParameter(description="The company's full name"),
51+
"industry": StringParameter(description="The company's industry"),
52+
"creation_year": IntegerParameter(description="The company's creation year"),
53+
},
54+
required_properties=["name"],
55+
additional_properties=False,
56+
)
57+
58+
59+
def run_company_tool(
60+
name: str, industry: Optional[str] = None, creation_year: Optional[int] = None
61+
) -> str:
62+
"""A simple function that summarizes company information from input parameters."""
63+
return (
64+
f"Found company {name} operating in industry {industry} since {creation_year}"
65+
)
66+
67+
68+
company_info_tool = Tool(
69+
name="extract_company_info",
70+
description="Extract information about a company from text",
71+
parameters=company_tool_parameters,
72+
execute_func=run_company_tool,
4273
)
4374

4475
# Create the tool instance
45-
TOOLS = [person_info_tool]
76+
TOOLS = [person_info_tool, company_info_tool]
4677

4778

4879
def process_tool_call(response: ToolCallResponse) -> str:
@@ -54,32 +85,42 @@ def process_tool_call(response: ToolCallResponse) -> str:
5485
print(f"\nTool called: {tool_call.name}")
5586
print(f"Arguments: {tool_call.arguments}")
5687
print(f"Additional content: {response.content or 'None'}")
57-
return person_info_tool.execute(**tool_call.arguments) # type: ignore[no-any-return]
88+
if tool_call.name == "extract_person_info":
89+
return person_info_tool.execute(**tool_call.arguments) # type: ignore[no-any-return]
90+
elif tool_call.name == "extract_company_info":
91+
return str(company_info_tool.execute(**tool_call.arguments))
92+
else:
93+
raise ValueError("Unknown tool call")
5894

5995

6096
async def main() -> None:
6197
# Initialize the VertexAI LLM
6298
generation_config = GenerationConfig(temperature=0.0)
6399
llm = VertexAILLM(
64-
model_name="gemini-1.5-flash-001",
100+
model_name="gemini-2.0-flash-001",
65101
generation_config=generation_config,
102+
# tool_config=ToolConfig(
103+
# function_calling_config=ToolConfig.FunctionCallingConfig(
104+
# mode=ToolConfig.FunctionCallingConfig.Mode.ANY,
105+
# # allowed_function_names=["extract_person_info"],
106+
# ))
66107
)
67108

68-
# Example text containing information about a person
69-
text = "Stella Hane is a 35-year-old software engineer who loves coding."
109+
# Example text containing information about a company
110+
text1 = "Neo4j is a software company created in 2007"
70111

71112
print("\n=== Synchronous Tool Call ===")
72113
# Make a synchronous tool call
73114
sync_response = llm.invoke_with_tools(
74-
input=f"Extract information about the person from this text: {text}",
115+
input=f"Extract information about the person from this text: {text1}",
75116
tools=TOOLS,
76117
)
77118
sync_result = process_tool_call(sync_response)
78119
print("\n=== Synchronous Tool Call Result ===")
79120
print(sync_result)
80121

81122
print("\n=== Asynchronous Tool Call ===")
82-
# Make an asynchronous tool call with a different text
123+
# Make an asynchronous tool call with a different text about a person
83124
text2 = "Molly Hane, 32, works as a data scientist and enjoys machine learning."
84125
async_response = await llm.ainvoke_with_tools(
85126
input=f"Extract information about the person from this text: {text2}",

src/neo4j_graphrag/llm/vertexai_llm.py

Lines changed: 57 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@
4040
Part,
4141
ResponseValidationError,
4242
Tool as VertexAITool,
43+
ToolConfig,
4344
)
4445
except ImportError:
4546
GenerativeModel = None
@@ -137,20 +138,17 @@ def invoke(
137138
Returns:
138139
LLMResponse: The response from the LLM.
139140
"""
140-
system_message = [system_instruction] if system_instruction is not None else []
141-
self.model = GenerativeModel(
142-
model_name=self.model_name,
143-
system_instruction=system_message,
144-
**self.options,
141+
model = self._get_model(
142+
system_instruction=system_instruction,
145143
)
146144
try:
147145
if isinstance(message_history, MessageHistory):
148146
message_history = message_history.messages
149-
messages = self.get_messages(input, message_history)
150-
response = self.model.generate_content(messages, **self.model_params)
151-
return LLMResponse(content=response.text)
147+
options = self._get_call_params(input, message_history, tools=None)
148+
response = model.generate_content(**options)
149+
return self._parse_content_response(response)
152150
except ResponseValidationError as e:
153-
raise LLMGenerationError(e)
151+
raise LLMGenerationError("Error calling VertexAILLM") from e
154152

155153
async def ainvoke(
156154
self,
@@ -172,65 +170,81 @@ async def ainvoke(
172170
try:
173171
if isinstance(message_history, MessageHistory):
174172
message_history = message_history.messages
175-
system_message = (
176-
[system_instruction] if system_instruction is not None else []
177-
)
178-
self.model = GenerativeModel(
179-
model_name=self.model_name,
180-
system_instruction=system_message,
181-
**self.options,
173+
model = self._get_model(
174+
system_instruction=system_instruction,
182175
)
183-
messages = self.get_messages(input, message_history)
184-
response = await self.model.generate_content_async(
185-
messages, **self.model_params
186-
)
187-
return LLMResponse(content=response.text)
176+
options = self._get_call_params(input, message_history, tools=None)
177+
response = await model.generate_content_async(**options)
178+
return self._parse_content_response(response)
188179
except ResponseValidationError as e:
189-
raise LLMGenerationError(e)
190-
191-
def _to_vertexai_tool(self, tool: Tool) -> VertexAITool:
192-
return VertexAITool(
193-
function_declarations=[
194-
FunctionDeclaration(
195-
name=tool.get_name(),
196-
description=tool.get_description(),
197-
parameters=tool.get_parameters(exclude=["additional_properties"]),
198-
)
199-
]
180+
raise LLMGenerationError("Error calling VertexAILLM") from e
181+
182+
def _to_vertexai_function_declaration(self, tool: Tool) -> FunctionDeclaration:
183+
return FunctionDeclaration(
184+
name=tool.get_name(),
185+
description=tool.get_description(),
186+
parameters=tool.get_parameters(exclude=["additional_properties"]),
200187
)
201188

202189
def _get_llm_tools(
203190
self, tools: Optional[Sequence[Tool]]
204191
) -> Optional[list[VertexAITool]]:
205192
if not tools:
206193
return None
207-
return [self._to_vertexai_tool(tool) for tool in tools]
194+
return [
195+
VertexAITool(
196+
function_declarations=[
197+
self._to_vertexai_function_declaration(tool) for tool in tools
198+
]
199+
)
200+
]
208201

209202
def _get_model(
210203
self,
211204
system_instruction: Optional[str] = None,
212-
tools: Optional[Sequence[Tool]] = None,
213205
) -> GenerativeModel:
214206
system_message = [system_instruction] if system_instruction is not None else []
215-
vertex_ai_tools = self._get_llm_tools(tools)
216207
model = GenerativeModel(
217208
model_name=self.model_name,
218209
system_instruction=system_message,
219-
tools=vertex_ai_tools,
220-
**self.options,
221210
)
222211
return model
223212

213+
def _get_call_params(
214+
self,
215+
input: str,
216+
message_history: Optional[Union[List[LLMMessage], MessageHistory]],
217+
tools: Optional[Sequence[Tool]],
218+
) -> dict[str, Any]:
219+
options = dict(self.options)
220+
if tools:
221+
# we want a tool back, remove generation_config if defined
222+
options.pop("generation_config", None)
223+
options["tools"] = self._get_llm_tools(tools)
224+
if "tool_config" not in options:
225+
options["tool_config"] = ToolConfig(
226+
function_calling_config=ToolConfig.FunctionCallingConfig(
227+
mode=ToolConfig.FunctionCallingConfig.Mode.ANY,
228+
)
229+
)
230+
else:
231+
# no tools, remove tool_config if defined
232+
options.pop("tool_config", None)
233+
234+
messages = self.get_messages(input, message_history)
235+
options["contents"] = messages
236+
return options
237+
224238
async def _acall_llm(
225239
self,
226240
input: str,
227241
message_history: Optional[Union[List[LLMMessage], MessageHistory]] = None,
228242
system_instruction: Optional[str] = None,
229243
tools: Optional[Sequence[Tool]] = None,
230244
) -> GenerationResponse:
231-
model = self._get_model(system_instruction=system_instruction, tools=tools)
232-
messages = self.get_messages(input, message_history)
233-
response = await model.generate_content_async(messages, **self.model_params)
245+
model = self._get_model(system_instruction=system_instruction)
246+
options = self._get_call_params(input, message_history, tools)
247+
response = await model.generate_content_async(**options)
234248
return response
235249

236250
def _call_llm(
@@ -240,9 +254,9 @@ def _call_llm(
240254
system_instruction: Optional[str] = None,
241255
tools: Optional[Sequence[Tool]] = None,
242256
) -> GenerationResponse:
243-
model = self._get_model(system_instruction=system_instruction, tools=tools)
244-
messages = self.get_messages(input, message_history)
245-
response = model.generate_content(messages, **self.model_params)
257+
model = self._get_model(system_instruction=system_instruction)
258+
options = self._get_call_params(input, message_history, tools)
259+
response = model.generate_content(**options)
246260
return response
247261

248262
def _to_tool_call(self, function_call: FunctionCall) -> ToolCall:

0 commit comments

Comments
 (0)