Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 18 additions & 11 deletions backend/windmill-api/src/ai.rs
Original file line number Diff line number Diff line change
Expand Up @@ -132,10 +132,11 @@ struct AIOAuthResource {
user: Option<String>,
}

/// Platform for Anthropic API
/// Platform for AI providers (Anthropic, Google AI)
/// Both Anthropic and Google AI can run on standard endpoints or Google Vertex AI
#[derive(Deserialize, Debug, Clone, Default, PartialEq)]
#[serde(rename_all = "snake_case")]
enum AnthropicPlatform {
enum AIPlatform {
#[default]
Standard,
GoogleVertexAi,
Expand All @@ -155,9 +156,10 @@ struct AIStandardResource {
aws_access_key_id: Option<String>,
#[serde(alias = "awsSecretAccessKey", default, deserialize_with = "empty_string_as_none")]
aws_secret_access_key: Option<String>,
/// Platform for Anthropic API (standard or google_vertex_ai)
/// Platform for AI providers (standard or google_vertex_ai)
/// Used by both Anthropic and Google AI providers to indicate Vertex AI usage
#[serde(default)]
platform: AnthropicPlatform,
platform: AIPlatform,
}

#[derive(Deserialize, Debug)]
Expand Down Expand Up @@ -185,7 +187,8 @@ struct AIRequestConfig {
pub aws_access_key_id: Option<String>,
#[allow(dead_code)]
pub aws_secret_access_key: Option<String>,
pub platform: AnthropicPlatform,
/// Platform for AI providers (standard or google_vertex_ai)
pub platform: AIPlatform,
}

impl AIRequestConfig {
Expand Down Expand Up @@ -213,9 +216,7 @@ impl AIRequestConfig {
let base_url = if matches!(provider, AIProvider::AWSBedrock) {
String::new()
} else {
provider
.get_base_url(resource.base_url, db)
.await?
provider.get_base_url(resource.base_url, db).await?
};
let api_key = if let Some(api_key) = resource.api_key {
Some(get_variable_or_self(api_key, db, w_id).await?)
Expand Down Expand Up @@ -269,7 +270,7 @@ impl AIRequestConfig {
None,
None,
None,
AnthropicPlatform::Standard,
AIPlatform::Standard,
)
}
};
Expand Down Expand Up @@ -338,12 +339,14 @@ impl AIRequestConfig {

let is_azure = provider.is_azure_openai(base_url);
let is_anthropic = matches!(provider, AIProvider::Anthropic);
let is_anthropic_vertex = is_anthropic && self.platform == AnthropicPlatform::GoogleVertexAi;
let is_anthropic_vertex = is_anthropic && self.platform == AIPlatform::GoogleVertexAi;
let is_anthropic_sdk = headers.get("X-Anthropic-SDK").is_some();
let is_google_ai = matches!(provider, AIProvider::GoogleAI);
let is_google_ai_vertex = is_google_ai && self.platform == AIPlatform::GoogleVertexAi;

// GoogleAI uses OpenAI-compatible endpoint in the proxy (for the chat), but not for the ai agent
let base_url = if is_google_ai {
// For Vertex AI, the base_url is already properly configured, no need to add /openai
let base_url = if is_google_ai && !is_google_ai_vertex {
format!("{}/openai", base_url)
} else {
base_url.to_string()
Expand Down Expand Up @@ -388,7 +391,11 @@ impl AIRequestConfig {
if let Some(api_key) = self.api_key {
if is_azure {
request = request.header("api-key", api_key.clone())
} else if is_google_ai && !is_google_ai_vertex {
// Standard Google AI uses x-goog-api-key header
request = request.header("x-goog-api-key", api_key.clone())
} else {
// Vertex AI (both Google and Anthropic) and most other providers use Bearer token
request = request.header("authorization", format!("Bearer {}", api_key.clone()))
}
// For standard Anthropic API, also add X-API-Key header (but not for Vertex AI)
Expand Down
6 changes: 3 additions & 3 deletions backend/windmill-worker/src/ai/providers/anthropic.rs
Original file line number Diff line number Diff line change
Expand Up @@ -324,16 +324,16 @@ pub struct AnthropicResponse {
pub struct AnthropicQueryBuilder {
#[allow(dead_code)]
provider_kind: AIProvider,
platform: AnthropicPlatform,
platform: AIPlatform,
}

impl AnthropicQueryBuilder {
pub fn new(provider_kind: AIProvider, platform: AnthropicPlatform) -> Self {
pub fn new(provider_kind: AIProvider, platform: AIPlatform) -> Self {
Self { provider_kind, platform }
}

fn is_vertex(&self) -> bool {
self.platform == AnthropicPlatform::GoogleVertexAi
self.platform == AIPlatform::GoogleVertexAi
}

async fn build_text_request(
Expand Down
71 changes: 52 additions & 19 deletions backend/windmill-worker/src/ai/providers/google_ai.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ use crate::ai::{
image_handler::download_and_encode_s3_image,
query_builder::{BuildRequestArgs, ParsedResponse, QueryBuilder, StreamEventProcessor},
sse::{GeminiSSEParser, SSEParser},
types::*,
types::{AIPlatform, *},
utils::parse_data_url,
};

Expand Down Expand Up @@ -214,11 +214,17 @@ pub struct GeminiPredictCandidate {
// Query Builder Implementation
// ============================================================================

pub struct GoogleAIQueryBuilder;
pub struct GoogleAIQueryBuilder {
platform: AIPlatform,
}

impl GoogleAIQueryBuilder {
pub fn new() -> Self {
Self
pub fn new(platform: AIPlatform) -> Self {
Self { platform }
}

fn is_vertex(&self) -> bool {
self.platform == AIPlatform::GoogleVertexAi
}

/// Build a text request using the native Gemini API format
Expand Down Expand Up @@ -661,20 +667,41 @@ impl QueryBuilder for GoogleAIQueryBuilder {
}

fn get_endpoint(&self, base_url: &str, model: &str, output_type: &OutputType) -> String {
match output_type {
OutputType::Text => {
format!(
"{}/models/{}:streamGenerateContent?alt=sse",
base_url, model
)
if self.is_vertex() {
// For Vertex AI, the base_url should be in format:
// https://{region}-aiplatform.googleapis.com/v1/projects/{project}/locations/{location}/publishers/google/models
// We append the model and the appropriate action
let base_url = base_url.trim_end_matches('/');
match output_type {
OutputType::Text => {
format!("{}/{}:streamGenerateContent?alt=sse", base_url, model)
}
OutputType::Image => {
let url_suffix = if model.contains("imagen") {
"predict"
} else {
"generateContent"
};
format!("{}/{}:{}", base_url, model, url_suffix)
}
}
OutputType::Image => {
let url_suffix = if model.contains("imagen") {
"predict"
} else {
"generateContent"
};
format!("{}/models/{}:{}", base_url, model, url_suffix)
} else {
// Standard Google AI endpoint
match output_type {
OutputType::Text => {
format!(
"{}/models/{}:streamGenerateContent?alt=sse",
base_url, model
)
}
OutputType::Image => {
let url_suffix = if model.contains("imagen") {
"predict"
} else {
"generateContent"
};
format!("{}/models/{}:{}", base_url, model, url_suffix)
}
}
}
}
Expand All @@ -685,7 +712,13 @@ impl QueryBuilder for GoogleAIQueryBuilder {
_base_url: &str,
_output_type: &OutputType,
) -> Vec<(&'static str, String)> {
// Native Gemini API always uses x-goog-api-key
vec![("x-goog-api-key", api_key.to_string())]
if self.is_vertex() {
// For Vertex AI, use Bearer token authentication
// The api_key should be an OAuth2 access token (from gcloud auth print-access-token)
vec![("Authorization", format!("Bearer {}", api_key))]
} else {
// Native Gemini API uses x-goog-api-key
vec![("x-goog-api-key", api_key.to_string())]
}
}
}
6 changes: 4 additions & 2 deletions backend/windmill-worker/src/ai/query_builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -113,8 +113,10 @@ pub fn create_query_builder(provider: &ProviderWithResource) -> Box<dyn QueryBui
use windmill_common::ai_providers::AIProvider;

match provider.kind {
// Google AI uses the Gemini API
AIProvider::GoogleAI => Box::new(GoogleAIQueryBuilder::new()),
// Google AI uses the Gemini API (with platform-specific handling for Vertex AI)
AIProvider::GoogleAI => {
Box::new(GoogleAIQueryBuilder::new(provider.get_platform().clone()))
}
// OpenAI use the Responses API
AIProvider::OpenAI => Box::new(OpenAIQueryBuilder::new(provider.kind.clone())),
// Anthropic uses its own API format (with platform-specific handling for Vertex AI)
Expand Down
11 changes: 7 additions & 4 deletions backend/windmill-worker/src/ai/types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -156,9 +156,11 @@ impl From<AIAgentArgsRaw> for AIAgentArgs {
}
}

/// Platform for AI providers (Anthropic, Google AI)
/// Both Anthropic and Google AI can run on standard endpoints or Google Vertex AI
#[derive(Deserialize, Debug, Clone, Default, PartialEq)]
#[serde(rename_all = "snake_case")]
pub enum AnthropicPlatform {
pub enum AIPlatform {
#[default]
Standard,
GoogleVertexAi,
Expand All @@ -179,9 +181,10 @@ pub struct ProviderResource {
#[allow(dead_code)]
#[serde(alias = "awsSecretAccessKey", default, deserialize_with = "empty_string_as_none")]
pub aws_secret_access_key: Option<String>,
/// Platform for Anthropic API (standard or google_vertex_ai)
/// Platform for AI providers (standard or google_vertex_ai)
/// Used by both Anthropic and Google AI providers to indicate Vertex AI usage
#[serde(default)]
pub platform: AnthropicPlatform,
pub platform: AIPlatform,
}

#[derive(Deserialize, Debug)]
Expand Down Expand Up @@ -224,7 +227,7 @@ impl ProviderWithResource {
self.resource.aws_secret_access_key.as_deref()
}

pub fn get_platform(&self) -> &AnthropicPlatform {
pub fn get_platform(&self) -> &AIPlatform {
&self.resource.platform
}
}
Expand Down
Loading