Skip to content
10 changes: 6 additions & 4 deletions snowflake-api/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -19,19 +19,21 @@ default = ["cert-auth"]
polars = ["dep:polars-core", "dep:polars-io"]

[dependencies]
arrow = "53"
arrow = "54.2.1"
async-trait = "0.1"
base64 = "0.22"
bytes = "1"
futures = "0.3"
futures-util = "0.3"
log = "0.4"
regex = "1"
reqwest = { version = "0.12", default-features = false, features = [
reqwest = { version = "=0.12.12", default-features = false, features = [
"gzip",
"json",
"rustls-tls",
"stream",
] }
reqwest-middleware = { version = "0.3", features = ["json"] }
reqwest-middleware = { version = "0.3.3", features = ["json"] }
reqwest-retry = "0.6"
serde = { version = "1", features = ["derive"] }
serde_json = "1"
Expand All @@ -54,7 +56,7 @@ tokio = { version = "1", features = ["macros", "rt-multi-thread"] }

[dev-dependencies]
anyhow = "1"
arrow = { version = "53", features = ["prettyprint"] }
arrow = { version = "54.2.1", features = ["prettyprint"] }
clap = { version = "4", features = ["derive"] }
pretty_env_logger = "0.5"
tokio = { version = "1.35", features = ["macros", "rt-multi-thread"] }
2 changes: 1 addition & 1 deletion snowflake-api/examples/polars/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ async fn main() -> Result<()> {
}

async fn run_and_print(api: &SnowflakeApi, sql: &str) -> Result<()> {
let res = api.exec_raw(sql).await?;
let res = api.exec_raw(sql, false).await?;

let df = DataFrame::try_from(res)?;
// alternatively, you can use the `try_into` method on the response
Expand Down
77 changes: 53 additions & 24 deletions snowflake-api/examples/run_sql.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,11 @@ extern crate snowflake_api;

use anyhow::Result;
use arrow::util::pretty::pretty_format_batches;
use clap::Parser;
use clap::{ArgAction, Parser};
use futures_util::StreamExt;
use std::fs;

use snowflake_api::{QueryResult, SnowflakeApi};
use snowflake_api::{responses::ExecResponse, QueryResult, RawQueryResult, SnowflakeApi};

#[derive(clap::ValueEnum, Clone, Debug)]
enum Output {
Expand Down Expand Up @@ -56,6 +57,12 @@ struct Args {
#[arg(long)]
#[arg(value_enum, default_value_t = Output::Arrow)]
output: Output,

#[arg(long)]
host: Option<String>,

#[clap(long, action = ArgAction::Set)]
stream: bool,
}

#[tokio::main]
Expand Down Expand Up @@ -89,30 +96,52 @@ async fn main() -> Result<()> {
_ => {
panic!("Either private key path or password must be set")
}
};

match args.output {
Output::Arrow => {
let res = api.exec(&args.sql).await?;
match res {
QueryResult::Arrow(a) => {
println!("{}", pretty_format_batches(&a).unwrap());
}
QueryResult::Json(j) => {
println!("{j}");
}
QueryResult::Empty => {
println!("Query finished successfully")
}
}
// add optional host
.with_host(args.host);

if args.stream {
let resp = api.exec_raw(&args.sql, true).await?;

if let RawQueryResult::Stream(mut bytes_stream) = resp {
let mut chunks = vec![];
while let Some(bytes) = bytes_stream.next().await {
chunks.push(bytes?);
}

let bytes = chunks.into_iter().flatten().collect::<Vec<u8>>();
let resp = serde_json::from_slice::<ExecResponse>(&bytes).unwrap();
let raw_query_result = api.parse_arrow_raw_response(resp).await.unwrap();
let batches = raw_query_result.deserialize_arrow().unwrap();

if let QueryResult::Arrow(a) = batches {
println!("{}", pretty_format_batches(&a).unwrap());
}
}
Output::Json => {
let res = api.exec_json(&args.sql).await?;
println!("{res}");
}
Output::Query => {
let res = api.exec_response(&args.sql).await?;
println!("{:?}", res);
} else {
match args.output {
Output::Arrow => {
let res = api.exec(&args.sql).await?;
match res {
QueryResult::Arrow(a) => {
println!("{}", pretty_format_batches(&a).unwrap());
}
QueryResult::Json(j) => {
println!("{j}");
}
QueryResult::Empty => {
println!("Query finished successfully")
}
}
}
Output::Json => {
let res = api.exec_json(&args.sql).await?;
println!("{res}");
}
Output::Query => {
let res = api.exec_response(&args.sql).await?;
println!("{:?}", res);
}
}
}

Expand Down
2 changes: 1 addition & 1 deletion snowflake-api/examples/tracing/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ version = "0.1.0"

[dependencies]
anyhow = "1"
arrow = { version = "53", features = ["prettyprint"] }
arrow = { version = "54.2.1", features = ["prettyprint"] }
dotenv = "0.15"
snowflake-api = { path = "../../../snowflake-api" }

Expand Down
44 changes: 34 additions & 10 deletions snowflake-api/src/connection.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
use reqwest::header::{self, HeaderMap, HeaderName, HeaderValue};
use reqwest::Response;
use reqwest_middleware::ClientWithMiddleware;
use reqwest_retry::policies::ExponentialBackoff;
use reqwest_retry::RetryTransientMiddleware;
Expand Down Expand Up @@ -113,17 +114,15 @@ impl Connection {
.with(RetryTransientMiddleware::new_with_policy(retry_policy)))
}

/// Perform request of given query type with extra body or parameters
// todo: implement soft error handling
// todo: is there better way to not repeat myself?
pub async fn request<R: serde::de::DeserializeOwned>(
pub async fn send_request(
&self,
query_type: QueryType,
account_identifier: &str,
extra_get_params: &[(&str, &str)],
auth: Option<&str>,
body: impl serde::Serialize,
) -> Result<R, ConnectionError> {
host: Option<&str>,
) -> Result<Response, ConnectionError> {
let context = query_type.query_context();

let request_id = Uuid::new_v4();
Expand All @@ -144,10 +143,10 @@ impl Connection {
];
get_params.extend_from_slice(extra_get_params);

let url = format!(
"https://{}.snowflakecomputing.com/{}",
&account_identifier, context.path
);
let base_url = host
.map(str::to_string)
.unwrap_or_else(|| format!("https://{}.snowflakecomputing.com", &account_identifier));
let url = format!("{base_url}/{}", context.path);
let url = Url::parse_with_params(&url, get_params)?;

let mut headers = HeaderMap::new();
Expand All @@ -162,7 +161,6 @@ impl Connection {
headers.append(header::AUTHORIZATION, auth_val);
}

// todo: persist client to use connection polling
let resp = self
.client
.post(url)
Expand All @@ -171,6 +169,32 @@ impl Connection {
.send()
.await?;

Ok(resp)
}

/// Perform request of given query type with extra body or parameters
// todo: implement soft error handling
// todo: is there better way to not repeat myself?
pub async fn request<R: serde::de::DeserializeOwned>(
&self,
query_type: QueryType,
account_identifier: &str,
extra_get_params: &[(&str, &str)],
auth: Option<&str>,
body: impl serde::Serialize,
host: Option<&str>,
) -> Result<R, ConnectionError> {
let resp = self
.send_request(
query_type,
account_identifier,
extra_get_params,
auth,
body,
host,
)
.await?;

Ok(resp.json::<R>().await?)
}

Expand Down
Loading