From c8c6e73f7e7a08346e40ca071b8d7b825e8e074b Mon Sep 17 00:00:00 2001 From: sgrebnov Date: Sun, 5 May 2024 18:18:45 -0700 Subject: [PATCH 1/3] Add support for async query response --- snowflake-api/src/connection.rs | 51 ++++++++++++++++++++++----------- snowflake-api/src/lib.rs | 40 ++++++++++++++++++++++++-- snowflake-api/src/responses.rs | 11 ++++++- 3 files changed, 83 insertions(+), 19 deletions(-) diff --git a/snowflake-api/src/connection.rs b/snowflake-api/src/connection.rs index e7087e1..0c146ea 100644 --- a/snowflake-api/src/connection.rs +++ b/snowflake-api/src/connection.rs @@ -29,8 +29,9 @@ pub enum ConnectionError { /// Container for query parameters /// This API has different endpoints and MIME types for different requests struct QueryContext { - path: &'static str, + path: String, accept_mime: &'static str, + method: reqwest::Method } pub enum QueryType { @@ -39,30 +40,40 @@ pub enum QueryType { CloseSession, JsonQuery, ArrowQuery, + ArrowQueryResult(String), } - impl QueryType { - const fn query_context(&self) -> QueryContext { + fn query_context(&self) -> QueryContext { match self { Self::LoginRequest => QueryContext { - path: "session/v1/login-request", + path: "session/v1/login-request".to_string(), accept_mime: "application/json", + method: reqwest::Method::POST, }, Self::TokenRequest => QueryContext { - path: "/session/token-request", + path: "/session/token-request".to_string(), accept_mime: "application/snowflake", + method: reqwest::Method::POST, }, Self::CloseSession => QueryContext { - path: "session", + path: "session".to_string(), accept_mime: "application/snowflake", + method: reqwest::Method::POST, }, Self::JsonQuery => QueryContext { - path: "queries/v1/query-request", + path: "queries/v1/query-request".to_string(), accept_mime: "application/json", + method: reqwest::Method::POST, }, Self::ArrowQuery => QueryContext { - path: "queries/v1/query-request", + path: "queries/v1/query-request".to_string(), + accept_mime: "application/snowflake", + method: reqwest::Method::POST, + }, + Self::ArrowQueryResult(query_result_url) => QueryContext { + path: query_result_url.to_string(), accept_mime: "application/snowflake", + method: reqwest::Method::GET, }, } } @@ -163,14 +174,22 @@ impl Connection { } // todo: persist client to use connection polling - let resp = self - .client - .post(url) - .headers(headers) - .json(&body) - .send() - .await?; - + let resp = match context.method { + reqwest::Method::POST => self + .client + .post(url) + .headers(headers) + .json(&body) + .send() + .await?, + reqwest::Method::GET => self + .client + .get(url) + .headers(headers) + .send() + .await?, + _ => panic!("Unsupported method"), + }; Ok(resp.json::().await?) } diff --git a/snowflake-api/src/lib.rs b/snowflake-api/src/lib.rs index 1fa7b36..3d8e78c 100644 --- a/snowflake-api/src/lib.rs +++ b/snowflake-api/src/lib.rs @@ -407,6 +407,7 @@ impl SnowflakeApi { match resp { ExecResponse::Query(_) => Err(SnowflakeApiError::UnexpectedResponse), + ExecResponse::QueryAsync(_) => Err(SnowflakeApiError::UnexpectedResponse), ExecResponse::PutGet(pg) => put::put(pg).await, ExecResponse::Error(e) => Err(SnowflakeApiError::ApiError( e.data.error_code, @@ -430,14 +431,21 @@ impl SnowflakeApi { } async fn exec_arrow_raw(&self, sql: &str) -> Result { - let resp = self + let mut resp = self .run_sql::(sql, QueryType::ArrowQuery) .await?; log::debug!("Got query response: {:?}", resp); + if let ExecResponse::QueryAsync(data) = &resp { + log::debug!("Got async exec response"); + resp = self.get_async_exec_result(&data.data.get_result_url).await?; + log::debug!("Got result for async exec: {:?}", resp); + } + let resp = match resp { // processable response ExecResponse::Query(qr) => Ok(qr), + ExecResponse::QueryAsync(_) => Err(SnowflakeApiError::UnexpectedResponse), ExecResponse::PutGet(_) => Err(SnowflakeApiError::UnexpectedResponse), ExecResponse::Error(e) => Err(SnowflakeApiError::ApiError( e.data.error_code, @@ -504,10 +512,38 @@ impl SnowflakeApi { &self.account_identifier, &[], Some(&parts.session_token_auth_header), - body, + Some(body), ) .await?; Ok(resp) } + + pub async fn get_async_exec_result(&self, query_result_url: &String) -> Result{ + log::debug!("Getting async exec result: {}", query_result_url); + + let mut delay = 1; // Initial delay of 1 second + + loop { + let parts = self.session.get_token().await?; + let resp = self + .connection + .request::( + QueryType::ArrowQueryResult(query_result_url.to_string()), + &self.account_identifier, + &[], + Some(&parts.session_token_auth_header), + serde_json::Value::default() + ) + .await?; + + if let ExecResponse::QueryAsync(_) = &resp { + // simple exponential retry with a maximum wait time of 5 seconds + tokio::time::sleep(tokio::time::Duration::from_secs(delay)).await; + delay = (delay * 2).min(5); // cap delay to 5 seconds + } else { + return Ok(resp); + } + }; + } } diff --git a/snowflake-api/src/responses.rs b/snowflake-api/src/responses.rs index b8a3e68..11034ce 100644 --- a/snowflake-api/src/responses.rs +++ b/snowflake-api/src/responses.rs @@ -7,6 +7,7 @@ use serde::Deserialize; #[serde(untagged)] pub enum ExecResponse { Query(QueryExecResponse), + QueryAsync(QueryAsyncExecResponse), PutGet(PutGetExecResponse), Error(ExecErrorResponse), } @@ -34,6 +35,7 @@ pub struct BaseRestResponse { pub type PutGetExecResponse = BaseRestResponse; pub type QueryExecResponse = BaseRestResponse; +pub type QueryAsyncExecResponse = BaseRestResponse; pub type ExecErrorResponse = BaseRestResponse; pub type AuthErrorResponse = BaseRestResponse; pub type AuthenticatorResponse = BaseRestResponse; @@ -54,7 +56,7 @@ pub struct ExecErrorResponseData { pub pos: Option, // fixme: only valid for exec query response error? present in any exec query response? - pub query_id: String, + pub query_id: Option, pub sql_state: String, } @@ -151,6 +153,13 @@ pub struct QueryExecResponseData { // `sendResultTime`, `queryResultFormat`, `queryContext` also exist } +#[derive(Deserialize, Debug)] +#[serde(rename_all = "camelCase")] +pub struct QueryAsyncExecResponseData { + pub query_id: String, + pub get_result_url: String, +} + #[derive(Deserialize, Debug)] pub struct ExecResponseRowType { pub name: String, From c68710819faddded8a2ee69df3c66814c8b68c3f Mon Sep 17 00:00:00 2001 From: Dmitriy Mazurin Date: Fri, 6 Sep 2024 14:54:27 +0100 Subject: [PATCH 2/3] Merge pull request #52 from andrusha/clippy_fixes clippy temporary fixes --- snowflake-api/README.md | 2 +- snowflake-api/src/responses.rs | 13 ++++++++++++- 2 files changed, 13 insertions(+), 2 deletions(-) diff --git a/snowflake-api/README.md b/snowflake-api/README.md index b08f911..2ba2c71 100644 --- a/snowflake-api/README.md +++ b/snowflake-api/README.md @@ -18,7 +18,7 @@ Since it does a lot of I/O the library is async-only, and currently has hard dep - [x] PUT support [example](./examples/filetransfer.rs) - [ ] GET support - [x] AWS integration -- [ ] GCloud integration +- [ ] `GCloud` integration - [ ] Azure integration - [x] Parallel uploading of small files - [x] Glob support for PUT (eg `*.csv`) diff --git a/snowflake-api/src/responses.rs b/snowflake-api/src/responses.rs index 11034ce..4d48f67 100644 --- a/snowflake-api/src/responses.rs +++ b/snowflake-api/src/responses.rs @@ -13,7 +13,8 @@ pub enum ExecResponse { } // todo: add close session response, which should be just empty? -#[allow(clippy::large_enum_variant)] +// FIXME: dead_code +#[allow(clippy::large_enum_variant, dead_code)] #[derive(Deserialize, Debug)] #[serde(untagged)] pub enum AuthResponse { @@ -62,6 +63,8 @@ pub struct ExecErrorResponseData { #[derive(Deserialize, Debug)] #[serde(rename_all = "camelCase")] +// FIXME: dead_code +#[allow(dead_code)] pub struct AuthErrorResponseData { pub authn_method: String, } @@ -74,6 +77,8 @@ pub struct NameValueParameter { #[derive(Deserialize, Debug)] #[serde(rename_all = "camelCase")] +// FIXME +#[allow(dead_code)] pub struct LoginResponseData { pub session_id: i64, pub token: String, @@ -88,6 +93,8 @@ pub struct LoginResponseData { #[derive(Deserialize, Debug)] #[serde(rename_all = "camelCase")] +// FIXME: dead_code +#[allow(dead_code)] pub struct SessionInfo { pub database_name: Option, pub schema_name: Option, @@ -97,6 +104,8 @@ pub struct SessionInfo { #[derive(Deserialize, Debug)] #[serde(rename_all = "camelCase")] +// FIXME: dead_code +#[allow(dead_code)] pub struct AuthenticatorResponseData { pub token_url: String, pub sso_url: String, @@ -105,6 +114,8 @@ pub struct AuthenticatorResponseData { #[derive(Deserialize, Debug)] #[serde(rename_all = "camelCase")] +// FIXME: dead_code +#[allow(dead_code)] pub struct RenewSessionResponseData { pub session_token: String, pub validity_in_seconds_s_t: i64, From d8241fcd788545177a09dd418b630e85396fdddd Mon Sep 17 00:00:00 2001 From: sgrebnov Date: Tue, 24 Sep 2024 18:55:12 -0700 Subject: [PATCH 3/3] Fix clippy and formatting --- snowflake-api/src/connection.rs | 26 ++++++++++------------ snowflake-api/src/lib.rs | 39 +++++++++++++++++++-------------- 2 files changed, 34 insertions(+), 31 deletions(-) diff --git a/snowflake-api/src/connection.rs b/snowflake-api/src/connection.rs index 0c146ea..8ab66b9 100644 --- a/snowflake-api/src/connection.rs +++ b/snowflake-api/src/connection.rs @@ -31,7 +31,7 @@ pub enum ConnectionError { struct QueryContext { path: String, accept_mime: &'static str, - method: reqwest::Method + method: reqwest::Method, } pub enum QueryType { @@ -175,20 +175,16 @@ impl Connection { // todo: persist client to use connection polling let resp = match context.method { - reqwest::Method::POST => self - .client - .post(url) - .headers(headers) - .json(&body) - .send() - .await?, - reqwest::Method::GET => self - .client - .get(url) - .headers(headers) - .send() - .await?, - _ => panic!("Unsupported method"), + reqwest::Method::POST => { + self.client + .post(url) + .headers(headers) + .json(&body) + .send() + .await? + } + reqwest::Method::GET => self.client.get(url).headers(headers).send().await?, + _ => panic!("Unsupported method"), }; Ok(resp.json::().await?) } diff --git a/snowflake-api/src/lib.rs b/snowflake-api/src/lib.rs index 3d8e78c..483665d 100644 --- a/snowflake-api/src/lib.rs +++ b/snowflake-api/src/lib.rs @@ -406,8 +406,9 @@ impl SnowflakeApi { log::debug!("Got PUT response: {:?}", resp); match resp { - ExecResponse::Query(_) => Err(SnowflakeApiError::UnexpectedResponse), - ExecResponse::QueryAsync(_) => Err(SnowflakeApiError::UnexpectedResponse), + ExecResponse::Query(_) | ExecResponse::QueryAsync(_) => { + Err(SnowflakeApiError::UnexpectedResponse) + } ExecResponse::PutGet(pg) => put::put(pg).await, ExecResponse::Error(e) => Err(SnowflakeApiError::ApiError( e.data.error_code, @@ -438,15 +439,18 @@ impl SnowflakeApi { if let ExecResponse::QueryAsync(data) = &resp { log::debug!("Got async exec response"); - resp = self.get_async_exec_result(&data.data.get_result_url).await?; + resp = self + .get_async_exec_result(&data.data.get_result_url) + .await?; log::debug!("Got result for async exec: {:?}", resp); } let resp = match resp { // processable response ExecResponse::Query(qr) => Ok(qr), - ExecResponse::QueryAsync(_) => Err(SnowflakeApiError::UnexpectedResponse), - ExecResponse::PutGet(_) => Err(SnowflakeApiError::UnexpectedResponse), + ExecResponse::PutGet(_) | ExecResponse::QueryAsync(_) => { + Err(SnowflakeApiError::UnexpectedResponse) + } ExecResponse::Error(e) => Err(SnowflakeApiError::ApiError( e.data.error_code, e.message.unwrap_or_default(), @@ -519,7 +523,10 @@ impl SnowflakeApi { Ok(resp) } - pub async fn get_async_exec_result(&self, query_result_url: &String) -> Result{ + pub async fn get_async_exec_result( + &self, + query_result_url: &String, + ) -> Result { log::debug!("Getting async exec result: {}", query_result_url); let mut delay = 1; // Initial delay of 1 second @@ -527,15 +534,15 @@ impl SnowflakeApi { loop { let parts = self.session.get_token().await?; let resp = self - .connection - .request::( - QueryType::ArrowQueryResult(query_result_url.to_string()), - &self.account_identifier, - &[], - Some(&parts.session_token_auth_header), - serde_json::Value::default() - ) - .await?; + .connection + .request::( + QueryType::ArrowQueryResult(query_result_url.to_string()), + &self.account_identifier, + &[], + Some(&parts.session_token_auth_header), + serde_json::Value::default(), + ) + .await?; if let ExecResponse::QueryAsync(_) = &resp { // simple exponential retry with a maximum wait time of 5 seconds @@ -544,6 +551,6 @@ impl SnowflakeApi { } else { return Ok(resp); } - }; + } } }