diff --git a/libs/vertexai/langchain_google_vertexai/chat_models.py b/libs/vertexai/langchain_google_vertexai/chat_models.py index 177775172..f4260beec 100644 --- a/libs/vertexai/langchain_google_vertexai/chat_models.py +++ b/libs/vertexai/langchain_google_vertexai/chat_models.py @@ -399,19 +399,13 @@ def _convert_to_parts(message: BaseMessage) -> List[Part]: vertex_messages.append(Content(role=role, parts=parts)) elif isinstance(message, ToolMessage): - role = "function" - # message.name can be null for ToolMessage name = message.name if name is None: if prev_ai_message: tool_call_id = message.tool_call_id tool_call: ToolCall | None = next( - ( - t - for t in prev_ai_message.tool_calls - if t["id"] == tool_call_id - ), + (t for t in prev_ai_message.tool_calls if t["id"] == tool_call_id), None, ) @@ -424,57 +418,79 @@ def _convert_to_parts(message: BaseMessage) -> List[Part]: ) name = tool_call["name"] - def _parse_content(raw_content: str | Dict[Any, Any]) -> Dict[Any, Any]: - if isinstance(raw_content, dict): - return raw_content - if isinstance(raw_content, str): - try: - content = json.loads(raw_content) - # json.loads("2") returns 2 since it's a valid json - if isinstance(content, dict): - return content - except json.JSONDecodeError: - pass - return {"content": raw_content} + if perform_literal_eval_on_string_raw_content and isinstance(message.content, str): + message.content = ast.literal_eval(message.content) + message_contains_image = False if isinstance(message.content, list): - parsed_content = [_parse_content(c) for c in message.content] - if len(parsed_content) > 1: - merged_content: Dict[Any, Any] = {} - for content_piece in parsed_content: - for key, value in content_piece.items(): - if key not in merged_content: - merged_content[key] = [] - merged_content[key].append(value) - logger.warning( - "Expected content to be a str, got a list with > 1 element." - "Merging values together" - ) - content = {k: "".join(v) for k, v in merged_content.items()} + message_contains_image = any( + isinstance(c, dict) and c.get("type") == "image_url" for c in message.content + ) + + if not message_contains_image: + # Old code, I don't touch + role = "function" + + def _parse_content(raw_content: str | Dict[Any, Any]) -> Dict[Any, Any]: + if isinstance(raw_content, dict): + return raw_content + if isinstance(raw_content, str): + try: + content = json.loads(raw_content) + # json.loads("2") returns 2 since it's a valid json + if isinstance(content, dict): + return content + except json.JSONDecodeError: + pass + return {"content": raw_content} + + if isinstance(message.content, list): + parsed_content = [_parse_content(c) for c in message.content] + if len(parsed_content) > 1: + merged_content: Dict[Any, Any] = {} + for content_piece in parsed_content: + for key, value in content_piece.items(): + if key not in merged_content: + merged_content[key] = [] + merged_content[key].append(value) + logger.warning( + "Expected content to be a str, got a list with > 1 element.Merging values together" + ) + content = {k: "".join(v) for k, v in merged_content.items()} + else: + content = parsed_content[0] else: - content = parsed_content[0] - else: - content = _parse_content(message.content) + content = _parse_content(message.content) - part = Part( - function_response=FunctionResponse( - name=name, - response=content, + part = Part( + function_response=FunctionResponse( + name=name, + response=content, + ) ) - ) - parts = [part] + parts = [part] - prev_content = vertex_messages[-1] - prev_content_is_function = prev_content and prev_content.role == "function" + prev_content = vertex_messages[-1] + prev_content_is_function = prev_content and prev_content.role == "function" - if prev_content_is_function: - prev_parts = list(prev_content.parts) - prev_parts.extend(parts) - # replacing last message - vertex_messages[-1] = Content(role=role, parts=prev_parts) - continue + if prev_content_is_function: + prev_parts = list(prev_content.parts) + prev_parts.extend(parts) + # replacing last message + vertex_messages[-1] = Content(role=role, parts=prev_parts) + continue - vertex_messages.append(Content(role=role, parts=parts)) + vertex_messages.append(Content(role=role, parts=parts)) + else: # Image branch + role = "user" + message.content = ( + [{"type": "text", "text": f""}] + + message.content + + [{"type": "text", "text": f""}] + ) + + parts = _convert_to_parts(message) + vertex_messages.append(Content(role=role, parts=parts)) else: raise ValueError( f"Unexpected message with type {type(message)} at the position {i}."