diff --git a/backend/windmill-api/src/ai.rs b/backend/windmill-api/src/ai.rs index 1422e9e52db65..e3bda50cc00ce 100644 --- a/backend/windmill-api/src/ai.rs +++ b/backend/windmill-api/src/ai.rs @@ -132,10 +132,11 @@ struct AIOAuthResource { user: Option, } -/// Platform for Anthropic API +/// Platform for AI providers (Anthropic, Google AI) +/// Both Anthropic and Google AI can run on standard endpoints or Google Vertex AI #[derive(Deserialize, Debug, Clone, Default, PartialEq)] #[serde(rename_all = "snake_case")] -enum AnthropicPlatform { +enum AIPlatform { #[default] Standard, GoogleVertexAi, @@ -155,9 +156,10 @@ struct AIStandardResource { aws_access_key_id: Option, #[serde(alias = "awsSecretAccessKey", default, deserialize_with = "empty_string_as_none")] aws_secret_access_key: Option, - /// Platform for Anthropic API (standard or google_vertex_ai) + /// Platform for AI providers (standard or google_vertex_ai) + /// Used by both Anthropic and Google AI providers to indicate Vertex AI usage #[serde(default)] - platform: AnthropicPlatform, + platform: AIPlatform, } #[derive(Deserialize, Debug)] @@ -185,7 +187,8 @@ struct AIRequestConfig { pub aws_access_key_id: Option, #[allow(dead_code)] pub aws_secret_access_key: Option, - pub platform: AnthropicPlatform, + /// Platform for AI providers (standard or google_vertex_ai) + pub platform: AIPlatform, } impl AIRequestConfig { @@ -213,9 +216,7 @@ impl AIRequestConfig { let base_url = if matches!(provider, AIProvider::AWSBedrock) { String::new() } else { - provider - .get_base_url(resource.base_url, db) - .await? + provider.get_base_url(resource.base_url, db).await? }; let api_key = if let Some(api_key) = resource.api_key { Some(get_variable_or_self(api_key, db, w_id).await?) @@ -269,7 +270,7 @@ impl AIRequestConfig { None, None, None, - AnthropicPlatform::Standard, + AIPlatform::Standard, ) } }; @@ -338,12 +339,14 @@ impl AIRequestConfig { let is_azure = provider.is_azure_openai(base_url); let is_anthropic = matches!(provider, AIProvider::Anthropic); - let is_anthropic_vertex = is_anthropic && self.platform == AnthropicPlatform::GoogleVertexAi; + let is_anthropic_vertex = is_anthropic && self.platform == AIPlatform::GoogleVertexAi; let is_anthropic_sdk = headers.get("X-Anthropic-SDK").is_some(); let is_google_ai = matches!(provider, AIProvider::GoogleAI); + let is_google_ai_vertex = is_google_ai && self.platform == AIPlatform::GoogleVertexAi; // GoogleAI uses OpenAI-compatible endpoint in the proxy (for the chat), but not for the ai agent - let base_url = if is_google_ai { + // For Vertex AI, the base_url is already properly configured, no need to add /openai + let base_url = if is_google_ai && !is_google_ai_vertex { format!("{}/openai", base_url) } else { base_url.to_string() @@ -388,7 +391,11 @@ impl AIRequestConfig { if let Some(api_key) = self.api_key { if is_azure { request = request.header("api-key", api_key.clone()) + } else if is_google_ai && !is_google_ai_vertex { + // Standard Google AI uses x-goog-api-key header + request = request.header("x-goog-api-key", api_key.clone()) } else { + // Vertex AI (both Google and Anthropic) and most other providers use Bearer token request = request.header("authorization", format!("Bearer {}", api_key.clone())) } // For standard Anthropic API, also add X-API-Key header (but not for Vertex AI) diff --git a/backend/windmill-worker/src/ai/providers/anthropic.rs b/backend/windmill-worker/src/ai/providers/anthropic.rs index 726a8e27e25ed..b6ce05b42549d 100644 --- a/backend/windmill-worker/src/ai/providers/anthropic.rs +++ b/backend/windmill-worker/src/ai/providers/anthropic.rs @@ -324,16 +324,16 @@ pub struct AnthropicResponse { pub struct AnthropicQueryBuilder { #[allow(dead_code)] provider_kind: AIProvider, - platform: AnthropicPlatform, + platform: AIPlatform, } impl AnthropicQueryBuilder { - pub fn new(provider_kind: AIProvider, platform: AnthropicPlatform) -> Self { + pub fn new(provider_kind: AIProvider, platform: AIPlatform) -> Self { Self { provider_kind, platform } } fn is_vertex(&self) -> bool { - self.platform == AnthropicPlatform::GoogleVertexAi + self.platform == AIPlatform::GoogleVertexAi } async fn build_text_request( diff --git a/backend/windmill-worker/src/ai/providers/google_ai.rs b/backend/windmill-worker/src/ai/providers/google_ai.rs index 62e6afff75dcb..975c13860a00c 100644 --- a/backend/windmill-worker/src/ai/providers/google_ai.rs +++ b/backend/windmill-worker/src/ai/providers/google_ai.rs @@ -6,7 +6,7 @@ use crate::ai::{ image_handler::download_and_encode_s3_image, query_builder::{BuildRequestArgs, ParsedResponse, QueryBuilder, StreamEventProcessor}, sse::{GeminiSSEParser, SSEParser}, - types::*, + types::{AIPlatform, *}, utils::parse_data_url, }; @@ -214,11 +214,17 @@ pub struct GeminiPredictCandidate { // Query Builder Implementation // ============================================================================ -pub struct GoogleAIQueryBuilder; +pub struct GoogleAIQueryBuilder { + platform: AIPlatform, +} impl GoogleAIQueryBuilder { - pub fn new() -> Self { - Self + pub fn new(platform: AIPlatform) -> Self { + Self { platform } + } + + fn is_vertex(&self) -> bool { + self.platform == AIPlatform::GoogleVertexAi } /// Build a text request using the native Gemini API format @@ -661,20 +667,41 @@ impl QueryBuilder for GoogleAIQueryBuilder { } fn get_endpoint(&self, base_url: &str, model: &str, output_type: &OutputType) -> String { - match output_type { - OutputType::Text => { - format!( - "{}/models/{}:streamGenerateContent?alt=sse", - base_url, model - ) + if self.is_vertex() { + // For Vertex AI, the base_url should be in format: + // https://{region}-aiplatform.googleapis.com/v1/projects/{project}/locations/{location}/publishers/google/models + // We append the model and the appropriate action + let base_url = base_url.trim_end_matches('/'); + match output_type { + OutputType::Text => { + format!("{}/{}:streamGenerateContent?alt=sse", base_url, model) + } + OutputType::Image => { + let url_suffix = if model.contains("imagen") { + "predict" + } else { + "generateContent" + }; + format!("{}/{}:{}", base_url, model, url_suffix) + } } - OutputType::Image => { - let url_suffix = if model.contains("imagen") { - "predict" - } else { - "generateContent" - }; - format!("{}/models/{}:{}", base_url, model, url_suffix) + } else { + // Standard Google AI endpoint + match output_type { + OutputType::Text => { + format!( + "{}/models/{}:streamGenerateContent?alt=sse", + base_url, model + ) + } + OutputType::Image => { + let url_suffix = if model.contains("imagen") { + "predict" + } else { + "generateContent" + }; + format!("{}/models/{}:{}", base_url, model, url_suffix) + } } } } @@ -685,7 +712,13 @@ impl QueryBuilder for GoogleAIQueryBuilder { _base_url: &str, _output_type: &OutputType, ) -> Vec<(&'static str, String)> { - // Native Gemini API always uses x-goog-api-key - vec![("x-goog-api-key", api_key.to_string())] + if self.is_vertex() { + // For Vertex AI, use Bearer token authentication + // The api_key should be an OAuth2 access token (from gcloud auth print-access-token) + vec![("Authorization", format!("Bearer {}", api_key))] + } else { + // Native Gemini API uses x-goog-api-key + vec![("x-goog-api-key", api_key.to_string())] + } } } diff --git a/backend/windmill-worker/src/ai/query_builder.rs b/backend/windmill-worker/src/ai/query_builder.rs index 2bbf93c69af32..c17a2be99e209 100644 --- a/backend/windmill-worker/src/ai/query_builder.rs +++ b/backend/windmill-worker/src/ai/query_builder.rs @@ -113,8 +113,10 @@ pub fn create_query_builder(provider: &ProviderWithResource) -> Box Box::new(GoogleAIQueryBuilder::new()), + // Google AI uses the Gemini API (with platform-specific handling for Vertex AI) + AIProvider::GoogleAI => { + Box::new(GoogleAIQueryBuilder::new(provider.get_platform().clone())) + } // OpenAI use the Responses API AIProvider::OpenAI => Box::new(OpenAIQueryBuilder::new(provider.kind.clone())), // Anthropic uses its own API format (with platform-specific handling for Vertex AI) diff --git a/backend/windmill-worker/src/ai/types.rs b/backend/windmill-worker/src/ai/types.rs index 0955e7c7a11ef..9952c5fd8458d 100644 --- a/backend/windmill-worker/src/ai/types.rs +++ b/backend/windmill-worker/src/ai/types.rs @@ -156,9 +156,11 @@ impl From for AIAgentArgs { } } +/// Platform for AI providers (Anthropic, Google AI) +/// Both Anthropic and Google AI can run on standard endpoints or Google Vertex AI #[derive(Deserialize, Debug, Clone, Default, PartialEq)] #[serde(rename_all = "snake_case")] -pub enum AnthropicPlatform { +pub enum AIPlatform { #[default] Standard, GoogleVertexAi, @@ -179,9 +181,10 @@ pub struct ProviderResource { #[allow(dead_code)] #[serde(alias = "awsSecretAccessKey", default, deserialize_with = "empty_string_as_none")] pub aws_secret_access_key: Option, - /// Platform for Anthropic API (standard or google_vertex_ai) + /// Platform for AI providers (standard or google_vertex_ai) + /// Used by both Anthropic and Google AI providers to indicate Vertex AI usage #[serde(default)] - pub platform: AnthropicPlatform, + pub platform: AIPlatform, } #[derive(Deserialize, Debug)] @@ -224,7 +227,7 @@ impl ProviderWithResource { self.resource.aws_secret_access_key.as_deref() } - pub fn get_platform(&self) -> &AnthropicPlatform { + pub fn get_platform(&self) -> &AIPlatform { &self.resource.platform } }