diff --git a/Cargo.lock b/Cargo.lock index 16f8f3bada2a..810fa87d4c4c 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -5415,6 +5415,7 @@ dependencies = [ "async-trait", "bollard", "clap", + "cookie", "databend-common-exception", "env_logger 0.11.8", "futures-util", @@ -5431,6 +5432,7 @@ dependencies = [ "testcontainers-modules", "thiserror 1.0.69", "tokio", + "url", "walkdir", ] diff --git a/src/query/service/src/servers/http/middleware/session_header.rs b/src/query/service/src/servers/http/middleware/session_header.rs index 57ccdf9ca0dd..9b34eaf46217 100644 --- a/src/query/service/src/servers/http/middleware/session_header.rs +++ b/src/query/service/src/servers/http/middleware/session_header.rs @@ -117,18 +117,15 @@ impl ClientSession { headers: &HeaderMap, caps: &mut ClientCapabilities, ) -> Result, String> { - if let Some(v) = headers.get(HEADER_SESSION) { - caps.session_header = true; - let v = v.to_str().unwrap().to_string().trim().to_owned(); - let s = if v.is_empty() { - // note that curl -H "X-xx:" not work - Self::new_session(false) - } else { - let header = decode_json_header(HEADER_SESSION, v.as_str())?; - Self::old_session(false, header) - }; - Ok(Some(s)) - } else if caps.session_header { + if caps.session_header { + if let Some(v) = headers.get(HEADER_SESSION) { + caps.session_header = true; + let v = v.to_str().unwrap().to_string().trim().to_owned(); + if !v.is_empty() { + let header = decode_json_header(HEADER_SESSION, &v)?; + return Ok(Some(Self::old_session(false, header))); + }; + } Ok(Some(Self::new_session(false))) } else { Ok(None) diff --git a/src/query/service/src/servers/http/v1/http_query_handlers.rs b/src/query/service/src/servers/http/v1/http_query_handlers.rs index a3ade94c40bb..366b34a87743 100644 --- a/src/query/service/src/servers/http/v1/http_query_handlers.rs +++ b/src/query/service/src/servers/http/v1/http_query_handlers.rs @@ -346,7 +346,6 @@ async fn query_state_handler( let http_query_manager = HttpQueryManager::instance(); match http_query_manager.get_query(&query_id) { Some(query) => { - query.check_client_session_id(&ctx.client_session_id)?; if let Some(reason) = query.check_removed() { Err(query_id_removed(&query_id, reason)) } else { diff --git a/tests/sqllogictests/Cargo.toml b/tests/sqllogictests/Cargo.toml index fd8547589c6f..365250f7cbbf 100644 --- a/tests/sqllogictests/Cargo.toml +++ b/tests/sqllogictests/Cargo.toml @@ -18,6 +18,7 @@ async-recursion = { workspace = true } async-trait = { workspace = true } bollard = { workspace = true } clap = { workspace = true } +cookie = { workspace = true } databend-common-exception = { workspace = true } env_logger = { workspace = true } futures-util = { workspace = true } @@ -34,6 +35,7 @@ testcontainers = { workspace = true } testcontainers-modules = { workspace = true, features = ["mysql", "redis"] } thiserror = { workspace = true } tokio = { workspace = true } +url = { workspace = true } walkdir = { workspace = true } [lints] diff --git a/tests/sqllogictests/src/client/global_cookie_store.rs b/tests/sqllogictests/src/client/global_cookie_store.rs new file mode 100644 index 000000000000..498e2d82c9cd --- /dev/null +++ b/tests/sqllogictests/src/client/global_cookie_store.rs @@ -0,0 +1,62 @@ +// Copyright 2021 Datafuse Labs +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use std::collections::HashMap; +use std::sync::RwLock; + +use cookie::Cookie; +use reqwest::cookie::CookieStore; +use reqwest::header::HeaderValue; +use url::Url; + +pub(crate) struct GlobalCookieStore { + cookies: RwLock>>, +} + +impl GlobalCookieStore { + pub fn new() -> Self { + GlobalCookieStore { + cookies: RwLock::new(HashMap::new()), + } + } +} + +impl CookieStore for GlobalCookieStore { + fn set_cookies(&self, cookie_headers: &mut dyn Iterator, _url: &Url) { + let iter = cookie_headers + .filter_map(|val| std::str::from_utf8(val.as_bytes()).ok()) + .filter_map(|kv| Cookie::parse(kv).map(|c| c.into_owned()).ok()); + + let mut guard = self.cookies.write().unwrap(); + for cookie in iter { + guard.insert(cookie.name().to_string(), cookie); + } + } + + fn cookies(&self, _url: &Url) -> Option { + let guard = self.cookies.read().unwrap(); + let s: String = guard + .values() + .map(|cookie| cookie.name_value()) + .map(|(name, value)| format!("{name}={value}")) + .collect::>() + .join("; "); + + if s.is_empty() { + return None; + } + + HeaderValue::from_str(&s).ok() + } +} diff --git a/tests/sqllogictests/src/client/http_client.rs b/tests/sqllogictests/src/client/http_client.rs index f8597387c2bd..3485a39f108a 100644 --- a/tests/sqllogictests/src/client/http_client.rs +++ b/tests/sqllogictests/src/client/http_client.rs @@ -13,6 +13,7 @@ // limitations under the License. use std::collections::HashMap; +use std::sync::Arc; use std::time::Duration; use std::time::Instant; @@ -20,22 +21,18 @@ use reqwest::header::HeaderMap; use reqwest::header::HeaderValue; use reqwest::Client; use reqwest::ClientBuilder; -use reqwest::Response; use serde::Deserialize; use sqllogictest::DBOutput; use sqllogictest::DefaultColumnType; -use crate::error::DSqlLogicTestError::Databend; +use crate::client::global_cookie_store::GlobalCookieStore; use crate::error::Result; use crate::util::parser_rows; use crate::util::HttpSessionConf; -const SESSION_HEADER: &str = "X-DATABEND-SESSION"; - pub struct HttpClient { pub client: Client, pub session_token: String, - pub session_headers: HeaderMap, pub debug: bool, pub session: Option, pub port: u16, @@ -86,59 +83,28 @@ impl HttpClient { header.insert("Accept", HeaderValue::from_str("application/json").unwrap()); header.insert( "X-DATABEND-CLIENT-CAPS", - HeaderValue::from_str("session_header").unwrap(), + HeaderValue::from_str("session_cookie").unwrap(), ); + let cookie_provider = GlobalCookieStore::new(); let client = ClientBuilder::new() + .cookie_provider(Arc::new(cookie_provider)) .default_headers(header) // https://github.com/hyperium/hyper/issues/2136#issuecomment-589488526 .http2_keep_alive_timeout(Duration::from_secs(15)) .pool_max_idle_per_host(0) .build()?; - let mut session_headers = HeaderMap::new(); - session_headers.insert(SESSION_HEADER, HeaderValue::from_str("").unwrap()); - let mut res = Self { - client, - session_token: "".to_string(), - session_headers, - session: None, - debug: false, - port, - }; - res.login().await?; - Ok(res) - } - async fn update_session_header(&mut self, response: Response) -> Result { - if let Some(value) = response.headers().get(SESSION_HEADER) { - let session_header = value.to_str().unwrap().to_owned(); - if !session_header.is_empty() { - self.session_headers - .insert(SESSION_HEADER, value.to_owned()); - return Ok(response); - } - } - let meta = format!("response={response:?}"); - let data = response.text().await.unwrap(); - Err(Databend( - format!("{} is empty, {meta}, {data}", SESSION_HEADER,).into(), - )) - } + let url = format!("http://127.0.0.1:{}/v1/session/login", port); - async fn login(&mut self) -> Result<()> { - let url = format!("http://127.0.0.1:{}/v1/session/login", self.port); - let response = self - .client + let session_token = client .post(&url) - .headers(self.session_headers.clone()) .body("{}") .basic_auth("root", Some("")) .send() .await .inspect_err(|e| { println!("fail to send to {}: {:?}", url, e); - })?; - let response = self.update_session_header(response).await?; - self.session_token = response + })? .json::() .await .inspect_err(|e| { @@ -147,7 +113,14 @@ impl HttpClient { .tokens .unwrap() .session_token; - Ok(()) + + Ok(Self { + client, + session_token, + session: None, + debug: false, + port, + }) } pub async fn query(&mut self, sql: &str) -> Result> { @@ -204,43 +177,43 @@ impl HttpClient { } // Send request and get response by json format - async fn post_query(&mut self, sql: &str, url: &str) -> Result { + async fn post_query(&self, sql: &str, url: &str) -> Result { let mut query = HashMap::new(); query.insert("sql", serde_json::to_value(sql)?); if let Some(session) = &self.session { query.insert("session", serde_json::to_value(session)?); } - let response = self + Ok(self .client .post(url) - .headers(self.session_headers.clone()) .json(&query) .bearer_auth(&self.session_token) .send() .await .inspect_err(|e| { println!("fail to send to {}: {:?}", url, e); - })?; - let response = self.update_session_header(response).await?; - Ok(response.json::().await.inspect_err(|e| { - println!("fail to decode json when call {}: {:?}", url, e); - })?) + })? + .json::() + .await + .inspect_err(|e| { + println!("fail to decode json when call {}: {:?}", url, e); + })?) } - async fn poll_query_result(&mut self, url: &str) -> Result { - let response = self + async fn poll_query_result(&self, url: &str) -> Result { + Ok(self .client .get(url) .bearer_auth(&self.session_token) - .headers(self.session_headers.clone()) .send() .await .inspect_err(|e| { println!("fail to send to {}: {:?}", url, e); - })?; - let response = self.update_session_header(response).await?; - Ok(response.json::().await.inspect_err(|e| { - println!("fail to decode json when call {}: {:?}", url, e); - })?) + })? + .json::() + .await + .inspect_err(|e| { + println!("fail to decode json when call {}: {:?}", url, e); + })?) } } diff --git a/tests/sqllogictests/src/client/mod.rs b/tests/sqllogictests/src/client/mod.rs index d12b284f24d7..bf089df32746 100644 --- a/tests/sqllogictests/src/client/mod.rs +++ b/tests/sqllogictests/src/client/mod.rs @@ -12,6 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. +mod global_cookie_store; mod http_client; mod mysql_client; mod ttc_client; diff --git a/tests/suites/1_stateful/09_http_handler/09_0007_session.py b/tests/suites/1_stateful/09_http_handler/09_0007_session.py index 32b22f574c96..35c65c291c73 100755 --- a/tests/suites/1_stateful/09_http_handler/09_0007_session.py +++ b/tests/suites/1_stateful/09_http_handler/09_0007_session.py @@ -10,6 +10,7 @@ from requests import Response HEADER_SESSION = "X-DATABEND-SESSION" +HEADER_CAPS = "X-DATABEND-CLIENT-CAPS" # Define the URLs and credentials query_url = "http://localhost:8000/v1/query" login_url = "http://localhost:8000/v1/session/login" @@ -22,10 +23,9 @@ def wrapper(self, *args, **kwargs): print(f"---- {func.__name__}{args[:1]}") resp: Response = func(self, *args, **kwargs) self.session_header = resp.headers.get(HEADER_SESSION) + json_str = base64.urlsafe_b64decode(self.session_header) last = self.session_header_json - self.session_header_json = json.loads( - base64.urlsafe_b64decode(self.session_header) - ) + self.session_header_json = json.loads(json_str) if last: if last["id"] != self.session_header_json["id"]: print( @@ -72,7 +72,7 @@ def login(self): auth=auth, headers={ "Content-Type": "application/json", - "X-DATABEND-CLIENT-CAPS": "session_header", + HEADER_CAPS: "session_header", }, json=payload, ) @@ -83,7 +83,10 @@ def do_logout(self, _case_id): response = self.client.post( logout_url, auth=auth, - headers={HEADER_SESSION: self.session_header}, + headers={ + HEADER_CAPS: "session_header", + HEADER_SESSION: self.session_header + }, ) return response @@ -95,6 +98,7 @@ def do_query(self, query, url=query_url): auth=auth, headers={ "Content-Type": "application/json", + HEADER_CAPS: "session_header", HEADER_SESSION: self.session_header, }, json=query_payload, @@ -109,7 +113,7 @@ def set_fake_last_refresh_time(self): ).decode("ascii") -def main(): +def test_session(): client = Client() client.login() @@ -134,6 +138,32 @@ def main(): pprint(query_resp.get("session").get("need_keep_alive")) +# without X-DATABEND-CLIENT-CAPS: +# 1. query still works +# 2. X-DATABEND-SESSION is ignored +def test_no_session(): + client = requests.session() + payload = {"sql": "select * from numbers(100)", "pagination": {"max_rows_per_page": 2}} + resp = client.post( + query_url, + auth=auth, + headers={"Content-Type": "application/json", HEADER_SESSION: "xxx"}, + json=payload, + ) + resp = resp.json() + next_uri = resp.get("next_uri") + resp = client.get( + f"http://localhost:8000/{next_uri}", + auth=auth, + ) + resp = resp.json() + assert len(resp["data"]) == 2, resp + +def main(): + test_no_session() + test_session() + + if __name__ == "__main__": import logging