16
16
Dict ,
17
17
List ,
18
18
Optional ,
19
+ Sequence ,
19
20
Type ,
20
21
Union ,
21
22
cast ,
24
25
TypedDict ,
25
26
overload ,
26
27
)
27
- from collections .abc import AsyncIterator , Iterator , Sequence
28
+ from collections .abc import AsyncIterator , Iterator
28
29
29
30
import proto # type: ignore[import-untyped]
30
31
74
75
)
75
76
from langchain_core .utils .pydantic import is_basemodel_subclass
76
77
from langchain_core .utils .utils import _build_model_kwargs
77
- from vertexai .generative_models import ( # type: ignore
78
+ from vertexai .generative_models import (
78
79
Tool as VertexTool ,
80
+ Candidate as VertexCandidate ,
79
81
)
80
- from vertexai .generative_models ._generative_models import ( # type: ignore
82
+ from vertexai .generative_models ._generative_models import (
81
83
ToolConfig ,
82
84
SafetySettingsType ,
83
85
GenerationConfigType ,
84
86
GenerationResponse ,
85
87
_convert_schema_dict_to_gapic ,
86
88
)
87
- from vertexai .language_models import ( # type: ignore
89
+ from vertexai .language_models import (
88
90
ChatMessage ,
89
91
InputOutputTextPair ,
90
92
)
@@ -227,10 +229,10 @@ def _parse_chat_history(history: List[BaseMessage]) -> _ChatHistory:
227
229
if i == 0 and isinstance (message , SystemMessage ):
228
230
context = content
229
231
elif isinstance (message , AIMessage ):
230
- vertex_message = ChatMessage (content = message . content , author = "bot" )
232
+ vertex_message = ChatMessage (content = content , author = "bot" )
231
233
vertex_messages .append (vertex_message )
232
234
elif isinstance (message , HumanMessage ):
233
- vertex_message = ChatMessage (content = message . content , author = "user" )
235
+ vertex_message = ChatMessage (content = content , author = "user" )
234
236
vertex_messages .append (vertex_message )
235
237
else :
236
238
msg = f"Unexpected message with type { type (message )} at the position { i } ."
@@ -559,16 +561,18 @@ def _parse_examples(examples: List[BaseMessage]) -> List[InputOutputTextPair]:
559
561
f"{ type (example )} for the { i } th message."
560
562
)
561
563
raise ValueError (msg )
562
- input_text = example .content
564
+ input_text = cast ( "str" , example .content )
563
565
if i % 2 == 1 :
564
566
if not isinstance (example , AIMessage ):
565
567
msg = (
566
568
f"Expected the second message in a part to be from AI, got "
567
569
f"{ type (example )} for the { i } th message."
568
570
)
569
571
raise ValueError (msg )
572
+ # input_text is guaranteed to be set in the previous iteration
573
+ assert input_text is not None
570
574
pair = InputOutputTextPair (
571
- input_text = input_text , output_text = example .content
575
+ input_text = input_text , output_text = cast ( "str" , example .content )
572
576
)
573
577
example_pairs .append (pair )
574
578
return example_pairs
@@ -608,18 +612,19 @@ def _append_to_content(
608
612
609
613
@overload
610
614
def _parse_response_candidate (
611
- response_candidate : Candidate , streaming : Literal [False ] = False
615
+ response_candidate : Union [Candidate , VertexCandidate ],
616
+ streaming : Literal [False ] = False ,
612
617
) -> AIMessage : ...
613
618
614
619
615
620
@overload
616
621
def _parse_response_candidate (
617
- response_candidate : Candidate , streaming : Literal [True ]
622
+ response_candidate : Union [ Candidate , VertexCandidate ] , streaming : Literal [True ]
618
623
) -> AIMessageChunk : ...
619
624
620
625
621
626
def _parse_response_candidate (
622
- response_candidate : Candidate , streaming : bool = False
627
+ response_candidate : Union [ Candidate , VertexCandidate ] , streaming : bool = False
623
628
) -> AIMessage :
624
629
content : Union [None , str , List [Union [str , dict [str , Any ]]]] = None
625
630
additional_kwargs = {}
@@ -635,7 +640,7 @@ def _parse_response_candidate(
635
640
except AttributeError :
636
641
pass
637
642
638
- if part .thought :
643
+ if hasattr ( part , "thought" ) and part .thought :
639
644
thinking_message = {
640
645
"type" : "thinking" ,
641
646
"thinking" : part .text ,
@@ -694,7 +699,9 @@ def _parse_response_candidate(
694
699
695
700
if getattr (part , "thought_signature" , None ):
696
701
# store dict of {tool_call_id: thought_signature}
697
- if isinstance (part .thought_signature , bytes ):
702
+ if hasattr (part , "thought_signature" ) and isinstance (
703
+ part .thought_signature , bytes
704
+ ):
698
705
thought_signature = _bytes_to_base64 (part .thought_signature )
699
706
if (
700
707
_FUNCTION_CALL_THOUGHT_SIGNATURES_MAP_KEY
@@ -1990,7 +1997,7 @@ def _safety_settings_gemini(
1990
1997
return self ._safety_settings_gemini (self .safety_settings )
1991
1998
return None
1992
1999
if isinstance (safety_settings , list ):
1993
- return safety_settings
2000
+ return cast ( "Sequence[SafetySetting]" , safety_settings )
1994
2001
if isinstance (safety_settings , dict ):
1995
2002
formatted_safety_settings = []
1996
2003
for category , threshold in safety_settings .items ():
@@ -2006,8 +2013,8 @@ def _safety_settings_gemini(
2006
2013
)
2007
2014
)
2008
2015
return formatted_safety_settings
2009
- msg = "safety_settings should be either"
2010
- raise ValueError (msg )
2016
+ # This should be unreachable as all cases are handled above
2017
+ raise ValueError ("Unexpected safety_settings type" )
2011
2018
2012
2019
def _prepare_request_gemini (
2013
2020
self ,
@@ -2041,7 +2048,7 @@ def _prepare_request_gemini(
2041
2048
tool_config = _tool_choice_to_tool_config (tool_choice , all_names )
2042
2049
else :
2043
2050
pass
2044
- safety_settings = self ._safety_settings_gemini (safety_settings )
2051
+ formatted_safety_settings = self ._safety_settings_gemini (safety_settings )
2045
2052
logprobs = logprobs if logprobs is not None else self .logprobs
2046
2053
logprobs = logprobs if isinstance (logprobs , (int , bool )) else False
2047
2054
generation_config = self ._generation_config_gemini (
@@ -2100,18 +2107,22 @@ def _content_to_v1(contents: list[Content]) -> list[v1Content]:
2100
2107
v1_tools = [v1Tool (** proto .Message .to_dict (t )) for t in formatted_tools ]
2101
2108
2102
2109
if tool_config :
2103
- v1_tool_config = v1ToolConfig (
2104
- function_calling_config = v1FunctionCallingConfig (
2105
- ** proto .Message .to_dict (tool_config .function_calling_config )
2110
+ v1_tool_config = (
2111
+ v1ToolConfig (
2112
+ function_calling_config = v1FunctionCallingConfig (
2113
+ ** proto .Message .to_dict (tool_config .function_calling_config )
2114
+ )
2106
2115
)
2116
+ if hasattr (tool_config , "function_calling_config" )
2117
+ else v1ToolConfig ()
2107
2118
)
2108
2119
2109
- if safety_settings :
2120
+ if formatted_safety_settings :
2110
2121
v1_safety_settings = [
2111
2122
v1SafetySetting (
2112
2123
category = s .category , method = s .method , threshold = s .threshold
2113
2124
)
2114
- for s in safety_settings
2125
+ for s in formatted_safety_settings
2115
2126
]
2116
2127
2117
2128
if (self .cached_content is not None ) or (cached_content is not None ):
@@ -2267,7 +2278,7 @@ def _tool_config_gemini(
2267
2278
self , tool_config : Optional [Union [_ToolConfigDict , ToolConfig ]] = None
2268
2279
) -> Optional [GapicToolConfig ]:
2269
2280
if tool_config and not isinstance (tool_config , ToolConfig ):
2270
- return _format_tool_config (cast ( "_ToolConfigDict" , tool_config ) )
2281
+ return _format_tool_config (tool_config )
2271
2282
return None
2272
2283
2273
2284
async def _agenerate (
0 commit comments