diff --git a/snowflake-api/README.md b/snowflake-api/README.md index b08f911..2ba2c71 100644 --- a/snowflake-api/README.md +++ b/snowflake-api/README.md @@ -18,7 +18,7 @@ Since it does a lot of I/O the library is async-only, and currently has hard dep - [x] PUT support [example](./examples/filetransfer.rs) - [ ] GET support - [x] AWS integration -- [ ] GCloud integration +- [ ] `GCloud` integration - [ ] Azure integration - [x] Parallel uploading of small files - [x] Glob support for PUT (eg `*.csv`) diff --git a/snowflake-api/src/connection.rs b/snowflake-api/src/connection.rs index e7087e1..8ab66b9 100644 --- a/snowflake-api/src/connection.rs +++ b/snowflake-api/src/connection.rs @@ -29,8 +29,9 @@ pub enum ConnectionError { /// Container for query parameters /// This API has different endpoints and MIME types for different requests struct QueryContext { - path: &'static str, + path: String, accept_mime: &'static str, + method: reqwest::Method, } pub enum QueryType { @@ -39,30 +40,40 @@ pub enum QueryType { CloseSession, JsonQuery, ArrowQuery, + ArrowQueryResult(String), } - impl QueryType { - const fn query_context(&self) -> QueryContext { + fn query_context(&self) -> QueryContext { match self { Self::LoginRequest => QueryContext { - path: "session/v1/login-request", + path: "session/v1/login-request".to_string(), accept_mime: "application/json", + method: reqwest::Method::POST, }, Self::TokenRequest => QueryContext { - path: "/session/token-request", + path: "/session/token-request".to_string(), accept_mime: "application/snowflake", + method: reqwest::Method::POST, }, Self::CloseSession => QueryContext { - path: "session", + path: "session".to_string(), accept_mime: "application/snowflake", + method: reqwest::Method::POST, }, Self::JsonQuery => QueryContext { - path: "queries/v1/query-request", + path: "queries/v1/query-request".to_string(), accept_mime: "application/json", + method: reqwest::Method::POST, }, Self::ArrowQuery => QueryContext { - path: "queries/v1/query-request", + path: "queries/v1/query-request".to_string(), + accept_mime: "application/snowflake", + method: reqwest::Method::POST, + }, + Self::ArrowQueryResult(query_result_url) => QueryContext { + path: query_result_url.to_string(), accept_mime: "application/snowflake", + method: reqwest::Method::GET, }, } } @@ -163,14 +174,18 @@ impl Connection { } // todo: persist client to use connection polling - let resp = self - .client - .post(url) - .headers(headers) - .json(&body) - .send() - .await?; - + let resp = match context.method { + reqwest::Method::POST => { + self.client + .post(url) + .headers(headers) + .json(&body) + .send() + .await? + } + reqwest::Method::GET => self.client.get(url).headers(headers).send().await?, + _ => panic!("Unsupported method"), + }; Ok(resp.json::().await?) } diff --git a/snowflake-api/src/lib.rs b/snowflake-api/src/lib.rs index 1fa7b36..483665d 100644 --- a/snowflake-api/src/lib.rs +++ b/snowflake-api/src/lib.rs @@ -406,7 +406,9 @@ impl SnowflakeApi { log::debug!("Got PUT response: {:?}", resp); match resp { - ExecResponse::Query(_) => Err(SnowflakeApiError::UnexpectedResponse), + ExecResponse::Query(_) | ExecResponse::QueryAsync(_) => { + Err(SnowflakeApiError::UnexpectedResponse) + } ExecResponse::PutGet(pg) => put::put(pg).await, ExecResponse::Error(e) => Err(SnowflakeApiError::ApiError( e.data.error_code, @@ -430,15 +432,25 @@ impl SnowflakeApi { } async fn exec_arrow_raw(&self, sql: &str) -> Result { - let resp = self + let mut resp = self .run_sql::(sql, QueryType::ArrowQuery) .await?; log::debug!("Got query response: {:?}", resp); + if let ExecResponse::QueryAsync(data) = &resp { + log::debug!("Got async exec response"); + resp = self + .get_async_exec_result(&data.data.get_result_url) + .await?; + log::debug!("Got result for async exec: {:?}", resp); + } + let resp = match resp { // processable response ExecResponse::Query(qr) => Ok(qr), - ExecResponse::PutGet(_) => Err(SnowflakeApiError::UnexpectedResponse), + ExecResponse::PutGet(_) | ExecResponse::QueryAsync(_) => { + Err(SnowflakeApiError::UnexpectedResponse) + } ExecResponse::Error(e) => Err(SnowflakeApiError::ApiError( e.data.error_code, e.message.unwrap_or_default(), @@ -504,10 +516,41 @@ impl SnowflakeApi { &self.account_identifier, &[], Some(&parts.session_token_auth_header), - body, + Some(body), ) .await?; Ok(resp) } + + pub async fn get_async_exec_result( + &self, + query_result_url: &String, + ) -> Result { + log::debug!("Getting async exec result: {}", query_result_url); + + let mut delay = 1; // Initial delay of 1 second + + loop { + let parts = self.session.get_token().await?; + let resp = self + .connection + .request::( + QueryType::ArrowQueryResult(query_result_url.to_string()), + &self.account_identifier, + &[], + Some(&parts.session_token_auth_header), + serde_json::Value::default(), + ) + .await?; + + if let ExecResponse::QueryAsync(_) = &resp { + // simple exponential retry with a maximum wait time of 5 seconds + tokio::time::sleep(tokio::time::Duration::from_secs(delay)).await; + delay = (delay * 2).min(5); // cap delay to 5 seconds + } else { + return Ok(resp); + } + } + } } diff --git a/snowflake-api/src/responses.rs b/snowflake-api/src/responses.rs index b8a3e68..4d48f67 100644 --- a/snowflake-api/src/responses.rs +++ b/snowflake-api/src/responses.rs @@ -7,12 +7,14 @@ use serde::Deserialize; #[serde(untagged)] pub enum ExecResponse { Query(QueryExecResponse), + QueryAsync(QueryAsyncExecResponse), PutGet(PutGetExecResponse), Error(ExecErrorResponse), } // todo: add close session response, which should be just empty? -#[allow(clippy::large_enum_variant)] +// FIXME: dead_code +#[allow(clippy::large_enum_variant, dead_code)] #[derive(Deserialize, Debug)] #[serde(untagged)] pub enum AuthResponse { @@ -34,6 +36,7 @@ pub struct BaseRestResponse { pub type PutGetExecResponse = BaseRestResponse; pub type QueryExecResponse = BaseRestResponse; +pub type QueryAsyncExecResponse = BaseRestResponse; pub type ExecErrorResponse = BaseRestResponse; pub type AuthErrorResponse = BaseRestResponse; pub type AuthenticatorResponse = BaseRestResponse; @@ -54,12 +57,14 @@ pub struct ExecErrorResponseData { pub pos: Option, // fixme: only valid for exec query response error? present in any exec query response? - pub query_id: String, + pub query_id: Option, pub sql_state: String, } #[derive(Deserialize, Debug)] #[serde(rename_all = "camelCase")] +// FIXME: dead_code +#[allow(dead_code)] pub struct AuthErrorResponseData { pub authn_method: String, } @@ -72,6 +77,8 @@ pub struct NameValueParameter { #[derive(Deserialize, Debug)] #[serde(rename_all = "camelCase")] +// FIXME +#[allow(dead_code)] pub struct LoginResponseData { pub session_id: i64, pub token: String, @@ -86,6 +93,8 @@ pub struct LoginResponseData { #[derive(Deserialize, Debug)] #[serde(rename_all = "camelCase")] +// FIXME: dead_code +#[allow(dead_code)] pub struct SessionInfo { pub database_name: Option, pub schema_name: Option, @@ -95,6 +104,8 @@ pub struct SessionInfo { #[derive(Deserialize, Debug)] #[serde(rename_all = "camelCase")] +// FIXME: dead_code +#[allow(dead_code)] pub struct AuthenticatorResponseData { pub token_url: String, pub sso_url: String, @@ -103,6 +114,8 @@ pub struct AuthenticatorResponseData { #[derive(Deserialize, Debug)] #[serde(rename_all = "camelCase")] +// FIXME: dead_code +#[allow(dead_code)] pub struct RenewSessionResponseData { pub session_token: String, pub validity_in_seconds_s_t: i64, @@ -151,6 +164,13 @@ pub struct QueryExecResponseData { // `sendResultTime`, `queryResultFormat`, `queryContext` also exist } +#[derive(Deserialize, Debug)] +#[serde(rename_all = "camelCase")] +pub struct QueryAsyncExecResponseData { + pub query_id: String, + pub get_result_url: String, +} + #[derive(Deserialize, Debug)] pub struct ExecResponseRowType { pub name: String,