From a9f36464d3e33587a72d4a324f809e07ac6a9441 Mon Sep 17 00:00:00 2001 From: cangozpi Date: Sat, 16 Aug 2025 17:27:24 +0300 Subject: [PATCH] GenAIModel implemented. --- pyproject.toml | 6 +- src/smolagents/models.py | 927 ++++++++++++++++++++++++++++++++++++++- 2 files changed, 931 insertions(+), 2 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 48f3cd0e0..02b3df5bf 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -47,6 +47,10 @@ gradio = [ litellm = [ "litellm>=1.60.2", ] +genai = [ + "google-genai>=1.29.0", + "hashlib", +] mcp = [ "mcpadapt>=0.1.11", # Support Image and Audio content "mcp", @@ -81,7 +85,7 @@ vllm = [ "torch" ] all = [ - "smolagents[audio,docker,e2b,gradio,litellm,mcp,mlx-lm,openai,telemetry,toolkit,transformers,vision,bedrock]", + "smolagents[audio,docker,e2b,gradio,litellm,genai,mcp,mlx-lm,openai,telemetry,toolkit,transformers,vision,bedrock]", ] quality = [ "ruff>=0.9.0", diff --git a/src/smolagents/models.py b/src/smolagents/models.py index 4e2ed2454..e70070e4d 100644 --- a/src/smolagents/models.py +++ b/src/smolagents/models.py @@ -22,7 +22,7 @@ from dataclasses import asdict, dataclass from enum import Enum from threading import Thread -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING, Any, Optional from .monitoring import TokenUsage from .tools import Tool @@ -1907,6 +1907,930 @@ def generate( AmazonBedrockModel = AmazonBedrockServerModel + +class GenAIModel(ApiModel): + """ + Model to use [GenAI Python SDK](https://googleapis.github.io/python-genai/genai.html#genai.caches.Caches.create) to access hundreds of LLMs. + Supports caching of the system prompts and the generation of content using the cached system prompts for the supported models (e.g. Gemini-2.5-flash-lite). + """ + + def __init__( + self, + vertex_project: str, + vertex_location: str, + credentials: dict, + model_id: str = "gemini-2.5-flash-lite", + custom_role_conversions: dict[str, str] | None = None, + flatten_messages_as_text: bool | None = None, + thinking_budget: Optional[int] = None, + cache_system_prompts: bool = True, + cache_ttl: str = "300s", + **kwargs, + ): + """ + Parameters: + vertex_project (str): The Google Cloud project ID where the Vertex AI service is hosted. + vertex_location (str): The location of the Vertex AI service (e.g. "europe-west1"). + credentials (dict): The Google credentials to use for authentication (i.e. The service account info in Google format). This is expected to be a dictionary containing the service account credentials. + model_id (`str`): + Name of the GenAI supported model which is to be used for generation (e.g. "gemini-2.5-flash-lite"). + custom_role_conversions (Optional[dict[str, str]]): + Custom role conversion mapping to convert message roles in others. + Default handles the GenAI model's role conversion as follows: + - `MessageRole.USER` -> "user" + - `MessageRole.ASSISTANT` -> "model" + - `MessageRole.SYSTEM` -> "model" + - `MessageRole.TOOL_CALL` -> "model" + - `MessageRole.TOOL_RESPONSE` -> "user" + You can override this mapping by providing a custom mapping. + flatten_messages_as_text (Optional[bool]): Whether to flatten messages as text. + Defaults to `True` for models that start with "ollama", "groq", "cerebras". + thinking_budget (Optional[int]): The thinking budget for the model. If None, it will be set to 0 to disable thinking for the models that support thinking (e.g. gemini-2.5-flash-lite). + This will be used to set the value for the google.genai.types.ThinkingConfig's "thinking_budget" field. + cache_system_prompts (bool): Whether to automatically cache the system prompts of the agent and use the cached system prompts for generation. + It uses the explicit context caching feature of the GenAI Python SDK to cache the system prompts (see: https://ai.google.dev/gemini-api/docs/caching?lang=python). + cache_ttl (str): The time-to-live (TTL) for the cached system prompts. Defaults to "300s" (5 minutes). + **kwargs: + Additional keyword arguments to pass to the OpenAI API. + """ + self.vertex_project = vertex_project + self.vertex_location = vertex_location + self.credentials = credentials + self.cache_system_prompts = cache_system_prompts + self.cache_ttl = cache_ttl + + flatten_messages_as_text = ( + flatten_messages_as_text + if flatten_messages_as_text is not None + else model_id.startswith(("ollama", "groq", "cerebras")) + ) + + if custom_role_conversions is None: + custom_role_conversions = { + MessageRole.USER: "user", + MessageRole.ASSISTANT: "model", + MessageRole.SYSTEM: "model", + MessageRole.TOOL_CALL: "model", + MessageRole.TOOL_RESPONSE: "user", + } + super().__init__( + model_id=model_id, + custom_role_conversions=custom_role_conversions, + flatten_messages_as_text=flatten_messages_as_text, + **kwargs, + ) + + if thinking_budget is None: + if model_id in ["gemini-2.5-flash", "gemini-2.5-flash-lite"]: # disable thinking by deafult + self.thinking_budget = 0 + else: + self.thinking_budget = thinking_budget # Thinking budget for the thinking supported models (see: https://ai.google.dev/gemini-api/docs/thinking). + + def create_client(self): + """Create the GenAI client.""" + return self.get_GenAI_Client(self.vertex_project, self.vertex_location, self.credentials) + + def generate( + self, + messages: list[dict[str, str | list[dict]] | ChatMessage], + stop_sequences: list[str] | None = None, + response_format: dict[str, str] | None = None, + tools_to_call_from: list[Tool] | None = None, + **kwargs, + ) -> ChatMessage: + """ + Process the input messages and return the model's response. + If self.cache_system_prompts is True then automatically caches the system prompt (extracts it from the passed in messages) if it is not already cached and uses the cached system prompt to generate the content. + + Parameters: + messages (`list[dict[str, str | list[dict]]] | list[ChatMessage]`): + A list of message dictionaries to be processed. Each dictionary should have the structure `{"role": "user/system", "content": "message content"}`. + stop_sequences (`List[str]`, *optional*): + A list of strings that will stop the generation if encountered in the model's output. + response_format (`dict[str, str]`, *optional*): + The response format to use in the model's response. + tools_to_call_from (`List[Tool]`, *optional*): + A list of tools that the model can use to generate responses. + **kwargs: + Additional keyword arguments to be passed to the underlying model. + + Returns: + `ChatMessage`: A chat message object containing the model's response. + """ + try: + from google.genai import errors, types + except ModuleNotFoundError as e: + raise ModuleNotFoundError( + "Please install 'genai' extra to use GenAIModel: `pip install 'smolagents[genai]'`" + ) from e + + completion_kwargs = self._prepare_completion_kwargs( + messages=messages, + stop_sequences=stop_sequences, + response_format=response_format, + tools_to_call_from=tools_to_call_from, + custom_role_conversions=self.custom_role_conversions, + convert_images_to_image_urls=True, + **kwargs, + ) + + # Convert the messages to GenAI compatible format: + message_contents: list[types.ContentDict] = [] + for message in completion_kwargs["messages"]: + role = message["role"] + m_content = message["content"] + + # convert the messages to parts: + parts = [] + for p in m_content: + p_type = p["type"] + if p_type == "text": + p_text = p["text"] + part = {"text": p_text} + parts.append(part) + else: + raise ValueError( + f"Unsupported content type '{p_type}' in the message: {message}. Only 'text' content type is supported for now." + ) # TODO: support other content types like images which the smolagents agents supports + + # create the ContentDict corresponding to the current message: + cur_message_content: types.ContentDict = { + "role": role, # Note: GenAI only supports "user" and "model" roles + "parts": parts, + } + message_contents.append(cur_message_content) + + # Convert response_format to GenAI compatible format: + if response_format is not None: + completion_kwargs["response_format"] = response_format["json_schema"]["schema"] + + # Convert tools_to_call_from to GenAI compatible format: + tools = None + if tools_to_call_from is not None: + tools = completion_kwargs["tools"] + tools = [ + types.Tool( + function_declarations=[ + types.FunctionDeclaration( + name=t["function"]["name"], + description=t["function"]["description"], + parameters=types.Schema( + type=t["function"]["parameters"]["type"], + properties=t["function"]["parameters"]["properties"], + required=t["function"]["parameters"]["required"], + ), + ) + for t in tools + ] + ) + ] + + # extract the system prompt from the messages (it must always be the first message): + assert (messages[0]["role"] == MessageRole.SYSTEM) and all( + [m["role"] != MessageRole.SYSTEM for m in messages[1:]] + ), ( + "Only the first message can be a system prompt message. The rest of the messages should not have the system role." + ) + system_prompt = messages[0]["content"][0]["text"] + del message_contents[ + 0 + ] # Remove the system prompt message from the messages list to avoid sending it again to the model as the system prompt will come from the cached_content and not the contents. + + if self.cache_system_prompts: + # generate a unique display name from the cached system prompt: + display_name = self.get_unique_string_hash( + system_prompt + ) # This will be used to identify the cached system prompt from the list of active GenAI caches. + + # Check if the system prompt of the agent is still cached: + content_cache = None + for cache in self.client.caches.list(): + cur_cache_display_name = cache.display_name + cur_model_id = cache.model.split( + "/" + )[ + -1 + ] # caches are model specific so we cannot use a prompt that is cached for one model with another model. + if (cur_model_id == self.model_id) and (cur_cache_display_name == display_name): + content_cache = cache + break + + # If not cached, then cache the system prompt: + try: + if content_cache is None: + content_cache = self.cache_GenAI_system_prompt( + system_prompt=system_prompt, + client=self.client, + display_name=display_name, + ttl=self.cache_ttl, + model=self.model_id, + ) + + # Use the resolved content cache to generate the content: + response = self.generate_content_with_cached_system_prompt_GenAI( + client=self.client, + content_cache=content_cache, + model=self.model_id, + contents=message_contents, + stop_sequences=completion_kwargs["stop"], + response_format=completion_kwargs["response_format"] if response_format is not None else None, + tools=tools, + thinking_budget=self.thinking_budget, + ) + except errors.ClientError as e: + # if the system prompt is too small for caching (less than 1024 tokens) then generate without caching: + if "The minimum token count to start caching is 1024." in e.message: + response = self.generate_content_without_cached_system_prompt_GenAI( + client=self.client, + model=self.model_id, + system_prompt=system_prompt, + contents=message_contents, + stop_sequences=completion_kwargs["stop"], + response_format=completion_kwargs["response_format"] if response_format is not None else None, + tools=tools, + thinking_budget=self.thinking_budget, + ) + else: + raise e # if the error is not related to the system prompt size, then raise the error: + else: + response = self.generate_content_without_cached_system_prompt_GenAI( + client=self.client, + model=self.model_id, + system_prompt=system_prompt, + contents=message_contents, + stop_sequences=completion_kwargs["stop"], + response_format=completion_kwargs["response_format"] if response_format is not None else None, + tools=tools, + thinking_budget=self.thinking_budget, + ) + + # Extract the generated response and token usage from the LLM response: + self._last_input_token_count = response.usage_metadata.prompt_token_count + self._last_output_token_count = response.usage_metadata.candidates_token_count + + assert len(response.model_dump()["candidates"]) == 1, ( + 'Current implementation exepcts only a single "candidate" in the response returned by the GenAI model.' + ) + assert len(response.model_dump()["candidates"][0]["content"]["parts"]) == 1, ( + 'Current implementation exepcts only a single "part" in the response returned by the GenAI model.' + ) + + return ChatMessage.from_dict( + { + "role": response.model_dump()["candidates"][0]["content"]["role"], + "content": response.model_dump()["candidates"][0]["content"]["parts"][0]["text"], + # Note that we do not set the "tool_calls" and leave it to smolagents->Model->"parse_tool_calls()" fn to parse the tool calls from the text response. + }, + raw=response, # raw response can be used to access the cache-hit/miss statistics and other metadata that are missing in the token_usage field. + token_usage=TokenUsage( + input_tokens=response.usage_metadata.prompt_token_count, + output_tokens=response.usage_metadata.candidates_token_count, + ), + ) + + def generate_stream( + self, + messages: list[dict[str, str | list[dict]]], + stop_sequences: list[str] | None = None, + response_format: dict[str, str] | None = None, + tools_to_call_from: list[Tool] | None = None, + **kwargs, + ) -> Generator[ChatMessageStreamDelta]: + """ + Returns a generator that yields ChatMessageStreamDelta objects as the model generates content. + If self.cache_system_prompts is True then automatically caches the system prompt (extracts it from the passed in messages) if it is not already cached and uses the cached system prompt to generate the content. + + Parameters: + messages (`list[dict[str, str | list[dict]]] | list[ChatMessage]`): + A list of message dictionaries to be processed. Each dictionary should have the structure `{"role": "user/system", "content": "message content"}`. + stop_sequences (`List[str]`, *optional*): + A list of strings that will stop the generation if encountered in the model's output. + response_format (`dict[str, str]`, *optional*): + The response format to use in the model's response. + tools_to_call_from (`List[Tool]`, *optional*): + A list of tools that the model can use to generate responses. + **kwargs: + Additional keyword arguments to be passed to the underlying model. + + Returns: + `Generator[ChatMessageStreamDelta]`: A generator that yields ChatMessageStreamDelta objects as the model generates content. + """ + try: + from google.genai import errors, types + except ModuleNotFoundError as e: + raise ModuleNotFoundError( + "Please install 'genai' extra to use GenAIModel: `pip install 'smolagents[genai]'`" + ) from e + completion_kwargs = self._prepare_completion_kwargs( + messages=messages, + stop_sequences=stop_sequences, + response_format=response_format, + tools_to_call_from=tools_to_call_from, + custom_role_conversions=self.custom_role_conversions, + convert_images_to_image_urls=True, + **kwargs, + ) + + # Convert the messages to GenAI compatible format: + message_contents: list[types.ContentDict] = [] + for message in completion_kwargs["messages"]: + role = message["role"] + m_content = message["content"] + + # convert the messages to parts: + parts = [] + for p in m_content: + p_type = p["type"] + if p_type == "text": + p_text = p["text"] + part = {"text": p_text} + parts.append(part) + else: + raise ValueError( + f"Unsupported content type '{p_type}' in the message: {message}. Only 'text' content type is supported for now." + ) # TODO: support other content types like images which the smolagents agents supports + + # create the ContentDict corresponding to the current message: + cur_message_content: types.ContentDict = { + "role": role, # Note: GenAI only supports "user" and "model" roles + "parts": parts, + } + message_contents.append(cur_message_content) + + # Convert response_format to GenAI compatible format: + if response_format is not None: + completion_kwargs["response_format"] = response_format["json_schema"]["schema"] + + # Convert tools_to_call_from to GenAI compatible format: + tools = None + if tools_to_call_from is not None: + tools = completion_kwargs["tools"] + tools = [ + types.Tool( + function_declarations=[ + types.FunctionDeclaration( + name=t["function"]["name"], + description=t["function"]["description"], + parameters=types.Schema( + type=t["function"]["parameters"]["type"], + properties=t["function"]["parameters"]["properties"], + required=t["function"]["parameters"]["required"], + ), + ) + for t in tools + ] + ) + ] + + # extract the system prompt from the messages (it must always be the first message): + assert (messages[0]["role"] == MessageRole.SYSTEM) and all( + [m["role"] != MessageRole.SYSTEM for m in messages[1:]] + ), ( + "Only the first message can be a system prompt message. The rest of the messages should not have the system role." + ) + system_prompt = messages[0]["content"][0]["text"] + del message_contents[ + 0 + ] # Remove the system prompt message from the messages list to avoid sending it again to the model as the system prompt will come from the cached_content and not the contents. + + if self.cache_system_prompts: + # generate a unique display name from the cached system prompt: + display_name = self.get_unique_string_hash( + system_prompt + ) # This will be used to identify the cached system prompt from the list of active GenAI caches. + + # Check if the system prompt of the agent is still cached: + content_cache = None + for cache in self.client.caches.list(): + cur_cache_display_name = cache.display_name + cur_model_id = cache.model.split( + "/" + )[ + -1 + ] # caches are model specific so we cannot use a prompt that is cached for one model with another model. + if (cur_model_id == self.model_id) and (cur_cache_display_name == display_name): + content_cache = cache + break + + # If not cached, then cache the system prompt: + try: + if content_cache is None: + content_cache = self.cache_GenAI_system_prompt( + system_prompt=system_prompt, + client=self.client, + display_name=display_name, + ttl=self.cache_ttl, + model=self.model_id, + ) + + # Use the resolved content cache to generate the content in a streaming manner: + for event in self.generate_streaming_content_with_cached_system_prompt_GenAI( + client=self.client, + content_cache=content_cache, + model=self.model_id, + contents=message_contents, + stop_sequences=completion_kwargs["stop"], + response_format=completion_kwargs["response_format"] if response_format is not None else None, + tools=tools, + thinking_budget=self.thinking_budget, + ): + if ( + (hasattr(event, "candidates")) + and (len(event.candidates) > 0) + and (event.candidates[0].content.parts) + ): + if event.candidates[0].content.parts[0].text: + yield ChatMessageStreamDelta( + content=event.candidates[0].content.parts[0].text, + ) + if getattr(event.usage_metadata, "total_token_count", None): + self._last_input_token_count = event.usage_metadata.prompt_token_count + self._last_output_token_count = event.usage_metadata.candidates_token_count + yield ChatMessageStreamDelta( + content="", + token_usage=TokenUsage( + input_tokens=event.usage_metadata.prompt_token_count, + output_tokens=event.usage_metadata.candidates_token_count, + ), + ) + except errors.ClientError as e: + # if the system prompt is too small for caching (less than 1024 tokens) then generate without caching: + if "The minimum token count to start caching is 1024." in e.message: + for event in self.generate_streaming_content_without_cached_system_prompt_GenAI( + client=self.client, + model=self.model_id, + system_prompt=system_prompt, + contents=message_contents, + stop_sequences=completion_kwargs["stop"], + response_format=completion_kwargs["response_format"] if response_format is not None else None, + tools=tools, + thinking_budget=self.thinking_budget, + ): + if ( + (hasattr(event, "candidates")) + and (len(event.candidates) > 0) + and (event.candidates[0].content.parts) + ): + if event.candidates[0].content.parts[0].text: + yield ChatMessageStreamDelta( + content=event.candidates[0].content.parts[0].text, + ) + if getattr(event.usage_metadata, "total_token_count", None): + self._last_input_token_count = event.usage_metadata.prompt_token_count + self._last_output_token_count = event.usage_metadata.candidates_token_count + yield ChatMessageStreamDelta( + content="", + token_usage=TokenUsage( + input_tokens=event.usage_metadata.prompt_token_count, + output_tokens=event.usage_metadata.candidates_token_count, + ), + ) + else: + raise e # if the error is not related to the system prompt size, then raise the error: + else: + for event in self.generate_streaming_content_without_cached_system_prompt_GenAI( + client=self.client, + model=self.model_id, + system_prompt=system_prompt, + contents=message_contents, + stop_sequences=completion_kwargs["stop"], + response_format=completion_kwargs["response_format"] if response_format is not None else None, + tools=tools, + thinking_budget=self.thinking_budget, + ): + if ( + (hasattr(event, "candidates")) + and (len(event.candidates) > 0) + and (event.candidates[0].content.parts) + ): + if event.candidates[0].content.parts[0].text: + yield ChatMessageStreamDelta( + content=event.candidates[0].content.parts[0].text, + ) + if getattr(event.usage_metadata, "total_token_count", None): + self._last_input_token_count = event.usage_metadata.prompt_token_count + self._last_output_token_count = event.usage_metadata.candidates_token_count + yield ChatMessageStreamDelta( + content="", + token_usage=TokenUsage( + input_tokens=event.usage_metadata.prompt_token_count, + output_tokens=event.usage_metadata.candidates_token_count, + ), + ) + + def get_GenAI_Client(self, vertex_project, vertex_location, credentials, http_options=None): + """ + Returns a GenAI Client instance configured for the Gemini model. + It authenticates using the Vertex AI authentication system using the given credentials. + The returned client can then be used to interact with any of the available models including the Gemini models. + + Inputs: + vertex_project (str): The Google Cloud project ID where the Vertex AI service is hosted. + vertex_location (str): The location of the Vertex AI service (e.g. "europe-west1"). + credentials (dict): The Google credentials to use for authentication (i.e. The service account info in Google format). This is expected to be a dictionary containing the service account credentials. + http_options (google.genai.types.HttpOptions): HTTP options to configure the client. If None it is set to google.genai.types.HttpOptions(api_version="v1"). + + Outputs: + client (google.genai.Client): Configured and authenticated GenAI Client. + """ + try: + from google import genai + from google.genai import types + from google.oauth2.service_account import Credentials + except ModuleNotFoundError as e: + raise ModuleNotFoundError( + "Please install 'genai' extra to use GenAIModel: `pip install 'smolagents[genai]'`" + ) from e + + # Authenticating to the Vertex AI API (refer to: https://pgaleone.eu/cloud/2025/06/29/vertex-ai-to-genai-sdk-service-account-auth-python-go/): + scopes = [ + "https://www.googleapis.com/auth/cloud-platform", + ] + google_credentials = Credentials.from_service_account_info(credentials, scopes=scopes) + + client = genai.Client( + vertexai=True, + project=vertex_project, + location=vertex_location, + credentials=google_credentials, + http_options=http_options if http_options is not None else types.HttpOptions(api_version="v1"), + ) + + return client + + def cache_GenAI_system_prompt( + self, system_prompt: str, client, display_name: str, ttl: str = "300s", model: str = "gemini-2.5-flash-lite" + ): + """ + Caches a system prompt for the Gemini model using the GenAI client. + + Inputs: + system_prompt (str): The system prompt to cache. + client (google.genai.Client): The GenAI client instance. + display_name (str): The display name for the cached content. + ttl (str, optional): Time to live for the cache. Defaults to "62s". + model (str, optional): The model to use for caching. Defaults to "gemini-2.5-flash-lite". + + Outputs: + content_cache (google.genai.types.CachedContent): The cached content object. + """ + try: + from google.genai import types + except ModuleNotFoundError as e: + raise ModuleNotFoundError( + "Please install 'genai' extra to use GenAIModel: `pip install 'smolagents[genai]'`" + ) from e + + # Cache the system prompt with the specified arguments: + content_cache = client.caches.create( + model=model, + config=types.CreateCachedContentConfig( + system_instruction=system_prompt, + display_name=display_name, + ttl=ttl, + ), + ) + + return content_cache + + def generate_content_with_cached_system_prompt_GenAI( + self, + client, + content_cache, + model: str = "gemini-2.5-flash-lite", + contents="What is the capital of France?", + stop_sequences: list[str] | None = None, + response_format: dict[str, str] | None = None, + tools=None, + thinking_budget: Optional[int] = None, + ): + """ + Generates content using a specified cached system prompt with the GenAI client. + + Inputs: + client (google.genai.Client): The GenAI client instance. + content_cache (types.CachedContent): The cached content object containing the system prompt. This is returned by the cache_GenAI_system_prompt function. + model (str): The model to use for generation. Defaults to "gemini-2.5-flash-lite". + contents (google.genai.types.ContentListUnionDict): The content to generate using the cached system prompt. In other words the user input to the model. + stop_sequences (list[str]): A list of strings that will stop the generation if encountered in the model's output. Defaults to None. + response_format (`dict[str, str]`, *optional*): The response format to use in the model's response. It should correspond to the + google.genai.types.GenerateContentConfig.response_json_schema field. + This correpsonds to the response_format["json_schema"]["schema"] if the response_format is in the smolagents format. + Sample response_format: + { + "additionalProperties": False, + "properties": { + "thought": { + "description": "A free form text description of the thought process.", + "title": "Thought", + "type": "string" + }, + "code": { + "description": "Valid Python code snippet implementing the thought.", + "title": "Code", + "type": "string" + } + }, + "required": [ + "thought", + "code" + ], + "title": "ThoughtAndCodeAnswer", + "type": "object" + } + tools (google.genai.types.ToolListUnion): A list of tools that the model can use to generate responses. If None, no tools will be used. + thinking_budget (Optional[int]): The thinking budget for the model. If not provided, then it will not be set. + For example, set it to 0 to disable thinking for the gemini-2.5-flash-lite model which supports thinking (https://ai.google.dev/gemini-api/docs/thinking). + + Outputs: + response (google.genai.types.GenerateContentResponse): The response from the model containing the generated content. + """ + try: + from google.genai import types + except ModuleNotFoundError as e: + raise ModuleNotFoundError( + "Please install 'genai' extra to use GenAIModel: `pip install 'smolagents[genai]'`" + ) from e + + optional_args = {} + if thinking_budget is not None: # if provided then set the thinking budget + optional_args["thinking_config"] = types.ThinkingConfig(thinking_budget=thinking_budget) + + # Use the cached system prompt for generation: + response = client.models.generate_content( + model=model, + contents=contents, + config=types.GenerateContentConfig( + cached_content=content_cache.name, # this contains the cached system_prompt so do not pass it explicitly again + stop_sequences=stop_sequences, + response_json_schema=response_format, + response_mime_type="application/json" if response_format is not None else None, + tools=tools, + **optional_args, + ), + ) + + return response + + def generate_content_without_cached_system_prompt_GenAI( + self, + client, + model: str = "gemini-2.5-flash-lite", + system_prompt=None, + contents="What is the capital of France?", + stop_sequences: list[str] | None = None, + response_format: dict[str, str] | None = None, + tools=None, + thinking_budget: Optional[int] = None, + ): + """ + Generates content using the GenAI client without using a cached system prompt. + + Inputs: + client (google.genai.Client): The GenAI client instance. + model (str): The model to use for generation. Defaults to "gemini-2.5-flash-lite". + system_prompt (google.genai.types.ContentUnion): The system prompt to use for generation. If not provided, it will not be used. + contents (google.genai.types.ContentListUnionDict): The content to generate using the system prompt. In other words the user input to the model. + stop_sequences (list[str]): A list of strings that will stop the generation if encountered in the model's output. Defaults to None. + response_format (`dict[str, str]`, *optional*): The response format to use in the model's response. It should correspond to the + google.genai.types.GenerateContentConfig.response_json_schema field. + This correpsonds to the response_format["json_schema"]["schema"] if the response_format is in the smolagents format. + Sample response_format: + { + "additionalProperties": False, + "properties": { + "thought": { + "description": "A free form text description of the thought process.", + "title": "Thought", + "type": "string" + }, + "code": { + "description": "Valid Python code snippet implementing the thought.", + "title": "Code", + "type": "string" + } + }, + "required": [ + "thought", + "code" + ], + "title": "ThoughtAndCodeAnswer", + "type": "object" + } + tools (google.genai.types.ToolListUnion): A list of tools that the model can use to generate responses. If None, no tools will be used. + thinking_budget (Optional[int]): The thinking budget for the model. If not provided, then it will not be set. + For example, set it to 0 to disable thinking for the gemini-2.5-flash-lite model which supports thinking (https://ai.google.dev/gemini-api/docs/thinking). + + Outputs: + response (google.genai.types.GenerateContentResponse): The response from the model containing the generated content. + """ + try: + from google.genai import types + except ModuleNotFoundError as e: + raise ModuleNotFoundError( + "Please install 'genai' extra to use GenAIModel: `pip install 'smolagents[genai]'`" + ) from e + + optional_args = {} + if thinking_budget is not None: # if provided then set the thinking budget + optional_args["thinking_config"] = types.ThinkingConfig(thinking_budget=thinking_budget) + + # Generate content without using a cached system prompt: + response = client.models.generate_content( + model=model, + contents=contents, + config=types.GenerateContentConfig( + system_instruction=system_prompt, + stop_sequences=stop_sequences, + response_json_schema=response_format, + response_mime_type="application/json" if response_format is not None else None, + tools=tools, + **optional_args, + ), + ) + + return response + + def generate_streaming_content_with_cached_system_prompt_GenAI( + self, + client, + content_cache, + model: str = "gemini-2.5-flash-lite", + contents="What is the capital of France?", + stop_sequences: list[str] | None = None, + response_format: dict[str, str] | None = None, + tools=None, + thinking_budget: Optional[int] = None, + ): + """ + Generates content using a specified cached system prompt with the GenAI client in a streaming manner. + + Inputs: + client (google.genai.Client): The GenAI client instance. + content_cache (google.genai.types.CachedContent): The cached content object containing the system prompt. This is returned by the cache_GenAI_system_prompt function. + model (str): The model to use for generation. Defaults to "gemini-2.5-flash-lite". + contents (google.genai.types.ContentListUnionDict): The content to generate using the cached system prompt. In other words the user input to the model. + stop_sequences (list[str]): A list of strings that will stop the generation if encountered in the model's output. Defaults to None. + response_format (`dict[str, str]`, *optional*): The response format to use in the model's response. It should correspond to the + google.genai.types.GenerateContentConfig.response_json_schema field. + This correpsonds to the response_format["json_schema"]["schema"] if the response_format is in the smolagents format. + Sample response_format: + { + "additionalProperties": False, + "properties": { + "thought": { + "description": "A free form text description of the thought process.", + "title": "Thought", + "type": "string" + }, + "code": { + "description": "Valid Python code snippet implementing the thought.", + "title": "Code", + "type": "string" + } + }, + "required": [ + "thought", + "code" + ], + "title": "ThoughtAndCodeAnswer", + "type": "object" + } + tools (google.genai.types.ToolListUnion): A list of tools that the model can use to generate responses. If None, no tools will be used. + thinking_budget (Optional[int]): The thinking budget for the model. If not provided, then it will not be set. + For example, set it to 0 to disable thinking for the gemini-2.5-flash-lite model which supports thinking (https://ai.google.dev/gemini-api/docs/thinking). + + Outputs: + generator (Iterator[google.genai.types.GenerateContentResponse]): A generator that yields responses from the model as it generates content. + """ + try: + from google.genai import types + except ModuleNotFoundError as e: + raise ModuleNotFoundError( + "Please install 'genai' extra to use GenAIModel: `pip install 'smolagents[genai]'`" + ) from e + + optional_args = {} + if thinking_budget is not None: # if provided then set the thinking budget + optional_args["thinking_config"] = types.ThinkingConfig(thinking_budget=thinking_budget) + + # Use the cached system prompt for generation: + generator = client.models.generate_content_stream( + model=model, + contents=contents, + config=types.GenerateContentConfig( + cached_content=content_cache.name, # this contains the cached system_prompt so do not pass it explicitly again + stop_sequences=stop_sequences, + response_json_schema=response_format, + response_mime_type="application/json" if response_format is not None else None, + tools=tools, + **optional_args, + ), + ) + + return generator + + def generate_streaming_content_without_cached_system_prompt_GenAI( + self, + client, + model: str = "gemini-2.5-flash-lite", + system_prompt=None, + contents="What is the capital of France?", + stop_sequences: list[str] | None = None, + response_format: dict[str, str] | None = None, + tools=None, + thinking_budget: Optional[int] = None, + ): + """ + Generates content using the GenAI client in a streaming manner without using a cached system prompt. + + Inputs: + client (google.genai.Client): The GenAI client instance. + model (str): The model to use for generation. Defaults to "gemini-2.5-flash-lite". + system_prompt (google.genai.types.ContentUnion): The system prompt to use for generation. If not provided, it will not be used. + contents (google.genai.types.ContentListUnionDict): The content to generate using the system prompt. In other words the user input to the model. + stop_sequences (list[str]): A list of strings that will stop the generation if encountered in the model's output. Defaults to None. + response_format (`dict[str, str]`, *optional*): The response format to use in the model's response. It should correspond to the + google.genai.types.GenerateContentConfig.response_json_schema field. + This correpsonds to the response_format["json_schema"]["schema"] if the response_format is in the smolagents format. + Sample response_format: + { + "additionalProperties": False, + "properties": { + "thought": { + "description": "A free form text description of the thought process.", + "title": "Thought", + "type": "string" + }, + "code": { + "description": "Valid Python code snippet implementing the thought.", + "title": "Code", + "type": "string" + } + }, + "required": [ + "thought", + "code" + ], + "title": "ThoughtAndCodeAnswer", + "type": "object" + } + tools (google.genai.types.ToolListUnion): A list of tools that the model can use to generate responses. If None, no tools will be used. + thinking_budget (Optional[int]): The thinking budget for the model. If not provided, then it will not be set. + For example, set it to 0 to disable thinking for the gemini-2.5-flash-lite model which supports thinking (https://ai.google.dev/gemini-api/docs/thinking). + + Outputs: + generator (Iterator[google.genai.types.GenerateContentResponse]): A generator that yields responses from the model as it generates content. + """ + try: + from google.genai import types + except ModuleNotFoundError as e: + raise ModuleNotFoundError( + "Please install 'genai' extra to use GenAIModel: `pip install 'smolagents[genai]'`" + ) from e + + optional_args = {} + if thinking_budget is not None: # if provided then set the thinking budget + optional_args["thinking_config"] = types.ThinkingConfig(thinking_budget=thinking_budget) + + # Generate content without using a cached system prompt: + generator = client.models.generate_content_stream( + model=model, + contents=contents, + config=types.GenerateContentConfig( + system_instruction=system_prompt, + stop_sequences=stop_sequences, + response_json_schema=response_format, + response_mime_type="application/json" if response_format is not None else None, + tools=tools, + **optional_args, + ), + ) + + return generator + + def get_unique_string_hash(self, txt): + """ + Generates a consistent unique hash for a given string which can be used to check for string equality. + Creates the hash using the MD5 algorithm. + This is used for identifying the cached system prompts in the GenAI caches. + + Inputs: + txt (str): The input string to hash. + + Outputs: + str: The unique hash of the input string. + """ + try: + import hashlib + except ModuleNotFoundError as e: + raise ModuleNotFoundError( + "Please install 'genai' extra to use GenAIModel: `pip install 'smolagents[genai]'`" + ) from e + + return hashlib.md5(txt.encode()).hexdigest() + + __all__ = [ "MessageRole", "tool_role_conversions", @@ -1926,4 +2850,5 @@ def generate( "AmazonBedrockServerModel", "AmazonBedrockModel", "ChatMessage", + "GenAIModel", ]