diff --git a/Cargo.lock b/Cargo.lock index dc443ac8c..e9c045d11 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1951,6 +1951,7 @@ dependencies = [ "futures", "hyper", "hyper-rustls", + "kawa", "libc", "mio", "rustls", diff --git a/bin/src/cli.rs b/bin/src/cli.rs index 454dc9dbd..c529b9ac1 100644 --- a/bin/src/cli.rs +++ b/bin/src/cli.rs @@ -454,6 +454,8 @@ pub enum HttpFrontendCmd { method: Option, #[clap(long = "tags", help = "Specify tag (key-value pair) to apply on front-end (example: 'key=value, other-key=other-value')", value_parser = parse_tags)] tags: Option>, + #[clap(help = "the frontend uses http2 with prio-knowledge")] + h2: Option, }, #[clap(name = "remove")] Remove { diff --git a/bin/src/ctl/command.rs b/bin/src/ctl/command.rs index d291b1e3f..a96aea86b 100644 --- a/bin/src/ctl/command.rs +++ b/bin/src/ctl/command.rs @@ -434,7 +434,7 @@ impl CommandManager { if let Some(response_content) = response.content { let certs = match response_content.content_type { Some(ContentType::CertificatesWithFingerprints(certs)) => certs.certs, - _ => bail!(format!("Wrong response content {:?}", response_content)), + _ => bail!(format!("Wrong response content {response_content:?}")), }; if certs.is_empty() { bail!("No certificates match your request."); diff --git a/bin/src/ctl/display.rs b/bin/src/ctl/display.rs index 78a8eb66a..8250a8fb4 100644 --- a/bin/src/ctl/display.rs +++ b/bin/src/ctl/display.rs @@ -611,7 +611,7 @@ pub fn print_cluster_responses( clusters_table.set_format(*prettytable::format::consts::FORMAT_BOX_CHARS); let mut header = vec![cell!("cluster id")]; for worker_id in worker_responses.map.keys() { - header.push(cell!(format!("worker {}", worker_id))); + header.push(cell!(format!("worker {worker_id}"))); } header.push(cell!("desynchronized")); clusters_table.add_row(Row::new(header)); @@ -659,14 +659,14 @@ pub fn print_certificates_by_worker( } for (worker_id, response_content) in response_contents.iter() { - println!("Worker {}", worker_id); + println!("Worker {worker_id}"); match &response_content.content_type { Some(ContentType::CertificatesByAddress(list)) => { for certs in list.certificates.iter() { println!("\t{}:", certs.address); for summary in certs.certificate_summaries.iter() { - println!("\t\t{}", summary); + println!("\t\t{summary}"); } println!(); diff --git a/bin/src/ctl/request_builder.rs b/bin/src/ctl/request_builder.rs index ba9706dd0..c1b38d537 100644 --- a/bin/src/ctl/request_builder.rs +++ b/bin/src/ctl/request_builder.rs @@ -202,6 +202,7 @@ impl CommandManager { method, cluster_id: route, tags, + h2, } => self.send_request( RequestType::AddHttpFrontend(RequestHttpFrontend { cluster_id: route.into(), @@ -214,6 +215,7 @@ impl CommandManager { Some(tags) => tags, None => BTreeMap::new(), }, + h2: h2.unwrap_or(false), }) .into(), ), @@ -250,6 +252,7 @@ impl CommandManager { method, cluster_id: route, tags, + h2, } => self.send_request( RequestType::AddHttpsFrontend(RequestHttpFrontend { cluster_id: route.into(), @@ -262,6 +265,7 @@ impl CommandManager { Some(tags) => tags, None => BTreeMap::new(), }, + h2: h2.unwrap_or(false), }) .into(), ), diff --git a/command/src/command.proto b/command/src/command.proto index 408bde6e0..d9fad822b 100644 --- a/command/src/command.proto +++ b/command/src/command.proto @@ -161,6 +161,7 @@ message HttpsListenerConfig { // The tickets allow the client to resume a session. This protects the client // agains session tracking. Defaults to 4. required uint64 send_tls13_tickets = 20; + repeated AlpnProtocol alpn = 21; } // details of an TCP listener @@ -221,6 +222,7 @@ message RequestHttpFrontend { required RulePosition position = 6 [default = TREE]; // custom tags to identify the frontend in the access logs map tags = 7; + required bool h2 = 8; } message RequestTcpFrontend { @@ -339,6 +341,11 @@ enum TlsVersion { TLS_V1_3 = 5; } +enum AlpnProtocol { + Http11 = 0; + H2 = 1; +} + // A cluster is what binds a frontend to backends with routing rules message Cluster { required string cluster_id = 1; diff --git a/command/src/config.rs b/command/src/config.rs index efe1cdc26..395199744 100644 --- a/command/src/config.rs +++ b/command/src/config.rs @@ -62,7 +62,7 @@ use toml; use crate::{ certificate::split_certificate_chain, proto::command::{ - request::RequestType, ActivateListener, AddBackend, AddCertificate, CertificateAndKey, + request::RequestType, ActivateListener, AddBackend, AddCertificate, AlpnProtocol, CertificateAndKey, Cluster, HttpListenerConfig, HttpsListenerConfig, ListenerType, LoadBalancingAlgorithms, LoadBalancingParams, LoadMetric, MetricsConfiguration, PathRule, ProxyProtocolConfig, Request, RequestHttpFrontend, RequestTcpFrontend, RulePosition, TcpListenerConfig, @@ -72,8 +72,10 @@ use crate::{ ObjectKind, }; +pub const DEFAULT_ALPN: [AlpnProtocol; 1] = [AlpnProtocol::Http11]; + /// provides all supported cipher suites exported by Rustls TLS -/// provider as it support only strongly secure ones. +/// provider as it supports only strongly secure ones. /// /// See the [documentation](https://docs.rs/rustls/latest/rustls/static.ALL_CIPHER_SUITES.html) pub const DEFAULT_RUSTLS_CIPHER_LIST: [&str; 9] = [ @@ -236,6 +238,7 @@ pub enum ConfigError { pub struct ListenerBuilder { pub address: String, pub protocol: Option, + pub alpn: Option>, pub public_address: Option, /// path to the 404 html file pub answer_404: Option, @@ -311,7 +314,7 @@ impl ListenerBuilder { } } - pub fn with_public_address(&mut self, public_address: Option) -> &mut Self + pub fn with_public_address(mut self, public_address: Option) -> Self where S: ToString, { @@ -321,7 +324,7 @@ impl ListenerBuilder { self } - pub fn with_answer_404_path(&mut self, answer_404_path: Option) -> &mut Self + pub fn with_answer_404_path(mut self, answer_404_path: Option) -> Self where S: ToString, { @@ -331,7 +334,7 @@ impl ListenerBuilder { self } - pub fn with_answer_503_path(&mut self, answer_503_path: Option) -> &mut Self + pub fn with_answer_503_path(mut self, answer_503_path: Option) -> Self where S: ToString, { @@ -341,27 +344,27 @@ impl ListenerBuilder { self } - pub fn with_tls_versions(&mut self, tls_versions: Vec) -> &mut Self { + pub fn with_tls_versions(mut self, tls_versions: Vec) -> Self { self.tls_versions = Some(tls_versions); self } - pub fn with_cipher_list(&mut self, cipher_list: Option>) -> &mut Self { + pub fn with_cipher_list(mut self, cipher_list: Option>) -> Self { self.cipher_list = cipher_list; self } - pub fn with_cipher_suites(&mut self, cipher_suites: Option>) -> &mut Self { + pub fn with_cipher_suites(mut self, cipher_suites: Option>) -> Self { self.cipher_suites = cipher_suites; self } - pub fn with_expect_proxy(&mut self, expect_proxy: bool) -> &mut Self { + pub fn with_expect_proxy(mut self, expect_proxy: bool) -> Self { self.expect_proxy = Some(expect_proxy); self } - pub fn with_sticky_name(&mut self, sticky_name: Option) -> &mut Self + pub fn with_sticky_name(mut self, sticky_name: Option) -> Self where S: ToString, { @@ -371,7 +374,7 @@ impl ListenerBuilder { self } - pub fn with_certificate(&mut self, certificate: S) -> &mut Self + pub fn with_certificate(mut self, certificate: S) -> Self where S: ToString, { @@ -379,12 +382,12 @@ impl ListenerBuilder { self } - pub fn with_certificate_chain(&mut self, certificate_chain: String) -> &mut Self { + pub fn with_certificate_chain(mut self, certificate_chain: String) -> Self { self.certificate = Some(certificate_chain); self } - pub fn with_key(&mut self, key: String) -> &mut Self + pub fn with_key(mut self, key: String) -> Self where S: ToString, { @@ -392,22 +395,22 @@ impl ListenerBuilder { self } - pub fn with_front_timeout(&mut self, front_timeout: Option) -> &mut Self { + pub fn with_front_timeout(mut self, front_timeout: Option) -> Self { self.front_timeout = front_timeout; self } - pub fn with_back_timeout(&mut self, back_timeout: Option) -> &mut Self { + pub fn with_back_timeout(mut self, back_timeout: Option) -> Self { self.back_timeout = back_timeout; self } - pub fn with_connect_timeout(&mut self, connect_timeout: Option) -> &mut Self { + pub fn with_connect_timeout(mut self, connect_timeout: Option) -> Self { self.connect_timeout = connect_timeout; self } - pub fn with_request_timeout(&mut self, request_timeout: Option) -> &mut Self { + pub fn with_request_timeout(mut self, request_timeout: Option) -> Self { self.request_timeout = request_timeout; self } @@ -432,29 +435,28 @@ impl ListenerBuilder { } /// build an HTTP listener with config timeouts, using defaults if no config is provided - pub fn to_http(&mut self, config: Option<&Config>) -> Result { + pub fn to_http(mut self, config: Option<&Config>) -> Result { if self.protocol != Some(ListenerProtocol::Http) { return Err(ConfigError::WrongListenerProtocol { expected: ListenerProtocol::Http, - found: self.protocol.to_owned(), + found: self.protocol, }); } - if let Some(config) = config { - self.assign_config_timeouts(config); - } + let _address = self.parse_address()?; + let _public_address = self.parse_public_address()?; let (answer_404, answer_503) = self.get_404_503_answers()?; - let _address = self.parse_address()?; - - let _public_address = self.parse_public_address()?; + if let Some(config) = config { + self.assign_config_timeouts(config); + } let configuration = HttpListenerConfig { - address: self.address.clone(), - public_address: self.public_address.clone(), + address: self.address, + public_address: self.public_address, expect_proxy: self.expect_proxy.unwrap_or(false), - sticky_name: self.sticky_name.clone(), + sticky_name: self.sticky_name, front_timeout: self.front_timeout.unwrap_or(DEFAULT_FRONT_TIMEOUT), back_timeout: self.back_timeout.unwrap_or(DEFAULT_BACK_TIMEOUT), connect_timeout: self.connect_timeout.unwrap_or(DEFAULT_CONNECT_TIMEOUT), @@ -468,34 +470,51 @@ impl ListenerBuilder { } /// build an HTTPS listener using defaults if no config or values were provided upstream - pub fn to_tls(&mut self, config: Option<&Config>) -> Result { + pub fn to_tls(mut self, config: Option<&Config>) -> Result { if self.protocol != Some(ListenerProtocol::Https) { return Err(ConfigError::WrongListenerProtocol { expected: ListenerProtocol::Https, - found: self.protocol.to_owned(), + found: self.protocol, }); } + let _address = self.parse_address()?; + let _public_address = self.parse_public_address()?; + + let (answer_404, answer_503) = self.get_404_503_answers()?; + + if let Some(config) = config { + self.assign_config_timeouts(config); + } + + let default_alpn = DEFAULT_ALPN.into_iter().map(|p| p as i32).collect(); + + let alpn = self + .alpn + .as_ref() + .map(|alpn| alpn.iter().map(|p| *p as i32).collect()) + .unwrap_or(default_alpn); + let default_cipher_list = DEFAULT_RUSTLS_CIPHER_LIST .into_iter() .map(String::from) .collect(); - let cipher_list = self.cipher_list.clone().unwrap_or(default_cipher_list); + let cipher_list = self.cipher_list.unwrap_or(default_cipher_list); let default_cipher_suites = DEFAULT_CIPHER_SUITES .into_iter() .map(String::from) .collect(); - let cipher_suites = self.cipher_suites.clone().unwrap_or(default_cipher_suites); + let cipher_suites = self.cipher_suites.unwrap_or(default_cipher_suites); - let signature_algorithms: Vec = DEFAULT_SIGNATURE_ALGORITHMS + let signature_algorithms = DEFAULT_SIGNATURE_ALGORITHMS .into_iter() .map(String::from) .collect(); - let groups_list: Vec = DEFAULT_GROUPS_LIST.into_iter().map(String::from).collect(); + let groups_list = DEFAULT_GROUPS_LIST.into_iter().map(String::from).collect(); let versions = match self.tls_versions { None => vec![TlsVersion::TlsV12 as i32, TlsVersion::TlsV13 as i32], @@ -532,23 +551,10 @@ impl ListenerBuilder { .map(split_certificate_chain) .unwrap_or_else(Vec::new); - let (answer_404, answer_503) = self - .get_404_503_answers() - //.with_context(|| "Could not get 404 and 503 answers from file system") - ?; - - let _address = self.parse_address()?; - - let _public_address = self.parse_public_address()?; - - if let Some(config) = config { - self.assign_config_timeouts(config); - } - let https_listener_config = HttpsListenerConfig { - address: self.address.clone(), - sticky_name: self.sticky_name.clone(), - public_address: self.public_address.clone(), + address: self.address, + sticky_name: self.sticky_name, + public_address: self.public_address, cipher_list, versions, expect_proxy: self.expect_proxy.unwrap_or(false), @@ -568,22 +574,22 @@ impl ListenerBuilder { send_tls13_tickets: self .send_tls13_tickets .unwrap_or(DEFAULT_SEND_TLS_13_TICKETS), + alpn, }; Ok(https_listener_config) } /// build an HTTPS listener using defaults if no config or values were provided upstream - pub fn to_tcp(&mut self, config: Option<&Config>) -> Result { + pub fn to_tcp(mut self, config: Option<&Config>) -> Result { if self.protocol != Some(ListenerProtocol::Tcp) { return Err(ConfigError::WrongListenerProtocol { expected: ListenerProtocol::Tcp, - found: self.protocol.to_owned(), + found: self.protocol, }); } let _address = self.parse_address()?; - let _public_address = self.parse_public_address()?; if let Some(config) = config { @@ -591,8 +597,8 @@ impl ListenerBuilder { } Ok(TcpListenerConfig { - address: self.address.clone(), - public_address: self.public_address.clone(), + address: self.address, + public_address: self.public_address, expect_proxy: self.expect_proxy.unwrap_or(false), front_timeout: self.front_timeout.unwrap_or(DEFAULT_FRONT_TIMEOUT), back_timeout: self.back_timeout.unwrap_or(DEFAULT_BACK_TIMEOUT), @@ -661,6 +667,11 @@ pub enum PathRuleType { Equals, } +/// Congruent with command.proto +fn default_rule_position() -> RulePosition { + RulePosition::Tree +} + #[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)] #[serde(deny_unknown_fields)] pub struct FileClusterFrontendConfig { @@ -676,9 +687,10 @@ pub struct FileClusterFrontendConfig { pub certificate_chain: Option, #[serde(default)] pub tls_versions: Vec, - #[serde(default)] + #[serde(default = "default_rule_position")] pub position: RulePosition, pub tags: Option>, + pub h2: Option, } impl FileClusterFrontendConfig { @@ -764,6 +776,7 @@ impl FileClusterFrontendConfig { path, method: self.method.clone(), tags: self.tags.clone(), + h2: self.h2.unwrap_or(false), }) } } @@ -789,6 +802,7 @@ pub struct FileClusterConfig { pub frontends: Vec, pub backends: Vec, pub protocol: FileClusterProtocolConfig, + pub http_version: Option, pub sticky_session: Option, pub https_redirect: Option, #[serde(default)] @@ -914,6 +928,7 @@ pub struct HttpFrontendConfig { #[serde(default)] pub position: RulePosition, pub tags: Option>, + pub h2: bool, } impl HttpFrontendConfig { @@ -949,6 +964,7 @@ impl HttpFrontendConfig { path: self.path.clone(), method: self.method.clone(), position: self.position.into(), + h2: self.h2, tags, }) .into(), @@ -963,6 +979,7 @@ impl HttpFrontendConfig { path: self.path.clone(), method: self.method.clone(), position: self.position.into(), + h2: self.h2, tags, }) .into(), @@ -1285,19 +1302,19 @@ impl ConfigBuilder { } } - fn push_tls_listener(&mut self, mut listener: ListenerBuilder) -> Result<(), ConfigError> { + fn push_tls_listener(&mut self, listener: ListenerBuilder) -> Result<(), ConfigError> { let listener = listener.to_tls(Some(&self.built))?; self.built.https_listeners.push(listener); Ok(()) } - fn push_http_listener(&mut self, mut listener: ListenerBuilder) -> Result<(), ConfigError> { + fn push_http_listener(&mut self, listener: ListenerBuilder) -> Result<(), ConfigError> { let listener = listener.to_http(Some(&self.built))?; self.built.http_listeners.push(listener); Ok(()) } - fn push_tcp_listener(&mut self, mut listener: ListenerBuilder) -> Result<(), ConfigError> { + fn push_tcp_listener(&mut self, listener: ListenerBuilder) -> Result<(), ConfigError> { let listener = listener.to_tcp(Some(&self.built))?; self.built.tcp_listeners.push(listener); Ok(()) diff --git a/command/src/proto/display.rs b/command/src/proto/display.rs index fd2e2dad7..96918ef49 100644 --- a/command/src/proto/display.rs +++ b/command/src/proto/display.rs @@ -32,9 +32,9 @@ impl Display for CertificateSummary { impl Display for QueryCertificatesFilters { fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { if let Some(d) = self.domain.clone() { - write!(f, "domain:{}", d) + write!(f, "domain:{d}") } else if let Some(fp) = self.fingerprint.clone() { - write!(f, "domain:{}", fp) + write!(f, "domain:{fp}") } else { write!(f, "all certificates") } diff --git a/command/src/request.rs b/command/src/request.rs index c8350b23b..84ecd8dec 100644 --- a/command/src/request.rs +++ b/command/src/request.rs @@ -210,6 +210,7 @@ impl RequestHttpFrontend { } })?, tags: Some(self.tags), + h2: self.h2, }) } } diff --git a/command/src/response.rs b/command/src/response.rs index eb6ed6992..3bd83acae 100644 --- a/command/src/response.rs +++ b/command/src/response.rs @@ -39,6 +39,7 @@ pub struct HttpFrontend { #[serde(default)] pub position: RulePosition, pub tags: Option>, + pub h2: bool, } impl From for RequestHttpFrontend { @@ -54,6 +55,7 @@ impl From for RequestHttpFrontend { path: val.path, method: val.method, position: val.position.into(), + h2: val.h2, tags, } } diff --git a/command/src/state.rs b/command/src/state.rs index 04af19958..033bd863b 100644 --- a/command/src/state.rs +++ b/command/src/state.rs @@ -478,7 +478,7 @@ impl ConfigState { if tcp_frontends.contains(&tcp_frontend) { return Err(StateError::Exists { kind: ObjectKind::TcpFrontend, - id: format!("{:?}", tcp_frontend), + id: format!("{tcp_frontend:?}"), }); } @@ -497,7 +497,7 @@ impl ConfigState { .get_mut(&front_to_remove.cluster_id) .ok_or(StateError::NotFound { kind: ObjectKind::TcpFrontend, - id: format!("{:?}", front_to_remove), + id: format!("{front_to_remove:?}"), })?; let len = tcp_frontends.len(); diff --git a/e2e/Cargo.toml b/e2e/Cargo.toml index e590a2678..e26c6f632 100644 --- a/e2e/Cargo.toml +++ b/e2e/Cargo.toml @@ -5,6 +5,7 @@ rust-version = "1.70.0" edition = "2021" [dependencies] +kawa = "0.6.3" futures = "^0.3.28" hyper = { version = "^0.14.27", features = ["client", "http1"] } hyper-rustls = { version = "^0.24.1", default-features = false, features = ["webpki-tokio", "http1", "tls12", "logging"] } diff --git a/e2e/src/http_utils/mod.rs b/e2e/src/http_utils/mod.rs index 7ca7a5bbd..fe91efcb5 100644 --- a/e2e/src/http_utils/mod.rs +++ b/e2e/src/http_utils/mod.rs @@ -24,14 +24,20 @@ pub fn http_request, S2: Into, S3: Into, S4: In ) } -// the default value for the 404 error, as provided in the command lib, -// used as default for listeners -pub fn default_404_answer() -> String { - String::from(include_str!("../../../command/assets/404.html")) -} +use kawa; +use std::io::Write; -// the default value for the 503 error, as provided in the command lib, -// used as default for listeners -pub fn default_503_answer() -> String { - String::from(include_str!("../../../command/assets/503.html")) +/// the default kawa answer for the error code provided, converted to HTTP/1.1 +pub fn default_answer(code: u16) -> String { + let mut kawa_answer = kawa::Kawa::new( + kawa::Kind::Response, + kawa::Buffer::new(kawa::SliceBuffer(&mut [])), + ); + sozu_lib::protocol::mux::fill_default_answer(&mut kawa_answer, code); + kawa_answer.prepare(&mut kawa::h1::converter::H1BlockConverter); + let out = kawa_answer.as_io_slice(); + let mut writer = std::io::BufWriter::new(Vec::new()); + writer.write_vectored(&out).expect("WRITE"); + let result = unsafe { std::str::from_utf8_unchecked(writer.buffer()) }; + result.to_string() } diff --git a/e2e/src/mock/async_backend.rs b/e2e/src/mock/async_backend.rs index 56c452b62..b0ff7d25c 100644 --- a/e2e/src/mock/async_backend.rs +++ b/e2e/src/mock/async_backend.rs @@ -35,7 +35,7 @@ impl BackendHandle { let name = name.into(); let (stop_tx, mut stop_rx) = mpsc::channel::<()>(1); let (mut aggregator_tx, aggregator_rx) = mpsc::channel::(1); - let listener = TcpListener::bind(address).expect("could not bind"); + let listener = TcpListener::bind(address).expect(&format!("could not bind on: {address}")); let mut clients = Vec::new(); let thread_name = name.to_owned(); diff --git a/e2e/src/mock/client.rs b/e2e/src/mock/client.rs index 2a1adae60..4b89bcc81 100644 --- a/e2e/src/mock/client.rs +++ b/e2e/src/mock/client.rs @@ -39,7 +39,7 @@ impl Client { /// Establish a TCP connection with its address, /// register the yielded TCP stream, apply timeouts pub fn connect(&mut self) { - let stream = TcpStream::connect(self.address).expect("could not connect"); + let stream = TcpStream::connect(self.address).expect(&format!("could not connect to: {}", self.address)); stream .set_read_timeout(Some(Duration::from_millis(100))) .expect("could not set read timeout"); diff --git a/e2e/src/mock/sync_backend.rs b/e2e/src/mock/sync_backend.rs index 14712bb7f..ecd770149 100644 --- a/e2e/src/mock/sync_backend.rs +++ b/e2e/src/mock/sync_backend.rs @@ -44,7 +44,7 @@ impl Backend { /// Binds itself to its address, stores the yielded TCP listener pub fn connect(&mut self) { - let listener = TcpListener::bind(self.address).expect("could not bind"); + let listener = TcpListener::bind(self.address).expect(&format!("could not bind on: {}", self.address)); let timeout = Duration::from_millis(100); let timeout = libc::timeval { tv_sec: 0, diff --git a/e2e/src/tests/tests.rs b/e2e/src/tests/tests.rs index 4672257b8..4123e8ec4 100644 --- a/e2e/src/tests/tests.rs +++ b/e2e/src/tests/tests.rs @@ -17,7 +17,7 @@ use sozu_command_lib::{ }; use crate::{ - http_utils::{default_404_answer, default_503_answer, http_ok_response, http_request}, + http_utils::{default_answer, http_ok_response, http_request}, mock::{ aggregator::SimpleAggregator, async_backend::BackendHandle as AsyncBackend, @@ -672,7 +672,7 @@ fn try_http_behaviors() -> State { let response = client.receive(); println!("response: {response:?}"); - assert_eq!(response, Some(default_404_answer())); + assert_eq!(response, Some(default_answer(404))); assert_eq!(client.receive(), None); worker.send_proxy_request_type(RequestType::AddHttpFrontend(RequestHttpFrontend { @@ -687,7 +687,7 @@ fn try_http_behaviors() -> State { let response = client.receive(); println!("response: {response:?}"); - assert_eq!(response, Some(default_503_answer())); + assert_eq!(response, Some(default_answer(503))); assert_eq!(client.receive(), None); let back_address = create_local_address(); @@ -705,12 +705,9 @@ fn try_http_behaviors() -> State { client.connect(); client.send(); - let expected_response = String::from( - "HTTP/1.1 400 Bad Request\r\nCache-Control: no-cache\r\nConnection: close\r\n\r\n", - ); let response = client.receive(); println!("response: {response:?}"); - assert_eq!(response, Some(expected_response)); + assert_eq!(response, Some(default_answer(400))); assert_eq!(client.receive(), None); let mut backend = SyncBackend::new("backend", back_address, "TEST\r\n\r\n"); @@ -724,13 +721,10 @@ fn try_http_behaviors() -> State { let request = backend.receive(0); backend.send(0); - let expected_response = String::from( - "HTTP/1.1 502 Bad Gateway\r\nCache-Control: no-cache\r\nConnection: close\r\n\r\n", - ); let response = client.receive(); println!("request: {request:?}"); println!("response: {response:?}"); - assert_eq!(response, Some(expected_response)); + assert_eq!(response, Some(default_answer(502))); assert_eq!(client.receive(), None); info!("expecting 200"); @@ -782,7 +776,8 @@ fn try_http_behaviors() -> State { && response.ends_with(&expected_response_end) ); - info!("server closes, expecting 503"); + // FIXME: do we want 502 or 503??? + info!("server closes, expecting 502"); // TODO: what if the client continue to use the closed stream client.connect(); client.send(); @@ -793,7 +788,7 @@ fn try_http_behaviors() -> State { let response = client.receive(); println!("request: {request:?}"); println!("response: {response:?}"); - assert_eq!(response, Some(default_503_answer())); + assert_eq!(response, Some(default_answer(502))); assert_eq!(client.receive(), None); worker.send_proxy_request_type(RequestType::RemoveBackend(RemoveBackend { @@ -1200,7 +1195,9 @@ pub fn try_stick() -> State { backend1.send(0); let response = client.receive(); println!("response: {response:?}"); - assert!(request.unwrap().starts_with("GET /api HTTP/1.1\r\nHost: localhost\r\nConnection: close\r\nCookie: foo=bar\r\nX-Forwarded-For:")); + assert!(request.unwrap().starts_with( + "GET /api HTTP/1.1\r\nHost: localhost\r\nConnection: close\r\nCookie: foo=bar\r\n" + )); assert!(response.unwrap().starts_with("HTTP/1.1 200 OK\r\nContent-Length: 5\r\nSet-Cookie: SOZUBALANCEID=sticky_cluster_0-0; Path=/\r\nSozu-Id:")); // invalid sticky_session @@ -1213,7 +1210,9 @@ pub fn try_stick() -> State { backend2.send(0); let response = client.receive(); println!("response: {response:?}"); - assert!(request.unwrap().starts_with("GET /api HTTP/1.1\r\nHost: localhost\r\nConnection: close\r\nCookie: foo=bar\r\nX-Forwarded-For:")); + assert!(request.unwrap().starts_with( + "GET /api HTTP/1.1\r\nHost: localhost\r\nConnection: close\r\nCookie: foo=bar\r\n" + )); assert!(response.unwrap().starts_with("HTTP/1.1 200 OK\r\nContent-Length: 5\r\nSet-Cookie: SOZUBALANCEID=sticky_cluster_0-1; Path=/\r\nSozu-Id:")); // good sticky_session (force use backend2, round-robin would have chosen backend1) @@ -1226,7 +1225,9 @@ pub fn try_stick() -> State { backend2.send(0); let response = client.receive(); println!("response: {response:?}"); - assert!(request.unwrap().starts_with("GET /api HTTP/1.1\r\nHost: localhost\r\nConnection: close\r\nCookie: foo=bar\r\nX-Forwarded-For:")); + assert!(request.unwrap().starts_with( + "GET /api HTTP/1.1\r\nHost: localhost\r\nConnection: close\r\nCookie: foo=bar\r\n" + )); assert!(response .unwrap() .starts_with("HTTP/1.1 200 OK\r\nContent-Length: 5\r\nSozu-Id:")); diff --git a/lib/src/http.rs b/lib/src/http.rs index f00fa9de4..fead82726 100644 --- a/lib/src/http.rs +++ b/lib/src/http.rs @@ -39,14 +39,14 @@ use crate::{ answers::HttpAnswers, parser::{hostname_and_port, Method}, }, + mux::{self, Mux}, proxy_protocol::expect::ExpectProxyProtocol, - Http, Pipe, SessionState, + Pipe, SessionState, }, router::{Route, Router}, server::{ListenSession, ListenToken, ProxyChannel, Server, SessionManager}, socket::server_bind, timer::TimeoutContainer, - util::UnwrapLog, AcceptError, CachedTags, FrontendFromRequestError, L7ListenerHandler, L7Proxy, ListenerError, ListenerHandler, Protocol, ProxyConfiguration, ProxyError, ProxySession, SessionIsToBeClosed, SessionMetrics, SessionResult, StateMachineBuilder, StateResult, @@ -66,7 +66,8 @@ StateMachineBuilder! { /// 3. WebSocket (passthrough) enum HttpStateMachine impl SessionState { Expect(ExpectProxyProtocol), - Http(Http), + // Http(Http), + Mux(Mux), WebSocket(Pipe), } } @@ -125,22 +126,42 @@ impl HttpSession { gauge_add!("protocol.http", 1); let session_address = sock.peer_addr().ok(); - HttpStateMachine::Http(Http::new( - answers.clone(), + let frontend = mux::Connection::new_h1_server(sock, container_frontend_timeout); + let router = mux::Router::new( configured_backend_timeout, configured_connect_timeout, - configured_frontend_timeout, - container_frontend_timeout, - sock, - token, listener.clone(), - pool.clone(), - Protocol::HTTP, + ); + let mut context = mux::Context::new(pool.clone()); + context + .create_stream(request_id, 1 << 16) + .ok_or(AcceptError::BufferCapacityReached)?; + HttpStateMachine::Mux(Mux { + configured_frontend_timeout, + frontend_token: token, + frontend, + router, public_address, - request_id, - session_address, - sticky_name.clone(), - )?) + peer_address: session_address, + sticky_name: sticky_name.clone(), + context, + }) + // HttpStateMachine::Http(Http::new( + // answers.clone(), + // configured_backend_timeout, + // configured_connect_timeout, + // configured_frontend_timeout, + // container_frontend_timeout, + // sock, + // token, + // listener.clone(), + // pool.clone(), + // Protocol::HTTP, + // public_address, + // request_id, + // session_address, + // sticky_name.clone(), + // )?) }; let metrics = SessionMetrics::new(Some(wait_time)); @@ -164,7 +185,8 @@ impl HttpSession { pub fn upgrade(&mut self) -> SessionIsToBeClosed { debug!("HTTP::upgrade"); let new_state = match self.state.take() { - HttpStateMachine::Http(http) => self.upgrade_http(http), + // HttpStateMachine::Http(http) => self.upgrade_http(http), + HttpStateMachine::Mux(mux) => self.upgrade_mux(mux), HttpStateMachine::Expect(expect) => self.upgrade_expect(expect), HttpStateMachine::WebSocket(ws) => self.upgrade_websocket(ws), HttpStateMachine::FailedUpgrade(_) => unreachable!(), @@ -191,64 +213,144 @@ impl HttpSession { .map(|add| (add.destination(), add.source())) { Some((Some(public_address), Some(session_address))) => { - let mut http = Http::new( - self.answers.clone(), + let frontend = mux::Connection::new_h1_server( + expect.frontend, + expect.container_frontend_timeout, + ); + let router = mux::Router::new( self.configured_backend_timeout, self.configured_connect_timeout, - self.configured_frontend_timeout, - expect.container_frontend_timeout, - expect.frontend, - expect.frontend_token, self.listener.clone(), - self.pool.clone(), - Protocol::HTTP, + ); + let mut context = mux::Context::new(self.pool.clone()); + context.create_stream(expect.request_id, 1 << 16)?; + let mut mux = Mux { + configured_frontend_timeout: self.configured_frontend_timeout, + frontend_token: self.frontend_token, + frontend, + router, public_address, - expect.request_id, - Some(session_address), - self.sticky_name.clone(), - ) - .ok()?; - http.frontend_readiness.event = expect.frontend_readiness.event; + peer_address: Some(session_address), + sticky_name: self.sticky_name.clone(), + context, + }; + mux.frontend.readiness_mut().event = expect.frontend_readiness.event; gauge_add!("protocol.proxy.expect", -1); gauge_add!("protocol.http", 1); - Some(HttpStateMachine::Http(http)) + Some(HttpStateMachine::Mux(mux)) } _ => None, } } - fn upgrade_http(&mut self, http: Http) -> Option { - debug!("http switching to ws"); - let front_token = self.frontend_token; - let back_token = unwrap_msg!(http.backend_token); - let ws_context = http.websocket_context(); + // fn upgrade_http(&mut self, http: Http) -> Option { + // debug!("http switching to ws"); + // let front_token = self.frontend_token; + // let back_token = unwrap_msg!(http.backend_token); + // let ws_context = http.context.websocket_context(); + + // let mut container_frontend_timeout = http.container_frontend_timeout; + // let mut container_backend_timeout = http.container_backend_timeout; + // container_frontend_timeout.reset(); + // container_backend_timeout.reset(); + + // let mut pipe = Pipe::new( + // http.response_stream.storage.buffer, + // http.backend_id, + // http.backend_socket, + // http.backend, + // Some(container_backend_timeout), + // Some(container_frontend_timeout), + // http.cluster_id, + // http.request_stream.storage.buffer, + // front_token, + // http.frontend_socket, + // self.listener.clone(), + // Protocol::HTTP, + // http.context.id, + // http.context.session_address, + // Some(ws_context), + // ); + + // pipe.frontend_readiness.event = http.frontend_readiness.event; + // pipe.backend_readiness.event = http.backend_readiness.event; + // pipe.set_back_token(back_token); + + // gauge_add!("protocol.http", -1); + // gauge_add!("protocol.ws", 1); + // gauge_add!("http.active_requests", -1); + // gauge_add!("websocket.active_requests", 1); + // Some(HttpStateMachine::WebSocket(pipe)) + // } + + fn upgrade_mux(&mut self, mut mux: Mux) -> Option { + debug!("mux switching to ws"); + let stream = mux.context.streams.pop().unwrap(); + + let (frontend_readiness, frontend_socket, mut container_frontend_timeout) = + match mux.frontend { + mux::Connection::H1(mux::ConnectionH1 { + readiness, + socket, + timeout_container, + .. + }) => (readiness, socket, timeout_container), + mux::Connection::H2(_) => { + error!("Only h1<->h1 connections can upgrade to websocket"); + return None; + } + }; + + let mux::StreamState::Linked(back_token) = stream.state else { + error!("Upgrading stream should be linked to a backend"); + return None; + }; + let backend = mux.router.backends.remove(&back_token).unwrap(); + let (cluster_id, backend_readiness, backend_socket, mut container_backend_timeout) = + match backend { + mux::Connection::H1(mux::ConnectionH1 { + position: mux::Position::Client(mux::BackendStatus::Connected(cluster_id)), + readiness, + socket, + timeout_container, + .. + }) => (cluster_id, readiness, socket, timeout_container), + mux::Connection::H1(_) => { + error!("The backend disconnected just after upgrade, abort"); + return None; + } + mux::Connection::H2(_) => { + error!("Only h1<->h1 connections can upgrade to websocket"); + return None; + } + }; + + let ws_context = stream.context.websocket_context(); - let mut container_frontend_timeout = http.container_frontend_timeout; - let mut container_backend_timeout = http.container_backend_timeout; container_frontend_timeout.reset(); container_backend_timeout.reset(); let mut pipe = Pipe::new( - http.response_stream.storage.buffer, - http.backend_id, - http.backend_socket, - http.backend, + stream.back.storage.buffer, + None, + Some(backend_socket), + None, Some(container_backend_timeout), Some(container_frontend_timeout), - http.cluster_id, - http.request_stream.storage.buffer, - front_token, - http.frontend_socket, + Some(cluster_id), + stream.front.storage.buffer, + self.frontend_token, + frontend_socket, self.listener.clone(), Protocol::HTTP, - http.context.id, - http.context.session_address, + stream.context.id, + stream.context.session_address, Some(ws_context), ); - pipe.frontend_readiness.event = http.frontend_readiness.event; - pipe.backend_readiness.event = http.backend_readiness.event; + pipe.frontend_readiness.event = frontend_readiness.event; + pipe.backend_readiness.event = backend_readiness.event; pipe.set_back_token(back_token); gauge_add!("protocol.http", -1); @@ -277,7 +379,8 @@ impl ProxySession for HttpSession { // Restore gauges match self.state.marker() { StateMarker::Expect => gauge_add!("protocol.proxy.expect", -1), - StateMarker::Http => gauge_add!("protocol.http", -1), + // StateMarker::Http => gauge_add!("protocol.http", -1), + StateMarker::Mux => gauge_add!("protocol.http", -1), StateMarker::WebSocket => { gauge_add!("protocol.ws", -1); gauge_add!("websocket.active_requests", -1); @@ -287,7 +390,7 @@ impl ProxySession for HttpSession { if self.state.failed() { match self.state.marker() { StateMarker::Expect => incr!("http.upgrade.expect.failed"), - StateMarker::Http => incr!("http.upgrade.http.failed"), + StateMarker::Mux => incr!("http.upgrade.http.failed"), StateMarker::WebSocket => incr!("http.upgrade.ws.failed"), } return; @@ -461,7 +564,7 @@ impl L7ListenerHandler for HttpListener { let now = Instant::now(); - if let Route::ClusterId(cluster) = &route { + if let Route::Cluster { id: cluster, .. } = &route { time!( "frontend_matching_time", cluster, @@ -671,7 +774,7 @@ impl HttpProxy { if !socket_errors.is_empty() { return Err(ProxyError::SoftStop { proxy_protocol: "HTTP".to_string(), - error: format!("Error deregistering listen sockets: {:?}", socket_errors), + error: format!("Error deregistering listen sockets: {socket_errors:?}"), }); } @@ -694,7 +797,7 @@ impl HttpProxy { if !socket_errors.is_empty() { return Err(ProxyError::HardStop { proxy_protocol: "HTTP".to_string(), - error: format!("Error deregistering listen sockets: {:?}", socket_errors), + error: format!("Error deregistering listen sockets: {socket_errors:?}"), }); } @@ -1395,7 +1498,7 @@ mod tests { ); println!("http client write: {w:?}"); - let expected_answer = "HTTP/1.1 301 Moved Permanently\r\nContent-Length: 0\r\nLocation: https://localhost/redirected?true\r\n\r\n"; + let expected_answer = "HTTP/1.1 301 Moved Permanently\r\nLocation: https://localhost/redirected?true\r\nContent-Length: 0\r\n\r\n"; let mut buffer = [0; 4096]; let mut index = 0; loop { @@ -1467,6 +1570,7 @@ mod tests { position: RulePosition::Tree, cluster_id: Some(cluster_id1), tags: None, + h2: false, }) .expect("Could not add http frontend"); fronts @@ -1478,6 +1582,7 @@ mod tests { position: RulePosition::Tree, cluster_id: Some(cluster_id2), tags: None, + h2: false, }) .expect("Could not add http frontend"); fronts @@ -1489,6 +1594,7 @@ mod tests { position: RulePosition::Tree, cluster_id: Some(cluster_id3), tags: None, + h2: false, }) .expect("Could not add http frontend"); fronts @@ -1500,6 +1606,7 @@ mod tests { position: RulePosition::Tree, cluster_id: Some("cluster_1".to_owned()), tags: None, + h2: false, }) .expect("Could not add http frontend"); @@ -1531,19 +1638,31 @@ mod tests { let frontend5 = listener.frontend_from_request("domain", "/", &Method::Get); assert_eq!( frontend1.expect("should find frontend"), - Route::ClusterId("cluster_1".to_string()) + Route::Cluster { + id: "cluster_1".to_string(), + h2: false + } ); assert_eq!( frontend2.expect("should find frontend"), - Route::ClusterId("cluster_1".to_string()) + Route::Cluster { + id: "cluster_1".to_string(), + h2: false + } ); assert_eq!( frontend3.expect("should find frontend"), - Route::ClusterId("cluster_2".to_string()) + Route::Cluster { + id: "cluster_2".to_string(), + h2: false + } ); assert_eq!( frontend4.expect("should find frontend"), - Route::ClusterId("cluster_3".to_string()) + Route::Cluster { + id: "cluster_3".to_string(), + h2: false + } ); assert!(frontend5.is_err()); } diff --git a/lib/src/https.rs b/lib/src/https.rs index da3364857..98e4ea99e 100644 --- a/lib/src/https.rs +++ b/lib/src/https.rs @@ -33,10 +33,10 @@ use sozu_command::{ config::DEFAULT_CIPHER_SUITES, logging, proto::command::{ - request::RequestType, response_content::ContentType, AddCertificate, CertificateSummary, - CertificatesByAddress, Cluster, HttpsListenerConfig, ListOfCertificatesByAddress, - ListenerType, RemoveCertificate, RemoveListener, ReplaceCertificate, RequestHttpFrontend, - ResponseContent, TlsVersion, + request::RequestType, response_content::ContentType, AddCertificate, AlpnProtocol, + CertificateSummary, CertificatesByAddress, Cluster, HttpsListenerConfig, + ListOfCertificatesByAddress, ListenerType, RemoveCertificate, RemoveListener, + ReplaceCertificate, RequestHttpFrontend, ResponseContent, TlsVersion, }, ready::Ready, request::WorkerRequest, @@ -49,14 +49,14 @@ use crate::{ backends::BackendMap, pool::Pool, protocol::{ - h2::Http2, http::{ answers::HttpAnswers, parser::{hostname_and_port, Method}, }, + mux::{self, Mux}, proxy_protocol::expect::ExpectProxyProtocol, rustls::TlsHandshake, - Http, Pipe, SessionState, + Pipe, SessionState, }, router::{Route, Router}, server::{ListenSession, ListenToken, ProxyChannel, Server, SessionManager, SessionToken}, @@ -69,9 +69,6 @@ use crate::{ SessionMetrics, SessionResult, StateMachineBuilder, StateResult, }; -// const SERVER_PROTOS: &[&str] = &["http/1.1", "h2"]; -const SERVER_PROTOS: &[&str] = &["http/1.1"]; - #[derive(Debug, Clone, PartialEq, Eq)] pub struct TlsCluster { cluster_id: String, @@ -89,17 +86,13 @@ StateMachineBuilder! { enum HttpsStateMachine impl SessionState { Expect(ExpectProxyProtocol, ServerConnection), Handshake(TlsHandshake), - Http(Http), + Mux(Mux), + // Http(Http), WebSocket(Pipe), - Http2(Http2) -> todo!("H2"), + // Http2(Http2) -> todo!("H2"), } } -pub enum AlpnProtocols { - H2, - Http11, -} - pub struct HttpsSession { answers: Rc>, configured_backend_timeout: Duration, @@ -191,8 +184,9 @@ impl HttpsSession { let new_state = match self.state.take() { HttpsStateMachine::Expect(expect, ssl) => self.upgrade_expect(expect, ssl), HttpsStateMachine::Handshake(handshake) => self.upgrade_handshake(handshake), - HttpsStateMachine::Http(http) => self.upgrade_http(http), - HttpsStateMachine::Http2(_) => self.upgrade_http2(), + // HttpsStateMachine::Http(http) => self.upgrade_http(http), + HttpsStateMachine::Mux(mux) => self.upgrade_mux(mux), + // HttpsStateMachine::Http2(_) => self.upgrade_http2(), HttpsStateMachine::WebSocket(wss) => self.upgrade_websocket(wss), HttpsStateMachine::FailedUpgrade(_) => unreachable!(), }; @@ -255,12 +249,6 @@ impl HttpsSession { } fn upgrade_handshake(&mut self, handshake: TlsHandshake) -> Option { - // Add 1st routing phase - // - get SNI - // - get ALPN - // - find corresponding listener - // - determine next protocol (tcps, https ,http2) - let sni = handshake.session.server_name(); let alpn = handshake.session.alpn_protocol(); let alpn = alpn.and_then(|alpn| from_utf8(alpn).ok()); @@ -270,15 +258,16 @@ impl HttpsSession { ); let alpn = match alpn { - Some("http/1.1") => AlpnProtocols::Http11, - Some("h2") => AlpnProtocols::H2, + Some("http/1.1") => AlpnProtocol::Http11, + Some("h2") => AlpnProtocol::H2, Some(other) => { error!("Unsupported ALPN protocol: {}", other); return None; } // Some client don't fill in the ALPN protocol, in this case we default to Http/1.1 - None => AlpnProtocols::Http11, + None => AlpnProtocol::Http11, }; + println!("ALPN: {alpn:?}"); if let Some(version) = handshake.session.protocol_version() { incr!(rustls_version_str(version)); @@ -293,80 +282,146 @@ impl HttpsSession { }; gauge_add!("protocol.tls.handshake", -1); - match alpn { - AlpnProtocols::Http11 => { - let mut http = Http::new( - self.answers.clone(), - self.configured_backend_timeout, - self.configured_connect_timeout, - self.configured_frontend_timeout, - handshake.container_frontend_timeout, - front_stream, - self.frontend_token, - self.listener.clone(), - self.pool.clone(), - Protocol::HTTPS, - self.public_address, - handshake.request_id, - self.peer_address, - self.sticky_name.clone(), - ) - .ok()?; - - http.frontend_readiness.event = handshake.frontend_readiness.event; - gauge_add!("protocol.https", 1); - Some(HttpsStateMachine::Http(http)) + let router = mux::Router::new( + self.configured_backend_timeout, + self.configured_connect_timeout, + self.listener.clone(), + ); + let mut context = mux::Context::new(self.pool.clone()); + let mut frontend = match alpn { + AlpnProtocol::Http11 => { + context.create_stream(handshake.request_id, 1 << 16)?; + mux::Connection::new_h1_server(front_stream, handshake.container_frontend_timeout) } - AlpnProtocols::H2 => { - let mut http = Http2::new( - front_stream, - self.frontend_token, - self.pool.clone(), - Some(self.public_address), - None, - self.sticky_name.clone(), - ); - - http.frontend.readiness.event = handshake.frontend_readiness.event; + AlpnProtocol::H2 => mux::Connection::new_h2_server( + front_stream, + self.pool.clone(), + handshake.container_frontend_timeout, + )?, + }; + frontend.readiness_mut().event = handshake.frontend_readiness.event; + + gauge_add!("protocol.https", 1); + Some(HttpsStateMachine::Mux(Mux { + configured_frontend_timeout: self.configured_frontend_timeout, + frontend_token: self.frontend_token, + frontend, + context, + router, + public_address: self.public_address, + peer_address: self.peer_address, + sticky_name: self.sticky_name.clone(), + })) + } + + // fn upgrade_http(&self, http: Http) -> Option { + // debug!("https switching to wss"); + // let front_token = self.frontend_token; + // let back_token = unwrap_msg!(http.backend_token); + // let ws_context = http.context.websocket_context(); + + // let mut container_frontend_timeout = http.container_frontend_timeout; + // let mut container_backend_timeout = http.container_backend_timeout; + // container_frontend_timeout.reset(); + // container_backend_timeout.reset(); + + // let mut pipe = Pipe::new( + // http.response_stream.storage.buffer, + // http.backend_id, + // http.backend_socket, + // http.backend, + // Some(container_backend_timeout), + // Some(container_frontend_timeout), + // http.cluster_id, + // http.request_stream.storage.buffer, + // front_token, + // http.frontend_socket, + // self.listener.clone(), + // Protocol::HTTP, + // http.context.id, + // http.context.session_address, + // Some(ws_context), + // ); + + // pipe.frontend_readiness.event = http.frontend_readiness.event; + // pipe.backend_readiness.event = http.backend_readiness.event; + // pipe.set_back_token(back_token); + + // gauge_add!("protocol.https", -1); + // gauge_add!("protocol.wss", 1); + // gauge_add!("http.active_requests", -1); + // gauge_add!("websocket.active_requests", 1); + // Some(HttpsStateMachine::WebSocket(pipe)) + // } + + fn upgrade_mux(&self, mut mux: Mux) -> Option { + debug!("mux switching to wss"); + let stream = mux.context.streams.pop().unwrap(); + + let (frontend_readiness, frontend_socket, mut container_frontend_timeout) = + match mux.frontend { + mux::Connection::H1(mux::ConnectionH1 { + readiness, + socket, + timeout_container, + .. + }) => (readiness, socket, timeout_container), + mux::Connection::H2(_) => { + error!("Only h1<->h1 connections can upgrade to websocket"); + return None; + } + }; - gauge_add!("protocol.http2", 1); - Some(HttpsStateMachine::Http2(http)) - } - } - } + let mux::StreamState::Linked(back_token) = stream.state else { + error!("Upgrading stream should be linked to a backend"); + return None; + }; + let backend = mux.router.backends.remove(&back_token).unwrap(); + let (cluster_id, backend_readiness, backend_socket, mut container_backend_timeout) = + match backend { + mux::Connection::H1(mux::ConnectionH1 { + position: mux::Position::Client(mux::BackendStatus::Connected(cluster_id)), + readiness, + socket, + timeout_container, + .. + }) => (cluster_id, readiness, socket, timeout_container), + mux::Connection::H1(_) => { + error!("The backend disconnected just after upgrade, abort"); + return None; + } + mux::Connection::H2(_) => { + error!("Only h1<->h1 connections can upgrade to websocket"); + return None; + } + }; - fn upgrade_http(&self, http: Http) -> Option { - debug!("https switching to wss"); - let front_token = self.frontend_token; - let back_token = unwrap_msg!(http.backend_token); - let ws_context = http.websocket_context(); + let ws_context = stream.context.websocket_context(); - let mut container_frontend_timeout = http.container_frontend_timeout; - let mut container_backend_timeout = http.container_backend_timeout; container_frontend_timeout.reset(); container_backend_timeout.reset(); let mut pipe = Pipe::new( - http.response_stream.storage.buffer, - http.backend_id, - http.backend_socket, - http.backend, + stream.back.storage.buffer, + None, + Some(backend_socket), + None, Some(container_backend_timeout), Some(container_frontend_timeout), - http.cluster_id, - http.request_stream.storage.buffer, - front_token, - http.frontend_socket, + Some(cluster_id), + stream.front.storage.buffer, + self.frontend_token, + frontend_socket, self.listener.clone(), - Protocol::HTTP, - http.context.id, - http.context.session_address, + Protocol::HTTPS, + stream.context.id, + stream.context.session_address, Some(ws_context), ); - pipe.frontend_readiness.event = http.frontend_readiness.event; - pipe.backend_readiness.event = http.backend_readiness.event; + pipe.frontend_readiness.event = frontend_readiness.event; + pipe.backend_readiness.event = backend_readiness.event; pipe.set_back_token(back_token); gauge_add!("protocol.https", -1); @@ -376,10 +431,6 @@ impl HttpsSession { Some(HttpsStateMachine::WebSocket(pipe)) } - fn upgrade_http2(&self) -> Option { - todo!() - } - fn upgrade_websocket( &self, wss: Pipe, @@ -403,21 +454,20 @@ impl ProxySession for HttpsSession { match self.state.marker() { StateMarker::Expect => gauge_add!("protocol.proxy.expect", -1), StateMarker::Handshake => gauge_add!("protocol.tls.handshake", -1), - StateMarker::Http => gauge_add!("protocol.https", -1), + // StateMarker::Http => gauge_add!("protocol.https", -1), + StateMarker::Mux => gauge_add!("protocol.https", -1), StateMarker::WebSocket => { gauge_add!("protocol.wss", -1); gauge_add!("websocket.active_requests", -1); - } - StateMarker::Http2 => gauge_add!("protocol.http2", -1), + } // StateMarker::Http2 => gauge_add!("protocol.http2", -1), } if self.state.failed() { match self.state.marker() { StateMarker::Expect => incr!("https.upgrade.expect.failed"), StateMarker::Handshake => incr!("https.upgrade.handshake.failed"), - StateMarker::Http => incr!("https.upgrade.http.failed"), StateMarker::WebSocket => incr!("https.upgrade.wss.failed"), - StateMarker::Http2 => incr!("https.upgrade.http2.failed"), + StateMarker::Mux => incr!("https.upgrade.http.failed"), } return; } @@ -466,18 +516,24 @@ impl ProxySession for HttpsSession { token, super::ready_to_string(events) ); + println!("EVENT: {token:?}->{events:?}"); self.last_event = Instant::now(); self.metrics.wait_start(); self.state.update_readiness(token, events); } fn ready(&mut self, session: Rc>) -> SessionIsToBeClosed { + let start = std::time::Instant::now(); + println!("READY {start:?}"); self.metrics.service_start(); let session_result = self.state .ready(session.clone(), self.proxy.clone(), &mut self.metrics); + let end = std::time::Instant::now(); + println!("READY END {end:?} -> {:?}", end.duration_since(start)); + let to_be_closed = match session_result { SessionResult::Close => true, SessionResult::Continue => false, @@ -561,7 +617,6 @@ impl L7ListenerHandler for HttpsListener { let (remaining_input, (hostname, _)) = match hostname_and_port(host.as_bytes()) { Ok(tuple) => tuple, Err(parse_error) => { - // parse_error contains a slice of given_host, which should NOT escape this scope return Err(FrontendFromRequestError::HostParse { host: host.to_owned(), error: parse_error.to_string(), @@ -587,7 +642,7 @@ impl L7ListenerHandler for HttpsListener { let now = Instant::now(); - if let Route::ClusterId(cluster) = &route { + if let Route::Cluster { id: cluster, .. } = &route { time!( "frontend_matching_time", cluster, @@ -760,11 +815,20 @@ impl HttpsListener { .with_cert_resolver(resolver); server_config.send_tls13_tickets = config.send_tls13_tickets as usize; - let mut protocols = SERVER_PROTOS + let protocols = config + .alpn .iter() - .map(|proto| proto.as_bytes().to_vec()) + .filter_map(|protocol| match AlpnProtocol::try_from(*protocol) { + Ok(AlpnProtocol::Http11) => Some("http/1.1"), + Ok(AlpnProtocol::H2) => Some("h2"), + other_protocol => { + error!("unsupported ALPN protocol: {:?}", other_protocol); + None + } + }) + .map(|protocol| protocol.as_bytes().to_vec()) .collect::>(); - server_config.alpn_protocols.append(&mut protocols); + server_config.alpn_protocols = protocols; Ok(server_config) } @@ -872,7 +936,7 @@ impl HttpsProxy { if !socket_errors.is_empty() { return Err(ProxyError::SoftStop { proxy_protocol: "HTTPS".to_string(), - error: format!("Error deregistering listen sockets: {:?}", socket_errors), + error: format!("Error deregistering listen sockets: {socket_errors:?}"), }); } @@ -895,7 +959,7 @@ impl HttpsProxy { if !socket_errors.is_empty() { return Err(ProxyError::HardStop { proxy_protocol: "HTTPS".to_string(), - error: format!("Error deregistering listen sockets: {:?}", socket_errors), + error: format!("Error deregistering listen sockets: {socket_errors:?}"), }); } @@ -1619,25 +1683,37 @@ mod tests { "lolcatho.st".as_bytes(), &PathRule::Prefix(uri1), &MethodRule::new(None), - &Route::ClusterId(cluster_id1.clone()) + &Route::Cluster { + id: cluster_id1.clone(), + h2: false + } )); assert!(fronts.add_tree_rule( "lolcatho.st".as_bytes(), &PathRule::Prefix(uri2), &MethodRule::new(None), - &Route::ClusterId(cluster_id2) + &Route::Cluster { + id: cluster_id2, + h2: false + } )); assert!(fronts.add_tree_rule( "lolcatho.st".as_bytes(), &PathRule::Prefix(uri3), &MethodRule::new(None), - &Route::ClusterId(cluster_id3) + &Route::Cluster { + id: cluster_id3, + h2: false + } )); assert!(fronts.add_tree_rule( "other.domain".as_bytes(), &PathRule::Prefix("test".to_string()), &MethodRule::new(None), - &Route::ClusterId(cluster_id1) + &Route::Cluster { + id: cluster_id1, + h2: false + } )); let address: StdSocketAddr = FromStr::from_str("127.0.0.1:1032") @@ -1679,25 +1755,37 @@ mod tests { let frontend1 = listener.frontend_from_request("lolcatho.st", "/", &Method::Get); assert_eq!( frontend1.expect("should find a frontend"), - Route::ClusterId("cluster_1".to_string()) + Route::Cluster { + id: "cluster_1".to_string(), + h2: false + } ); println!("TEST {}", line!()); let frontend2 = listener.frontend_from_request("lolcatho.st", "/test", &Method::Get); assert_eq!( frontend2.expect("should find a frontend"), - Route::ClusterId("cluster_1".to_string()) + Route::Cluster { + id: "cluster_1".to_string(), + h2: false + } ); println!("TEST {}", line!()); let frontend3 = listener.frontend_from_request("lolcatho.st", "/yolo/test", &Method::Get); assert_eq!( frontend3.expect("should find a frontend"), - Route::ClusterId("cluster_2".to_string()) + Route::Cluster { + id: "cluster_2".to_string(), + h2: false + } ); println!("TEST {}", line!()); let frontend4 = listener.frontend_from_request("lolcatho.st", "/yolo/swag", &Method::Get); assert_eq!( frontend4.expect("should find a frontend"), - Route::ClusterId("cluster_3".to_string()) + Route::Cluster { + id: "cluster_3".to_string(), + h2: false + } ); println!("TEST {}", line!()); let frontend5 = listener.frontend_from_request("domain", "/", &Method::Get); diff --git a/lib/src/lib.rs b/lib/src/lib.rs index 18706cd6b..55041ec4b 100644 --- a/lib/src/lib.rs +++ b/lib/src/lib.rs @@ -515,9 +515,7 @@ macro_rules! StateMachineBuilder { /// leaving a FailedUpgrade in its place. /// The FailedUpgrade retains the marker of the previous State. fn take(&mut self) -> $state_name { - let mut owned_state = $state_name::FailedUpgrade(self.marker()); - std::mem::swap(&mut owned_state, self); - owned_state + std::mem::replace(self, $state_name::FailedUpgrade(self.marker())) } _fn_impl!{front_socket(&, self) -> &mio::net::TcpStream} } @@ -620,6 +618,8 @@ pub enum BackendConnectionError { MaxConnectionRetries(Option), #[error("the sessions slab has reached maximum capacity")] MaxSessionsMemory, + #[error("the checkout pool has reached maximum capacity")] + MaxBuffers, #[error("error from the backend: {0}")] Backend(BackendError), #[error("failed to retrieve the cluster: {0}")] @@ -639,6 +639,8 @@ pub enum RetrieveClusterError { UnauthorizedRoute, #[error("{0}")] RetrieveFrontend(FrontendFromRequestError), + #[error("https redirect")] + HttpsRedirect, } /// Used in sessions diff --git a/lib/src/protocol/kawa_h1/editor.rs b/lib/src/protocol/kawa_h1/editor.rs index 177a86c74..67abd464a 100644 --- a/lib/src/protocol/kawa_h1/editor.rs +++ b/lib/src/protocol/kawa_h1/editor.rs @@ -6,9 +6,10 @@ use std::{ use rusty_ulid::Ulid; use crate::{ + logs::Endpoint, pool::Checkout, protocol::http::{parser::compare_no_case, GenericHttpStream, Method}, - Protocol, + Protocol, RetrieveClusterError, }; /// This is the container used to store and use information about the session from within a Kawa parser callback @@ -36,7 +37,7 @@ pub struct HttpContext { pub user_agent: Option, // ========== Read only - /// signals wether Kawa should write a "Connection" header with a "close" value (request and response) + /// signals whether Kawa should write a "Connection" header with a "close" value (request and response) pub closing: bool, /// the value of the custom header, named "Sozu-Id", that Kawa should write (request and response) pub id: Ulid, @@ -337,4 +338,40 @@ impl HttpContext { val: kawa::Store::from_string(self.id.to_string()), })); } + + // -> host, path, method + pub fn extract_route(&self) -> Result<(&str, &str, &Method), RetrieveClusterError> { + let given_method = self.method.as_ref().ok_or(RetrieveClusterError::NoMethod)?; + let given_authority = self + .authority + .as_deref() + .ok_or(RetrieveClusterError::NoHost)?; + let given_path = self.path.as_deref().ok_or(RetrieveClusterError::NoPath)?; + + Ok((given_authority, given_path, given_method)) + } + + /// Format the context of the websocket into a loggable String + pub fn websocket_context(&self) -> String { + Endpoint::Http { + method: self.method.as_ref(), + authority: self.authority.as_deref(), + path: self.path.as_deref(), + status: self.status, + reason: self.reason.as_deref(), + } + .to_string() + } + + pub fn reset(&mut self) { + self.keep_alive_backend = true; + self.sticky_session_found = None; + self.method = None; + self.authority = None; + self.path = None; + self.status = None; + self.reason = None; + self.user_agent = None; + self.id = Ulid::generate(); + } } diff --git a/lib/src/protocol/kawa_h1/mod.rs b/lib/src/protocol/kawa_h1/mod.rs index d1a42bf1b..4ac3dc9ff 100644 --- a/lib/src/protocol/kawa_h1/mod.rs +++ b/lib/src/protocol/kawa_h1/mod.rs @@ -219,10 +219,7 @@ impl Http Http String { - format!( - "{}", - Endpoint::Http { - method: self.context.method.as_ref(), - authority: self.context.authority.as_deref(), - path: self.context.path.as_deref(), - status: self.context.status, - reason: self.context.reason.as_deref(), - } - ) - } - pub fn log_request(&self, metrics: &SessionMetrics, message: Option<&str>) { let listener = self.listener.borrow(); let tags = self.context.authority.as_ref().and_then(|host| { @@ -1078,32 +1061,11 @@ impl Http host, path, method - pub fn extract_route(&self) -> Result<(&str, &str, &Method), RetrieveClusterError> { - let given_method = self - .context - .method - .as_ref() - .ok_or(RetrieveClusterError::NoMethod)?; - let given_authority = self - .context - .authority - .as_deref() - .ok_or(RetrieveClusterError::NoHost)?; - let given_path = self - .context - .path - .as_deref() - .ok_or(RetrieveClusterError::NoPath)?; - - Ok((given_authority, given_path, given_method)) - } - fn cluster_id_from_request( &mut self, proxy: Rc>, ) -> Result { - let (host, uri, method) = match self.extract_route() { + let (host, uri, method) = match self.context.extract_route() { Ok(tuple) => tuple, Err(cluster_error) => { self.set_answer(DefaultAnswerStatus::Answer400, None); @@ -1125,7 +1087,7 @@ impl Http cluster_id, + Route::Cluster { id, .. } => id, Route::Deny => { self.set_answer(DefaultAnswerStatus::Answer401, None); return Err(RetrieveClusterError::UnauthorizedRoute); diff --git a/lib/src/protocol/mod.rs b/lib/src/protocol/mod.rs index 1342173a8..88dde18cb 100644 --- a/lib/src/protocol/mod.rs +++ b/lib/src/protocol/mod.rs @@ -1,5 +1,6 @@ pub mod h2; pub mod kawa_h1; +pub mod mux; pub mod pipe; pub mod proxy_protocol; pub mod rustls; diff --git a/lib/src/protocol/mux/converter.rs b/lib/src/protocol/mux/converter.rs new file mode 100644 index 000000000..aafcda8bf --- /dev/null +++ b/lib/src/protocol/mux/converter.rs @@ -0,0 +1,171 @@ +use std::str::from_utf8_unchecked; + +use kawa::{AsBuffer, Block, BlockConverter, Chunk, Flags, Kawa, Pair, StatusLine, Store}; + +use crate::protocol::http::parser::compare_no_case; + +use super::{ + parser::{FrameHeader, FrameType, H2Error}, + serializer::{gen_frame_header, gen_rst_stream}, + StreamId, +}; + +pub struct H2BlockConverter<'a> { + pub stream_id: StreamId, + pub encoder: &'a mut hpack::Encoder<'static>, + pub out: Vec, +} + +impl<'a, T: AsBuffer> BlockConverter for H2BlockConverter<'a> { + fn call(&mut self, block: Block, kawa: &mut Kawa) { + let buffer = kawa.storage.buffer(); + match block { + Block::StatusLine => match kawa.detached.status_line.pop() { + StatusLine::Request { + method, + authority, + path, + .. + } => { + self.encoder + .encode_header_into((b":method", method.data(buffer)), &mut self.out) + .unwrap(); + self.encoder + .encode_header_into((b":authority", authority.data(buffer)), &mut self.out) + .unwrap(); + self.encoder + .encode_header_into((b":path", path.data(buffer)), &mut self.out) + .unwrap(); + self.encoder + .encode_header_into((b":scheme", b"https"), &mut self.out) + .unwrap(); + } + StatusLine::Response { status, .. } => { + self.encoder + .encode_header_into((b":status", status.data(buffer)), &mut self.out) + .unwrap(); + } + StatusLine::Unknown => unreachable!(), + }, + Block::Cookies => { + if kawa.detached.jar.is_empty() { + return; + } + for cookie in kawa + .detached + .jar + .drain(..) + .filter(|cookie| !cookie.is_elided()) + { + let cookie = [cookie.key.data(buffer), b"=", cookie.val.data(buffer)].concat(); + self.encoder + .encode_header_into((b"cookie", &cookie), &mut self.out) + .unwrap(); + } + } + Block::Header(Pair { + key: Store::Empty, .. + }) => { + // elided header + } + Block::Header(Pair { key, val }) => { + { + let key = key.data(buffer); + let val = val.data(buffer); + if compare_no_case(key, b"connection") + || compare_no_case(key, b"host") + || compare_no_case(key, b"http2-settings") + || compare_no_case(key, b"keep-alive") + || compare_no_case(key, b"proxy-connection") + || compare_no_case(key, b"te") && !compare_no_case(val, b"trailers") + || compare_no_case(key, b"trailer") + || compare_no_case(key, b"transfer-encoding") + || compare_no_case(key, b"upgrade") + { + println!("Elided H2 header: {}", unsafe { from_utf8_unchecked(key) }); + return; + } + } + self.encoder + .encode_header_into( + (&key.data(buffer).to_ascii_lowercase(), val.data(buffer)), + &mut self.out, + ) + .unwrap(); + } + Block::ChunkHeader(_) => { + // this converter doesn't align H1 chunks on H2 data frames + } + Block::Chunk(Chunk { data }) => { + let mut header = [0; 9]; + let payload_len = match &data { + Store::Empty => 0, + Store::Detached(s) | Store::Slice(s) => s.len, + Store::Static(s) => s.len() as u32, + Store::Alloc(a, i) => a.len() as u32 - i, + }; + gen_frame_header( + &mut header, + &FrameHeader { + payload_len, + frame_type: FrameType::Data, + flags: 0, + stream_id: self.stream_id, + }, + ) + .unwrap(); + kawa.push_out(Store::new_vec(&header)); + kawa.push_out(data); + kawa.push_delimiter() + } + Block::Flags(Flags { + end_header, + end_stream, + .. + }) => { + if end_header { + let payload = std::mem::replace(&mut self.out, Vec::new()); + let mut header = [0; 9]; + let flags = if end_stream { 1 } else { 0 } | if end_header { 4 } else { 0 }; + gen_frame_header( + &mut header, + &FrameHeader { + payload_len: payload.len() as u32, + frame_type: FrameType::Headers, + flags, + stream_id: self.stream_id, + }, + ) + .unwrap(); + kawa.push_out(Store::new_vec(&header)); + kawa.push_out(Store::Alloc(payload.into_boxed_slice(), 0)); + } else if end_stream { + if kawa.is_error() { + let mut frame = [0; 13]; + gen_rst_stream(&mut frame, self.stream_id, H2Error::InternalError).unwrap(); + kawa.push_out(Store::new_vec(&frame)); + } else { + let mut header = [0; 9]; + gen_frame_header( + &mut header, + &FrameHeader { + payload_len: 0, + frame_type: FrameType::Data, + flags: 1, + stream_id: self.stream_id, + }, + ) + .unwrap(); + kawa.push_out(Store::new_vec(&header)); + } + } + if end_header || end_stream { + kawa.push_delimiter() + } + } + } + } + fn finalize(&mut self, _kawa: &mut Kawa) { + assert!(self.out.is_empty()); + } +} diff --git a/lib/src/protocol/mux/h1.rs b/lib/src/protocol/mux/h1.rs new file mode 100644 index 000000000..822bcaa36 --- /dev/null +++ b/lib/src/protocol/mux/h1.rs @@ -0,0 +1,266 @@ +use sozu_command::ready::Ready; + +use crate::{ + println_, + protocol::mux::{ + debug_kawa, forcefully_terminate_answer, set_default_answer, update_readiness_after_read, + update_readiness_after_write, BackendStatus, Context, Endpoint, GlobalStreamId, MuxResult, + Position, StreamState, + }, + socket::SocketHandler, + timer::TimeoutContainer, + Readiness, +}; + +pub struct ConnectionH1 { + pub position: Position, + pub readiness: Readiness, + pub requests: usize, + pub socket: Front, + /// note: a Server H1 will always reference stream 0, but a client can reference any stream + pub stream: GlobalStreamId, + pub timeout_container: TimeoutContainer, +} + +impl std::fmt::Debug for ConnectionH1 { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("ConnectionH1") + .field("position", &self.position) + .field("readiness", &self.readiness) + .field("socket", &self.socket.socket_ref()) + .field("stream", &self.stream) + .finish() + } +} + +impl ConnectionH1 { + pub fn readable(&mut self, context: &mut Context, mut endpoint: E) -> MuxResult + where + E: Endpoint, + { + println_!("======= MUX H1 READABLE {:?}", self.position); + self.timeout_container.reset(); + let stream = &mut context.streams[self.stream]; + let parts = stream.split(&self.position); + let kawa = parts.rbuffer; + let (size, status) = self.socket.socket_read(kawa.storage.space()); + kawa.storage.fill(size); + if update_readiness_after_read(size, status, &mut self.readiness) { + return MuxResult::Continue; + } + + let was_main_phase = kawa.is_main_phase(); + kawa::h1::parse(kawa, parts.context); + debug_kawa(kawa); + if kawa.is_error() { + match self.position { + Position::Client(_) => { + let StreamState::Linked(token) = stream.state else { unreachable!() }; + let global_stream_id = self.stream; + self.readiness.interest.remove(Ready::ALL); + self.end_stream(global_stream_id, context); + endpoint.end_stream(token, global_stream_id, context); + } + Position::Server => { + set_default_answer(stream, &mut self.readiness, 400); + } + } + return MuxResult::Continue; + } + if kawa.is_terminated() { + self.timeout_container.cancel(); + self.readiness.interest.remove(Ready::READABLE); + } + if kawa.is_main_phase() { + if let StreamState::Linked(token) = stream.state { + endpoint + .readiness_mut(token) + .interest + .insert(Ready::WRITABLE) + } + match self.position { + Position::Server => { + if !was_main_phase { + self.requests += 1; + println_!("REQUESTS: {}", self.requests); + stream.state = StreamState::Link + } + } + Position::Client(_) => {} + } + }; + MuxResult::Continue + } + + pub fn writable(&mut self, context: &mut Context, mut endpoint: E) -> MuxResult + where + E: Endpoint, + { + println_!("======= MUX H1 WRITABLE {:?}", self.position); + self.timeout_container.reset(); + let stream = &mut context.streams[self.stream]; + let kawa = stream.wbuffer(&self.position); + kawa.prepare(&mut kawa::h1::BlockConverter); + debug_kawa(kawa); + let bufs = kawa.as_io_slice(); + if bufs.is_empty() { + self.readiness.interest.remove(Ready::WRITABLE); + return MuxResult::Continue; + } + let (size, status) = self.socket.socket_write_vectored(&bufs); + kawa.consume(size); + if update_readiness_after_write(size, status, &mut self.readiness) { + return MuxResult::Continue; + } + + if kawa.is_terminated() && kawa.is_completed() { + match self.position { + Position::Client(_) => self.readiness.interest.insert(Ready::READABLE), + Position::Server => { + if stream.context.closing { + return MuxResult::CloseSession; + } + let kawa = &mut stream.back; + match kawa.detached.status_line { + kawa::StatusLine::Response { code: 101, .. } => { + println!("============== HANDLE UPGRADE!"); + return MuxResult::Upgrade; + } + kawa::StatusLine::Response { code: 100, .. } => { + println!("============== HANDLE CONTINUE!"); + // after a 100 continue, we expect the client to continue with its request + self.timeout_container.reset(); + self.readiness.interest.insert(Ready::READABLE); + kawa.clear(); + return MuxResult::Continue; + } + kawa::StatusLine::Response { code: 103, .. } => { + println!("============== HANDLE EARLY HINT!"); + if let StreamState::Linked(token) = stream.state { + // after a 103 early hints, we expect the backend to send its response + endpoint + .readiness_mut(token) + .interest + .insert(Ready::READABLE); + kawa.clear(); + return MuxResult::Continue; + } else { + return MuxResult::CloseSession; + } + } + _ => {} + } + let old_state = std::mem::replace(&mut stream.state, StreamState::Unlinked); + if stream.context.keep_alive_frontend { + self.timeout_container.reset(); + println!("{old_state:?} {:?}", self.readiness); + if let StreamState::Linked(token) = old_state { + println!("{:?}", endpoint.readiness(token)); + endpoint.end_stream(token, self.stream, context); + } + self.readiness.interest.insert(Ready::READABLE); + let stream = &mut context.streams[self.stream]; + stream.context.reset(); + stream.back.clear(); + stream.back.storage.clear(); + stream.front.clear(); + // do not stream.front.storage.clear() because of H1 pipelining + stream.attempts = 0; + } else { + return MuxResult::CloseSession; + } + } + } + } + MuxResult::Continue + } + + pub fn force_disconnect(&mut self) -> MuxResult { + match self.position { + Position::Client(_) => { + self.position = Position::Client(BackendStatus::Disconnecting); + self.readiness.event = Ready::HUP; + MuxResult::Continue + } + Position::Server => MuxResult::CloseSession, + } + } + + pub fn close(&mut self, context: &mut Context, mut endpoint: E) + where + E: Endpoint, + { + match self.position { + Position::Client(BackendStatus::KeepAlive(_)) + | Position::Client(BackendStatus::Disconnecting) => { + println_!("close detached client ConnectionH1"); + return; + } + Position::Client(BackendStatus::Connecting(_)) + | Position::Client(BackendStatus::Connected(_)) => {} + Position::Server => unreachable!(), + } + // reconnection is handled by the server + let StreamState::Linked(token) = context.streams[self.stream].state else {unreachable!()}; + endpoint.end_stream(token, self.stream, context) + } + + pub fn end_stream(&mut self, stream: GlobalStreamId, context: &mut Context) { + assert_eq!(stream, self.stream); + let stream = &mut context.streams[stream]; + let stream_context = &mut stream.context; + println_!("end H1 stream {}: {stream_context:#?}", self.stream); + match &mut self.position { + Position::Client(BackendStatus::Connected(cluster_id)) + | Position::Client(BackendStatus::Connecting(cluster_id)) => { + self.stream = usize::MAX; + if stream_context.keep_alive_backend { + self.position = + Position::Client(BackendStatus::KeepAlive(std::mem::take(cluster_id))) + } else { + self.force_disconnect(); + } + } + Position::Client(BackendStatus::KeepAlive(_)) + | Position::Client(BackendStatus::Disconnecting) => unreachable!(), + Position::Server => match (stream.front.consumed, stream.back.is_main_phase()) { + (true, true) => { + // we have a "forwardable" answer from the back + // if the answer is not terminated we send an RstStream to properly clean the stream + // if it is terminated, we finish the transfer, the backend is not necessary anymore + if !stream.back.is_terminated() { + forcefully_terminate_answer(stream, &mut self.readiness); + } else { + stream.state = StreamState::Unlinked; + self.readiness.interest.insert(Ready::WRITABLE); + } + } + (true, false) => { + // we do not have an answer, but the request has already been partially consumed + // so we can't retry, send a 502 bad gateway instead + set_default_answer(stream, &mut self.readiness, 502); + } + (false, false) => { + // we do not have an answer, but the request is untouched so we can retry + println!("H1 RECONNECT"); + stream.state = StreamState::Link; + } + (false, true) => unreachable!(), + }, + } + } + + pub fn start_stream(&mut self, stream: GlobalStreamId, context: &mut Context) { + println_!("start H1 stream {stream} {:?}", self.readiness); + self.readiness.interest.insert(Ready::ALL); + self.stream = stream; + match &mut self.position { + Position::Client(BackendStatus::KeepAlive(cluster_id)) => { + self.position = + Position::Client(BackendStatus::Connecting(std::mem::take(cluster_id))) + } + Position::Client(_) => {} + Position::Server => unreachable!(), + } + } +} diff --git a/lib/src/protocol/mux/h2.rs b/lib/src/protocol/mux/h2.rs new file mode 100644 index 000000000..1f40274dc --- /dev/null +++ b/lib/src/protocol/mux/h2.rs @@ -0,0 +1,773 @@ +use std::collections::HashMap; + +use rusty_ulid::Ulid; +use sozu_command::ready::Ready; + +use crate::{ + println_, + protocol::mux::{ + converter, debug_kawa, forcefully_terminate_answer, + parser::{self, error_code_to_str, Frame, FrameHeader, FrameType, H2Error, Headers}, + pkawa, serializer, set_default_answer, update_readiness_after_read, + update_readiness_after_write, BackendStatus, Context, Endpoint, GenericHttpStream, + GlobalStreamId, MuxResult, Position, StreamId, StreamState, + }, + socket::SocketHandler, + timer::TimeoutContainer, + Readiness, +}; + +#[inline(always)] +fn error_nom_to_h2(error: nom::Err) -> H2Error { + match error { + nom::Err::Error(parser::Error { + error: parser::InnerError::H2(e), + .. + }) => return e, + _ => return H2Error::ProtocolError, + } +} + +#[derive(Debug)] +pub enum H2State { + ClientPreface, + ClientSettings, + ServerSettings, + Header, + Frame(FrameHeader), + ContinuationHeader(Headers), + ContinuationFrame(Headers), + GoAway, + Error, +} + +#[derive(Debug)] +pub struct H2Settings { + pub settings_header_table_size: u32, + pub settings_enable_push: bool, + pub settings_max_concurrent_streams: u32, + pub settings_initial_window_size: u32, + pub settings_max_frame_size: u32, + pub settings_max_header_list_size: u32, + /// RFC 8441 + pub settings_enable_connect_protocol: bool, + /// RFC 9218 + pub settings_no_rfc7540_priorities: bool, +} + +impl Default for H2Settings { + fn default() -> Self { + Self { + settings_header_table_size: 4096, + settings_enable_push: true, + settings_max_concurrent_streams: 100, + settings_initial_window_size: (1 << 16) - 1, + settings_max_frame_size: 1 << 14, + settings_max_header_list_size: u32::MAX, + settings_enable_connect_protocol: false, + settings_no_rfc7540_priorities: true, + } + } +} + +pub struct Prioriser {} + +impl Prioriser { + pub fn new() -> Self { + Self {} + } + pub fn push_priority(&mut self, priority: parser::Priority) { + println!("DEPRECATED: {priority:?}"); + } +} + +pub struct ConnectionH2 { + pub decoder: hpack::Decoder<'static>, + pub encoder: hpack::Encoder<'static>, + pub expect_read: Option<(H2StreamId, usize)>, + pub expect_write: Option, + pub last_stream_id: StreamId, + pub local_settings: H2Settings, + pub peer_settings: H2Settings, + pub position: Position, + pub prioriser: Prioriser, + pub readiness: Readiness, + pub socket: Front, + pub state: H2State, + pub streams: HashMap, + pub timeout_container: TimeoutContainer, + pub window: u32, + pub zero: GenericHttpStream, +} +impl std::fmt::Debug for ConnectionH2 { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("ConnectionH2") + .field("expect", &self.expect_read) + .field("position", &self.position) + .field("readiness", &self.readiness) + .field("local_settings", &self.local_settings) + .field("peer_settings", &self.peer_settings) + .field("socket", &self.socket.socket_ref()) + .field("state", &self.state) + .field("streams", &self.streams) + .field("zero", &self.zero.storage.meter(20)) + .field("window", &self.window) + .finish() + } +} + +#[derive(Debug, Clone, Copy)] +pub enum H2StreamId { + Zero, + Other(StreamId, GlobalStreamId), +} + +impl ConnectionH2 { + pub fn readable(&mut self, context: &mut Context, endpoint: E) -> MuxResult + where + E: Endpoint, + { + println_!("======= MUX H2 READABLE {:?}", self.position); + self.timeout_container.reset(); + let (stream_id, kawa) = if let Some((stream_id, amount)) = self.expect_read { + let kawa = match stream_id { + H2StreamId::Zero => &mut self.zero, + H2StreamId::Other(stream_id, global_stream_id) => { + context.streams[global_stream_id].rbuffer(&self.position) + } + }; + println_!("{:?}({stream_id:?}, {amount})", self.state); + if amount > 0 { + let (size, status) = self.socket.socket_read(&mut kawa.storage.space()[..amount]); + kawa.storage.fill(size); + if update_readiness_after_read(size, status, &mut self.readiness) { + return MuxResult::Continue; + } else { + if size == amount { + self.expect_read = None; + } else { + self.expect_read = Some((stream_id, amount - size)); + return MuxResult::Continue; + } + } + } else { + self.expect_read = None; + } + (stream_id, kawa) + } else { + self.readiness.event.remove(Ready::READABLE); + return MuxResult::Continue; + }; + match (&self.state, &self.position) { + (H2State::Error, _) + | (H2State::GoAway, _) + | (H2State::ServerSettings, Position::Server) + | (H2State::ClientPreface, Position::Client(_)) + | (H2State::ClientSettings, Position::Client(_)) => unreachable!( + "Unexpected combination: (Writable, {:?}, {:?})", + self.state, self.position + ), + (H2State::ClientPreface, Position::Server) => { + let i = kawa.storage.data(); + let i = match parser::preface(i) { + Ok((i, _)) => i, + Err(_) => return self.force_disconnect(), + }; + match parser::frame_header(i) { + Ok(( + _, + FrameHeader { + payload_len, + frame_type: FrameType::Settings, + flags: 0, + stream_id: 0, + }, + )) => { + kawa.storage.clear(); + self.state = H2State::ClientSettings; + self.expect_read = Some((H2StreamId::Zero, payload_len as usize)); + } + _ => return self.force_disconnect(), + }; + } + (H2State::ClientSettings, Position::Server) => { + let i = kawa.storage.data(); + let settings = match parser::settings_frame( + i, + &FrameHeader { + payload_len: i.len() as u32, + frame_type: FrameType::Settings, + flags: 0, + stream_id: 0, + }, + ) { + Ok((_, settings)) => { + kawa.storage.clear(); + settings + } + Err(_) => return self.force_disconnect(), + }; + let kawa = &mut self.zero; + match serializer::gen_settings(kawa.storage.space(), &self.local_settings) { + Ok((_, size)) => kawa.storage.fill(size), + Err(e) => { + println!("could not serialize SettingsFrame: {e:?}"); + return self.force_disconnect(); + } + }; + + self.state = H2State::ServerSettings; + self.expect_write = Some(H2StreamId::Zero); + return self.handle_frame(settings, context, endpoint); + } + (H2State::ServerSettings, Position::Client(_)) => { + let i = kawa.storage.data(); + match parser::frame_header(i) { + Ok(( + _, + header @ FrameHeader { + payload_len, + frame_type: FrameType::Settings, + flags: 0, + stream_id: 0, + }, + )) => { + kawa.storage.clear(); + self.expect_read = Some((H2StreamId::Zero, payload_len as usize)); + self.state = H2State::Frame(header) + } + _ => return self.force_disconnect(), + }; + } + (H2State::Header, _) => { + let i = kawa.storage.data(); + println_!(" header: {i:?}"); + match parser::frame_header(i) { + Ok((_, header)) => { + println_!("{header:#?}"); + kawa.storage.clear(); + let stream_id = header.stream_id; + let stream_id = + if stream_id == 0 || header.frame_type == FrameType::RstStream { + H2StreamId::Zero + } else { + let global_stream_id = + if let Some(global_stream_id) = self.streams.get(&stream_id) { + *global_stream_id + } else { + match self.create_stream(stream_id, context) { + Some(global_stream_id) => global_stream_id, + None => return self.goaway(H2Error::InternalError), + } + }; + if header.frame_type == FrameType::Data { + H2StreamId::Other(stream_id, global_stream_id) + } else { + H2StreamId::Zero + } + }; + println_!("{} {stream_id:?} {:#?}", header.stream_id, self.streams); + self.expect_read = Some((stream_id, header.payload_len as usize)); + self.state = H2State::Frame(header); + } + Err(_) => return self.goaway(H2Error::ProtocolError), + }; + } + (H2State::ContinuationHeader(headers), _) => { + let i = kawa.storage.data(); + println_!(" continuation header: {i:?}"); + match parser::frame_header(i) { + Ok((_, header)) => { + println_!("{header:#?}"); + kawa.storage.end -= 9; + let stream_id = header.stream_id; + assert_eq!(stream_id, headers.stream_id); + self.expect_read = Some((H2StreamId::Zero, header.payload_len as usize)); + let mut headers = headers.clone(); + headers.end_headers = header.flags & 0x4 != 0; + headers.header_block_fragment.len += header.payload_len; + self.state = H2State::ContinuationFrame(headers); + } + Err(_) => return self.goaway(H2Error::ProtocolError), + }; + } + (H2State::Frame(header), _) => { + let i = kawa.storage.data(); + println_!(" data: {i:?}"); + let frame = match parser::frame_body( + i, + header, + self.local_settings.settings_max_frame_size, + ) { + Ok((_, frame)) => frame, + Err(e) => panic!("stream error: {:?}", error_nom_to_h2(e)), + }; + if let H2StreamId::Zero = stream_id { + kawa.storage.clear(); + } + self.state = H2State::Header; + self.expect_read = Some((H2StreamId::Zero, 9)); + return self.handle_frame(frame, context, endpoint); + } + (H2State::ContinuationFrame(headers), _) => { + let i = kawa.storage.data(); + println_!(" data: {i:?}"); + let headers = headers.clone(); + self.state = H2State::Header; + self.expect_read = Some((H2StreamId::Zero, 9)); + return self.handle_frame(Frame::Headers(headers), context, endpoint); + } + } + MuxResult::Continue + } + + pub fn writable(&mut self, context: &mut Context, mut endpoint: E) -> MuxResult + where + E: Endpoint, + { + println_!("======= MUX H2 WRITABLE {:?}", self.position); + self.timeout_container.reset(); + if let Some(H2StreamId::Zero) = self.expect_write { + let kawa = &mut self.zero; + println_!("{:?}", kawa.storage.data()); + while !kawa.storage.is_empty() { + let (size, status) = self.socket.socket_write(kawa.storage.data()); + kawa.storage.consume(size); + if update_readiness_after_write(size, status, &mut self.readiness) { + return MuxResult::Continue; + } + } + // when H2StreamId::Zero is used to write READABLE is disabled + // so when we finish the write we enable READABLE again + self.readiness.interest.insert(Ready::READABLE); + self.expect_write = None; + } + match (&self.state, &self.position) { + (H2State::Error, _) + | (H2State::ClientPreface, Position::Server) + | (H2State::ClientSettings, Position::Server) + | (H2State::ServerSettings, Position::Client(_)) => unreachable!( + "Unexpected combination: (Readable, {:?}, {:?})", + self.state, self.position + ), + (H2State::GoAway, _) => self.force_disconnect(), + (H2State::ClientPreface, Position::Client(_)) => { + println_!("Preparing preface and settings"); + let pri = serializer::H2_PRI.as_bytes(); + let kawa = &mut self.zero; + + kawa.storage.space()[0..pri.len()].copy_from_slice(pri); + kawa.storage.fill(pri.len()); + match serializer::gen_settings(kawa.storage.space(), &self.local_settings) { + Ok((_, size)) => kawa.storage.fill(size), + Err(e) => { + println!("could not serialize SettingsFrame: {e:?}"); + return self.force_disconnect(); + } + }; + + self.state = H2State::ClientSettings; + self.expect_write = Some(H2StreamId::Zero); + MuxResult::Continue + } + (H2State::ClientSettings, Position::Client(_)) => { + println_!("Sent preface and settings"); + self.state = H2State::ServerSettings; + self.readiness.interest.remove(Ready::WRITABLE); + self.expect_read = Some((H2StreamId::Zero, 9)); + MuxResult::Continue + } + (H2State::ServerSettings, Position::Server) => { + self.state = H2State::Header; + self.readiness.interest.remove(Ready::WRITABLE); + self.expect_read = Some((H2StreamId::Zero, 9)); + MuxResult::Continue + } + // Proxying states + (H2State::Header, _) + | (H2State::Frame(_), _) + | (H2State::ContinuationFrame(_), _) + | (H2State::ContinuationHeader(_), _) => { + let mut dead_streams = Vec::new(); + + if let Some(H2StreamId::Other(stream_id, global_stream_id)) = self.expect_write { + let stream = &mut context.streams[global_stream_id]; + let kawa = stream.wbuffer(&self.position); + while !kawa.out.is_empty() { + let bufs = kawa.as_io_slice(); + let (size, status) = self.socket.socket_write_vectored(&bufs); + kawa.consume(size); + if update_readiness_after_write(size, status, &mut self.readiness) { + return MuxResult::Continue; + } + } + self.expect_write = None; + if (kawa.is_terminated() || kawa.is_error()) && kawa.is_completed() { + match self.position { + Position::Client(_) => {} + Position::Server => { + // mark stream as reusable + println_!("Recycle stream: {global_stream_id}"); + let state = + std::mem::replace(&mut stream.state, StreamState::Recycle); + if let StreamState::Linked(token) = state { + endpoint.end_stream(token, global_stream_id, context); + } + dead_streams.push(stream_id); + } + } + } + } + + let mut converter = converter::H2BlockConverter { + stream_id: 0, + encoder: &mut self.encoder, + out: Vec::new(), + }; + let mut priorities = self.streams.keys().collect::>(); + priorities.sort(); + + println_!("PRIORITIES: {priorities:?}"); + 'outer: for stream_id in priorities { + let global_stream_id = *self.streams.get(stream_id).unwrap(); + let stream = &mut context.streams[global_stream_id]; + let kawa = stream.wbuffer(&self.position); + if kawa.is_main_phase() || kawa.is_error() { + converter.stream_id = *stream_id; + kawa.prepare(&mut converter); + debug_kawa(kawa); + } + while !kawa.out.is_empty() { + let bufs = kawa.as_io_slice(); + let (size, status) = self.socket.socket_write_vectored(&bufs); + kawa.consume(size); + if update_readiness_after_write(size, status, &mut self.readiness) { + self.expect_write = + Some(H2StreamId::Other(*stream_id, global_stream_id)); + break 'outer; + } + } + if (kawa.is_terminated() || kawa.is_error()) && kawa.is_completed() { + match self.position { + Position::Client(_) => {} + Position::Server => { + // mark stream as reusable + println_!("Recycle stream: {global_stream_id}"); + let state = + std::mem::replace(&mut stream.state, StreamState::Recycle); + if let StreamState::Linked(token) = state { + endpoint.end_stream(token, global_stream_id, context); + } + dead_streams.push(*stream_id); + } + } + } + } + for stream_id in dead_streams { + self.streams.remove(&stream_id).unwrap(); + } + + if self.expect_write.is_none() { + // We wrote everything + self.readiness.interest.remove(Ready::WRITABLE); + } + MuxResult::Continue + } + } + } + + pub fn goaway(&mut self, error: H2Error) -> MuxResult { + self.state = H2State::Error; + self.expect_read = None; + self.expect_write = Some(H2StreamId::Zero); + let kawa = &mut self.zero; + + match serializer::gen_goaway(kawa.storage.space(), self.last_stream_id, error) { + Ok((_, size)) => { + kawa.storage.fill(size); + self.state = H2State::GoAway; + self.expect_write = Some(H2StreamId::Zero); + self.readiness.interest = Ready::WRITABLE | Ready::HUP | Ready::ERROR; + MuxResult::Continue + } + Err(e) => { + println!("could not serialize GoAwayFrame: {e:?}"); + self.force_disconnect() + } + } + } + + pub fn create_stream( + &mut self, + stream_id: StreamId, + context: &mut Context, + ) -> Option { + let global_stream_id = context.create_stream( + Ulid::generate(), + self.peer_settings.settings_initial_window_size, + )?; + if stream_id > self.last_stream_id { + self.last_stream_id = stream_id & !1; + } + self.streams.insert(stream_id, global_stream_id); + Some(global_stream_id) + } + + pub fn new_stream_id(&mut self) -> StreamId { + self.last_stream_id += 2; + match self.position { + Position::Client(_) => self.last_stream_id + 1, + Position::Server => self.last_stream_id, + } + } + + fn handle_frame(&mut self, frame: Frame, context: &mut Context, mut endpoint: E) -> MuxResult + where + E: Endpoint, + { + println_!("{frame:#?}"); + match frame { + Frame::Data(data) => { + let mut slice = data.payload; + let global_stream_id = match self.streams.get(&data.stream_id) { + Some(global_stream_id) => *global_stream_id, + None => panic!("stream error"), + }; + let stream = &mut context.streams[global_stream_id]; + let kawa = stream.rbuffer(&self.position); + slice.start += kawa.storage.head as u32; + kawa.storage.head += slice.len(); + kawa.push_block(kawa::Block::Chunk(kawa::Chunk { + data: kawa::Store::Slice(slice), + })); + if data.end_stream { + kawa.push_block(kawa::Block::Flags(kawa::Flags { + end_body: true, + end_chunk: false, + end_header: false, + end_stream: true, + })); + kawa.parsing_phase = kawa::ParsingPhase::Terminated; + } + if let StreamState::Linked(token) = stream.state { + endpoint + .readiness_mut(token) + .interest + .insert(Ready::WRITABLE) + } + } + Frame::Headers(headers) => { + if !headers.end_headers { + self.state = H2State::ContinuationHeader(headers); + return MuxResult::Continue; + } + // can this fail? + let global_stream_id = *self.streams.get(&headers.stream_id).unwrap(); + let kawa = &mut self.zero; + let buffer = headers.header_block_fragment.data(kawa.storage.buffer()); + let stream = &mut context.streams[global_stream_id]; + let parts = &mut stream.split(&self.position); + let was_initial = parts.rbuffer.is_initial(); + pkawa::handle_header( + parts.rbuffer, + buffer, + headers.end_stream, + &mut self.decoder, + parts.context, + ); + debug_kawa(parts.rbuffer); + if let StreamState::Linked(token) = stream.state { + endpoint + .readiness_mut(token) + .interest + .insert(Ready::WRITABLE) + } + if was_initial { + match self.position { + Position::Server => stream.state = StreamState::Link, + Position::Client(_) => {} + }; + } + } + Frame::PushPromise(push_promise) => match self.position { + Position::Client(_) => { + if self.local_settings.settings_enable_push { + todo!("forward the push") + } else { + return self.goaway(H2Error::ProtocolError); + } + } + Position::Server => { + println_!("A client should not push promises"); + return self.goaway(H2Error::ProtocolError); + } + }, + Frame::Priority(priority) => self.prioriser.push_priority(priority), + Frame::RstStream(rst_stream) => { + println_!( + "RstStream({} -> {})", + rst_stream.error_code, + error_code_to_str(rst_stream.error_code) + ); + if let Some(stream_id) = self.streams.remove(&rst_stream.stream_id) { + let stream = &mut context.streams[stream_id]; + if let StreamState::Linked(token) = stream.state { + endpoint.end_stream(token, stream_id, context); + } + let stream = &mut context.streams[stream_id]; + match self.position { + Position::Client(_) => {} + Position::Server => { + stream.state = StreamState::Recycle; + } + } + } + } + Frame::Settings(settings) => { + if settings.ack { + return MuxResult::Continue; + } + for setting in settings.settings { + #[rustfmt::skip] + let _ = match setting.identifier { + 1 => self.peer_settings.settings_header_table_size = setting.value, + 2 => self.peer_settings.settings_enable_push = setting.value == 1, + 3 => self.peer_settings.settings_max_concurrent_streams = setting.value, + 4 => self.peer_settings.settings_initial_window_size = setting.value, + 5 => self.peer_settings.settings_max_frame_size = setting.value, + 6 => self.peer_settings.settings_max_header_list_size = setting.value, + 8 => self.peer_settings.settings_enable_connect_protocol = setting.value == 1, + 9 => self.peer_settings.settings_no_rfc7540_priorities = setting.value == 1, + other => println!("unknown setting_id: {other}, we MUST ignore this"), + }; + } + println_!("{:#?}", self.peer_settings); + + let kawa = &mut self.zero; + kawa.storage.space()[0..serializer::SETTINGS_ACKNOWLEDGEMENT.len()] + .copy_from_slice(&serializer::SETTINGS_ACKNOWLEDGEMENT); + kawa.storage + .fill(serializer::SETTINGS_ACKNOWLEDGEMENT.len()); + + self.readiness.interest.insert(Ready::WRITABLE); + self.readiness.interest.remove(Ready::READABLE); + self.expect_write = Some(H2StreamId::Zero); + } + Frame::Ping(ping) => { + let kawa = &mut self.zero; + match serializer::gen_ping_acknolegment(kawa.storage.space(), &ping.payload) { + Ok((_, size)) => kawa.storage.fill(size), + Err(e) => { + println!("could not serialize PingFrame: {e:?}"); + return self.force_disconnect(); + } + }; + self.readiness.interest.insert(Ready::WRITABLE); + self.readiness.interest.remove(Ready::READABLE); + self.expect_write = Some(H2StreamId::Zero); + } + Frame::GoAway(goaway) => { + println_!( + "GoAway({} -> {})", + goaway.error_code, + error_code_to_str(goaway.error_code) + ); + return self.goaway(H2Error::NoError); + } + Frame::WindowUpdate(update) => { + if update.stream_id == 0 { + self.window += update.increment; + } else { + if let Some(global_stream_id) = self.streams.get(&update.stream_id) { + context.streams[*global_stream_id].window += update.increment as i32; + } + } + } + Frame::Continuation(_) => unreachable!(), + } + MuxResult::Continue + } + + pub fn force_disconnect(&mut self) -> MuxResult { + self.state = H2State::Error; + match self.position { + Position::Client(_) => { + self.position = Position::Client(BackendStatus::Disconnecting); + self.readiness.event = Ready::HUP; + MuxResult::Continue + } + Position::Server => MuxResult::CloseSession, + } + } + + pub fn close(&mut self, context: &mut Context, mut endpoint: E) + where + E: Endpoint, + { + match self.position { + Position::Client(BackendStatus::Connected(_)) + | Position::Client(BackendStatus::Connecting(_)) + | Position::Client(BackendStatus::Disconnecting) => {} + Position::Client(BackendStatus::KeepAlive(_)) => unreachable!(), + Position::Server => unreachable!(), + } + // reconnection is handled by the server for each stream separately + for global_stream_id in self.streams.values() { + println_!("end stream: {global_stream_id}"); + let StreamState::Linked(token) = context.streams[*global_stream_id].state else { unreachable!() }; + endpoint.end_stream(token, *global_stream_id, context) + } + } + + pub fn end_stream(&mut self, stream: GlobalStreamId, context: &mut Context) { + let stream_context = &mut context.streams[stream].context; + println_!("end H2 stream {stream}: {stream_context:#?}"); + match self.position { + Position::Client(_) => { + for (stream_id, global_stream_id) in &self.streams { + if *global_stream_id == stream { + let id = *stream_id; + self.streams.remove(&id); + return; + } + } + unreachable!() + } + Position::Server => { + let stream = &mut context.streams[stream]; + match (stream.front.consumed, stream.back.is_main_phase()) { + (_, true) => { + // front might not have been consumed (in case of PushPromise) + // we have a "forwardable" answer from the back + // if the answer is not terminated we send an RstStream to properly clean the stream + // if it is terminated, we finish the transfer, the backend is not necessary anymore + if !stream.back.is_terminated() { + forcefully_terminate_answer(stream, &mut self.readiness); + } else { + stream.state = StreamState::Unlinked; + self.readiness.interest.insert(Ready::WRITABLE); + } + } + (true, false) => { + // we do not have an answer, but the request has already been partially consumed + // so we can't retry, send a 502 bad gateway instead + // note: it might be possible to send a RstStream with an adequate error code + set_default_answer(stream, &mut self.readiness, 502); + } + (false, false) => { + // we do not have an answer, but the request is untouched so we can retry + println!("H2 RECONNECT"); + stream.state = StreamState::Link + } + } + } + } + } + + pub fn start_stream(&mut self, stream: GlobalStreamId, context: &mut Context) { + println_!("start new H2 stream {stream} {:?}", self.readiness); + let stream_id = self.new_stream_id(); + self.streams.insert(stream_id, stream); + self.readiness.interest.insert(Ready::WRITABLE); + } +} diff --git a/lib/src/protocol/mux/mod.rs b/lib/src/protocol/mux/mod.rs new file mode 100644 index 000000000..48c9afa64 --- /dev/null +++ b/lib/src/protocol/mux/mod.rs @@ -0,0 +1,1310 @@ +use std::{ + cell::RefCell, + collections::HashMap, + io::ErrorKind, + net::{Shutdown, SocketAddr}, + rc::{Rc, Weak}, +}; + +use mio::{net::TcpStream, Interest, Token}; +use rusty_ulid::Ulid; +use sozu_command::{proto::command::ListenerType, ready::Ready}; +use time::Duration; + +mod converter; +mod h1; +mod h2; +mod parser; +mod pkawa; +mod serializer; + +use crate::{ + backends::{Backend, BackendError}, + pool::{Checkout, Pool}, + protocol::{ + http::editor::HttpContext, + mux::h2::{H2Settings, H2State, H2StreamId}, + SessionState, + }, + router::Route, + server::CONN_RETRIES, + socket::{SocketHandler, SocketResult}, + timer::TimeoutContainer, + BackendConnectionError, L7ListenerHandler, L7Proxy, ProxySession, Readiness, + RetrieveClusterError, SessionIsToBeClosed, SessionMetrics, SessionResult, StateResult, +}; + +pub use crate::protocol::mux::{h1::ConnectionH1, h2::ConnectionH2}; + +use self::h2::Prioriser; + +#[macro_export] +macro_rules! println_ { + ($($t:expr),*) => { + // print!("{}:{} ", file!(), line!()); + println!($($t),*) + // $(let _ = &$t;)* + }; +} +fn debug_kawa(_kawa: &GenericHttpStream) { + // kawa::debug_kawa(_kawa); +} + +/// Generic Http representation using the Kawa crate using the Checkout of Sozu as buffer +type GenericHttpStream = kawa::Kawa; +type StreamId = u32; +type GlobalStreamId = usize; + +pub fn fill_default_301_answer(kawa: &mut kawa::Kawa, host: &str, uri: &str) { + kawa.detached.status_line = kawa::StatusLine::Response { + version: kawa::Version::V20, + code: 301, + status: kawa::Store::Static(b"301"), + reason: kawa::Store::Static(b"Moved Permanently"), + }; + kawa.push_block(kawa::Block::StatusLine); + kawa.push_block(kawa::Block::Header(kawa::Pair { + key: kawa::Store::Static(b"Location"), + val: kawa::Store::from_string(format!("https://{host}{uri}")), + })); + terminate_default_answer(kawa, false); +} + +pub fn fill_default_answer(kawa: &mut kawa::Kawa, code: u16) { + kawa.detached.status_line = kawa::StatusLine::Response { + version: kawa::Version::V20, + code, + status: kawa::Store::from_string(code.to_string()), + reason: kawa::Store::Static(b"Sozu Default Answer"), + }; + kawa.push_block(kawa::Block::StatusLine); + terminate_default_answer(kawa, true); +} + +pub fn terminate_default_answer(kawa: &mut kawa::Kawa, close: bool) { + if close { + kawa.push_block(kawa::Block::Header(kawa::Pair { + key: kawa::Store::Static(b"Cache-Control"), + val: kawa::Store::Static(b"no-cache"), + })); + kawa.push_block(kawa::Block::Header(kawa::Pair { + key: kawa::Store::Static(b"Connection"), + val: kawa::Store::Static(b"close"), + })); + } + kawa.push_block(kawa::Block::Header(kawa::Pair { + key: kawa::Store::Static(b"Content-Length"), + val: kawa::Store::Static(b"0"), + })); + kawa.push_block(kawa::Block::Flags(kawa::Flags { + end_body: false, + end_chunk: false, + end_header: true, + end_stream: true, + })); + kawa.parsing_phase = kawa::ParsingPhase::Terminated; +} + +/// Replace the content of the kawa message with a default Sozu answer for a given status code +fn set_default_answer(stream: &mut Stream, readiness: &mut Readiness, code: u16) { + let kawa = &mut stream.back; + kawa.clear(); + kawa.storage.clear(); + if code == 301 { + let host = stream.context.authority.as_deref().unwrap(); + let uri = stream.context.path.as_deref().unwrap(); + fill_default_301_answer(kawa, host, uri); + } else { + fill_default_answer(kawa, code); + } + stream.state = StreamState::Unlinked; + readiness.interest.insert(Ready::WRITABLE); +} + +/// Forcefully terminates a kawa message by setting the "end_stream" flag and setting the parsing_phase to Error. +/// An H2 converter will produce an RstStream frame. +fn forcefully_terminate_answer(stream: &mut Stream, readiness: &mut Readiness) { + let kawa = &mut stream.back; + kawa.push_block(kawa::Block::Flags(kawa::Flags { + end_body: false, + end_chunk: false, + end_header: false, + end_stream: true, + })); + kawa.parsing_phase.error("Termination".into()); + debug_kawa(kawa); + stream.state = StreamState::Unlinked; + readiness.interest.insert(Ready::WRITABLE); +} + +#[derive(Debug)] +pub enum Position { + Client(BackendStatus), + Server, +} + +#[derive(Debug)] +pub enum BackendStatus { + Connecting(String), + Connected(String), + KeepAlive(String), + Disconnecting, +} + +pub enum MuxResult { + Continue, + Upgrade, + CloseSession, +} + +pub trait Endpoint { + fn readiness(&self, token: Token) -> &Readiness; + fn readiness_mut(&mut self, token: Token) -> &mut Readiness; + /// If end_stream is called on a client it means the stream has PROPERLY finished, + /// the server has completed serving the response and informs the endpoint that this stream won't be used anymore. + /// If end_stream is called on a server it means the stream was BROKEN, the client was most likely disconnected or encountered an error + /// it is for the server to decide if the stream can be retried or an error should be sent. It should be GUARANTEED that all bytes from + /// the backend were read. However it is almost certain that all bytes were not already sent to the client. + fn end_stream(&mut self, token: Token, stream: GlobalStreamId, context: &mut Context); + /// If start_stream is called on a client it means the stream should be attached to this endpoint, + /// the stream might be recovering from a disconnection, in any case at this point its response MUST be empty. + /// If the start_stream is called on a H2 server it means the stream is a server push and its request MUST be empty. + fn start_stream(&mut self, token: Token, stream: GlobalStreamId, context: &mut Context); +} + +#[derive(Debug)] +pub enum Connection { + H1(ConnectionH1), + H2(ConnectionH2), +} + +impl Connection { + pub fn new_h1_server( + front_stream: Front, + timeout_container: TimeoutContainer, + ) -> Connection { + Connection::H1(ConnectionH1 { + position: Position::Server, + readiness: Readiness { + interest: Ready::READABLE | Ready::HUP | Ready::ERROR, + event: Ready::EMPTY, + }, + requests: 0, + socket: front_stream, + stream: 0, + timeout_container, + }) + } + pub fn new_h1_client( + front_stream: Front, + cluster_id: String, + timeout_container: TimeoutContainer, + ) -> Connection { + Connection::H1(ConnectionH1 { + socket: front_stream, + position: Position::Client(BackendStatus::Connecting(cluster_id)), + readiness: Readiness { + interest: Ready::WRITABLE | Ready::READABLE | Ready::HUP | Ready::ERROR, + event: Ready::EMPTY, + }, + stream: usize::MAX - 1, + requests: 0, + timeout_container, + }) + } + + pub fn new_h2_server( + front_stream: Front, + pool: Weak>, + timeout_container: TimeoutContainer, + ) -> Option> { + let buffer = pool + .upgrade() + .and_then(|pool| pool.borrow_mut().checkout())?; + Some(Connection::H2(ConnectionH2 { + decoder: hpack::Decoder::new(), + encoder: hpack::Encoder::new(), + expect_read: Some((H2StreamId::Zero, 24 + 9)), + expect_write: None, + last_stream_id: 0, + local_settings: H2Settings::default(), + peer_settings: H2Settings::default(), + position: Position::Server, + prioriser: Prioriser::new(), + readiness: Readiness { + interest: Ready::READABLE | Ready::HUP | Ready::ERROR, + event: Ready::EMPTY, + }, + socket: front_stream, + state: H2State::ClientPreface, + streams: HashMap::new(), + timeout_container, + window: 1 << 16, + zero: kawa::Kawa::new(kawa::Kind::Request, kawa::Buffer::new(buffer)), + })) + } + pub fn new_h2_client( + front_stream: Front, + cluster_id: String, + pool: Weak>, + timeout_container: TimeoutContainer, + ) -> Option> { + let buffer = pool + .upgrade() + .and_then(|pool| pool.borrow_mut().checkout())?; + Some(Connection::H2(ConnectionH2 { + decoder: hpack::Decoder::new(), + encoder: hpack::Encoder::new(), + expect_read: None, + expect_write: None, + last_stream_id: 0, + local_settings: H2Settings::default(), + peer_settings: H2Settings::default(), + position: Position::Client(BackendStatus::Connecting(cluster_id)), + prioriser: Prioriser::new(), + readiness: Readiness { + interest: Ready::WRITABLE | Ready::HUP | Ready::ERROR, + event: Ready::EMPTY, + }, + socket: front_stream, + state: H2State::ClientPreface, + streams: HashMap::new(), + timeout_container, + window: 1 << 16, + zero: kawa::Kawa::new(kawa::Kind::Request, kawa::Buffer::new(buffer)), + })) + } + + pub fn readiness(&self) -> &Readiness { + match self { + Connection::H1(c) => &c.readiness, + Connection::H2(c) => &c.readiness, + } + } + pub fn readiness_mut(&mut self) -> &mut Readiness { + match self { + Connection::H1(c) => &mut c.readiness, + Connection::H2(c) => &mut c.readiness, + } + } + pub fn position(&self) -> &Position { + match self { + Connection::H1(c) => &c.position, + Connection::H2(c) => &c.position, + } + } + pub fn position_mut(&mut self) -> &mut Position { + match self { + Connection::H1(c) => &mut c.position, + Connection::H2(c) => &mut c.position, + } + } + pub fn socket(&self) -> &TcpStream { + match self { + Connection::H1(c) => c.socket.socket_ref(), + Connection::H2(c) => c.socket.socket_ref(), + } + } + pub fn socket_mut(&mut self) -> &mut TcpStream { + match self { + Connection::H1(c) => c.socket.socket_mut(), + Connection::H2(c) => c.socket.socket_mut(), + } + } + pub fn timeout_container(&mut self) -> &mut TimeoutContainer { + match self { + Connection::H1(c) => &mut c.timeout_container, + Connection::H2(c) => &mut c.timeout_container, + } + } + fn force_disconnect(&mut self) -> MuxResult { + match self { + Connection::H1(c) => c.force_disconnect(), + Connection::H2(c) => c.force_disconnect(), + } + } + fn readable(&mut self, context: &mut Context, endpoint: E) -> MuxResult + where + E: Endpoint, + { + match self { + Connection::H1(c) => c.readable(context, endpoint), + Connection::H2(c) => c.readable(context, endpoint), + } + } + fn writable(&mut self, context: &mut Context, endpoint: E) -> MuxResult + where + E: Endpoint, + { + match self { + Connection::H1(c) => c.writable(context, endpoint), + Connection::H2(c) => c.writable(context, endpoint), + } + } + + fn close(&mut self, context: &mut Context, endpoint: E) + where + E: Endpoint, + { + match self { + Connection::H1(c) => c.close(context, endpoint), + Connection::H2(c) => c.close(context, endpoint), + } + } + + fn end_stream(&mut self, stream: GlobalStreamId, context: &mut Context) { + match self { + Connection::H1(c) => c.end_stream(stream, context), + Connection::H2(c) => c.end_stream(stream, context), + } + } + + fn start_stream(&mut self, stream: GlobalStreamId, context: &mut Context) { + match self { + Connection::H1(c) => c.start_stream(stream, context), + Connection::H2(c) => c.start_stream(stream, context), + } + } +} + +struct EndpointServer<'a, Front: SocketHandler>(&'a mut Connection); +struct EndpointClient<'a>(&'a mut Router); + +// note: EndpointServer are used by client Connection, they do not know the frontend Token +// they will use the Stream's Token which is their backend token +impl<'a, Front: SocketHandler> Endpoint for EndpointServer<'a, Front> { + fn readiness(&self, _token: Token) -> &Readiness { + self.0.readiness() + } + fn readiness_mut(&mut self, _token: Token) -> &mut Readiness { + self.0.readiness_mut() + } + + fn end_stream(&mut self, token: Token, stream: GlobalStreamId, context: &mut Context) { + // this may be used to forward H2<->H2 RstStream + // or to handle backend hup + self.0.end_stream(stream, context); + } + + fn start_stream(&mut self, token: Token, stream: GlobalStreamId, context: &mut Context) { + // this may be used to forward H2<->H2 PushPromise + todo!() + } +} +impl<'a> Endpoint for EndpointClient<'a> { + fn readiness(&self, token: Token) -> &Readiness { + self.0.backends.get(&token).unwrap().readiness() + } + fn readiness_mut(&mut self, token: Token) -> &mut Readiness { + self.0.backends.get_mut(&token).unwrap().readiness_mut() + } + + fn end_stream(&mut self, token: Token, stream: GlobalStreamId, context: &mut Context) { + self.0 + .backends + .get_mut(&token) + .unwrap() + .end_stream(stream, context); + } + + fn start_stream(&mut self, token: Token, stream: GlobalStreamId, context: &mut Context) { + self.0 + .backends + .get_mut(&token) + .unwrap() + .start_stream(stream, context); + } +} + +fn update_readiness_after_read( + size: usize, + status: SocketResult, + readiness: &mut Readiness, +) -> bool { + println_!(" size={size}, status={status:?}"); + match status { + SocketResult::Continue => {} + SocketResult::Closed | SocketResult::Error => { + readiness.event.remove(Ready::ALL); + } + SocketResult::WouldBlock => { + readiness.event.remove(Ready::READABLE); + } + } + if size > 0 { + false + } else { + readiness.event.remove(Ready::READABLE); + true + } +} +fn update_readiness_after_write( + size: usize, + status: SocketResult, + readiness: &mut Readiness, +) -> bool { + println_!(" size={size}, status={status:?}"); + match status { + SocketResult::Continue => {} + SocketResult::Closed | SocketResult::Error => { + // even if the socket closed there might be something left to read + readiness.event.remove(Ready::WRITABLE); + } + SocketResult::WouldBlock => { + readiness.event.remove(Ready::WRITABLE); + } + } + if size > 0 { + false + } else { + readiness.event.remove(Ready::WRITABLE); + true + } +} + +// enum Stream { +// Idle { +// window: i32, +// }, +// Open { +// window: i32, +// token: Token, +// front: GenericHttpStream, +// back: GenericHttpStream, +// context: HttpContext, +// }, +// Reserved { +// window: i32, +// token: Token, +// position: Position, +// buffer: GenericHttpStream, +// context: HttpContext, +// }, +// HalfClosed { +// window: i32, +// token: Token, +// position: Position, +// buffer: GenericHttpStream, +// context: HttpContext, +// }, +// Closed, +// } + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum StreamState { + Idle, + /// the Stream is asking for connection, this will trigger a call to connect + Link, + /// the Stream is linked to a Client (note that the client might not be connected) + Linked(Token), + /// the Stream was linked to a Client, but the connection closed, the client was removed + /// and this Stream could not be retried (it should be terminated) + Unlinked, + /// the Stream is unlinked and can be reused + Recycle, +} + +pub struct Stream { + // pub request_id: Ulid, + pub window: i32, + pub attempts: u8, + pub state: StreamState, + pub front: GenericHttpStream, + pub back: GenericHttpStream, + pub context: HttpContext, +} + +/// This struct allows to mutably borrow the read and write buffers (dependant on the position) +/// as well as the context of a Stream at the same time +pub struct StreamParts<'a> { + pub rbuffer: &'a mut GenericHttpStream, + pub wbuffer: &'a mut GenericHttpStream, + pub context: &'a mut HttpContext, +} + +fn temporary_http_context(request_id: Ulid) -> HttpContext { + HttpContext { + keep_alive_backend: true, + keep_alive_frontend: true, + sticky_session_found: None, + method: None, + authority: None, + path: None, + status: None, + reason: None, + user_agent: None, + closing: false, + id: request_id, + protocol: crate::Protocol::HTTPS, + public_address: "0.0.0.0:80".parse().unwrap(), + session_address: None, + sticky_name: "SOZUBALANCEID".to_owned(), + sticky_session: None, + } +} + +impl Stream { + pub fn new(pool: Weak>, request_id: Ulid, window: u32) -> Option { + let (front_buffer, back_buffer) = match pool.upgrade() { + Some(pool) => { + let mut pool = pool.borrow_mut(); + match (pool.checkout(), pool.checkout()) { + (Some(front_buffer), Some(back_buffer)) => (front_buffer, back_buffer), + _ => return None, + } + } + None => return None, + }; + Some(Self { + state: StreamState::Idle, + attempts: 0, + window: window as i32, + front: GenericHttpStream::new(kawa::Kind::Request, kawa::Buffer::new(front_buffer)), + back: GenericHttpStream::new(kawa::Kind::Response, kawa::Buffer::new(back_buffer)), + context: temporary_http_context(request_id), + }) + } + pub fn split(&mut self, position: &Position) -> StreamParts<'_> { + match position { + Position::Client(_) => StreamParts { + rbuffer: &mut self.back, + wbuffer: &mut self.front, + context: &mut self.context, + }, + Position::Server => StreamParts { + rbuffer: &mut self.front, + wbuffer: &mut self.back, + context: &mut self.context, + }, + } + } + pub fn rbuffer(&mut self, position: &Position) -> &mut GenericHttpStream { + match position { + Position::Client(_) => &mut self.back, + Position::Server => &mut self.front, + } + } + pub fn wbuffer(&mut self, position: &Position) -> &mut GenericHttpStream { + match position { + Position::Client(_) => &mut self.front, + Position::Server => &mut self.back, + } + } +} + +pub struct Context { + pub streams: Vec, + pub pool: Weak>, +} + +impl Context { + pub fn create_stream(&mut self, request_id: Ulid, window: u32) -> Option { + for (stream_id, stream) in self.streams.iter_mut().enumerate() { + if stream.state == StreamState::Recycle { + println_!("Reuse stream: {stream_id}"); + stream.state = StreamState::Idle; + stream.attempts = 0; + stream.window = window as i32; + stream.context = temporary_http_context(request_id); + stream.back.clear(); + stream.back.storage.clear(); + stream.front.clear(); + stream.front.storage.clear(); + return Some(stream_id); + } + } + self.streams + .push(Stream::new(self.pool.clone(), request_id, window)?); + Some(self.streams.len() - 1) + } + + pub fn new(pool: Weak>) -> Context { + Self { + streams: Vec::new(), + pool, + } + } +} + +pub struct Router { + pub backends: HashMap>, + pub configured_backend_timeout: Duration, + pub configured_connect_timeout: Duration, + pub listener: Rc>, +} + +impl Router { + pub fn new( + configured_backend_timeout: Duration, + configured_connect_timeout: Duration, + listener: Rc>, + ) -> Self { + Self { + backends: HashMap::new(), + configured_backend_timeout, + configured_connect_timeout, + listener, + } + } + + fn connect( + &mut self, + stream_id: GlobalStreamId, + context: &mut Context, + session: Rc>, + proxy: Rc>, + metrics: &mut SessionMetrics, + ) -> Result<(), BackendConnectionError> { + let stream = &mut context.streams[stream_id]; + // when reused, a stream should be detached from its old connection, if not we could end + // with concurrent connections on a single endpoint + assert!(matches!(stream.state, StreamState::Link)); + if stream.attempts >= CONN_RETRIES { + return Err(BackendConnectionError::MaxConnectionRetries(None)); + } + stream.attempts += 1; + + let stream_context = &mut stream.context; + let (cluster_id, h2) = self + .route_from_request(stream_context, proxy.clone()) + .map_err(BackendConnectionError::RetrieveClusterError)?; + + let (frontend_should_stick, frontend_should_redirect_https) = proxy + .borrow() + .clusters() + .get(&cluster_id) + .map(|cluster| (cluster.sticky_session, cluster.https_redirect)) + .unwrap_or((false, false)); + + if frontend_should_redirect_https && matches!(proxy.borrow().kind(), ListenerType::Http) { + return Err(BackendConnectionError::RetrieveClusterError( + RetrieveClusterError::HttpsRedirect, + )); + } + + let mut reuse_token = None; + // let mut priority = 0; + let mut reuse_connecting = true; + for (token, backend) in &self.backends { + match (h2, reuse_connecting, backend.position()) { + (_, _, Position::Server) => { + unreachable!("Backend connection behaves like a server") + } + (_, _, Position::Client(BackendStatus::Disconnecting)) => {} + + (true, _, Position::Client(BackendStatus::Connected(old_cluster_id))) => { + if *old_cluster_id == cluster_id { + reuse_token = Some(*token); + reuse_connecting = false; + break; + } + } + (true, true, Position::Client(BackendStatus::Connecting(old_cluster_id))) => { + if *old_cluster_id == cluster_id { + reuse_token = Some(*token) + } + } + (true, false, Position::Client(BackendStatus::Connecting(_))) => {} + (true, _, Position::Client(BackendStatus::KeepAlive(old_cluster_id))) => { + if *old_cluster_id == cluster_id { + unreachable!("ConnectionH2 behaves like H1") + } + } + + (false, _, Position::Client(BackendStatus::KeepAlive(old_cluster_id))) => { + if *old_cluster_id == cluster_id { + reuse_token = Some(*token); + reuse_connecting = false; + break; + } + } + // can't bundle H1 streams together + (false, _, Position::Client(BackendStatus::Connected(_))) + | (false, _, Position::Client(BackendStatus::Connecting(_))) => {} + } + } + println_!("connect: {cluster_id} (stick={frontend_should_stick}, h2={h2}) -> (reuse={reuse_token:?})"); + + let token = if let Some(token) = reuse_token { + println_!("reused backend: {:#?}", self.backends.get(&token).unwrap()); + token + } else { + let mut socket = self.backend_from_request( + &cluster_id, + frontend_should_stick, + stream_context, + proxy.clone(), + metrics, + )?; + + if let Err(e) = socket.set_nodelay(true) { + error!( + "error setting nodelay on back socket({:?}): {:?}", + socket, e + ); + } + // self.backend_readiness.interest = Ready::WRITABLE | Ready::HUP | Ready::ERROR; + // self.backend_connection_status = BackendConnectionStatus::Connecting(Instant::now()); + + let token = proxy.borrow().add_session(session); + + if let Err(e) = proxy.borrow().register_socket( + &mut socket, + token, + Interest::READABLE | Interest::WRITABLE, + ) { + error!("error registering back socket({:?}): {:?}", socket, e); + } + + let timeout_container = TimeoutContainer::new(self.configured_connect_timeout, token); + let connection = if h2 { + match Connection::new_h2_client( + socket, + cluster_id, + context.pool.clone(), + timeout_container, + ) { + Some(connection) => connection, + None => return Err(BackendConnectionError::MaxBuffers), + } + } else { + Connection::new_h1_client(socket, cluster_id, timeout_container) + }; + self.backends.insert(token, connection); + token + }; + + // link stream to backend + stream.state = StreamState::Linked(token); + // link backend to stream + self.backends + .get_mut(&token) + .unwrap() + .start_stream(stream_id, context); + Ok(()) + } + + fn route_from_request( + &mut self, + context: &mut HttpContext, + _proxy: Rc>, + ) -> Result<(String, bool), RetrieveClusterError> { + let (host, uri, method) = match context.extract_route() { + Ok(tuple) => tuple, + Err(cluster_error) => { + // we are past kawa parsing if it succeeded this can't fail + // if the request was malformed it was caught by kawa and we sent a 400 + panic!("{cluster_error}"); + } + }; + + let route_result = self + .listener + .borrow() + .frontend_from_request(host, uri, method); + + let route = match route_result { + Ok(route) => route, + Err(frontend_error) => { + println!("{}", frontend_error); + // self.set_answer(DefaultAnswerStatus::Answer404, None); + return Err(RetrieveClusterError::RetrieveFrontend(frontend_error)); + } + }; + + let cluster_id = match route { + Route::Cluster { id, h2 } => (id, h2), + Route::Deny => { + println!("Route::Deny"); + // self.set_answer(DefaultAnswerStatus::Answer401, None); + return Err(RetrieveClusterError::UnauthorizedRoute); + } + }; + + Ok(cluster_id) + } + + pub fn backend_from_request( + &mut self, + cluster_id: &str, + frontend_should_stick: bool, + context: &mut HttpContext, + proxy: Rc>, + _metrics: &mut SessionMetrics, + ) -> Result { + let (backend, conn) = self + .get_backend_for_sticky_session( + cluster_id, + frontend_should_stick, + context.sticky_session_found.as_deref(), + proxy, + ) + .map_err(|backend_error| { + println!("{backend_error}"); + // self.set_answer(DefaultAnswerStatus::Answer503, None); + BackendConnectionError::Backend(backend_error) + })?; + + if frontend_should_stick { + // update sticky name in case it changed I guess? + context.sticky_name = self.listener.borrow().get_sticky_name().to_string(); + + context.sticky_session = Some( + backend + .borrow() + .sticky_id + .clone() + .unwrap_or_else(|| backend.borrow().backend_id.clone()), + ); + } + + // metrics.backend_id = Some(backend.borrow().backend_id.clone()); + // metrics.backend_start(); + // self.set_backend_id(backend.borrow().backend_id.clone()); + // self.backend = Some(backend); + + Ok(conn) + } + + fn get_backend_for_sticky_session( + &self, + cluster_id: &str, + frontend_should_stick: bool, + sticky_session: Option<&str>, + proxy: Rc>, + ) -> Result<(Rc>, TcpStream), BackendError> { + match (frontend_should_stick, sticky_session) { + (true, Some(sticky_session)) => proxy + .borrow() + .backends() + .borrow_mut() + .backend_from_sticky_session(cluster_id, sticky_session), + _ => proxy + .borrow() + .backends() + .borrow_mut() + .backend_from_cluster_id(cluster_id), + } + } +} + +pub struct Mux { + pub configured_frontend_timeout: Duration, + pub frontend_token: Token, + pub frontend: Connection, + pub router: Router, + pub public_address: SocketAddr, + pub peer_address: Option, + pub sticky_name: String, + pub context: Context, +} + +impl Mux { + pub fn front_socket(&self) -> &TcpStream { + self.frontend.socket() + } +} + +impl SessionState for Mux { + fn ready( + &mut self, + session: Rc>, + proxy: Rc>, + metrics: &mut SessionMetrics, + ) -> SessionResult { + let mut counter = 0; + let max_loop_iterations = 100000; + + if self.frontend.readiness().event.is_hup() { + return SessionResult::Close; + } + + let start = std::time::Instant::now(); + println_!("{start:?}"); + loop { + loop { + let context = &mut self.context; + if self.frontend.readiness().filter_interest().is_readable() { + match self + .frontend + .readable(context, EndpointClient(&mut self.router)) + { + MuxResult::Continue => {} + MuxResult::CloseSession => return SessionResult::Close, + MuxResult::Upgrade => return SessionResult::Upgrade, + } + } + + let mut all_backends_readiness_are_empty = true; + let context = &mut self.context; + let mut dead_backends = Vec::new(); + for (token, backend) in self.router.backends.iter_mut() { + let readiness = backend.readiness_mut(); + println!("{token:?} -> {readiness:?}"); + let dead = readiness.filter_interest().is_hup() + || readiness.filter_interest().is_error(); + if dead { + println_!("Backend({token:?}) -> {readiness:?}"); + readiness.event.remove(Ready::WRITABLE); + } + + if backend.readiness().filter_interest().is_writable() { + let position = backend.position_mut(); + match position { + Position::Client(BackendStatus::Connecting(cluster_id)) => { + *position = Position::Client(BackendStatus::Connected( + std::mem::take(cluster_id), + )); + backend + .timeout_container() + .set_duration(self.router.configured_backend_timeout); + } + _ => {} + } + match backend.writable(context, EndpointServer(&mut self.frontend)) { + MuxResult::Continue => {} + MuxResult::Upgrade => unreachable!(), // only frontend can upgrade + MuxResult::CloseSession => return SessionResult::Close, + } + } + + if backend.readiness().filter_interest().is_readable() { + match backend.readable(context, EndpointServer(&mut self.frontend)) { + MuxResult::Continue => {} + MuxResult::Upgrade => unreachable!(), // only frontend can upgrade + MuxResult::CloseSession => return SessionResult::Close, + } + } + + if dead && !backend.readiness().filter_interest().is_readable() { + println_!("Closing {:#?}", backend); + backend.close(context, EndpointServer(&mut self.frontend)); + dead_backends.push(*token); + } + + if !backend.readiness().filter_interest().is_empty() { + all_backends_readiness_are_empty = false; + } + } + if !dead_backends.is_empty() { + for token in &dead_backends { + let proxy_borrow = proxy.borrow(); + if let Some(mut backend) = self.router.backends.remove(token) { + backend.timeout_container().cancel(); + let socket = backend.socket_mut(); + if let Err(e) = proxy_borrow.deregister_socket(socket) { + error!("error deregistering back socket({:?}): {:?}", socket, e); + } + if let Err(e) = socket.shutdown(Shutdown::Both) { + if e.kind() != ErrorKind::NotConnected { + error!( + "error shutting down back socket({:?}): {:?}", + socket, e + ); + } + } + } else { + error!("session {:?} has no backend!", token); + } + if !proxy_borrow.remove_session(*token) { + error!("session {:?} was already removed!", token); + } else { + println!("SUCCESS: session {token:?} was removed!"); + } + } + println_!("FRONTEND: {:#?}", self.frontend); + println_!("BACKENDS: {:#?}", self.router.backends); + } + + let context = &mut self.context; + if self.frontend.readiness().filter_interest().is_writable() { + match self + .frontend + .writable(context, EndpointClient(&mut self.router)) + { + MuxResult::Continue => {} + MuxResult::CloseSession => return SessionResult::Close, + MuxResult::Upgrade => return SessionResult::Upgrade, + } + } + + if self.frontend.readiness().filter_interest().is_empty() + && all_backends_readiness_are_empty + { + break; + } + + counter += 1; + if counter >= max_loop_iterations { + incr!("http.infinite_loop.error"); + return SessionResult::Close; + } + } + + let context = &mut self.context; + let mut dirty = false; + for stream_id in 0..context.streams.len() { + if context.streams[stream_id].state == StreamState::Link { + // Before the first request triggers a stream Link, the frontend timeout is set + // to a shorter request_timeout, here we switch to the longer nominal timeout + self.frontend + .timeout_container() + .set_duration(self.configured_frontend_timeout); + let front_readiness = self.frontend.readiness_mut(); + dirty = true; + match self.router.connect( + stream_id, + context, + session.clone(), + proxy.clone(), + metrics, + ) { + Ok(_) => {} + Err(error) => { + println_!("Connection error: {error}"); + let stream = &mut context.streams[stream_id]; + use BackendConnectionError as BE; + match error { + BE::Backend(BackendError::NoBackendForCluster(_)) + | BE::MaxConnectionRetries(_) + | BE::MaxSessionsMemory + | BE::MaxBuffers => { + set_default_answer(stream, front_readiness, 503); + } + BE::RetrieveClusterError( + RetrieveClusterError::RetrieveFrontend(_), + ) => { + set_default_answer(stream, front_readiness, 404); + } + BE::RetrieveClusterError( + RetrieveClusterError::UnauthorizedRoute, + ) => { + set_default_answer(stream, front_readiness, 401); + } + BE::RetrieveClusterError(RetrieveClusterError::HttpsRedirect) => { + set_default_answer(stream, front_readiness, 301); + } + + BE::Backend(_) => {} + BE::RetrieveClusterError(_) => unreachable!(), + BE::NotFound(_) => unreachable!(), + } + } + } + } + } + if !dirty { + break; + } + } + + SessionResult::Continue + } + + fn update_readiness(&mut self, token: Token, events: Ready) { + if token == self.frontend_token { + self.frontend.readiness_mut().event |= events; + } else if let Some(c) = self.router.backends.get_mut(&token) { + c.readiness_mut().event |= events; + } + } + + fn timeout(&mut self, token: Token, _metrics: &mut SessionMetrics) -> StateResult { + println_!("MuxState::timeout({token:?})"); + let front_is_h2 = match self.frontend { + Connection::H1(_) => false, + Connection::H2(_) => true, + }; + let mut should_close = true; + let mut should_write = false; + if self.frontend_token == token { + println_!("MuxState::timeout_frontend({:#?})", self.frontend); + self.frontend.timeout_container().triggered(); + let front_readiness = self.frontend.readiness_mut(); + for stream in &mut self.context.streams { + match stream.state { + StreamState::Idle => { + // In h1 an Idle stream is always the first request, so we can send a 408 + // In h2 an Idle stream doesn't necessarily hold a request yet, + // in most cases it was just reserved, so we can just ignore them. + if !front_is_h2 { + set_default_answer(stream, front_readiness, 408); + should_write = true; + } + } + StreamState::Link => { + // This is an unusual case, as we have both a complete request and no + // available backend yet. For now, we answer with 503 + set_default_answer(stream, front_readiness, 503); + should_write = true; + } + StreamState::Linked(_) => { + // A stream Linked to a backend is waiting for the response, not the request. + // For streaming or malformed requests, it is possible that the request is not + // terminated at this point. For now, we do nothing + should_close = false; + } + StreamState::Unlinked => { + // A stream Unlinked already has a response and its backend closed. + // In case it hasn't finished proxying we wait. Otherwise it is a stream + // kept alive for a new request, which can be killed. + if !stream.back.is_completed() { + should_close = false; + } + } + StreamState::Recycle => { + // A recycled stream is an h2 stream which doesn't hold a request anymore. + // We can ignore it. + } + } + } + } else if let Some(backend) = self.router.backends.get_mut(&token) { + println_!("MuxState::timeout_backend({:#?})", backend); + backend.timeout_container().triggered(); + let front_readiness = self.frontend.readiness_mut(); + for stream_id in 0..self.context.streams.len() { + let stream = &mut self.context.streams[stream_id]; + if let StreamState::Linked(back_token) = stream.state { + if token == back_token { + // This stream is linked to the backend that timedout. + if stream.back.is_terminated() || stream.back.is_error() { + println!( + "Stream terminated or in error, do nothing, just wait a bit more" + ); + // Nothing to do, simply wait for the remaining bytes to be proxied + if !stream.back.is_completed() { + should_close = false; + } + } else if stream.back.is_initial() { + // The response has not started yet + println!("Stream still waiting for response, send 504"); + set_default_answer(stream, front_readiness, 504); + should_write = true; + } else { + println!("Stream waiting for end of response, forcefully terminate it"); + forcefully_terminate_answer(stream, front_readiness); + should_write = true; + } + backend.end_stream(stream_id, &mut self.context); + backend.force_disconnect(); + } + } + } + } + if should_write { + return match self + .frontend + .writable(&mut self.context, EndpointClient(&mut self.router)) + { + MuxResult::Continue => StateResult::Continue, + MuxResult::Upgrade => StateResult::Upgrade, + MuxResult::CloseSession => StateResult::CloseSession, + }; + } + if should_close { + StateResult::CloseSession + } else { + StateResult::Continue + } + } + + fn cancel_timeouts(&mut self) { + println_!("MuxState::cancel_timeouts"); + self.frontend.timeout_container().cancel(); + for backend in self.router.backends.values_mut() { + backend.timeout_container().cancel(); + } + } + + fn print_state(&self, context: &str) { + error!( + "\ +{} Session(Mux) +\tFrontend: +\t\ttoken: {:?}\treadiness: {:?} +\tBackend(s):", + context, + self.frontend_token, + self.frontend.readiness() + ); + for (backend_token, backend) in &self.router.backends { + error!( + "\t\ttoken: {:?}\treadiness: {:?}", + backend_token, + backend.readiness() + ) + } + } + + fn close(&mut self, proxy: Rc>, _metrics: &mut SessionMetrics) { + println_!("FRONTEND: {:#?}", self.frontend); + println_!("BACKENDS: {:#?}", self.router.backends); + + for (token, backend) in &mut self.router.backends { + let proxy_borrow = proxy.borrow(); + backend.timeout_container().cancel(); + let socket = backend.socket_mut(); + if let Err(e) = proxy_borrow.deregister_socket(socket) { + error!("error deregistering back socket({:?}): {:?}", socket, e); + } + if let Err(e) = socket.shutdown(Shutdown::Both) { + if e.kind() != ErrorKind::NotConnected { + error!("error shutting down back socket({:?}): {:?}", socket, e); + } + } + if !proxy_borrow.remove_session(*token) { + error!("session {:?} was already removed!", token); + } else { + println!("SUCCESS: session {token:?} was removed!"); + } + } + // let s = match &mut self.frontend { + // Connection::H1(c) => &mut c.socket, + // Connection::H2(c) => &mut c.socket, + // }; + // let mut b = [0; 1024]; + // let (size, status) = s.socket_read(&mut b); + // println_!("{size} {status:?} {:?}", &b[..size]); + // for stream in &mut self.context.streams { + // for kawa in [&mut stream.front, &mut stream.back] { + // debug_kawa(kawa); + // kawa.prepare(&mut kawa::h1::BlockConverter); + // let out = kawa.as_io_slice(); + // let mut writer = std::io::BufWriter::new(Vec::new()); + // let amount = writer.write_vectored(&out).unwrap(); + // println_!( + // "amount: {amount}\n{}", + // String::from_utf8_lossy(writer.buffer()) + // ); + // } + // } + } + + fn shutting_down(&mut self) -> SessionIsToBeClosed { + let mut can_stop = true; + for stream in &mut self.context.streams { + match stream.state { + StreamState::Linked(_) => { + can_stop = false; + } + StreamState::Unlinked => { + let front = &stream.front; + let back = &stream.back; + kawa::debug_kawa(front); + kawa::debug_kawa(back); + if front.is_initial() + && front.storage.is_empty() + && back.is_initial() + && back.storage.is_empty() + { + continue; + } + stream.context.closing = true; + can_stop = false; + } + _ => {} + } + } + can_stop + } +} diff --git a/lib/src/protocol/mux/parser.rs b/lib/src/protocol/mux/parser.rs new file mode 100644 index 000000000..b0285bc06 --- /dev/null +++ b/lib/src/protocol/mux/parser.rs @@ -0,0 +1,604 @@ +use std::convert::From; + +use kawa::repr::Slice; +use nom::{ + bytes::complete::{tag, take}, + combinator::{complete, map, map_opt}, + error::{ErrorKind, ParseError}, + multi::many0, + number::complete::{be_u16, be_u24, be_u32, be_u8}, + sequence::tuple, + Err, IResult, +}; + +#[derive(Clone, Debug, PartialEq)] +pub struct FrameHeader { + pub payload_len: u32, + pub frame_type: FrameType, + pub flags: u8, + pub stream_id: u32, +} + +#[derive(Clone, Debug, PartialEq)] +pub enum FrameType { + Data, + Headers, + Priority, + RstStream, + Settings, + PushPromise, + Ping, + GoAway, + WindowUpdate, + Continuation, +} + +const NO_ERROR: u32 = 0x0; +const PROTOCOL_ERROR: u32 = 0x1; +const INTERNAL_ERROR: u32 = 0x2; +const FLOW_CONTROL_ERROR: u32 = 0x3; +const SETTINGS_TIMEOUT: u32 = 0x4; +const STREAM_CLOSED: u32 = 0x5; +const FRAME_SIZE_ERROR: u32 = 0x6; +const REFUSED_STREAM: u32 = 0x7; +const CANCEL: u32 = 0x8; +const COMPRESSION_ERROR: u32 = 0x9; +const CONNECT_ERROR: u32 = 0xa; +const ENHANCE_YOUR_CALM: u32 = 0xb; +const INADEQUATE_SECURITY: u32 = 0xc; +const HTTP_1_1_REQUIRED: u32 = 0xd; + +pub fn error_code_to_str(error_code: u32) -> &'static str { + match error_code { + NO_ERROR => "NO_ERROR", + PROTOCOL_ERROR => "PROTOCOL_ERROR", + INTERNAL_ERROR => "INTERNAL_ERROR", + FLOW_CONTROL_ERROR => "FLOW_CONTROL_ERROR", + SETTINGS_TIMEOUT => "SETTINGS_TIMEOUT", + STREAM_CLOSED => "STREAM_CLOSED", + FRAME_SIZE_ERROR => "FRAME_SIZE_ERROR", + REFUSED_STREAM => "REFUSED_STREAM", + CANCEL => "CANCEL", + COMPRESSION_ERROR => "COMPRESSION_ERROR", + CONNECT_ERROR => "CONNECT_ERROR", + ENHANCE_YOUR_CALM => "ENHANCE_YOUR_CALM", + INADEQUATE_SECURITY => "INADEQUATE_SECURITY", + HTTP_1_1_REQUIRED => "HTTP_1_1_REQUIRED", + _ => "UNKNOWN_ERROR", + } +} + +#[derive(Clone, Debug, PartialEq)] +pub struct Error<'a> { + pub input: &'a [u8], + pub error: InnerError, +} + +#[derive(Clone, Debug, PartialEq)] +pub enum InnerError { + Nom(ErrorKind), + H2(H2Error), +} + +#[derive(Clone, Debug, PartialEq)] +pub enum H2Error { + NoError, + ProtocolError, + InternalError, + FlowControlError, + SettingsTimeout, + StreamClosed, + FrameSizeError, + RefusedStream, + Cancel, + CompressionError, + ConnectError, + EnhanceYourCalm, + InadequateSecurity, + HTTP11Required, +} + +impl<'a> Error<'a> { + pub fn new(input: &'a [u8], error: InnerError) -> Error<'a> { + Error { input, error } + } + pub fn new_h2(input: &'a [u8], error: H2Error) -> Error<'a> { + Error { + input, + error: InnerError::H2(error), + } + } +} + +impl<'a> ParseError<&'a [u8]> for Error<'a> { + fn from_error_kind(input: &'a [u8], kind: ErrorKind) -> Self { + Error { + input, + error: InnerError::Nom(kind), + } + } + + fn append(input: &'a [u8], kind: ErrorKind, other: Self) -> Self { + Error { + input, + error: InnerError::Nom(kind), + } + } +} + +impl<'a> From<(&'a [u8], ErrorKind)> for Error<'a> { + fn from((input, kind): (&'a [u8], ErrorKind)) -> Self { + Error { + input, + error: InnerError::Nom(kind), + } + } +} + +pub fn preface(i: &[u8]) -> IResult<&[u8], &[u8]> { + tag(b"PRI * HTTP/2.0\r\n\r\nSM\r\n\r\n")(i) +} + +// https://httpwg.org/specs/rfc7540.html#rfc.section.4.1 +/*named!(pub frame_header, + do_parse!( + payload_len: dbg_dmp!(be_u24) >> + frame_type: map_opt!(be_u8, convert_frame_type) >> + flags: dbg_dmp!(be_u8) >> + stream_id: dbg_dmp!(verify!(be_u32, |id| { + match frame_type { + + } + }) >> + (FrameHeader { payload_len, frame_type, flags, stream_id }) + ) +); + */ + +pub fn frame_header(input: &[u8]) -> IResult<&[u8], FrameHeader, Error> { + let (i, payload_len) = be_u24(input)?; + let (i, frame_type) = map_opt(be_u8, convert_frame_type)(i)?; + let (i, flags) = be_u8(i)?; + let (i, stream_id) = be_u32(i)?; + + Ok(( + i, + FrameHeader { + payload_len, + frame_type, + flags, + stream_id, + }, + )) +} + +fn convert_frame_type(t: u8) -> Option { + info!("got frame type: {}", t); + match t { + 0 => Some(FrameType::Data), + 1 => Some(FrameType::Headers), + 2 => Some(FrameType::Priority), + 3 => Some(FrameType::RstStream), + 4 => Some(FrameType::Settings), + 5 => Some(FrameType::PushPromise), + 6 => Some(FrameType::Ping), + 7 => Some(FrameType::GoAway), + 8 => Some(FrameType::WindowUpdate), + 9 => Some(FrameType::Continuation), + _ => None, + } +} + +#[derive(Clone, Debug)] +pub enum Frame { + Data(Data), + Headers(Headers), + Priority(Priority), + RstStream(RstStream), + Settings(Settings), + PushPromise(PushPromise), + Ping(Ping), + GoAway(GoAway), + WindowUpdate(WindowUpdate), + Continuation(Continuation), +} + +impl Frame { + pub fn is_stream_specific(&self) -> bool { + match self { + Frame::Data(_) + | Frame::Headers(_) + | Frame::Priority(_) + | Frame::RstStream(_) + | Frame::PushPromise(_) + | Frame::Continuation(_) => true, + Frame::Settings(_) | Frame::Ping(_) | Frame::GoAway(_) => false, + Frame::WindowUpdate(w) => w.stream_id != 0, + } + } + + pub fn stream_id(&self) -> u32 { + match self { + Frame::Data(d) => d.stream_id, + Frame::Headers(h) => h.stream_id, + Frame::Priority(p) => p.stream_id, + Frame::RstStream(r) => r.stream_id, + Frame::PushPromise(p) => p.stream_id, + Frame::Continuation(c) => c.stream_id, + Frame::Settings(_) | Frame::Ping(_) | Frame::GoAway(_) => 0, + Frame::WindowUpdate(w) => w.stream_id, + } + } +} + +pub fn frame_body<'a>( + i: &'a [u8], + header: &FrameHeader, + max_frame_size: u32, +) -> IResult<&'a [u8], Frame, Error<'a>> { + if header.payload_len > max_frame_size { + return Err(Err::Failure(Error::new_h2(i, H2Error::FrameSizeError))); + } + + let valid_stream_id = match header.frame_type { + FrameType::Data + | FrameType::Headers + | FrameType::Priority + | FrameType::RstStream + | FrameType::PushPromise + | FrameType::Continuation => header.stream_id != 0, + FrameType::Settings | FrameType::Ping | FrameType::GoAway => header.stream_id == 0, + FrameType::WindowUpdate => true, + }; + + if !valid_stream_id { + return Err(Err::Failure(Error::new_h2(i, H2Error::ProtocolError))); + } + + let f = match header.frame_type { + FrameType::Data => data_frame(i, header)?, + FrameType::Headers => headers_frame(i, header)?, + FrameType::Priority => { + if header.payload_len != 5 { + return Err(Err::Failure(Error::new_h2(i, H2Error::FrameSizeError))); + } + priority_frame(i, header)? + } + FrameType::RstStream => { + if header.payload_len != 4 { + return Err(Err::Failure(Error::new_h2(i, H2Error::FrameSizeError))); + } + rst_stream_frame(i, header)? + } + FrameType::PushPromise => push_promise_frame(i, header)?, + FrameType::Continuation => continuation_frame(i, header)?, + FrameType::Settings => { + if header.payload_len % 6 != 0 { + return Err(Err::Failure(Error::new_h2(i, H2Error::FrameSizeError))); + } + settings_frame(i, header)? + } + FrameType::Ping => { + if header.payload_len != 8 { + return Err(Err::Failure(Error::new_h2(i, H2Error::FrameSizeError))); + } + ping_frame(i, header)? + } + FrameType::GoAway => goaway_frame(i, header)?, + FrameType::WindowUpdate => { + if header.payload_len != 4 { + return Err(Err::Failure(Error::new_h2(i, H2Error::FrameSizeError))); + } + window_update_frame(i, header)? + } + }; + + Ok(f) +} + +#[derive(Clone, Debug)] +pub struct Data { + pub stream_id: u32, + pub payload: Slice, + pub end_stream: bool, +} + +pub fn data_frame<'a>( + input: &'a [u8], + header: &FrameHeader, +) -> IResult<&'a [u8], Frame, Error<'a>> { + let (remaining, i) = take(header.payload_len)(input)?; + + let (i, pad_length) = if header.flags & 0x8 != 0 { + let (i, pad_length) = be_u8(i)?; + (i, Some(pad_length)) + } else { + (i, None) + }; + + if pad_length.is_some() && i.len() <= pad_length.unwrap() as usize { + return Err(Err::Failure(Error::new_h2(input, H2Error::ProtocolError))); + } + + let (_, payload) = take(i.len() - pad_length.unwrap_or(0) as usize)(i)?; + + Ok(( + remaining, + Frame::Data(Data { + stream_id: header.stream_id, + payload: Slice::new(input, payload), + end_stream: header.flags & 0x1 != 0, + }), + )) +} + +#[derive(Clone, Debug)] +pub struct Headers { + pub stream_id: u32, + pub stream_dependency: Option, + pub weight: Option, + pub header_block_fragment: Slice, + // pub header_block_fragment: &'a [u8], + pub end_stream: bool, + pub end_headers: bool, + pub priority: bool, +} + +#[derive(Clone, Debug, PartialEq)] +pub struct StreamDependency { + pub exclusive: bool, + pub stream_id: u32, +} + +fn stream_dependency(i: &[u8]) -> IResult<&[u8], StreamDependency, Error<'_>> { + let (i, stream) = map(be_u32, |i| StreamDependency { + exclusive: i & 0x8000 != 0, + stream_id: i & 0x7FFFFFFF, + })(i)?; + Ok((i, stream)) +} + +pub fn headers_frame<'a>( + input: &'a [u8], + header: &FrameHeader, +) -> IResult<&'a [u8], Frame, Error<'a>> { + let (remaining, i) = take(header.payload_len)(input)?; + + let (i, pad_length) = if header.flags & 0x8 != 0 { + let (i, pad_length) = be_u8(i)?; + (i, Some(pad_length)) + } else { + (i, None) + }; + + let (i, stream_dependency, weight) = if header.flags & 0x20 != 0 { + let (i, stream_dependency) = stream_dependency(i)?; + let (i, weight) = be_u8(i)?; + (i, Some(stream_dependency), Some(weight)) + } else { + (i, None, None) + }; + + if pad_length.is_some() && i.len() <= pad_length.unwrap() as usize { + return Err(Err::Failure(Error::new_h2(input, H2Error::ProtocolError))); + } + + let (_, header_block_fragment) = take(i.len() - pad_length.unwrap_or(0) as usize)(i)?; + + Ok(( + remaining, + Frame::Headers(Headers { + stream_id: header.stream_id, + stream_dependency, + weight, + header_block_fragment: Slice::new(input, header_block_fragment), + end_stream: header.flags & 0x1 != 0, + end_headers: header.flags & 0x4 != 0, + priority: header.flags & 0x20 != 0, + }), + )) +} + +#[derive(Clone, Debug, PartialEq)] +pub struct Priority { + pub stream_id: u32, + pub stream_dependency: StreamDependency, + pub weight: u8, +} + +pub fn priority_frame<'a>( + input: &'a [u8], + header: &FrameHeader, +) -> IResult<&'a [u8], Frame, Error<'a>> { + let (i, stream_dependency) = stream_dependency(input)?; + let (i, weight) = be_u8(i)?; + Ok(( + i, + Frame::Priority(Priority { + stream_dependency, + stream_id: header.stream_id, + weight, + }), + )) +} + +#[derive(Clone, Debug, PartialEq)] +pub struct RstStream { + pub stream_id: u32, + pub error_code: u32, +} + +pub fn rst_stream_frame<'a>( + input: &'a [u8], + header: &FrameHeader, +) -> IResult<&'a [u8], Frame, Error<'a>> { + let (i, error_code) = be_u32(input)?; + Ok(( + i, + Frame::RstStream(RstStream { + stream_id: header.stream_id, + error_code, + }), + )) +} + +#[derive(Clone, Debug, PartialEq)] +pub struct Settings { + pub settings: Vec, + pub ack: bool, +} + +#[derive(Clone, Debug, PartialEq)] +pub struct Setting { + pub identifier: u16, + pub value: u32, +} + +pub fn settings_frame<'a>( + input: &'a [u8], + header: &FrameHeader, +) -> IResult<&'a [u8], Frame, Error<'a>> { + let (i, data) = take(header.payload_len)(input)?; + + let (_, settings) = many0(map( + complete(tuple((be_u16, be_u32))), + |(identifier, value)| Setting { identifier, value }, + ))(data)?; + + Ok(( + i, + Frame::Settings(Settings { + settings, + ack: header.flags & 0x1 != 0, + }), + )) +} + +#[derive(Clone, Debug)] +pub struct PushPromise { + pub stream_id: u32, + pub promised_stream_id: u32, + pub header_block_fragment: Slice, + pub end_headers: bool, +} + +pub fn push_promise_frame<'a>( + input: &'a [u8], + header: &FrameHeader, +) -> IResult<&'a [u8], Frame, Error<'a>> { + let (remaining, i) = take(header.payload_len)(input)?; + + let (i, pad_length) = if header.flags & 0x8 != 0 { + let (i, pad_length) = be_u8(i)?; + (i, Some(pad_length)) + } else { + (i, None) + }; + + if pad_length.is_some() && i.len() <= pad_length.unwrap() as usize { + return Err(Err::Failure(Error::new_h2(input, H2Error::ProtocolError))); + } + + let (i, promised_stream_id) = be_u32(i)?; + let (_, header_block_fragment) = take(i.len() - pad_length.unwrap_or(0) as usize)(i)?; + + Ok(( + remaining, + Frame::PushPromise(PushPromise { + stream_id: header.stream_id, + promised_stream_id, + header_block_fragment: Slice::new(input, header_block_fragment), + end_headers: header.flags & 0x4 != 0, + }), + )) +} + +#[derive(Clone, Debug, PartialEq)] +pub struct Ping { + pub payload: [u8; 8], +} + +pub fn ping_frame<'a>( + input: &'a [u8], + _header: &FrameHeader, +) -> IResult<&'a [u8], Frame, Error<'a>> { + let (i, data) = take(8usize)(input)?; + + let mut p = Ping { payload: [0; 8] }; + p.payload[..8].copy_from_slice(&data[..8]); + + Ok((i, Frame::Ping(p))) +} + +#[derive(Clone, Debug)] +pub struct GoAway { + pub last_stream_id: u32, + pub error_code: u32, + pub additional_debug_data: Slice, +} + +pub fn goaway_frame<'a>( + input: &'a [u8], + header: &FrameHeader, +) -> IResult<&'a [u8], Frame, Error<'a>> { + let (remaining, i) = take(header.payload_len)(input)?; + let (i, last_stream_id) = be_u32(i)?; + let (additional_debug_data, error_code) = be_u32(i)?; + Ok(( + remaining, + Frame::GoAway(GoAway { + last_stream_id, + error_code, + additional_debug_data: Slice::new(input, additional_debug_data), + }), + )) +} + +#[derive(Clone, Debug, PartialEq)] +pub struct WindowUpdate { + pub stream_id: u32, + pub increment: u32, +} + +pub fn window_update_frame<'a>( + input: &'a [u8], + header: &FrameHeader, +) -> IResult<&'a [u8], Frame, Error<'a>> { + let (i, increment) = be_u32(input)?; + let increment = increment & 0x7FFFFFFF; + + //FIXME: if stream id is 0, trat it as connection error? + if increment == 0 { + return Err(Err::Failure(Error::new_h2(input, H2Error::ProtocolError))); + } + + Ok(( + i, + Frame::WindowUpdate(WindowUpdate { + stream_id: header.stream_id, + increment, + }), + )) +} + +#[derive(Clone, Debug)] +pub struct Continuation { + pub stream_id: u32, + pub header_block_fragment: Slice, + pub end_headers: bool, +} + +pub fn continuation_frame<'a>( + input: &'a [u8], + header: &FrameHeader, +) -> IResult<&'a [u8], Frame, Error<'a>> { + let (i, header_block_fragment) = take(header.payload_len)(input)?; + Ok(( + i, + Frame::Continuation(Continuation { + stream_id: header.stream_id, + header_block_fragment: Slice::new(input, header_block_fragment), + end_headers: header.flags & 0x4 != 0, + }), + )) +} diff --git a/lib/src/protocol/mux/pkawa.rs b/lib/src/protocol/mux/pkawa.rs new file mode 100644 index 000000000..d00a430ef --- /dev/null +++ b/lib/src/protocol/mux/pkawa.rs @@ -0,0 +1,192 @@ +use std::{io::Write, str::from_utf8_unchecked}; + +use kawa::{ + h1::ParserCallbacks, repr::Slice, Block, BodySize, Flags, Kind, Pair, ParsingPhase, StatusLine, + Store, Version, +}; + +use crate::{pool::Checkout, protocol::http::parser::compare_no_case}; + +use super::GenericHttpStream; + +pub fn handle_header( + kawa: &mut GenericHttpStream, + input: &[u8], + end_stream: bool, + decoder: &mut hpack::Decoder, + callbacks: &mut C, +) where + C: ParserCallbacks, +{ + if !kawa.is_initial() { + return handle_trailer(kawa, input, end_stream, decoder); + } + kawa.push_block(Block::StatusLine); + kawa.detached.status_line = match kawa.kind { + Kind::Request => { + let mut method = Store::Empty; + let mut authority = Store::Empty; + let mut path = Store::Empty; + let mut scheme = Store::Empty; + decoder + .decode_with_cb(input, |k, v| { + let start = kawa.storage.end as u32; + kawa.storage.write_all(&v).unwrap(); + let len_key = k.len() as u32; + let len_val = v.len() as u32; + let val = Store::Slice(Slice { + start, + len: len_val, + }); + + if compare_no_case(&k, b":method") { + method = val; + } else if compare_no_case(&k, b":authority") { + authority = val; + } else if compare_no_case(&k, b":path") { + path = val; + } else if compare_no_case(&k, b":scheme") { + scheme = val; + } else if compare_no_case(&k, b"cookie") { + todo!("cookies should be split in pairs"); + } else if compare_no_case(&k, b"priority") { + unimplemented!(); + } else { + if compare_no_case(&k, b"content-length") { + let length = + unsafe { from_utf8_unchecked(&v).parse::().unwrap() }; + kawa.body_size = BodySize::Length(length); + } + kawa.storage.write_all(&k).unwrap(); + let key = Store::Slice(Slice { + start: start + len_val, + len: len_key, + }); + kawa.push_block(Block::Header(Pair { key, val })); + } + }) + .unwrap(); + // uri is only used by H1 statusline, in most cases it only consists of the path + // a better algorithm should be used though + // let buffer = kawa.storage.data(); + // let uri = unsafe { + // format!( + // "{}://{}{}", + // from_utf8_unchecked(scheme.data(buffer)), + // from_utf8_unchecked(authority.data(buffer)), + // from_utf8_unchecked(path.data(buffer)) + // ) + // }; + // println!("Reconstructed URI: {uri}"); + StatusLine::Request { + version: Version::V20, + method, + uri: path.clone(), //Store::from_string(uri), + authority, + path, + } + } + Kind::Response => { + let mut code = 0; + let mut status = Store::Empty; + decoder + .decode_with_cb(input, |k, v| { + let start = kawa.storage.end as u32; + kawa.storage.write_all(&v).unwrap(); + let len_key = k.len() as u32; + let len_val = v.len() as u32; + let val = Store::Slice(Slice { + start, + len: len_val, + }); + + if compare_no_case(&k, b":status") { + status = val; + code = unsafe { from_utf8_unchecked(&v).parse::().ok().unwrap() } + } else { + kawa.storage.write_all(&k).unwrap(); + let key = Store::Slice(Slice { + start: start + len_val, + len: len_key, + }); + kawa.push_block(Block::Header(Pair { key, val })); + } + }) + .unwrap(); + StatusLine::Response { + version: Version::V20, + code, + status, + reason: Store::Static(b"FromH2"), + } + } + }; + + // everything has been parsed + kawa.storage.head = kawa.storage.end; + + callbacks.on_headers(kawa); + + if end_stream { + if let BodySize::Empty = kawa.body_size { + kawa.body_size = BodySize::Length(0); + kawa.push_block(Block::Header(Pair { + key: Store::Static(b"Content-Length"), + val: Store::Static(b"0"), + })); + } + } + + kawa.push_block(Block::Flags(Flags { + end_body: end_stream, + end_chunk: false, + end_header: true, + end_stream, + })); + + if kawa.parsing_phase == ParsingPhase::Terminated { + return; + } + + kawa.parsing_phase = match kawa.body_size { + BodySize::Chunked => ParsingPhase::Chunks { first: true }, + BodySize::Length(0) => ParsingPhase::Terminated, + BodySize::Length(_) => ParsingPhase::Body, + BodySize::Empty => ParsingPhase::Chunks { first: true }, + }; +} + +pub fn handle_trailer( + kawa: &mut GenericHttpStream, + input: &[u8], + end_stream: bool, + decoder: &mut hpack::Decoder, +) { + decoder + .decode_with_cb(input, |k, v| { + let start = kawa.storage.end as u32; + kawa.storage.write_all(&k).unwrap(); + kawa.storage.write_all(&v).unwrap(); + let len_key = k.len() as u32; + let len_val = v.len() as u32; + let key = Store::Slice(Slice { + start, + len: len_key, + }); + let val = Store::Slice(Slice { + start: start + len_key, + len: len_val, + }); + kawa.push_block(Block::Header(Pair { key, val })); + }) + .unwrap(); + + assert!(end_stream); + kawa.push_block(Block::Flags(Flags { + end_body: end_stream, + end_chunk: false, + end_header: true, + end_stream, + })); + kawa.parsing_phase = ParsingPhase::Terminated; +} diff --git a/lib/src/protocol/mux/serializer.rs b/lib/src/protocol/mux/serializer.rs new file mode 100644 index 000000000..fa94ead4d --- /dev/null +++ b/lib/src/protocol/mux/serializer.rs @@ -0,0 +1,139 @@ +use cookie_factory::{ + bytes::{be_u16, be_u24, be_u32, be_u8}, + combinator::slice, + gen, + sequence::tuple, + GenError, +}; + +use super::{ + h2::H2Settings, + parser::{FrameHeader, FrameType, H2Error}, +}; + +pub const H2_PRI: &str = "PRI * HTTP/2.0\r\n\r\nSM\r\n\r\n"; +pub const SETTINGS_ACKNOWLEDGEMENT: [u8; 9] = [0, 0, 0, 4, 1, 0, 0, 0, 0]; +pub const PING_ACKNOWLEDGEMENT_HEADER: [u8; 9] = [0, 0, 0, 6, 1, 0, 0, 0, 0]; + +pub fn gen_frame_header<'a, 'b>( + buf: &'a mut [u8], + frame: &'b FrameHeader, +) -> Result<(&'a mut [u8], usize), GenError> { + let serializer = tuple(( + be_u24(frame.payload_len), + be_u8(serialize_frame_type(&frame.frame_type)), + be_u8(frame.flags), + be_u32(frame.stream_id), + )); + + gen(serializer, buf).map(|(buf, size)| (buf, size as usize)) +} + +pub fn serialize_frame_type(f: &FrameType) -> u8 { + match *f { + FrameType::Data => 0, + FrameType::Headers => 1, + FrameType::Priority => 2, + FrameType::RstStream => 3, + FrameType::Settings => 4, + FrameType::PushPromise => 5, + FrameType::Ping => 6, + FrameType::GoAway => 7, + FrameType::WindowUpdate => 8, + FrameType::Continuation => 9, + } +} + +// pub fn gen_settings_acknoledgement<'a>(buf: &'a mut [u8]) { +// for (i, b) in SETTINGS_ACKNOWLEDGEMENT.iter().enumerate() { +// buf[i] = *b; +// } +// } + +pub fn gen_ping_acknolegment<'a>( + buf: &'a mut [u8], + payload: &[u8], +) -> Result<(&'a mut [u8], usize), GenError> { + gen( + tuple((slice(PING_ACKNOWLEDGEMENT_HEADER), slice(payload))), + buf, + ) + .map(|(buf, size)| (buf, size as usize)) +} + +pub fn gen_settings<'a>( + buf: &'a mut [u8], + settings: &H2Settings, +) -> Result<(&'a mut [u8], usize), GenError> { + gen_frame_header( + buf, + &FrameHeader { + payload_len: 6 * 6, + frame_type: FrameType::Settings, + flags: 0, + stream_id: 0, + }, + ) + .and_then(|(buf, old_size)| { + gen( + tuple(( + be_u16(1), + be_u32(settings.settings_header_table_size), + be_u16(2), + be_u32(settings.settings_enable_push as u32), + be_u16(3), + be_u32(settings.settings_max_concurrent_streams), + be_u16(4), + be_u32(settings.settings_initial_window_size), + be_u16(5), + be_u32(settings.settings_max_frame_size), + be_u16(6), + be_u32(settings.settings_max_header_list_size), + )), + buf, + ) + .map(|(buf, size)| (buf, (old_size + size as usize))) + }) +} + +pub fn gen_rst_stream<'a>( + buf: &'a mut [u8], + stream_id: u32, + error_code: H2Error, +) -> Result<(&'a mut [u8], usize), GenError> { + gen_frame_header( + buf, + &FrameHeader { + payload_len: 4, + frame_type: FrameType::RstStream, + flags: 0, + stream_id, + }, + ) + .and_then(|(buf, old_size)| { + gen(be_u32(error_code as u32), buf).map(|(buf, size)| (buf, (old_size + size as usize))) + }) +} + +pub fn gen_goaway<'a>( + buf: &'a mut [u8], + last_stream_id: u32, + error_code: H2Error, +) -> Result<(&'a mut [u8], usize), GenError> { + gen_frame_header( + buf, + &FrameHeader { + payload_len: 4, + frame_type: FrameType::GoAway, + flags: 0, + stream_id: 0, + }, + ) + .and_then(|(buf, old_size)| { + gen( + tuple((be_u32(last_stream_id), be_u32(error_code as u32))), + buf, + ) + .map(|(buf, size)| (buf, (old_size + size as usize))) + }) +} diff --git a/lib/src/router/mod.rs b/lib/src/router/mod.rs index 1001c72a8..9a8fd594a 100644 --- a/lib/src/router/mod.rs +++ b/lib/src/router/mod.rs @@ -134,7 +134,10 @@ impl Router { let method_rule = MethodRule::new(front.method.clone()); let route = match &front.cluster_id { - Some(cluster_id) => Route::ClusterId(cluster_id.clone()), + Some(cluster_id) => Route::Cluster { + id: cluster_id.clone(), + h2: front.h2, + }, None => Route::Deny, }; @@ -162,7 +165,7 @@ impl Router { } }; if !success { - return Err(RouterError::AddRoute(format!("{:?}", front))); + return Err(RouterError::AddRoute(format!("{front:?}"))); } Ok(()) } @@ -197,7 +200,7 @@ impl Router { } }; if !remove_success { - return Err(RouterError::RemoveRoute(format!("{:?}", front))); + return Err(RouterError::RemoveRoute(format!("{front:?}"))); } Ok(()) } @@ -598,7 +601,7 @@ pub enum Route { /// send a 401 default answer Deny, /// the cluster to which the frontend belongs - ClusterId(ClusterId), + Cluster { id: ClusterId, h2: bool }, } #[cfg(test)] @@ -717,27 +720,42 @@ mod tests { b"*.sozu.io", &PathRule::Prefix("".to_string()), &MethodRule::new(Some("GET".to_string())), - &Route::ClusterId("base".to_string()) + &Route::Cluster { + id: "base".to_string(), + h2: false + } )); println!("{:#?}", router.tree); assert_eq!( router.lookup("www.sozu.io", "/api", &Method::Get), - Ok(Route::ClusterId("base".to_string())) + Ok(Route::Cluster { + id: "base".to_string(), + h2: false + }) ); assert!(router.add_tree_rule( b"*.sozu.io", &PathRule::Prefix("/api".to_string()), &MethodRule::new(Some("GET".to_string())), - &Route::ClusterId("api".to_string()) + &Route::Cluster { + id: "api".to_string(), + h2: false + } )); println!("{:#?}", router.tree); assert_eq!( router.lookup("www.sozu.io", "/ap", &Method::Get), - Ok(Route::ClusterId("base".to_string())) + Ok(Route::Cluster { + id: "base".to_string(), + h2: false + }) ); assert_eq!( router.lookup("www.sozu.io", "/api", &Method::Get), - Ok(Route::ClusterId("api".to_string())) + Ok(Route::Cluster { + id: "api".to_string(), + h2: false + }) ); } @@ -756,27 +774,42 @@ mod tests { b"*.sozu.io", &PathRule::Prefix("".to_string()), &MethodRule::new(Some("GET".to_string())), - &Route::ClusterId("base".to_string()) + &Route::Cluster { + id: "base".to_string(), + h2: false + } )); println!("{:#?}", router.tree); assert_eq!( router.lookup("www.sozu.io", "/api", &Method::Get), - Ok(Route::ClusterId("base".to_string())) + Ok(Route::Cluster { + id: "base".to_string(), + h2: false + }) ); assert!(router.add_tree_rule( b"api.sozu.io", &PathRule::Prefix("".to_string()), &MethodRule::new(Some("GET".to_string())), - &Route::ClusterId("api".to_string()) + &Route::Cluster { + id: "api".to_string(), + h2: false + } )); println!("{:#?}", router.tree); assert_eq!( router.lookup("www.sozu.io", "/api", &Method::Get), - Ok(Route::ClusterId("base".to_string())) + Ok(Route::Cluster { + id: "base".to_string(), + h2: false + }) ); assert_eq!( router.lookup("api.sozu.io", "/api", &Method::Get), - Ok(Route::ClusterId("api".to_string())) + Ok(Route::Cluster { + id: "api".to_string(), + h2: false + }) ); } @@ -788,23 +821,35 @@ mod tests { b"www./.*/.io", &PathRule::Prefix("".to_string()), &MethodRule::new(Some("GET".to_string())), - &Route::ClusterId("base".to_string()) + &Route::Cluster { + id: "base".to_string(), + h2: false + } )); println!("{:#?}", router.tree); assert!(router.add_tree_rule( b"www.doc./.*/.io", &PathRule::Prefix("".to_string()), &MethodRule::new(Some("GET".to_string())), - &Route::ClusterId("doc".to_string()) + &Route::Cluster { + id: "doc".to_string(), + h2: false + } )); println!("{:#?}", router.tree); assert_eq!( router.lookup("www.sozu.io", "/", &Method::Get), - Ok(Route::ClusterId("base".to_string())) + Ok(Route::Cluster { + id: "base".to_string(), + h2: false + }) ); assert_eq!( router.lookup("www.doc.sozu.io", "/", &Method::Get), - Ok(Route::ClusterId("doc".to_string())) + Ok(Route::Cluster { + id: "doc".to_string(), + h2: false + }) ); assert!(router.remove_tree_rule( b"www./.*/.io", @@ -815,7 +860,10 @@ mod tests { assert!(router.lookup("www.sozu.io", "/", &Method::Get).is_err()); assert_eq!( router.lookup("www.doc.sozu.io", "/", &Method::Get), - Ok(Route::ClusterId("doc".to_string())) + Ok(Route::Cluster { + id: "doc".to_string(), + h2: false + }) ); } @@ -827,30 +875,45 @@ mod tests { &"*".parse::().unwrap(), &PathRule::Prefix("/.well-known/acme-challenge".to_string()), &MethodRule::new(Some("GET".to_string())), - &Route::ClusterId("acme".to_string()) + &Route::Cluster { + id: "acme".to_string(), + h2: false + } )); assert!(router.add_tree_rule( "www.example.com".as_bytes(), &PathRule::Prefix("/".to_string()), &MethodRule::new(Some("GET".to_string())), - &Route::ClusterId("example".to_string()) + &Route::Cluster { + id: "example".to_string(), + h2: false + } )); assert!(router.add_tree_rule( "*.test.example.com".as_bytes(), &PathRule::Regex(Regex::new("/hello[A-Z]+/").unwrap()), &MethodRule::new(Some("GET".to_string())), - &Route::ClusterId("examplewildcard".to_string()) + &Route::Cluster { + id: "examplewildcard".to_string(), + h2: false + } )); assert!(router.add_tree_rule( "/test[0-9]/.example.com".as_bytes(), &PathRule::Prefix("/".to_string()), &MethodRule::new(Some("GET".to_string())), - &Route::ClusterId("exampleregex".to_string()) + &Route::Cluster { + id: "exampleregex".to_string(), + h2: false + } )); assert_eq!( router.lookup("www.example.com", "/helloA", &Method::new(&b"GET"[..])), - Ok(Route::ClusterId("example".to_string())) + Ok(Route::Cluster { + id: "example".to_string(), + h2: false + }) ); assert_eq!( router.lookup( @@ -858,7 +921,10 @@ mod tests { "/.well-known/acme-challenge", &Method::new(&b"GET"[..]) ), - Ok(Route::ClusterId("acme".to_string())) + Ok(Route::Cluster { + id: "acme".to_string(), + h2: false + }) ); assert!(router .lookup("www.test.example.com", "/", &Method::new(&b"GET"[..])) @@ -869,11 +935,28 @@ mod tests { "/helloAB/", &Method::new(&b"GET"[..]) ), - Ok(Route::ClusterId("examplewildcard".to_string())) + Ok(Route::Cluster { + id: "examplewildcard".to_string(), + h2: false + }) + ); + assert_eq!( + router.lookup( + "www.test.example.com", + "/helloAB/", + &Method::new(&b"GET"[..]) + ), + Ok(Route::Cluster { + id: "examplewildcard".to_string(), + h2: false + }) ); assert_eq!( router.lookup("test1.example.com", "/helloAB/", &Method::new(&b"GET"[..])), - Ok(Route::ClusterId("exampleregex".to_string())) + Ok(Route::Cluster { + id: "exampleregex".to_string(), + h2: false + }) ); } } diff --git a/lib/src/server.rs b/lib/src/server.rs index 0461dbe41..0a9926e1a 100644 --- a/lib/src/server.rs +++ b/lib/src/server.rs @@ -1305,8 +1305,7 @@ impl Server { None => { let error = format!( - "Couldn't deactivate HTTPS listener at address {:?}", - address + "Couldn't deactivate HTTPS listener at address {address:?}", ); error!("{}", error); return WorkerResponse::error(req_id, error); @@ -1345,7 +1344,7 @@ impl Server { Some((token, listener)) => (token, listener), None => { let error = - format!("Couldn't deactivate TCP listener at address {:?}", address); + format!("Couldn't deactivate TCP listener at address {address:?}"); error!("{}", error); return WorkerResponse::error(req_id, error); } diff --git a/lib/src/socket.rs b/lib/src/socket.rs index 9440fb054..057798492 100644 --- a/lib/src/socket.rs +++ b/lib/src/socket.rs @@ -171,6 +171,12 @@ pub struct FrontRustls { pub session: ServerConnection, } +impl std::fmt::Debug for FrontRustls { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.write_str("FrontRustls") + } +} + impl SocketHandler for FrontRustls { fn socket_read(&mut self, buf: &mut [u8]) -> (usize, SocketResult) { let mut size = 0usize; @@ -190,7 +196,7 @@ impl SocketHandler for FrontRustls { break; } - if !can_read | is_error | is_closed { + if !can_read || is_error || is_closed { break; }