diff --git a/backend/migrations/20251105100125_legacy_sql_result_flag.down.sql b/backend/migrations/20251105100125_legacy_sql_result_flag.down.sql new file mode 100644 index 0000000000000..d2f607c5b8bd6 --- /dev/null +++ b/backend/migrations/20251105100125_legacy_sql_result_flag.down.sql @@ -0,0 +1 @@ +-- Add down migration script here diff --git a/backend/migrations/20251105100125_legacy_sql_result_flag.up.sql b/backend/migrations/20251105100125_legacy_sql_result_flag.up.sql new file mode 100644 index 0000000000000..c0fcfb482755e --- /dev/null +++ b/backend/migrations/20251105100125_legacy_sql_result_flag.up.sql @@ -0,0 +1,89 @@ +CREATE OR REPLACE FUNCTION update_string(s text) +RETURNS text +LANGUAGE plpgsql +AS $$ +DECLARE + prefix TEXT := '-- https://www.windmill.dev/docs/getting_started/scripts_quickstart/sql#result-collection +-- result_collection=legacy + +'; +BEGIN + RETURN prefix || s; +END; +$$; + +CREATE OR REPLACE FUNCTION update_all_modules(obj jsonb) +RETURNS jsonb +LANGUAGE plpgsql +AS $$ +DECLARE + result jsonb; + k text; + v jsonb; +BEGIN + IF jsonb_typeof(obj) = 'object' THEN + result := '{}'::jsonb; + + FOR k, v IN SELECT * FROM jsonb_each(obj) + LOOP + IF k = 'content' and jsonb_typeof(v) = 'string' AND obj->>'language' IN ('bigquery', 'postgresql', 'duckdb', 'mssql', 'oracledb', 'snowflake', 'mysql') THEN + result := result || jsonb_build_object('content', update_string(obj->>'content')); + ELSE + result := result || jsonb_build_object(k, update_all_modules(v)); + END IF; + END LOOP; + + RETURN result; + ELSIF jsonb_typeof(obj) = 'array' AND jsonb_array_length(obj) > 0 THEN + SELECT jsonb_agg(update_all_modules(elem)) + INTO result + FROM jsonb_array_elements(obj) elem; + + RETURN result; + ELSE + RETURN obj; + END IF; +END; +$$; + +-- Run on a flow_version_lite jsonb value. Returns an array of flow_node ids whose languages are SQL. +CREATE OR REPLACE FUNCTION find_sql_flow_nodes_ids(obj jsonb) +RETURNS BIGINT[] +LANGUAGE plpgsql +AS $$ +DECLARE + result BIGINT[] := '{}'; + k text; + v jsonb; +BEGIN + IF jsonb_typeof(obj) = 'object' THEN + IF obj->>'language' IN ('bigquery', 'postgresql', 'duckdb', 'mssql', 'oracledb', 'snowflake', 'mysql') AND jsonb_typeof(obj->'id') = 'number' THEN + result := result || (obj->>'id')::BIGINT; + END IF; + + FOR k, v IN SELECT * FROM jsonb_each(obj) + LOOP + result := result || find_sql_flow_nodes_ids(v); + END LOOP; + ELSIF jsonb_typeof(obj) = 'array' AND jsonb_array_length(obj) > 0 THEN + SELECT array_agg(result_ids) + INTO result + FROM jsonb_array_elements(obj) elem, unnest(find_sql_flow_nodes_ids(elem)) as result_ids; + END IF; + RETURN result; +END; +$$; + + +UPDATE app_version SET value = update_all_modules(value::jsonb)::json; +UPDATE draft SET value = update_all_modules(value::jsonb)::json; +UPDATE flow SET value = update_all_modules(value); +UPDATE flow_version SET value = update_all_modules(value); +UPDATE flow_node SET code = update_string(code) WHERE id IN ( + SELECT v FROM flow_version_lite, unnest(find_sql_flow_nodes_ids(value)) as v +); +UPDATE script SET content = update_string(content) WHERE language IN ('bigquery', 'postgresql', 'duckdb', 'mssql', 'oracledb', 'snowflake', 'mysql'); + +DROP FUNCTION IF EXISTS update_all_modules(jsonb); +DROP FUNCTION IF EXISTS update_string(text); +DROP FUNCTION IF EXISTS find_sql_flow_nodes_ids(jsonb); \ No newline at end of file diff --git a/backend/windmill-common/src/worker.rs b/backend/windmill-common/src/worker.rs index 29d5488be8d60..37b7d5670dca0 100644 --- a/backend/windmill-common/src/worker.rs +++ b/backend/windmill-common/src/worker.rs @@ -663,13 +663,11 @@ fn parse_file(path: &str) -> Option { .flatten() } -#[derive(Copy, Clone)] #[annotations("#")] pub struct RubyAnnotations { pub verbose: bool, } -#[derive(Copy, Clone)] #[annotations("#")] pub struct PythonAnnotations { pub no_cache: bool, @@ -682,7 +680,6 @@ pub struct PythonAnnotations { pub py313: bool, } -#[derive(Copy, Clone)] #[annotations("//")] pub struct GoAnnotations { pub go1_22_compat: bool, @@ -698,13 +695,174 @@ pub struct TypeScriptAnnotations { #[annotations("--")] pub struct SqlAnnotations { - pub return_last_result: bool, + pub return_last_result: bool, // deprecated, use result_collection instead + pub result_collection: SqlResultCollectionStrategy, } #[annotations("#")] pub struct BashAnnotations { pub docker: bool, } + +#[derive(Debug, Clone, Copy, PartialEq)] +pub enum SqlResultCollectionStrategy { + LastStatementAllRows, + LastStatementFirstRow, + LastStatementAllRowsScalar, + LastStatementFirstRowScalar, + AllStatementsAllRows, + AllStatementsFirstRow, + AllStatementsAllRowsScalar, + AllStatementsFirstRowScalar, + Legacy, +} + +impl SqlResultCollectionStrategy { + pub fn parse(s: &str) -> Self { + use SqlResultCollectionStrategy::*; + match s { + "last_statement_all_rows" => LastStatementAllRows, + "last_statement_first_row" => LastStatementFirstRow, + "last_statement_all_rows_scalar" => LastStatementAllRowsScalar, + "last_statement_first_row_scalar" => LastStatementFirstRowScalar, + "all_statements_all_rows" => AllStatementsAllRows, + "all_statements_first_row" => AllStatementsFirstRow, + "all_statements_all_rows_scalar" => AllStatementsAllRowsScalar, + "all_statements_first_row_scalar" => AllStatementsFirstRowScalar, + "legacy" => Legacy, + _ => SqlResultCollectionStrategy::default(), + } + } + + pub fn collect_last_statement_only(&self, query_count: usize) -> bool { + use SqlResultCollectionStrategy::*; + match self { + LastStatementAllRows + | LastStatementFirstRow + | LastStatementFirstRowScalar + | LastStatementAllRowsScalar => true, + Legacy => query_count == 1, + _ => false, + } + } + pub fn collect_first_row_only(&self) -> bool { + use SqlResultCollectionStrategy::*; + match self { + LastStatementFirstRow + | LastStatementFirstRowScalar + | AllStatementsFirstRow + | AllStatementsFirstRowScalar => true, + _ => false, + } + } + pub fn collect_scalar(&self) -> bool { + use SqlResultCollectionStrategy::*; + match self { + LastStatementFirstRowScalar + | AllStatementsFirstRowScalar + | LastStatementAllRowsScalar + | AllStatementsAllRowsScalar => true, + _ => false, + } + } + + // This function transforms the shape (e.g Row[][] -> Row) + // It is the responsibility of the executor to avoid fetching unnecessary statements/rows + pub fn collect( + &self, + values: Vec>>, + ) -> error::Result> { + let null = || serde_json::value::RawValue::from_string("null".to_string()).unwrap(); + + let values = if self.collect_last_statement_only(values.len()) { + values.into_iter().rev().take(1).collect() + } else { + values + }; + + let values = if self.collect_first_row_only() { + values + .into_iter() + .map(|rows| rows.into_iter().take(1).collect()) + .collect() + } else { + values + }; + + let values = if self.collect_scalar() { + values + .into_iter() + .map(|rows| { + rows.into_iter() + .map(|row| { + // Take the first value in the object + let record = + match serde_json::from_str(row.get()) { + Ok(serde_json::Value::Object(record)) => record, + Ok(_) => return Err(error::Error::ExecutionErr( + "Could not collect sql scalar value from non-object row" + .to_string(), + )), + Err(e) => { + return Err(error::Error::ExecutionErr(format!( + "Could not collect sql scalar value (failed to parse row): {}", + e + ))) + } + }; + let Some((_, value)) = record.iter().next() else { + return Err(error::Error::ExecutionErr( + "Could not collect sql scalar value from empty row".to_string(), + )); + }; + Ok(serde_json::value::RawValue::from_string( + serde_json::to_string(value).map_err(to_anyhow)?, + ) + .map_err(to_anyhow)?) + }) + .collect::>>() + }) + .collect::>>()? + } else { + values + }; + + match ( + self.collect_last_statement_only(values.len()), + self.collect_first_row_only(), + ) { + (true, true) => { + match values + .into_iter() + .last() + .map(|rows| rows.into_iter().next()) + { + Some(Some(row)) => Ok(row.clone()), + _ => Ok(null()), + } + } + (true, false) => match values.into_iter().last() { + Some(rows) => Ok(to_raw_value(&rows)), + None => Ok(null()), + }, + (false, true) => { + let values = values + .into_iter() + .map(|rows| rows.into_iter().next().unwrap_or_else(null)) + .collect::>(); + Ok(to_raw_value(&values)) + } + (false, false) => Ok(to_raw_value(&values)), + } + } +} + +impl Default for SqlResultCollectionStrategy { + fn default() -> Self { + SqlResultCollectionStrategy::LastStatementAllRows + } +} + /// length = 5 /// value = "foo" /// output = "foo " diff --git a/backend/windmill-duckdb-ffi-internal/README_DEV.md b/backend/windmill-duckdb-ffi-internal/README_DEV.md index 97dd8a5eb03cb..0c34ebc069fac 100644 --- a/backend/windmill-duckdb-ffi-internal/README_DEV.md +++ b/backend/windmill-duckdb-ffi-internal/README_DEV.md @@ -12,4 +12,4 @@ INSERT INTO t VALUES (NULL); causes `Constraint Error: NOT NULL constraint failed: t.x` normally, but here we see `Unknown exception in ExecutorTask::Execute`. This opaque errors comes directly from the C++ DuckDB library : https://github.com/duckdb/duckdb/blob/f99fed1e0b16a842573f9dad529f6c170a004f6e/src/parallel/executor_task.cpp#L58 -To solve this, we compile duckdb separately from the main backend crate and call it with FFI +To solve this, we compile duckdb separately from the main backend crate and call it with FFI. It has to be loaded dynamically, if it is loaded statically it will still share lib c++ with deno_core and cause issues. diff --git a/backend/windmill-duckdb-ffi-internal/src/lib.rs b/backend/windmill-duckdb-ffi-internal/src/lib.rs index 29112baff946d..501b5c5dee84a 100644 --- a/backend/windmill-duckdb-ffi-internal/src/lib.rs +++ b/backend/windmill-duckdb-ffi-internal/src/lib.rs @@ -1,6 +1,6 @@ use std::{ collections::HashMap, - ffi::{CStr, CString, c_char}, + ffi::{CStr, CString, c_char, c_uint}, ptr::null_mut, }; @@ -27,6 +27,14 @@ pub extern "C" fn free_cstr(string: *mut c_char) -> () { } } +#[unsafe(no_mangle)] +pub extern "C" fn get_version() -> c_uint { + // Increment when making breaking changes to the FFI interface. + // The windmill worker will check that the version matches or else refuse to call + // the FFI functions to avoid undefined behavior. + return 1; +} + #[unsafe(no_mangle)] pub extern "C" fn run_duckdb_ffi( query_block_list: *const *const c_char, @@ -36,6 +44,8 @@ pub extern "C" fn run_duckdb_ffi( base_internal_url: *const c_char, w_id: *const c_char, column_order_ptr: *mut *mut c_char, + collect_last_only: bool, + collect_first_row_only: bool, ) -> *mut c_char { let (r, column_order) = match convert_args( query_block_list, @@ -54,6 +64,8 @@ pub extern "C" fn run_duckdb_ffi( token, base_internal_url, w_id, + collect_last_only, + collect_first_row_only, ) }, ) { @@ -153,6 +165,8 @@ fn run_duckdb_internal<'a>( token: &str, base_internal_url: &str, w_id: &str, + collect_last_only: bool, + collect_first_row_only: bool, ) -> Result<(String, Option>), String> { let conn = duckdb::Connection::open_in_memory().map_err(|e| e.to_string())?; @@ -189,23 +203,23 @@ fn run_duckdb_internal<'a>( )) .map_err(|e| format!("Error setting up S3 secret: {}", e.to_string()))?; - let mut result: Option> = None; + let mut results: Vec>> = vec![]; let mut column_order = None; for (query_block_index, query_block) in query_block_list.enumerate() { - result = Some( - do_duckdb_inner( - &conn, - query_block, - &job_args, - query_block_index != query_block_list_count - 1, - &mut column_order, - ) - .map_err(|e| e.to_string())?, - ); + let result = do_duckdb_inner( + &conn, + query_block, + &job_args, + collect_last_only && query_block_index != query_block_list_count - 1, + collect_first_row_only, + &mut column_order, + ) + .map_err(|e| e.to_string())?; + results.push(result); } - let result = result.unwrap_or_else(|| RawValue::from_string("[]".to_string()).unwrap()); - Ok((result.get().to_string(), column_order)) + let results = serde_json::value::to_raw_value(&results).map_err(|e| e.to_string())?; + Ok((results.get().to_string(), column_order)) } fn do_duckdb_inner( @@ -213,8 +227,9 @@ fn do_duckdb_inner( query: &str, job_args: &HashMap, skip_collect: bool, + collect_first_row_only: bool, column_order: &mut Option>, -) -> Result, String> { +) -> Result>, String> { let mut rows_vec = vec![]; let (query, job_args) = interpolate_named_args(query, &job_args); @@ -226,7 +241,7 @@ fn do_duckdb_inner( .map_err(|e| e.to_string())?; if skip_collect { - return Ok(RawValue::from_string("[]".to_string()).unwrap()); + return Ok(vec![]); } // Statement needs to be stepped at least once or stmt.column_names() will panic let mut column_names = None; @@ -266,11 +281,14 @@ fn do_duckdb_inner( return Err(e.to_string()); } } + if collect_first_row_only { + break; + } } *column_order = column_names; - serde_json::value::to_raw_value(&rows_vec).map_err(|e| e.to_string()) + Ok(rows_vec) } // duckdb-rs does not support named parameters, diff --git a/backend/windmill-macros/src/lib.rs b/backend/windmill-macros/src/lib.rs index 25ff8cbaa2141..778ac72b0f415 100644 --- a/backend/windmill-macros/src/lib.rs +++ b/backend/windmill-macros/src/lib.rs @@ -1,93 +1,83 @@ use proc_macro::TokenStream; use quote::quote; -use syn::{parse_macro_input, Ident, ItemStruct, Lit}; +use syn::{parse_macro_input, ItemStruct, Lit, Type}; + +fn is_bool_type(ty: &Type) -> bool { + if let Type::Path(type_path) = ty { + if let Some(segment) = type_path.path.segments.last() { + return segment.ident == "bool"; + } + } + false +} #[proc_macro_attribute] pub fn annotations(attr: TokenStream, item: TokenStream) -> TokenStream { let input = parse_macro_input!(item as ItemStruct); let name = input.ident.clone(); - let fields = input - .fields - .iter() - .map(|f| f.ident.clone().unwrap()) - .collect::>(); - // Match on the literal to extract the string value - let comm_lit = match parse_macro_input!(attr as Lit) { - Lit::Str(lit_str) => lit_str.value(), // This will give "#" without quotes - _ => panic!("Expected a string literal"), - }; + // Separate fields by type + let mut bool_fields = Vec::new(); + let mut custom_fields = Vec::new(); - // Generate regex - let mut reg = format!("^{}|", &comm_lit); - { - for field in fields.iter() { - reg.push_str(&(field.to_string())); - reg.push_str("\\b"); - } + for field in input.fields.iter() { + let field_name = field.ident.clone().unwrap(); + let field_type = &field.ty; - reg.push_str(r#"|\w+"#); + if is_bool_type(field_type) { + bool_fields.push(field_name); + } else { + custom_fields.push((field_name, field_type)); + } } - // Example of generated regex: - // ^# - // |ann1\b|ann2\b|ann3\b|ann4\b - // |\w+ + let (custom_field_names, custom_field_types): (Vec<_>, Vec<_>) = + custom_fields.into_iter().unzip(); + + // Parse comment literal + let comm_lit = match parse_macro_input!(attr as Lit) { + Lit::Str(lit_str) => lit_str.value(), + _ => panic!("Expected a string literal"), + }; TokenStream::from(quote! { - #[derive(Default, Debug)] + #[derive(Default, Debug, Copy, Clone)] #input - impl std::ops::BitOrAssign for #name{ - fn bitor_assign(&mut self, rhs: Self) { - // Unfold fields - // Read more: https://docs.rs/quote/latest/quote/macro.quote.html#interpolation - #( self.#fields |= rhs.#fields; )* - } - } - impl #name { /// Autogenerated by windmill-macros - pub fn parse(inner_content: &str) -> Self{ + pub fn parse(code: &str) -> Self { let mut res = Self::default(); - lazy_static::lazy_static! { - static ref RE: regex::Regex = regex::Regex::new(#reg).unwrap(); - } - // Create lines stream - let mut lines = inner_content.lines(); - 'outer: while let Some(line) = lines.next() { - // If comment sign(s) on the right place - let mut comms = false; - // New instance - // We will apply it if in line only annotations - let mut new = Self::default(); - - 'inner: for (i, mat) in RE.find_iter(line).enumerate() { - match mat.as_str(){ - #comm_lit if i == 0 => { - comms = true; - continue 'inner; - }, + let mut lines = code.lines(); - // Will expand into something like: - // "ann1" => new.ann1 = true, - // "ann2" => new.ann2 = true, - // "ann3" => new.ann3 = true, - #( stringify!(#fields) => new.#fields = true, )* - // Non annotations - _ => continue 'outer, - }; + while let Some(line) = lines.next() { + if !line.starts_with(#comm_lit) { + break; } + let line = line[#comm_lit.len()..].trim(); + let (key, value) = line.split_once('=').unwrap_or((line, "")); - if !comms { - // We dont want to continue if line does not start with # - return res; - } - // Apply changes - res |= new; + match key { + #( + stringify!(#custom_field_names) => { + if value.is_empty() { + continue; + } + res.#custom_field_names = #custom_field_types::parse(value); + } + )* + #( + stringify!(#bool_fields) => { + res.#bool_fields = true; + } + )* + _ => { + // Unknown key=value annotation + continue; + } + } } - res } } diff --git a/backend/windmill-macros/tests/annotations.rs b/backend/windmill-macros/tests/annotations.rs index fd430dfaeb755..d906b728879ca 100644 --- a/backend/windmill-macros/tests/annotations.rs +++ b/backend/windmill-macros/tests/annotations.rs @@ -24,7 +24,7 @@ mod annotations_tests { } #[annotations("#")] - #[derive(Eq, PartialEq, Copy, Clone)] + #[derive(Eq, PartialEq)] pub struct Annotations { pub ann1: bool, pub ann2: bool, @@ -34,7 +34,7 @@ mod annotations_tests { } #[annotations("//")] - #[derive(Eq, PartialEq, Copy, Clone)] + #[derive(Eq, PartialEq)] pub struct SlashedAnnotations { pub ann1: bool, pub ann2: bool, @@ -43,7 +43,7 @@ mod annotations_tests { } #[annotations("--")] - #[derive(Eq, PartialEq, Copy, Clone)] + #[derive(Eq, PartialEq)] pub struct MinusedAnnotations { pub ann1: bool, pub ann2: bool, diff --git a/backend/windmill-worker/src/bigquery_executor.rs b/backend/windmill-worker/src/bigquery_executor.rs index cf937368504f3..c58b73f55165d 100644 --- a/backend/windmill-worker/src/bigquery_executor.rs +++ b/backend/windmill-worker/src/bigquery_executor.rs @@ -7,7 +7,7 @@ use serde_json::{json, value::RawValue, Value}; use windmill_common::client::AuthedClient; use windmill_common::error::to_anyhow; use windmill_common::s3_helpers::convert_json_line_stream; -use windmill_common::worker::Connection; +use windmill_common::worker::{Connection, SqlResultCollectionStrategy}; use windmill_common::{error::Error, worker::to_raw_value}; use windmill_parser_sql::{ parse_bigquery_sig, parse_db_resource, parse_s3_mode, parse_sql_blocks, @@ -86,9 +86,11 @@ fn do_bigquery_inner<'a>( timeout_ms: u64, column_order: Option<&'a mut Option>>, skip_collect: bool, + first_row_only: bool, http_client: &'a Client, s3: Option, -) -> windmill_common::error::Result>>> { +) -> windmill_common::error::Result>>>> +{ let param_names = parse_sql_statement_named_params(query, '@'); let statement_values = all_statement_values @@ -113,7 +115,7 @@ fn do_bigquery_inner<'a>( .json(&json!({ "query": query, "useLegacySql": false, - "maxResults": 10000, + "maxResults": if first_row_only { 1 } else { 10000 }, "timeoutMs": timeout_ms, "queryParameters": statement_values, })) @@ -126,7 +128,7 @@ fn do_bigquery_inner<'a>( match response.error_for_status_ref() { Ok(_) => { if skip_collect { - return Ok(to_raw_value(&Value::Array(vec![]))); + return Ok(vec![]); } else { let result = response.json::().await.map_err(|e| { Error::ExecutionErr(format!( @@ -205,10 +207,10 @@ fn do_bigquery_inner<'a>( convert_json_line_stream(rows_stream.boxed(), s3.format).await?; s3.upload(stream.boxed()).await?; - return Ok(to_raw_value(&s3.to_return_s3_obj())); + return Ok(vec![to_raw_value(&s3.to_return_s3_obj())]); } - Ok(to_raw_value(&rows)) + Ok(rows.iter().map(to_raw_value).collect::>()) } } Err(e) => match response.json::().await { @@ -341,6 +343,11 @@ pub async fn do_bigquery( }; let annotations = windmill_common::worker::SqlAnnotations::parse(query); + let collection_strategy = if annotations.return_last_result { + SqlResultCollectionStrategy::LastStatementAllRows + } else { + annotations.result_collection + }; let service_account = CustomServiceAccount::from_json(&database) .map_err(|e| Error::ExecutionErr(e.to_string()))?; @@ -425,52 +432,34 @@ pub async fn do_bigquery( statement_values.insert(arg_n, bigquery_v); } - let result_f = if queries.len() > 1 { - let futures = queries - .iter() - .enumerate() - .map(|(i, x)| { - do_bigquery_inner( - x, - &statement_values, - &project_id, - token.as_str(), - timeout_ms, - None, - annotations.return_last_result && i < queries.len() - 1, - &http_client, - s3.clone(), - ) - }) - .collect::>>()?; - - let f = async { - let mut res: Vec> = vec![]; - - for fut in futures { - let r = fut.await?; - res.push(r); - } - if annotations.return_last_result && res.len() > 0 { - Ok(res.pop().unwrap()) - } else { - Ok(to_raw_value(&res)) - } - }; + let result_f = async move { + let mut results = vec![]; + for (i, q) in queries.iter().enumerate() { + let result = do_bigquery_inner( + q, + &statement_values, + &project_id, + token.as_str(), + timeout_ms, + if i == queries.len() - 1 + && collection_strategy.collect_last_statement_only(queries.len()) + && !collection_strategy.collect_scalar() + { + Some(column_order) + } else { + None + }, + collection_strategy.collect_last_statement_only(queries.len()) + && i < queries.len() - 1, + collection_strategy.collect_first_row_only(), + &http_client, + s3.clone(), + )? + .await?; + results.push(result); + } - f.boxed() - } else { - do_bigquery_inner( - query, - &statement_values, - &project_id, - token.as_str(), - timeout_ms, - Some(column_order), - false, - &http_client, - s3, - )? + collection_strategy.collect(results) }; let r = run_future_with_polling_update_job_poller( diff --git a/backend/windmill-worker/src/duckdb_executor.rs b/backend/windmill-worker/src/duckdb_executor.rs index e3408ffca724d..d0e65db6992f4 100644 --- a/backend/windmill-worker/src/duckdb_executor.rs +++ b/backend/windmill-worker/src/duckdb_executor.rs @@ -1,6 +1,6 @@ use std::cell::RefCell; use std::env; -use std::ffi::{c_char, CStr, CString}; +use std::ffi::{c_char, c_uint, CStr, CString}; use std::ptr::NonNull; use std::sync::{Arc, Mutex}; @@ -12,7 +12,7 @@ use uuid::Uuid; use windmill_common::error::{to_anyhow, Error, Result}; use windmill_common::s3_helpers::{S3Object, S3_PROXY_LAST_ERRORS_CACHE}; use windmill_common::utils::sanitize_string_from_password; -use windmill_common::worker::Connection; +use windmill_common::worker::{Connection, SqlResultCollectionStrategy}; use windmill_common::workspaces::{get_ducklake_from_db_unchecked, DucklakeCatalogResourceType}; use windmill_parser_sql::{parse_duckdb_sig, parse_sql_blocks}; use windmill_queue::{CanceledBy, MiniPulledJob}; @@ -39,6 +39,20 @@ pub async fn do_duckdb( occupancy_metrics: &mut OccupancyMetrics, parent_runnable_path: Option, ) -> Result> { + let annotations = windmill_common::worker::SqlAnnotations::parse(query); + let collection_strategy = + if annotations.result_collection == SqlResultCollectionStrategy::Legacy { + // Before result_collection was introduced, duckdb ignored all statements results except the last one + SqlResultCollectionStrategy::LastStatementAllRows + } else { + annotations.result_collection + }; + if annotations.return_last_result { + return Err(Error::ExecutionErr( + "return_last_result annotation is deprecated, use result_collection=last_statement_all_rows instead".to_string(), + )); + } + let token = client.token.clone(); let hidden_passwords = Arc::new(Mutex::new(Vec::::new())); @@ -141,11 +155,13 @@ pub async fn do_duckdb( &token, &base_internal_url, &w_id, + collection_strategy, ) }) .await .map_err(|e| Error::from(to_anyhow(e))) .and_then(|r| r); + let (result, column_order) = match result { Ok(r) => r, Err(e) => { @@ -210,6 +226,8 @@ struct DuckDbFfiLib { base_internal_url: *const c_char, w_id: *const c_char, column_order_ptr: *mut *mut c_char, + collect_last_only: bool, + collect_first_row_only: bool, ) -> *mut c_char, >, free_cstr: Symbol<'static, unsafe extern "C" fn(string: *mut c_char) -> ()>, @@ -249,7 +267,28 @@ impl DuckDbFfiLib { )) })? }; + let lib = Box::leak(Box::new(lib)); + + // Version mismatch should only be possible on Windows agent workers + // We check for it because FFI interface mismatch will cause undefined behavior / crashes + unsafe { + let expected_version: c_uint = 1; + let get_version: Symbol<'static, unsafe extern "C" fn() -> c_uint> = + lib.get(b"get_version") + .map_err(|e| return Error::ExecutionErr(format!("Could not find get_version in the duckdb ffi library. If you are not using docker, consider manually upgrading windmill_duckdb_ffi_lib. {}", e.to_string())))?; + let actual_version = get_version(); + if actual_version < expected_version { + return Err(Error::InternalErr( + format!("Incompatible duckdb ffi library version. Expected: {expected_version}, actual: {actual_version}. Please update to the latest windmill_duckdb_ffi_lib."), + )); + } else if actual_version > expected_version { + return Err(Error::InternalErr( + format!("Incompatible duckdb ffi library version. Expected: {expected_version}, actual: {actual_version}. Please upgrade your worker to the latest windmill version."), + )); + } + } + Ok(DuckDbFfiLib { run_duckdb_ffi: unsafe { lib.get(b"run_duckdb_ffi").map_err(to_anyhow)? }, free_cstr: unsafe { lib.get(b"free_cstr").map_err(to_anyhow)? }, @@ -265,6 +304,7 @@ fn run_duckdb_ffi_safe<'a>( token: &str, base_internal_url: &str, w_id: &str, + collection_strategy: SqlResultCollectionStrategy, ) -> Result<(Box, Option>)> { let query_block_list = query_block_list .map(|s| { @@ -296,13 +336,18 @@ fn run_duckdb_ffi_safe<'a>( base_internal_url.as_ptr(), w_id.as_ptr(), &mut column_order, + collection_strategy.collect_last_statement_only(query_block_list_count), + collection_strategy.collect_first_row_only(), ); let str = CStr::from_ptr(ptr).to_string_lossy().to_string(); free_cstr(ptr); str }; - let column_order = if column_order.is_null() { + let column_order = if column_order.is_null() + || !collection_strategy.collect_last_statement_only(query_block_list_count) + || collection_strategy.collect_scalar() + { None } else { let str = unsafe { CStr::from_ptr(column_order).to_string_lossy().to_string() }; @@ -313,7 +358,14 @@ fn run_duckdb_ffi_safe<'a>( if result_str.starts_with("ERROR") { Err(Error::ExecutionErr(result_str[6..].to_string())) } else { - let result = serde_json::value::RawValue::from_string(result_str).map_err(to_anyhow)?; + let result = if collection_strategy == SqlResultCollectionStrategy::AllStatementsAllRows { + // Avoid parsing JSON + serde_json::value::RawValue::from_string(result_str).map_err(to_anyhow)? + } else { + let result = + serde_json::from_str::>>>(&result_str).map_err(to_anyhow)?; + collection_strategy.collect(result)? + }; Ok((result, column_order)) } } diff --git a/backend/windmill-worker/src/mssql_executor.rs b/backend/windmill-worker/src/mssql_executor.rs index c64e644be3514..204df30ed60c8 100644 --- a/backend/windmill-worker/src/mssql_executor.rs +++ b/backend/windmill-worker/src/mssql_executor.rs @@ -12,6 +12,7 @@ use tokio::net::TcpStream; use tokio_util::compat::TokioAsyncWriteCompatExt; use uuid::Uuid; use windmill_common::s3_helpers::convert_json_line_stream; +use windmill_common::worker::SqlResultCollectionStrategy; use windmill_common::{ error::{self, to_anyhow, Error}, utils::empty_as_none, @@ -95,6 +96,11 @@ pub async fn do_mssql( }; let annotations = windmill_common::worker::SqlAnnotations::parse(query); + let collection_strategy = if annotations.return_last_result { + SqlResultCollectionStrategy::LastStatementAllRows + } else { + annotations.result_collection + }; let mut config = Config::new(); @@ -237,21 +243,20 @@ pub async fn do_mssql( let len = results.len(); let mut json_results = vec![]; for (i, statement_result) in results.into_iter().enumerate() { - if annotations.return_last_result && i < len - 1 { + if collection_strategy.collect_last_statement_only(len) && i < len - 1 { continue; } let mut json_rows = vec![]; for row in statement_result { - let row = row_to_json(row)?; + let row = to_raw_value(&row_to_json(row)?); json_rows.push(row); + if collection_strategy.collect_first_row_only() { + break; + } } json_results.push(json_rows); } - if annotations.return_last_result && json_results.len() > 0 { - Ok(to_raw_value(&json_results.pop().unwrap())) - } else { - Ok(to_raw_value(&json_results)) - } + collection_strategy.collect(json_results) } }; diff --git a/backend/windmill-worker/src/mysql_executor.rs b/backend/windmill-worker/src/mysql_executor.rs index 55cbe35893339..2163ac98903a3 100644 --- a/backend/windmill-worker/src/mysql_executor.rs +++ b/backend/windmill-worker/src/mysql_executor.rs @@ -16,7 +16,7 @@ use windmill_common::{ client::AuthedClient, error::{to_anyhow, Error}, s3_helpers::convert_json_line_stream, - worker::{to_raw_value, Connection}, + worker::{to_raw_value, Connection, SqlResultCollectionStrategy}, }; use windmill_parser_sql::{ parse_db_resource, parse_mysql_sig, parse_s3_mode, parse_sql_blocks, @@ -50,8 +50,10 @@ fn do_mysql_inner<'a>( conn: Arc>, column_order: Option<&'a mut Option>>, skip_collect: bool, + first_row_only: bool, s3: Option, -) -> windmill_common::error::Result>>> { +) -> windmill_common::error::Result>>>> +{ let param_names = parse_sql_statement_named_params(query, ':') .into_iter() .map(|x| x.into_bytes()) @@ -76,7 +78,7 @@ fn do_mysql_inner<'a>( .await .map_err(to_anyhow)?; - Ok(to_raw_value(&Value::Array(vec![]))) + Ok(vec![]) } else if let Some(ref s3) = s3 { let query = query.to_string(); let rows_stream = async_stream::stream! { @@ -108,14 +110,23 @@ fn do_mysql_inner<'a>( let stream = convert_json_line_stream(rows_stream.boxed(), s3.format).await?; s3.upload(stream.boxed()).await?; - Ok(to_raw_value(&s3.to_return_s3_obj())) + Ok(vec![to_raw_value(&s3.to_return_s3_obj())]) } else { - let rows: Vec = conn - .lock() - .await - .exec(query, statement_values) - .await - .map_err(to_anyhow)?; + let rows: Vec = if first_row_only { + conn.lock() + .await + .exec_first(query, statement_values) + .await + .map_err(to_anyhow)? + .into_iter() + .collect() + } else { + conn.lock() + .await + .exec(query, statement_values) + .await + .map_err(to_anyhow)? + }; if let Some(column_order) = column_order { *column_order = Some( @@ -130,12 +141,10 @@ fn do_mysql_inner<'a>( ); } - Ok(to_raw_value( - &rows - .into_iter() - .map(|x| convert_row_to_value(x)) - .collect::>(), - )) + Ok(rows + .into_iter() + .map(|x| to_raw_value(&convert_row_to_value(x))) + .collect::>()) } }; @@ -180,6 +189,11 @@ pub async fn do_mysql( }; let annotations = windmill_common::worker::SqlAnnotations::parse(query); + let collection_strategy = if annotations.return_last_result { + SqlResultCollectionStrategy::LastStatementAllRows + } else { + annotations.result_collection + }; let opts = OptsBuilder::default() .db_name(Some(database.database)) @@ -285,45 +299,32 @@ pub async fn do_mysql( let queries = parse_sql_blocks(query); - let result_f = if queries.len() > 1 { - let futures = queries - .iter() - .enumerate() - .map(|(i, x)| { - do_mysql_inner( - x, - &statement_values, - conn_a.clone(), - None, - annotations.return_last_result && i < queries.len() - 1, - s3.clone(), - ) - }) - .collect::>>()?; - - let f = async { - let mut res: Vec> = vec![]; - for fut in futures { - let r = fut.await?; - res.push(r); - } - if annotations.return_last_result && res.len() > 0 { - Ok(res.pop().unwrap()) - } else { - Ok(to_raw_value(&res)) - } - }; + let conn_a_ref = &conn_a; + let result_f = async move { + let mut results = vec![]; + for (i, query) in queries.iter().enumerate() { + let result = do_mysql_inner( + query, + &statement_values, + conn_a_ref.clone(), + if i == queries.len() - 1 + && collection_strategy.collect_last_statement_only(queries.len()) + && !collection_strategy.collect_scalar() + { + Some(column_order) + } else { + None + }, + collection_strategy.collect_last_statement_only(queries.len()) + && i < queries.len() - 1, + collection_strategy.collect_first_row_only(), + s3.clone(), + )? + .await?; + results.push(result); + } - f.boxed() - } else { - do_mysql_inner( - query, - &statement_values, - conn_a.clone(), - Some(column_order), - false, - s3, - )? + collection_strategy.collect(results) }; let result = run_future_with_polling_update_job_poller( @@ -344,11 +345,10 @@ pub async fn do_mysql( pool.disconnect().await.map_err(to_anyhow)?; - let raw_result = windmill_common::worker::to_raw_value(&json!(result)); - *mem_peak = (raw_result.get().len() / 1000) as i32; + *mem_peak = (result.get().len() / 1000) as i32; // And then check that we got back the same string we sent over. - return Ok(raw_result); + return Ok(result); } // 2023-12-01T16:18:00.000Z diff --git a/backend/windmill-worker/src/oracledb_executor.rs b/backend/windmill-worker/src/oracledb_executor.rs index a1c20d1282a31..d83edf01044b1 100644 --- a/backend/windmill-worker/src/oracledb_executor.rs +++ b/backend/windmill-worker/src/oracledb_executor.rs @@ -11,7 +11,7 @@ use serde_json::{json, value::RawValue, Value}; use windmill_common::{ error::{to_anyhow, Error}, s3_helpers::convert_json_line_stream, - worker::{to_raw_value, Connection}, + worker::{to_raw_value, Connection, SqlResultCollectionStrategy}, }; use windmill_queue::MiniPulledJob; @@ -48,8 +48,10 @@ pub fn do_oracledb_inner<'a>( conn: Arc>, column_order: Option<&'a mut Option>>, skip_collect: bool, + first_row_only: bool, s3: Option, -) -> windmill_common::error::Result>>> { +) -> windmill_common::error::Result>>>> +{ let qw = query.trim_end_matches(';').to_string(); let result_f = async move { @@ -85,7 +87,7 @@ pub fn do_oracledb_inner<'a>( .map_err(to_anyhow)? .map_err(to_anyhow)?; - Ok(to_raw_value(&Value::Array(vec![]))) + Ok(vec![]) } else { // We use an mpsc because we need an async stream for s3 mode. However since everything is sync // in rust-oracle, I assumed that calling ResultSet::next() is blocking when it has to refetch. @@ -133,6 +135,9 @@ pub fn do_oracledb_inner<'a>( break; } } + if first_row_only { + break; + } } } _ => { @@ -159,17 +164,16 @@ pub fn do_oracledb_inner<'a>( if let Some(s3) = s3 { let stream = convert_json_line_stream(rows_stream.boxed(), s3.format).await?; s3.upload(stream.boxed()).await?; - return Ok(to_raw_value(&s3.to_return_s3_obj())); + return Ok(vec![to_raw_value(&s3.to_return_s3_obj())]); } else { let rows: Vec<_> = rows_stream.collect().await; - Ok(to_raw_value( - &rows - .into_iter() - .collect::, _>>() - .map_err(to_anyhow)? - .into_iter() - .collect::>(), - )) + Ok(rows + .into_iter() + .collect::, _>>() + .map_err(to_anyhow)? + .iter() + .map(to_raw_value) + .collect::>()) } } }; @@ -377,6 +381,11 @@ pub async fn do_oracledb( }; let annotations = windmill_common::worker::SqlAnnotations::parse(query); + let collection_strategy = if annotations.return_last_result { + SqlResultCollectionStrategy::LastStatementAllRows + } else { + annotations.result_collection + }; let sig = parse_oracledb_sig(query) .map_err(|x| Error::ExecutionErr(x.to_string()))? @@ -388,7 +397,7 @@ pub async fn do_oracledb( let (query, args_to_skip) = sanitize_and_interpolate_unsafe_sql_args(query, &sig, &job_args, &reserved_variables)?; - let (statement_values, errors) = get_statement_values(sig.clone(), &job_args, &args_to_skip); + let (_, errors) = get_statement_values(sig.clone(), &job_args, &args_to_skip); if !errors.is_empty() { return Err(Error::ExecutionErr(errors.join("\n"))); @@ -412,40 +421,33 @@ pub async fn do_oracledb( let queries = parse_sql_blocks(&query); - let result_f = if queries.len() > 1 { - let f = async { - let mut res: Vec> = vec![]; - for (i, q) in queries.iter().enumerate() { - let (vals, _) = get_statement_values(sig.clone(), &job_args, &args_to_skip); - let r = do_oracledb_inner( - q, - vals, - conn_a.clone(), - None, - annotations.return_last_result && i < queries.len() - 1, - s3.clone(), - )? - .await?; - res.push(r); - } - - if annotations.return_last_result && res.len() > 0 { - Ok(res.pop().unwrap()) - } else { - Ok(to_raw_value(&res)) - } - }; + let result_f = async move { + let mut results = vec![]; + for (i, q) in queries.iter().enumerate() { + let (vals, _) = get_statement_values(sig.clone(), &job_args, &args_to_skip); + + let result = do_oracledb_inner( + q, + vals, + conn_a.clone(), + if i == queries.len() - 1 + && collection_strategy.collect_last_statement_only(queries.len()) + && !collection_strategy.collect_scalar() + { + Some(column_order) + } else { + None + }, + collection_strategy.collect_last_statement_only(queries.len()) + && i < queries.len() - 1, + collection_strategy.collect_first_row_only(), + s3.clone(), + )? + .await?; + results.push(result); + } - f.boxed() - } else { - do_oracledb_inner( - &query, - statement_values, - conn_a, - Some(column_order), - false, - s3, - )? + collection_strategy.collect(results) }; let result = run_future_with_polling_update_job_poller( diff --git a/backend/windmill-worker/src/pg_executor.rs b/backend/windmill-worker/src/pg_executor.rs index 9bb7c70cda1a7..e240b5d87b01b 100644 --- a/backend/windmill-worker/src/pg_executor.rs +++ b/backend/windmill-worker/src/pg_executor.rs @@ -28,7 +28,9 @@ use uuid::Uuid; use windmill_common::error::to_anyhow; use windmill_common::error::{self, Error}; use windmill_common::s3_helpers::convert_json_line_stream; -use windmill_common::worker::{to_raw_value, Connection, CLOUD_HOSTED}; +use windmill_common::worker::{ + to_raw_value, Connection, SqlResultCollectionStrategy, CLOUD_HOSTED, +}; use windmill_parser::{Arg, Typ}; use windmill_parser_sql::{ parse_db_resource, parse_pg_statement_arg_indices, parse_pgsql_sig, parse_s3_mode, @@ -73,8 +75,9 @@ fn do_postgresql_inner<'a>( column_order: Option<&'a mut Option>>, siz: &'a AtomicUsize, skip_collect: bool, + first_row_only: bool, s3: Option, -) -> error::Result>>> { +) -> error::Result>>>> { let mut query_params = vec![]; let arg_indices = parse_pg_statement_arg_indices(&query); @@ -100,7 +103,7 @@ fn do_postgresql_inner<'a>( let result_f = async move { // Now we can execute a simple statement that just returns its parameter. - let mut res: Vec = vec![]; + let mut res: Vec> = vec![]; let query_params = query_params .iter() @@ -125,13 +128,19 @@ fn do_postgresql_inner<'a>( let stream = convert_json_line_stream(rows_stream.boxed(), s3.format).await?; s3.upload(stream.boxed()).await?; - return Ok(to_raw_value(&s3.to_return_s3_obj())); + return Ok(vec![to_raw_value(&s3.to_return_s3_obj())]); } else { let rows = client .query_raw(&query, query_params) .await .map_err(to_anyhow)?; + let rows = if first_row_only { + rows.take(1).boxed() + } else { + rows.boxed() + }; + let rows = rows.try_collect::>().await.map_err(to_anyhow)?; if let Some(column_order) = column_order { @@ -164,14 +173,14 @@ fn do_postgresql_inner<'a>( } } if let Ok(v) = r { - res.push(v); + res.push(to_raw_value(&v)); } else { return Err(to_anyhow(r.err().unwrap()).into()); } } } - Ok(to_raw_value(&res)) + Ok(res) }; Ok(result_f.boxed()) @@ -216,6 +225,11 @@ pub async fn do_postgresql( }; let annotations = windmill_common::worker::SqlAnnotations::parse(query); + let collection_strategy = if annotations.return_last_result { + SqlResultCollectionStrategy::LastStatementAllRows + } else { + annotations.result_collection + }; let sslmode = match database.sslmode.as_deref() { Some("allow") => "prefer".to_string(), @@ -336,47 +350,33 @@ pub async fn do_postgresql( .collect::>(); let size = AtomicUsize::new(0); - let result_f = if queries.len() > 1 { - let futures = queries - .iter() - .enumerate() - .map(|(i, x)| { - do_postgresql_inner( - x.to_string(), - ¶m_idx_to_arg_and_value, - client, - None, - &size, - annotations.return_last_result && i < queries.len() - 1, - s3.clone(), - ) - }) - .collect::>>()?; - - let f = async { - let mut res: Vec> = vec![]; - for fut in futures { - let r = fut.await?; - res.push(r); - } - if annotations.return_last_result && res.len() > 0 { - Ok(res.pop().unwrap()) - } else { - Ok(to_raw_value(&res)) - } - }; + let size_ref = &size; + let result_f = async move { + let mut results = vec![]; + for (i, query) in queries.iter().enumerate() { + let result = do_postgresql_inner( + query.to_string(), + ¶m_idx_to_arg_and_value, + client, + if i == queries.len() - 1 + && collection_strategy.collect_last_statement_only(queries.len()) + && !collection_strategy.collect_scalar() + { + Some(column_order) + } else { + None + }, + size_ref, + collection_strategy.collect_last_statement_only(queries.len()) + && i < queries.len() - 1, + collection_strategy.collect_first_row_only(), + s3.clone(), + )? + .await?; + results.push(result); + } - f.boxed() - } else { - do_postgresql_inner( - query.to_string(), - ¶m_idx_to_arg_and_value, - client, - Some(column_order), - &size, - false, - s3, - )? + collection_strategy.collect(results) }; let result = run_future_with_polling_update_job_poller( diff --git a/backend/windmill-worker/src/snowflake_executor.rs b/backend/windmill-worker/src/snowflake_executor.rs index e657ef7bb00d9..38eff99ae1beb 100644 --- a/backend/windmill-worker/src/snowflake_executor.rs +++ b/backend/windmill-worker/src/snowflake_executor.rs @@ -10,7 +10,7 @@ use sha2::{Digest, Sha256}; use std::collections::HashMap; use windmill_common::error::to_anyhow; use windmill_common::s3_helpers::convert_json_line_stream; -use windmill_common::worker::Connection; +use windmill_common::worker::{Connection, SqlResultCollectionStrategy}; use windmill_common::{error::Error, worker::to_raw_value}; use windmill_parser_sql::{ @@ -130,10 +130,12 @@ fn do_snowflake_inner<'a>( token_is_keypair: bool, column_order: Option<&'a mut Option>>, skip_collect: bool, + first_row_only: bool, http_client: &'a Client, s3: Option, reserved_variables: &HashMap, -) -> windmill_common::error::Result>>> { +) -> windmill_common::error::Result>>>> +{ let sig = parse_snowflake_sig(&query) .map_err(|x| Error::ExecutionErr(x.to_string()))? .args; @@ -179,7 +181,7 @@ fn do_snowflake_inner<'a>( if skip_collect { handle_snowflake_result(result).await?; - Ok(to_raw_value(&Value::Array(vec![]))) + Ok(vec![]) } else { let response = result .parse_snowflake_response::() @@ -254,19 +256,22 @@ fn do_snowflake_inner<'a>( row_map }); + let rows_stream = rows_stream.take(if first_row_only { 1 } else { usize::MAX }); + if let Some(s3) = s3 { let rows_stream = rows_stream.map(|r| serde_json::value::to_value(&r?).map_err(to_anyhow)); let stream = convert_json_line_stream(rows_stream.boxed(), s3.format).await?; s3.upload(stream.boxed()).await?; - Ok(to_raw_value(&s3.to_return_s3_obj())) + Ok(vec![to_raw_value(&s3.to_return_s3_obj())]) } else { let rows = rows_stream .collect::>() .await .into_iter() + .map(|x| x.map(|v| to_raw_value(&v))) .collect::, _>>()?; - Ok(to_raw_value(&rows)) + Ok(rows) } } }; @@ -312,6 +317,11 @@ pub async fn do_snowflake( }; let annotations = windmill_common::worker::SqlAnnotations::parse(query); + let collection_strategy = if annotations.return_last_result { + SqlResultCollectionStrategy::LastStatementAllRows + } else { + annotations.result_collection + }; // Check if the token is present in db_arg and use it if available let (token, token_is_keypair) = if let Some(token) = db_arg @@ -409,56 +419,38 @@ pub async fn do_snowflake( let reserved_variables = get_reserved_variables(job, &client.token, conn, parent_runnable_path).await?; - let result_f = if queries.len() > 1 { - let futures = queries - .iter() - .enumerate() - .map(|(i, x)| { - do_snowflake_inner( - x, - &snowflake_args, - body.clone(), - &database.account_identifier, - &token, - token_is_keypair, - None, - annotations.return_last_result && i < queries.len() - 1, - &http_client, - s3.clone(), - &reserved_variables, - ) - }) - .collect::>>()?; - - let f = async { - let mut res: Vec> = vec![]; - for fut in futures { - let r = fut.await?; - res.push(r); - } - if annotations.return_last_result && res.len() > 0 { - Ok(res.pop().unwrap()) - } else { - Ok(to_raw_value(&res)) - } - }; + let result_f = async move { + let mut results = vec![]; + for (i, q) in queries.iter().enumerate() { + let result = do_snowflake_inner( + q, + &snowflake_args, + body.clone(), + &database.account_identifier, + &token, + token_is_keypair, + if i == queries.len() - 1 + && collection_strategy.collect_last_statement_only(queries.len()) + && !collection_strategy.collect_scalar() + { + Some(column_order) + } else { + None + }, + collection_strategy.collect_last_statement_only(queries.len()) + && i < queries.len() - 1, + collection_strategy.collect_first_row_only(), + &http_client, + s3.clone(), + &reserved_variables, + )? + .await?; + results.push(result); + } - f.boxed() - } else { - do_snowflake_inner( - query, - &snowflake_args, - body.clone(), - &database.account_identifier, - &token, - token_is_keypair, - Some(column_order), - false, - &http_client, - s3.clone(), - &reserved_variables, - )? + collection_strategy.collect(results) }; + let r = run_future_with_polling_update_job_poller( job.id, job.timeout, diff --git a/frontend/src/lib/components/Editor.svelte b/frontend/src/lib/components/Editor.svelte index e1b5ded284a4e..01c0925985eff 100644 --- a/frontend/src/lib/components/Editor.svelte +++ b/frontend/src/lib/components/Editor.svelte @@ -432,11 +432,46 @@ let command: IDisposable | undefined = undefined let sqlTypeCompletor: IDisposable | undefined = $state(undefined) + let resultCollectionCompletor: IDisposable | undefined = $state(undefined) function addSqlTypeCompletions() { - if (sqlTypeCompletor) { - sqlTypeCompletor.dispose() - } + sqlTypeCompletor?.dispose() + resultCollectionCompletor?.dispose() + + resultCollectionCompletor = languages.registerCompletionItemProvider('sql', { + triggerCharacters: ['='], + provideCompletionItems: function (model, position) { + const lineContent = model.getLineContent(position.lineNumber) + const match = lineContent.match(/^--\s*result_collection=/) + if (!match) { + return { suggestions: [] } + } + const word = model.getWordUntilPosition(position) + const range = { + startLineNumber: position.lineNumber, + endLineNumber: position.lineNumber, + startColumn: word.startColumn, + endColumn: word.endColumn + } + const suggestions = [ + 'last_statement_all_rows', + 'last_statement_first_row', + 'last_statement_all_rows_scalar', + 'last_statement_first_row_scalar', + 'all_statements_all_rows', + 'all_statements_first_row', + 'all_statements_all_rows_scalar', + 'all_statements_first_row_scalar' + ].map((label) => ({ + label: label, + kind: languages.CompletionItemKind.Function, + insertText: label, + range, + sortText: 'a' + })) + return { suggestions } + } + }) sqlTypeCompletor = languages.registerCompletionItemProvider('sql', { triggerCharacters: scriptLang === 'postgresql' ? [':'] : ['('], provideCompletionItems: function (model, position) { @@ -1581,6 +1616,7 @@ sqlSchemaCompletor && sqlSchemaCompletor.dispose() autocompletor && autocompletor.dispose() sqlTypeCompletor && sqlTypeCompletor.dispose() + resultCollectionCompletor && resultCollectionCompletor.dispose() preprocessorCompletor && preprocessorCompletor.dispose() timeoutModel && clearTimeout(timeoutModel) loadTimeout && clearTimeout(loadTimeout) @@ -1649,7 +1685,7 @@ $effect(() => { initialized && lang === 'sql' && scriptLang ? untrack(() => addSqlTypeCompletions()) - : sqlTypeCompletor?.dispose() + : (sqlTypeCompletor?.dispose(), resultCollectionCompletor?.dispose()) }) $effect(() => { diff --git a/frontend/src/lib/script_helpers.ts b/frontend/src/lib/script_helpers.ts index 75003d837d816..6011813d6bb3c 100644 --- a/frontend/src/lib/script_helpers.ts +++ b/frontend/src/lib/script_helpers.ts @@ -248,9 +248,9 @@ export async function main(message: string, name: string, step_id: string) { } ` -const POSTGRES_INIT_CODE = `-- to pin the database use '-- database f/your/path' +const POSTGRES_INIT_CODE = `-- result_collection=last_statement_all_rows +-- to pin the database use '-- database f/your/path' -- to stream a large query result to your workspace storage use '-- s3' --- to only return the result of the last query use '--return_last_result' -- $1 name1 = default arg -- $2 name2 -- $3 name3 @@ -259,7 +259,8 @@ INSERT INTO demo VALUES (\$1::TEXT, \$2::INT, \$3::TEXT[]) RETURNING *; UPDATE demo SET col2 = \$4::INT WHERE col2 = \$2::INT; ` -const MYSQL_INIT_CODE = `-- to pin the database use '-- database f/your/path' +const MYSQL_INIT_CODE = `-- result_collection=last_statement_all_rows +-- to pin the database use '-- database f/your/path' -- to stream a large query result to your workspace storage use '-- s3' -- :name1 (text) = default arg -- :name2 (int) @@ -268,7 +269,8 @@ INSERT INTO demo VALUES (:name1, :name2); UPDATE demo SET col2 = :name3 WHERE col2 = :name2; ` -const BIGQUERY_INIT_CODE = `-- to pin the database use '-- database f/your/path' +const BIGQUERY_INIT_CODE = `-- result_collection=last_statement_all_rows +-- to pin the database use '-- database f/your/path' -- to stream a large query result to your workspace storage use '-- s3' -- @name1 (string) = default arg -- @name2 (integer) @@ -278,7 +280,8 @@ INSERT INTO \`demodb.demo\` VALUES (@name1, @name2, @name3); UPDATE \`demodb.demo\` SET col2 = @name4 WHERE col2 = @name2; ` -const ORACLEDB_INIT_CODE = `-- to pin the database use '-- database f/your/path' +const ORACLEDB_INIT_CODE = `-- result_collection=last_statement_all_rows +-- to pin the database use '-- database f/your/path' -- to stream a large query result to your workspace storage use '-- s3' -- :name1 (text) = default arg -- :name2 (int) @@ -287,7 +290,8 @@ INSERT INTO demo VALUES (:name1, :name2); UPDATE demo SET col2 = :name3 WHERE col2 = :name2; ` -const SNOWFLAKE_INIT_CODE = `-- to pin the database use '-- database f/your/path' +const SNOWFLAKE_INIT_CODE = `-- result_collection=last_statement_all_rows +-- to pin the database use '-- database f/your/path' -- to stream a large query result to your workspace storage use '-- s3' -- ? name1 (varchar) = default arg -- ? name2 (int) @@ -297,7 +301,7 @@ INSERT INTO demo VALUES (?, ?); UPDATE demo SET col2 = ? WHERE col2 = ?; ` -const MSSQL_INIT_CODE = `-- return_last_result +const MSSQL_INIT_CODE = `-- result_collection=last_statement_all_rows -- to pin the database use '-- database f/your/path' -- to stream a large query result to your workspace storage use '-- s3' -- @P1 name1 (varchar) = default arg @@ -307,7 +311,8 @@ INSERT INTO demo VALUES (@P1, @P2); UPDATE demo SET col2 = @P3 WHERE col2 = @P2; ` -const DUCKDB_INIT_CODE = `-- $name (text) = Ben +const DUCKDB_INIT_CODE = `-- result_collection=last_statement_all_rows +-- $name (text) = Ben -- $age (text) = 20 -- -- $friends_csv (s3object)