Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 29 additions & 5 deletions src/auth/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,21 +4,41 @@ use chrono::Utc;
use reqwest::Client;
use serde_json::{json, Value};

use crate::{Error, Result, SnowflakeAuthMethod, SnowflakeClientConfig};
use crate::{Error, Result, SnowflakeAuthMethod, SnowflakeClientConfig, SnowflakeConnectionConfig};

use self::key_pair::generate_jwt_from_key_pair;

fn get_base_url(
config: &SnowflakeClientConfig,
connection_config: &Option<SnowflakeConnectionConfig>,
) -> String {
if let Some(connection_config) = connection_config {
let host = &connection_config.host;
let port = connection_config
.port
.map(|p| format!(":{p}"))
.unwrap_or_else(|| "".to_string());
let protocol = connection_config
.protocol
.clone()
.unwrap_or_else(|| "https".to_string());

format!("{protocol}://{host}{port}")
} else {
format!("https://{}.snowflakecomputing.com", config.account)
}
}

/// Login to Snowflake and return a session token.
pub(super) async fn login(
http: &Client,
username: &str,
auth: &SnowflakeAuthMethod,
config: &SnowflakeClientConfig,
connection_config: &Option<SnowflakeConnectionConfig>,
) -> Result<String> {
let url = format!(
"https://{account}.snowflakecomputing.com/session/v1/login-request",
account = config.account
);
let base_url = get_base_url(config, connection_config);
let url = format!("{base_url}/session/v1/login-request");

let mut queries = vec![];
if let Some(warehouse) = &config.warehouse {
Expand Down Expand Up @@ -86,6 +106,10 @@ fn login_request_data(
"AUTHENTICATOR": "SNOWFLAKE_JWT"
}))
}
SnowflakeAuthMethod::Oauth { token } => Ok(json!({
"AUTHENTICATOR": "OAUTH",
"TOKEN": token
})),
}
}

Expand Down
50 changes: 49 additions & 1 deletion src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ pub struct SnowflakeClient {
username: String,
auth: SnowflakeAuthMethod,
config: SnowflakeClientConfig,
connection_config: Option<SnowflakeConnectionConfig>,
}

#[derive(Default, Clone)]
Expand All @@ -72,13 +73,23 @@ pub struct SnowflakeClientConfig {
pub timeout: Option<Duration>,
}

#[derive(Default, Clone)]
struct SnowflakeConnectionConfig {
host: String,
port: Option<u16>,
protocol: Option<String>,
}

#[derive(Clone)]
pub enum SnowflakeAuthMethod {
Password(String),
KeyPair {
encrypted_pem: String,
password: Vec<u8>,
},
Oauth {
token: String,
},
}

impl SnowflakeClient {
Expand All @@ -93,6 +104,7 @@ impl SnowflakeClient {
username: username.to_string(),
auth,
config,
connection_config: None,
})
}

Expand All @@ -110,16 +122,52 @@ impl SnowflakeClient {
username: self.username,
auth: self.auth,
config: self.config,
connection_config: self.connection_config,
})
}

pub fn with_address(
self,
host: &str,
port: Option<u16>,
protocol: Option<String>,
) -> Result<Self> {
Ok(Self {
http: self.http,
username: self.username,
auth: self.auth,
config: self.config,
connection_config: Some(SnowflakeConnectionConfig {
host: host.to_string(),
port,
protocol,
}),
})
}

pub async fn create_session(&self) -> Result<SnowflakeSession> {
let session_token = login(&self.http, &self.username, &self.auth, &self.config).await?;
let session_token = login(
&self.http,
&self.username,
&self.auth,
&self.config,
&self.connection_config,
)
.await?;
Ok(SnowflakeSession {
http: self.http.clone(),
account: self.config.account.clone(),
session_token,
timeout: self.config.timeout,
host: self
.connection_config
.as_ref()
.map(|conf| conf.host.clone()),
port: self.connection_config.as_ref().and_then(|conf| conf.port),
protocol: self
.connection_config
.as_ref()
.and_then(|conf| conf.protocol.clone()),
})
}
}
63 changes: 47 additions & 16 deletions src/query.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,23 +28,33 @@ pub struct QueryExecutor {
row_set: Mutex<Option<Vec<Vec<Option<String>>>>>,
}

fn get_base_url(sess: &SnowflakeSession) -> String {
let host = sess
.host
.clone()
.unwrap_or_else(|| format!("{}.snowflakecomputing.com", sess.account));
let port = sess.port.map(|p| format!(":{p}")).unwrap_or_default();
let protocol = sess.protocol.clone().unwrap_or_else(|| "https".to_string());

format!("{protocol}://{host}{port}")
}

impl QueryExecutor {
pub(super) async fn create<Q: Into<QueryRequest>>(
sess: &SnowflakeSession,
request: Q,
) -> Result<Self> {
let SnowflakeSession {
http,
account,
session_token,
timeout,
..
} = sess;
let timeout = timeout.unwrap_or(Duration::from_secs(DEFAULT_TIMEOUT_SECONDS));

let request_id = uuid::Uuid::new_v4();
let url = format!(
r"https://{account}.snowflakecomputing.com/queries/v1/query-request?requestId={request_id}"
);
let base_url = get_base_url(sess);
let url = format!(r"{base_url}/queries/v1/query-request?requestId={request_id}");

let request: QueryRequest = request.into();
let response = http
Expand All @@ -64,17 +74,27 @@ impl QueryExecutor {
return Err(Error::Communication(body));
}

let mut response: SnowflakeResponse =
let response: SnowflakeRawResponse =
serde_json::from_str(&body).map_err(|e| Error::Json(e, body))?;

if let Some(SESSION_EXPIRED) = response.code.as_deref() {
return Err(Error::SessionExpired);
}

if !response.success {
return Err(Error::Communication(response.message.unwrap_or_default()));
}

let mut response: SnowflakeResponse = response.try_into()?;

let response_code = response.code.as_deref();
if response_code == Some(QUERY_IN_PROGRESS_ASYNC_CODE)
|| response_code == Some(QUERY_IN_PROGRESS_CODE)
{
match response.data.get_result_url {
Some(result_url) => {
response =
poll_for_async_results(http, account, &result_url, session_token, timeout)
poll_for_async_results(http, &result_url, session_token, timeout, base_url)
.await?
}
None => {
Expand All @@ -83,14 +103,6 @@ impl QueryExecutor {
}
}

if let Some(SESSION_EXPIRED) = response.code.as_deref() {
return Err(Error::SessionExpired);
}

if !response.success {
return Err(Error::Communication(response.message.unwrap_or_default()));
}

if let Some(format) = response.data.query_result_format {
if format != "json" {
return Err(Error::UnsupportedFormat(format.clone()));
Expand Down Expand Up @@ -212,15 +224,15 @@ impl QueryExecutor {

async fn poll_for_async_results(
http: &Client,
account: &str,
result_url: &str,
session_token: &str,
timeout: Duration,
base_url: String,
) -> Result<SnowflakeResponse> {
let start = Instant::now();
while start.elapsed() < timeout {
sleep(Duration::from_secs(10)).await;
let url = format!("https://{account}.snowflakecomputing.com{}", result_url);
let url = format!("{base_url}{result_url}");

let resp = http
.get(url)
Expand Down Expand Up @@ -357,7 +369,26 @@ struct RawQueryResponseChunk {
#[derive(serde::Deserialize, Debug)]
struct SnowflakeResponse {
data: RawQueryResponse,
code: Option<String>,
}

#[derive(serde::Deserialize, Debug)]
struct SnowflakeRawResponse {
data: Option<RawQueryResponse>,
message: Option<String>,
success: bool,
code: Option<String>,
}

impl TryInto<SnowflakeResponse> for SnowflakeRawResponse {
type Error = crate::Error;

fn try_into(self) -> std::result::Result<SnowflakeResponse, Self::Error> {
Ok(SnowflakeResponse {
data: self
.data
.ok_or_else(|| crate::error::Error::Decode("Expected data property".into()))?,
code: self.code,
})
}
}
3 changes: 3 additions & 0 deletions src/session/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,9 @@ use crate::{

pub struct SnowflakeSession {
pub(super) http: reqwest::Client,
pub(super) host: Option<String>,
pub(super) port: Option<u16>,
pub(super) protocol: Option<String>,
pub(super) account: String,
pub(super) session_token: String,
pub(super) timeout: Option<Duration>,
Expand Down
11 changes: 11 additions & 0 deletions tests/common.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,11 @@ pub fn connect() -> Result<SnowflakeClient> {
let warehouse = std::env::var("SNOWFLAKE_WAREHOUSE").ok();
let database = std::env::var("SNOWFLAKE_DATABASE").ok();
let schema = std::env::var("SNOWFLAKE_SCHEMA").ok();
let host = std::env::var("SNOWFLAKE_HOST").ok();
let port = std::env::var("SNOWFLAKE_PORT")
.ok()
.and_then(|var| var.parse().ok());
let protocol = std::env::var("SNOWFLAKE_PROTOCOL").ok();

let client = SnowflakeClient::new(
&username,
Expand All @@ -23,5 +28,11 @@ pub fn connect() -> Result<SnowflakeClient> {
},
)?;

let client = if let Some(ref host) = host {
client.with_address(host, port, protocol)?
} else {
client
};

Ok(client)
}