Skip to content

Commit 8b8e18e

Browse files
committed
chore: linting and updating typing as a result of upgrade
1 parent 3351d8e commit 8b8e18e

File tree

2 files changed

+81
-52
lines changed

2 files changed

+81
-52
lines changed

libs/vertexai/langchain_google_vertexai/_utils.py

Lines changed: 47 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -14,11 +14,11 @@
1414
AsyncCallbackManagerForLLMRun,
1515
CallbackManagerForLLMRun,
1616
)
17-
from vertexai.generative_models import ( # type: ignore[import-untyped]
17+
from vertexai.generative_models import (
1818
Candidate,
1919
Image,
2020
)
21-
from vertexai.language_models import ( # type: ignore[import-untyped]
21+
from vertexai.language_models import (
2222
TextGenerationResponse,
2323
)
2424

@@ -167,33 +167,51 @@ def get_generation_info(
167167
logprobs: Union[bool, int] = False,
168168
) -> Dict[str, Any]:
169169
# https://cloud.google.com/vertex-ai/docs/generative-ai/model-reference/gemini#response_body
170-
info = {
171-
"is_blocked": any(rating.blocked for rating in candidate.safety_ratings),
172-
"safety_ratings": [
173-
{
174-
"category": rating.category.name,
175-
"probability_label": rating.probability.name,
176-
"probability_score": rating.probability_score,
177-
"blocked": rating.blocked,
178-
"severity": rating.severity.name,
179-
"severity_score": rating.severity_score,
180-
}
181-
# Image generation models sometime return ratings that are not
182-
# included in the proto.
183-
for rating in candidate.safety_ratings
184-
if hasattr(rating.category, "name")
185-
],
186-
"citation_metadata": (
187-
proto.Message.to_dict(candidate.citation_metadata)
188-
if candidate.citation_metadata
189-
else None
190-
),
191-
"usage_metadata": usage_metadata,
192-
"finish_reason": _get_finish_reason_string(candidate.finish_reason),
193-
"finish_message": (
194-
candidate.finish_message if candidate.finish_message else None
195-
),
196-
}
170+
# Handle different response types - Candidate has different attributes than
171+
# TextGenerationResponse. Check for attributes rather than isinstance to
172+
# support mocks in tests.
173+
if hasattr(candidate, "safety_ratings") and hasattr(candidate, "citation_metadata"):
174+
info = {
175+
"is_blocked": any(rating.blocked for rating in candidate.safety_ratings),
176+
"safety_ratings": [
177+
{
178+
"category": rating.category.name,
179+
"probability_label": rating.probability.name,
180+
"probability_score": rating.probability_score,
181+
"blocked": rating.blocked,
182+
"severity": rating.severity.name,
183+
"severity_score": rating.severity_score,
184+
}
185+
# Image generation models sometime return ratings that are not
186+
# included in the proto.
187+
for rating in candidate.safety_ratings
188+
if hasattr(rating.category, "name")
189+
],
190+
"citation_metadata": (
191+
proto.Message.to_dict(candidate.citation_metadata)
192+
if candidate.citation_metadata
193+
else None
194+
),
195+
"usage_metadata": usage_metadata,
196+
"finish_reason": _get_finish_reason_string(candidate.finish_reason)
197+
if hasattr(candidate, "finish_reason")
198+
else None,
199+
"finish_message": (
200+
candidate.finish_message
201+
if hasattr(candidate, "finish_message") and candidate.finish_message
202+
else None
203+
),
204+
}
205+
else: # TextGenerationResponse
206+
# TextGenerationResponse doesn't have the same attributes as Candidate
207+
info = {
208+
"is_blocked": False, # TextGenerationResponse doesn't have safety_ratings
209+
"safety_ratings": [],
210+
"citation_metadata": None,
211+
"usage_metadata": usage_metadata,
212+
"finish_reason": None, # TextGenerationResponse doesn't have finish_reason
213+
"finish_message": None,
214+
}
197215
if hasattr(candidate, "avg_logprobs") and candidate.avg_logprobs is not None:
198216
if (
199217
isinstance(candidate.avg_logprobs, float)

libs/vertexai/langchain_google_vertexai/chat_models.py

Lines changed: 34 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
Dict,
1717
List,
1818
Optional,
19+
Sequence,
1920
Type,
2021
Union,
2122
cast,
@@ -24,7 +25,7 @@
2425
TypedDict,
2526
overload,
2627
)
27-
from collections.abc import AsyncIterator, Iterator, Sequence
28+
from collections.abc import AsyncIterator, Iterator
2829

2930
import proto # type: ignore[import-untyped]
3031

@@ -74,17 +75,18 @@
7475
)
7576
from langchain_core.utils.pydantic import is_basemodel_subclass
7677
from langchain_core.utils.utils import _build_model_kwargs
77-
from vertexai.generative_models import ( # type: ignore
78+
from vertexai.generative_models import (
7879
Tool as VertexTool,
80+
Candidate as VertexCandidate,
7981
)
80-
from vertexai.generative_models._generative_models import ( # type: ignore
82+
from vertexai.generative_models._generative_models import (
8183
ToolConfig,
8284
SafetySettingsType,
8385
GenerationConfigType,
8486
GenerationResponse,
8587
_convert_schema_dict_to_gapic,
8688
)
87-
from vertexai.language_models import ( # type: ignore
89+
from vertexai.language_models import (
8890
ChatMessage,
8991
InputOutputTextPair,
9092
)
@@ -227,10 +229,10 @@ def _parse_chat_history(history: List[BaseMessage]) -> _ChatHistory:
227229
if i == 0 and isinstance(message, SystemMessage):
228230
context = content
229231
elif isinstance(message, AIMessage):
230-
vertex_message = ChatMessage(content=message.content, author="bot")
232+
vertex_message = ChatMessage(content=content, author="bot")
231233
vertex_messages.append(vertex_message)
232234
elif isinstance(message, HumanMessage):
233-
vertex_message = ChatMessage(content=message.content, author="user")
235+
vertex_message = ChatMessage(content=content, author="user")
234236
vertex_messages.append(vertex_message)
235237
else:
236238
msg = f"Unexpected message with type {type(message)} at the position {i}."
@@ -559,16 +561,18 @@ def _parse_examples(examples: List[BaseMessage]) -> List[InputOutputTextPair]:
559561
f"{type(example)} for the {i}th message."
560562
)
561563
raise ValueError(msg)
562-
input_text = example.content
564+
input_text = cast("str", example.content)
563565
if i % 2 == 1:
564566
if not isinstance(example, AIMessage):
565567
msg = (
566568
f"Expected the second message in a part to be from AI, got "
567569
f"{type(example)} for the {i}th message."
568570
)
569571
raise ValueError(msg)
572+
# input_text is guaranteed to be set in the previous iteration
573+
assert input_text is not None
570574
pair = InputOutputTextPair(
571-
input_text=input_text, output_text=example.content
575+
input_text=input_text, output_text=cast("str", example.content)
572576
)
573577
example_pairs.append(pair)
574578
return example_pairs
@@ -608,18 +612,19 @@ def _append_to_content(
608612

609613
@overload
610614
def _parse_response_candidate(
611-
response_candidate: Candidate, streaming: Literal[False] = False
615+
response_candidate: Union[Candidate, VertexCandidate],
616+
streaming: Literal[False] = False,
612617
) -> AIMessage: ...
613618

614619

615620
@overload
616621
def _parse_response_candidate(
617-
response_candidate: Candidate, streaming: Literal[True]
622+
response_candidate: Union[Candidate, VertexCandidate], streaming: Literal[True]
618623
) -> AIMessageChunk: ...
619624

620625

621626
def _parse_response_candidate(
622-
response_candidate: Candidate, streaming: bool = False
627+
response_candidate: Union[Candidate, VertexCandidate], streaming: bool = False
623628
) -> AIMessage:
624629
content: Union[None, str, List[Union[str, dict[str, Any]]]] = None
625630
additional_kwargs = {}
@@ -635,7 +640,7 @@ def _parse_response_candidate(
635640
except AttributeError:
636641
pass
637642

638-
if part.thought:
643+
if hasattr(part, "thought") and part.thought:
639644
thinking_message = {
640645
"type": "thinking",
641646
"thinking": part.text,
@@ -694,7 +699,9 @@ def _parse_response_candidate(
694699

695700
if getattr(part, "thought_signature", None):
696701
# store dict of {tool_call_id: thought_signature}
697-
if isinstance(part.thought_signature, bytes):
702+
if hasattr(part, "thought_signature") and isinstance(
703+
part.thought_signature, bytes
704+
):
698705
thought_signature = _bytes_to_base64(part.thought_signature)
699706
if (
700707
_FUNCTION_CALL_THOUGHT_SIGNATURES_MAP_KEY
@@ -1990,7 +1997,7 @@ def _safety_settings_gemini(
19901997
return self._safety_settings_gemini(self.safety_settings)
19911998
return None
19921999
if isinstance(safety_settings, list):
1993-
return safety_settings
2000+
return cast("Sequence[SafetySetting]", safety_settings)
19942001
if isinstance(safety_settings, dict):
19952002
formatted_safety_settings = []
19962003
for category, threshold in safety_settings.items():
@@ -2006,8 +2013,8 @@ def _safety_settings_gemini(
20062013
)
20072014
)
20082015
return formatted_safety_settings
2009-
msg = "safety_settings should be either"
2010-
raise ValueError(msg)
2016+
# This should be unreachable as all cases are handled above
2017+
raise ValueError("Unexpected safety_settings type")
20112018

20122019
def _prepare_request_gemini(
20132020
self,
@@ -2041,7 +2048,7 @@ def _prepare_request_gemini(
20412048
tool_config = _tool_choice_to_tool_config(tool_choice, all_names)
20422049
else:
20432050
pass
2044-
safety_settings = self._safety_settings_gemini(safety_settings)
2051+
formatted_safety_settings = self._safety_settings_gemini(safety_settings)
20452052
logprobs = logprobs if logprobs is not None else self.logprobs
20462053
logprobs = logprobs if isinstance(logprobs, (int, bool)) else False
20472054
generation_config = self._generation_config_gemini(
@@ -2100,18 +2107,22 @@ def _content_to_v1(contents: list[Content]) -> list[v1Content]:
21002107
v1_tools = [v1Tool(**proto.Message.to_dict(t)) for t in formatted_tools]
21012108

21022109
if tool_config:
2103-
v1_tool_config = v1ToolConfig(
2104-
function_calling_config=v1FunctionCallingConfig(
2105-
**proto.Message.to_dict(tool_config.function_calling_config)
2110+
v1_tool_config = (
2111+
v1ToolConfig(
2112+
function_calling_config=v1FunctionCallingConfig(
2113+
**proto.Message.to_dict(tool_config.function_calling_config)
2114+
)
21062115
)
2116+
if hasattr(tool_config, "function_calling_config")
2117+
else v1ToolConfig()
21072118
)
21082119

2109-
if safety_settings:
2120+
if formatted_safety_settings:
21102121
v1_safety_settings = [
21112122
v1SafetySetting(
21122123
category=s.category, method=s.method, threshold=s.threshold
21132124
)
2114-
for s in safety_settings
2125+
for s in formatted_safety_settings
21152126
]
21162127

21172128
if (self.cached_content is not None) or (cached_content is not None):
@@ -2267,7 +2278,7 @@ def _tool_config_gemini(
22672278
self, tool_config: Optional[Union[_ToolConfigDict, ToolConfig]] = None
22682279
) -> Optional[GapicToolConfig]:
22692280
if tool_config and not isinstance(tool_config, ToolConfig):
2270-
return _format_tool_config(cast("_ToolConfigDict", tool_config))
2281+
return _format_tool_config(tool_config)
22712282
return None
22722283

22732284
async def _agenerate(

0 commit comments

Comments
 (0)