diff --git a/libs/vertexai/langchain_google_vertexai/chat_models.py b/libs/vertexai/langchain_google_vertexai/chat_models.py index b391f8457..2b93ccc1d 100644 --- a/libs/vertexai/langchain_google_vertexai/chat_models.py +++ b/libs/vertexai/langchain_google_vertexai/chat_models.py @@ -119,6 +119,7 @@ Tool as GapicTool, ToolConfig as GapicToolConfig, VideoMetadata, + Schema, ) from langchain_google_vertexai._base import _VertexAICommon from langchain_google_vertexai._image_utils import ( @@ -1570,9 +1571,11 @@ def _prepare_params( f"`response_mime_type` is set to one of {allowed_mime_types}" ) raise ValueError(error_message) - - gapic_response_schema = _convert_schema_dict_to_gapic(response_schema) - params["response_schema"] = gapic_response_schema + if isinstance(response_schema, Schema): + params["response_schema"] = response_schema + else: + gapic_response_schema = _convert_schema_dict_to_gapic(response_schema) + params["response_schema"] = gapic_response_schema audio_timestamp = kwargs.get("audio_timestamp", self.audio_timestamp) if audio_timestamp is not None: @@ -2088,7 +2091,7 @@ async def _astream( def with_structured_output( self, - schema: Union[Dict, Type[BaseModel], Type], + schema: Union[Dict, Type[BaseModel], Type, Schema], *, include_raw: bool = False, method: Optional[Literal["json_mode"]] = None, @@ -2223,36 +2226,47 @@ class Explanation(BaseModel): parser: OutputParserLike if method == "json_mode": - if isinstance(schema, type) and is_basemodel_subclass(schema): - if issubclass(schema, BaseModelV1): - schema_json = schema.schema() - else: - schema_json = schema.model_json_schema() - parser = PydanticOutputParser(pydantic_object=schema) + if isinstance(schema, Schema): + llm = self.bind( + response_mime_type="application/json", + response_schema=schema, + ls_structured_output_format={ + "kwargs": {"method": method}, + "schema": schema, + }, + ) + parser = JsonOutputParser() else: - if is_typeddict(schema): - schema_json = convert_to_json_schema(schema) - elif isinstance(schema, dict): - schema_json = schema + if isinstance(schema, type) and is_basemodel_subclass(schema): + if issubclass(schema, BaseModelV1): + schema_json = schema.schema() + else: + schema_json = schema.model_json_schema() + parser = PydanticOutputParser(pydantic_object=schema) else: - raise ValueError(f"Unsupported schema type {type(schema)}") - parser = JsonOutputParser() + if is_typeddict(schema): + schema_json = convert_to_json_schema(schema) + elif isinstance(schema, dict): + schema_json = schema + else: + raise ValueError(f"Unsupported schema type {type(schema)}") + parser = JsonOutputParser() - # Resolve refs in schema because they are not supported - # by the Gemini API. - schema_json = replace_defs_in_schema(schema_json) + # Resolve refs in schema because they are not supported + # by the Gemini API. + schema_json = replace_defs_in_schema(schema_json) - # API does not support anyOf. - schema_json = _strip_nullable_anyof(schema_json) + # API does not support anyOf. + schema_json = _strip_nullable_anyof(schema_json) - llm = self.bind( - response_mime_type="application/json", - response_schema=schema_json, - ls_structured_output_format={ - "kwargs": {"method": method}, - "schema": schema_json, - }, - ) + llm = self.bind( + response_mime_type="application/json", + response_schema=schema_json, + ls_structured_output_format={ + "kwargs": {"method": method}, + "schema": schema_json, + }, + ) else: tool_name = _get_tool_name(schema) if isinstance(schema, type) and is_basemodel_subclass(schema): diff --git a/libs/vertexai/langchain_google_vertexai/llms.py b/libs/vertexai/langchain_google_vertexai/llms.py index 3555a70b3..61f9cabaf 100644 --- a/libs/vertexai/langchain_google_vertexai/llms.py +++ b/libs/vertexai/langchain_google_vertexai/llms.py @@ -2,8 +2,9 @@ import logging from difflib import get_close_matches -from typing import Any, AsyncIterator, Dict, Iterator, List, Optional +from typing import Any, AsyncIterator, Dict, Iterator, List, Optional, Union +from google.cloud.aiplatform_v1beta1.types import Schema from langchain_core.callbacks.manager import ( AsyncCallbackManagerForLLMRun, CallbackManagerForLLMRun, @@ -38,7 +39,7 @@ class VertexAI(_VertexAICommon, BaseLLM): The model also needs to be prompted to output the appropriate response type, otherwise the behavior is undefined. This is a preview feature. """ - response_schema: Optional[Dict[str, Any]] = None + response_schema: Optional[Union[Dict[str, Any], Schema]] = None """ Optional. Enforce an schema to the output. The format of the dictionary should follow Open API schema. """ diff --git a/libs/vertexai/tests/integration_tests/test_llms.py b/libs/vertexai/tests/integration_tests/test_llms.py index eb7ea1883..654e1a284 100644 --- a/libs/vertexai/tests/integration_tests/test_llms.py +++ b/libs/vertexai/tests/integration_tests/test_llms.py @@ -7,6 +7,7 @@ import json import pytest +from google.cloud.aiplatform_v1beta1.types import Schema, Type from langchain_core.outputs import LLMResult from langchain_core.rate_limiters import InMemoryRateLimiter @@ -119,3 +120,41 @@ def test_structured_output_schema_json(): assert isinstance(parsed_response, list) assert len(parsed_response) > 0 assert "recipe_name" in parsed_response[0] + + +@pytest.mark.extended +def test_structured_output_schema_json_with_openapi_schema_object(): + model = VertexAI( + rate_limiter=rate_limiter, + model_name="gemini-2.0-flash-001", + response_mime_type="application/json", + response_schema=Schema( + type_=Type.ARRAY, + items=Schema( + type_=Type.OBJECT, + properties={ + "recipe_name": Schema(type_=Type.STRING), + "level": Schema(type_=Type.ENUM, values=["easy", "medium", "hard"]), + }, + required=["recipe_name", "level"], + property_ordering=["level", "recipe_name"], + ), + min_items=3, + max_items=4, + ), + ) + + response = model.invoke("List a few popular cookie recipes") + + assert isinstance(response, str) + parsed_response = json.loads(response) + assert isinstance(parsed_response, list) + assert len(parsed_response) >= 3 and len(parsed_response) <= 4 + for recipe in parsed_response: + assert isinstance(recipe, dict) + assert "recipe_name" in recipe + assert "level" in recipe + assert isinstance(recipe["recipe_name"], str) + assert recipe["level"] in ["easy", "medium", "hard"] + keys = list(recipe.keys()) + assert keys.index("level") < keys.index("recipe_name")