diff --git a/snowflake-api/README.md b/snowflake-api/README.md index 7eb76f2..6a906e4 100644 --- a/snowflake-api/README.md +++ b/snowflake-api/README.md @@ -8,7 +8,7 @@ Since it does a lot of I/O the library is async-only, and currently has hard dep - [x] Single statements [example](./examples/run_sql.rs) - [ ] Multiple statements -- [ ] Async requests (is it needed if whole library is async?) +- [ ] Async requests (to allow for long-running queries and multi-statement) - [x] Query results in [Arrow](https://arrow.apache.org/) - [x] Chunked query results - [x] Password, certificate, env auth diff --git a/snowflake-api/src/lib.rs b/snowflake-api/src/lib.rs index 91dd4a9..2968938 100644 --- a/snowflake-api/src/lib.rs +++ b/snowflake-api/src/lib.rs @@ -1,6 +1,6 @@ #![doc( - issue_tracker_base_url = "https://github.com/mycelial/snowflake-rs/issues", - test(no_crate_inject) +issue_tracker_base_url = "https://github.com/mycelial/snowflake-rs/issues", +test(no_crate_inject) )] #![doc = include_str!("../README.md")] #![warn(clippy::all, clippy::pedantic)] @@ -13,6 +13,7 @@ clippy::future_not_send, // This one seems like something we should eventually f clippy::missing_panics_doc )] +use std::collections::HashMap; use std::fmt::{Display, Formatter}; use std::io; use std::path::Path; @@ -31,10 +32,10 @@ use regex::Regex; use reqwest_middleware::ClientWithMiddleware; use thiserror::Error; -use crate::connection::{Connection, ConnectionError}; use responses::ExecResponse; use session::{AuthError, Session}; +use crate::connection::{Connection, ConnectionError}; use crate::connection::QueryType; use crate::requests::ExecRequest; use crate::responses::{ @@ -395,8 +396,9 @@ impl SnowflakeApi { log::debug!("Got PUT response: {:?}", resp); match resp { - ExecResponse::Query(_) => Err(SnowflakeApiError::UnexpectedResponse), ExecResponse::PutGet(pg) => self.put(pg).await, + // put-get by design is async, and isn't a query response + ExecResponse::MultiStatementQuery(_) | ExecResponse::AsyncQuery(_) | ExecResponse::Query(_) => Err(SnowflakeApiError::UnexpectedResponse), ExecResponse::Error(e) => Err(SnowflakeApiError::ApiError( e.data.error_code, e.message.unwrap_or_default(), @@ -479,6 +481,8 @@ impl SnowflakeApi { // processable response ExecResponse::Query(qr) => Ok(qr), ExecResponse::PutGet(_) => Err(SnowflakeApiError::UnexpectedResponse), + ExecResponse::AsyncQuery(_) => Err(SnowflakeApiError::Unimplemented("Async queries, ie the ones returning a handle to query id".to_owned())), + ExecResponse::MultiStatementQuery(_) => Err(SnowflakeApiError::Unimplemented("Multi-statement queries are not implemented as they require polling on the user-side".to_owned())), ExecResponse::Error(e) => Err(SnowflakeApiError::ApiError( e.data.error_code, e.message.unwrap_or_default(), @@ -505,7 +509,7 @@ impl SnowflakeApi { self.connection .get_chunk(&chunk.url, &resp.data.chunk_headers) })) - .await?; + .await?; // fixme: should base64 chunk go first? // fixme: if response is chunked is it both base64 + chunks or just chunks? @@ -529,12 +533,17 @@ impl SnowflakeApi { log::debug!("Executing: {}", sql_text); let parts = self.session.get_token().await?; + // todo: move clientStartTime, requestId, request_guid from request parameters to request body into this map + let mut parameters = HashMap::new(); + parameters.insert("MULTI_STATEMENT_COUNT".to_owned(), Self::count_statements(sql_text).to_string()); let body = ExecRequest { sql_text: sql_text.to_string(), async_exec: false, sequence_id: parts.sequence_id, is_internal: false, + describe_only: false, + parameters, }; let resp = self @@ -550,4 +559,17 @@ impl SnowflakeApi { Ok(resp) } + + fn count_statements(sql_text: &str) -> usize { + // fixme: find better way to count split + let count = sql_text.chars().filter(|&c| c == ';').count(); + + if count == 0 { + // non-terminated single query is still a single query + 1 + } else { + // what if there are multiple queries, but the last one is not ;-terminated? + count + } + } } diff --git a/snowflake-api/src/requests.rs b/snowflake-api/src/requests.rs index 77b0434..0e5ed9e 100644 --- a/snowflake-api/src/requests.rs +++ b/snowflake-api/src/requests.rs @@ -1,12 +1,16 @@ +use std::collections::HashMap; + use serde::Serialize; #[derive(Serialize, Debug)] #[serde(rename_all = "camelCase")] pub struct ExecRequest { pub sql_text: String, - pub async_exec: bool, - pub sequence_id: u64, - pub is_internal: bool, + pub async_exec: bool, // fixme: doesn't exist in .NET + pub sequence_id: u64, // fixme: doesn't exist in .NET + pub is_internal: bool, // fixme: doesn't exist in .NET + pub describe_only: bool, // fixme: optional in GO, required in .NET + pub parameters: HashMap, } #[derive(Serialize, Debug)] diff --git a/snowflake-api/src/responses.rs b/snowflake-api/src/responses.rs index dee26c3..4fc0013 100644 --- a/snowflake-api/src/responses.rs +++ b/snowflake-api/src/responses.rs @@ -6,9 +6,12 @@ use serde::Deserialize; #[derive(Deserialize, Debug)] #[serde(untagged)] pub enum ExecResponse { - Query(QueryExecResponse), PutGet(PutGetExecResponse), - Error(ExecErrorResponse), + AsyncQuery(AsyncQueryResponse), + MultiStatementQuery(MultiStatementQueryResponse), + Query(QueryExecResponse), + // before-last since has intersecting fields + Error(ExecErrorResponse), // last since essentially catch-all } // todo: add close session response, which should be just empty? @@ -32,6 +35,8 @@ pub struct BaseRestResponse { pub data: D, } +pub type MultiStatementQueryResponse = BaseRestResponse; +pub type AsyncQueryResponse = BaseRestResponse; pub type PutGetExecResponse = BaseRestResponse; pub type QueryExecResponse = BaseRestResponse; pub type ExecErrorResponse = BaseRestResponse; @@ -124,15 +129,21 @@ pub struct QueryExecResponseData { // is base64-encoded Arrow IPC payload pub rowset_base64: Option, pub total: i64, - pub returned: i64, // unused in .NET - pub query_id: String, // unused in .NET + pub returned: i64, + // unused in .NET + pub query_id: String, + // unused in .NET pub database_provider: Option, - pub final_database_name: Option, // unused in .NET + pub final_database_name: Option, + // unused in .NET pub final_schema_name: Option, - pub final_warehouse_name: Option, // unused in .NET - pub final_role_name: String, // unused in .NET + pub final_warehouse_name: Option, + // unused in .NET + pub final_role_name: String, + // unused in .NET // only present on SELECT queries - pub number_of_binds: Option, // unused in .NET + pub number_of_binds: Option, + // unused in .NET // todo: deserialize into enum pub statement_type_id: i64, pub version: i64, @@ -143,12 +154,6 @@ pub struct QueryExecResponseData { pub qrmk: Option, #[serde(default)] // chunks are present pub chunk_headers: HashMap, - // when async query is run (ping pong request?) - pub get_result_url: Option, - // multi-statement response, comma-separated - pub result_ids: Option, - // `progressDesc`, and `queryAbortAfterSecs` are not used but exist in .NET - // `sendResultTime`, `queryResultFormat`, `queryContext` also exist } #[derive(Deserialize, Debug)] @@ -304,3 +309,24 @@ pub struct PutGetEncryptionMaterial { pub query_id: String, pub smk_id: i64, } + +#[derive(Deserialize, Debug)] +#[serde(rename_all = "camelCase")] +pub struct AsyncQueryResponseData { + pub query_id: String, + pub get_result_url: String, + pub query_aborts_after_secs: i64, + pub progress_desc: Option, +} + +// fixme: this is not correct, but useful +// since the response will include more fields from [`QueryExecResponseData`] +#[derive(Deserialize, Debug)] +#[serde(rename_all = "camelCase")] +pub struct MultiStatementQueryResponseData { + pub query_id: String, + // comma-separated + pub result_ids: String, + // comma-separated + pub result_types: String, +} \ No newline at end of file