diff --git a/snowflake-api/.gitignore b/snowflake-api/.gitignore new file mode 100644 index 0000000..711174a --- /dev/null +++ b/snowflake-api/.gitignore @@ -0,0 +1,3 @@ +/target +.env +Cargo.lock \ No newline at end of file diff --git a/snowflake-api/Cargo.toml b/snowflake-api/Cargo.toml index b0aaedf..0bd32b6 100644 --- a/snowflake-api/Cargo.toml +++ b/snowflake-api/Cargo.toml @@ -24,6 +24,7 @@ async-trait = "0.1" base64 = "0.21" bytes = "1" futures = "0.3" +http = "1" log = "0.4" object_store = { version = "0.9", features = ["aws"] } regex = "1" @@ -33,6 +34,7 @@ reqwest = { version = "0.11", default-features = false, features = [ "rustls-tls", ] } reqwest-middleware = "0.2" +task-local-extensions = "0.1" reqwest-retry = "0.3" serde = { version = "1", features = ["derive"] } serde_json = "1" @@ -40,6 +42,7 @@ snowflake-jwt = { version = "0.3.0", optional = true } thiserror = "1" url = "2" uuid = { version = "1", features = ["v4"] } + polars-io = { version = ">=0.32", features = ["json", "ipc_streaming"], optional = true} polars-core = { version = ">=0.32", optional = true} @@ -50,3 +53,7 @@ arrow = { version = "50", features = ["prettyprint"] } clap = { version = "4", features = ["derive"] } pretty_env_logger = "0.5" tokio = { version = "1", features = ["macros", "rt-multi-thread"] } +mockito = "1.3.1" +tracing-subscriber = "0.3" +serde_urlencoded = "0.7.1" +dashmap = "5" diff --git a/snowflake-api/examples/tracing/src/main.rs b/snowflake-api/examples/tracing/src/main.rs index 3345a74..8f6cb9a 100644 --- a/snowflake-api/examples/tracing/src/main.rs +++ b/snowflake-api/examples/tracing/src/main.rs @@ -60,7 +60,7 @@ async fn main() -> Result<()> { #[tracing::instrument(name = "snowflake_api", skip(api))] async fn run_in_span(api: &snowflake_api::SnowflakeApi) -> anyhow::Result<()> { - let res = api.exec("select 'hello from snowflake' as col1;").await?; + let res = api.exec("select 1;").await?; match res { QueryResult::Arrow(a) => { diff --git a/snowflake-api/src/connection.rs b/snowflake-api/src/connection.rs index e7087e1..36d6bde 100644 --- a/snowflake-api/src/connection.rs +++ b/snowflake-api/src/connection.rs @@ -1,12 +1,13 @@ +use http::uri::Scheme; use reqwest::header::{self, HeaderMap, HeaderName, HeaderValue}; use reqwest_middleware::ClientWithMiddleware; use reqwest_retry::policies::ExponentialBackoff; use reqwest_retry::RetryTransientMiddleware; use std::collections::HashMap; -use std::time::{SystemTime, UNIX_EPOCH}; use thiserror::Error; use url::Url; -use uuid::Uuid; + +use crate::middleware::UuidMiddleware; #[derive(Error, Debug)] pub enum ConnectionError { @@ -73,13 +74,19 @@ impl QueryType { pub struct Connection { // no need for Arc as it's already inside the reqwest client client: ClientWithMiddleware, + base_url: String, + scheme: http::uri::Scheme, } impl Connection { pub fn new() -> Result { let client = Self::default_client_builder()?; - Ok(Self::new_with_middware(client.build())) + Ok(Self::new_with_middware( + client.build(), + None, + Some(http::uri::Scheme::HTTPS), + )) } /// Allow a user to provide their own middleware @@ -89,11 +96,19 @@ impl Connection { /// use snowflake_api::connection::Connection; /// let mut client = Connection::default_client_builder(); /// // modify the client builder here - /// let connection = Connection::new_with_middware(client.unwrap().build()); + /// let connection = Connection::new_with_middware(client.unwrap().build(), None, Some(http::uri::Scheme::HTTPS)); /// ``` /// This is not intended to be called directly, but is used by `SnowflakeApiBuilder::with_client` - pub fn new_with_middware(client: ClientWithMiddleware) -> Self { - Self { client } + pub fn new_with_middware( + client: ClientWithMiddleware, + base_url: Option, + scheme: Option, + ) -> Self { + Self { + client, + base_url: base_url.unwrap_or(".snowflakecomputing.com".to_string()), + scheme: scheme.unwrap_or(Scheme::HTTPS), + } } pub fn default_client_builder() -> Result { @@ -110,7 +125,8 @@ impl Connection { let client = client.build()?; Ok(reqwest_middleware::ClientBuilder::new(client) - .with(RetryTransientMiddleware::new_with_policy(retry_policy))) + .with(RetryTransientMiddleware::new_with_policy(retry_policy)) + .with(UuidMiddleware)) } /// Perform request of given query type with extra body or parameters @@ -126,27 +142,12 @@ impl Connection { ) -> Result { let context = query_type.query_context(); - let request_id = Uuid::new_v4(); - let request_guid = Uuid::new_v4(); - let client_start_time = SystemTime::now() - .duration_since(UNIX_EPOCH) - .unwrap() - .as_secs() - .to_string(); - // fixme: update uuid's on the retry - let request_id = request_id.to_string(); - let request_guid = request_guid.to_string(); - - let mut get_params = vec![ - ("clientStartTime", client_start_time.as_str()), - ("requestId", request_id.as_str()), - ("request_guid", request_guid.as_str()), - ]; + let mut get_params = vec![]; get_params.extend_from_slice(extra_get_params); let url = format!( - "https://{}.snowflakecomputing.com/{}", - &account_identifier, context.path + "{}://{}{}/{}", + self.scheme, &account_identifier, self.base_url, context.path ); let url = Url::parse_with_params(&url, get_params)?; @@ -197,3 +198,95 @@ impl Connection { Ok(bytes) } } + +#[cfg(test)] +mod tests { + use super::*; + use dashmap::DashMap; + use http::uri::Scheme; + use serde_json::json; + use std::sync::Arc; + use uuid::Uuid; + + #[tokio::test] + async fn test_request() { + tracing_subscriber::fmt::init(); + + let opts = mockito::ServerOpts { + host: "127.0.0.1", + port: 1234, + ..Default::default() + }; + + let client = Connection::default_client_builder(); + let conn = Connection::new_with_middware( + client.unwrap().build(), + Some("127.0.0.1:1234".to_string()), + Some(Scheme::HTTP), + ); + + let mut server = mockito::Server::new_with_opts_async(opts).await; + + let ctx = QueryType::LoginRequest.query_context(); + + // using a dashmap to capture the requestIds across + // all requests to our mock server + let request_ids = Arc::new(DashMap::new()); + let request_ids_clone = Arc::clone(&request_ids); + + let _m1 = server + .mock("POST", "/session/v1/login-request") + .match_query(mockito::Matcher::Any) + // force an error to validate retries + .with_status(500) + .with_header("content-type", ctx.accept_mime) + // mechanism to validate the request body (feed it back to the client) + .with_body_from_request(move |request| { + let path_and_query = request.path_and_query(); + let binding = String::new(); + let query = path_and_query.split('?').nth(1).unwrap_or(&binding); + let params: HashMap = + serde_urlencoded::from_str(query).unwrap_or_else(|_| HashMap::new()); + + let another_binding = String::new(); + let request_id = params.get("requestId").unwrap_or(&another_binding); + + request_ids_clone.insert(request_id.clone(), true); + + let body = json!({"error": "an error happened", "requestId": request_id}); + body.to_string().as_bytes().to_vec() + }) + .expect(4) + .create_async() + .await; + + match conn + .request::( + QueryType::LoginRequest, + "", + &[], + None, + json!({"query": "SELECT 1"}), + ) + .await + { + Ok(res) => { + assert_eq!(res["error"], "an error happened"); + } + Err(e) => { + log::error!("Error: {}", e); + } + }; + + // assert that all requests were made with different requestIds + assert_eq!(request_ids.len(), 4); + + request_ids.iter().for_each(|entry| { + let request_id = entry.key(); + log::info!("Captured Request ID: {}", request_id); + assert_eq!(Uuid::parse_str(request_id).is_ok(), true); + }); + + _m1.assert_async().await; + } +} diff --git a/snowflake-api/src/lib.rs b/snowflake-api/src/lib.rs index 41f9d0a..2278694 100644 --- a/snowflake-api/src/lib.rs +++ b/snowflake-api/src/lib.rs @@ -24,6 +24,7 @@ use arrow::record_batch::RecordBatch; use base64::Engine; use bytes::{Buf, Bytes}; use futures::future::try_join_all; +use http::uri::Scheme; use object_store::aws::AmazonS3Builder; use object_store::local::LocalFileSystem; use object_store::ObjectStore; @@ -42,6 +43,8 @@ use crate::responses::{ }; pub mod connection; + +mod middleware; #[cfg(feature = "polars")] mod polars; mod requests; @@ -218,7 +221,11 @@ impl SnowflakeApiBuilder { pub fn build(self) -> Result { let connection = match self.client { - Some(client) => Arc::new(Connection::new_with_middware(client)), + Some(client) => Arc::new(Connection::new_with_middware( + client, + None, + Some(Scheme::HTTPS), + )), None => Arc::new(Connection::new()?), }; diff --git a/snowflake-api/src/middleware.rs b/snowflake-api/src/middleware.rs new file mode 100644 index 0000000..d02102d --- /dev/null +++ b/snowflake-api/src/middleware.rs @@ -0,0 +1,40 @@ +use reqwest::Request; +use reqwest::Response; +use reqwest_middleware::{Middleware, Next, Result as MiddlewareResult}; + +use std::time::{SystemTime, UNIX_EPOCH}; + +use task_local_extensions::Extensions; +use uuid::Uuid; + +pub struct UuidMiddleware; + +#[async_trait::async_trait] +impl Middleware for UuidMiddleware { + async fn handle( + &self, + req: Request, + extensions: &mut Extensions, + next: Next<'_>, + ) -> MiddlewareResult { + let request_id = Uuid::new_v4(); + let request_guid = Uuid::new_v4(); + let client_start_time = SystemTime::now() + .duration_since(UNIX_EPOCH) + .unwrap() + .as_secs(); + + let mut new_req = req.try_clone().unwrap(); + + // Modify the request URL to include the new UUIDs and client start time + let url = new_req.url_mut(); + + let query = format!( + "{}&clientStartTime={client_start_time}&requestId={request_id}&request_guid={request_guid}", + url.query().unwrap_or("") + ); + + url.set_query(Some(query.as_str())); + next.run(new_req, extensions).await + } +} diff --git a/snowflake-api/src/responses.rs b/snowflake-api/src/responses.rs index dee26c3..0fa61cd 100644 --- a/snowflake-api/src/responses.rs +++ b/snowflake-api/src/responses.rs @@ -165,9 +165,9 @@ pub struct ExecResponseRowType { pub nullable: bool, } -// fixme: is it good idea to keep this as an enum if more types could be added in future? #[derive(Deserialize, Debug)] #[serde(rename_all = "snake_case")] +#[non_exhaustive] pub enum SnowflakeType { Fixed, Real,