Skip to content

Commit 5b5115c

Browse files
authored
1 parent a989f82 commit 5b5115c

File tree

3 files changed

+44
-11
lines changed

3 files changed

+44
-11
lines changed

libs/partners/google-vertexai/langchain_google_vertexai/_utils.py

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -97,22 +97,31 @@ def is_gemini_model(model_name: str) -> bool:
9797

9898

9999
def get_generation_info(
100-
candidate: Union[TextGenerationResponse, Candidate], is_gemini: bool
100+
candidate: Union[TextGenerationResponse, Candidate],
101+
is_gemini: bool,
102+
*,
103+
stream: bool = False,
101104
) -> Dict[str, Any]:
102105
if is_gemini:
103106
# https://cloud.google.com/vertex-ai/docs/generative-ai/model-reference/gemini#response_body
104-
return {
107+
info = {
105108
"is_blocked": any([rating.blocked for rating in candidate.safety_ratings]),
106109
"safety_ratings": [
107110
{
108111
"category": rating.category.name,
109112
"probability_label": rating.probability.name,
113+
"blocked": rating.blocked,
110114
}
111115
for rating in candidate.safety_ratings
112116
],
113117
"citation_metadata": candidate.citation_metadata,
114118
}
115119
# https://cloud.google.com/vertex-ai/docs/generative-ai/model-reference/text-chat#response_body
116-
candidate_dc = dataclasses.asdict(candidate)
117-
candidate_dc.pop("text")
118-
return {k: v for k, v in candidate_dc.items() if not k.startswith("_")}
120+
else:
121+
info = dataclasses.asdict(candidate)
122+
info.pop("text")
123+
info = {k: v for k, v in info.items() if not k.startswith("_")}
124+
if stream:
125+
# Remove non-streamable types, like bools.
126+
info.pop("is_blocked")
127+
return info

libs/partners/google-vertexai/langchain_google_vertexai/llms.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -315,10 +315,12 @@ def get_num_tokens(self, text: str) -> int:
315315
return result.total_tokens
316316

317317
def _response_to_generation(
318-
self, response: TextGenerationResponse
318+
self, response: TextGenerationResponse, *, stream: bool = False
319319
) -> GenerationChunk:
320320
"""Converts a stream response to a generation chunk."""
321-
generation_info = get_generation_info(response, self._is_gemini_model)
321+
generation_info = get_generation_info(
322+
response, self._is_gemini_model, stream=stream
323+
)
322324
try:
323325
text = response.text
324326
except AttributeError:
@@ -401,7 +403,14 @@ def _stream(
401403
run_manager=run_manager,
402404
**params,
403405
):
404-
chunk = self._response_to_generation(stream_resp)
406+
# Gemini models return GenerationResponse even when streaming, which has a
407+
# candidates field.
408+
stream_resp = (
409+
stream_resp
410+
if isinstance(stream_resp, TextGenerationResponse)
411+
else stream_resp.candidates[0]
412+
)
413+
chunk = self._response_to_generation(stream_resp, stream=True)
405414
yield chunk
406415
if run_manager:
407416
run_manager.on_llm_new_token(

libs/partners/google-vertexai/tests/integration_tests/test_llms.py

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -32,18 +32,33 @@ def test_vertex_initialization(model_name: str) -> None:
3232
"model_name",
3333
model_names_to_test_with_default,
3434
)
35-
def test_vertex_call(model_name: str) -> None:
35+
def test_vertex_invoke(model_name: str) -> None:
3636
llm = (
3737
VertexAI(model_name=model_name, temperature=0)
3838
if model_name
3939
else VertexAI(temperature=0.0)
4040
)
41-
output = llm("Say foo:")
41+
output = llm.invoke("Say foo:")
4242
assert isinstance(output, str)
4343

4444

45+
@pytest.mark.parametrize(
46+
"model_name",
47+
model_names_to_test_with_default,
48+
)
49+
def test_vertex_generate(model_name: str) -> None:
50+
llm = (
51+
VertexAI(model_name=model_name, temperature=0)
52+
if model_name
53+
else VertexAI(temperature=0.0)
54+
)
55+
output = llm.generate(["Say foo:"])
56+
assert isinstance(output, LLMResult)
57+
assert len(output.generations) == 1
58+
59+
4560
@pytest.mark.xfail(reason="VertexAI doesn't always respect number of candidates")
46-
def test_vertex_generate() -> None:
61+
def test_vertex_generate_multiple_candidates() -> None:
4762
llm = VertexAI(temperature=0.3, n=2, model_name="text-bison@001")
4863
output = llm.generate(["Say foo:"])
4964
assert isinstance(output, LLMResult)

0 commit comments

Comments
 (0)