From d8a8f2e701e700ecb0b5d4725cedcc64933e30c5 Mon Sep 17 00:00:00 2001 From: Matt Mastracci Date: Tue, 3 Jun 2025 18:32:09 -0600 Subject: [PATCH 1/2] placeholder --- gel-pg-protocol/examples/connect.rs | 271 ++++++ gel-pg-protocol/src/conn/conn.rs | 953 ++++++++++++++++++++ gel-pg-protocol/src/conn/flow.rs | 1224 ++++++++++++++++++++++++++ gel-pg-protocol/src/conn/mod.rs | 94 ++ gel-pg-protocol/src/conn/queue.rs | 166 ++++ gel-pg-protocol/src/conn/raw_conn.rs | 232 +++++ 6 files changed, 2940 insertions(+) create mode 100644 gel-pg-protocol/examples/connect.rs create mode 100644 gel-pg-protocol/src/conn/conn.rs create mode 100644 gel-pg-protocol/src/conn/flow.rs create mode 100644 gel-pg-protocol/src/conn/mod.rs create mode 100644 gel-pg-protocol/src/conn/queue.rs create mode 100644 gel-pg-protocol/src/conn/raw_conn.rs diff --git a/gel-pg-protocol/examples/connect.rs b/gel-pg-protocol/examples/connect.rs new file mode 100644 index 00000000..c9f87305 --- /dev/null +++ b/gel-pg-protocol/examples/connect.rs @@ -0,0 +1,271 @@ +use captive_postgres::{ + setup_postgres, ListenAddress, Mode, DEFAULT_DATABASE, DEFAULT_PASSWORD, DEFAULT_USERNAME, +}; +use clap::Parser; +use clap_derive::Parser; +use gel_auth::AuthType; +use gel_dsn::postgres::*; +use gel_pg_protocol::protocol::*; +use gel_stream::{Connector, ResolvedTarget, Target}; +use pgrust::connection::{ + Client, Credentials, ExecuteSink, Format, MaxRows, PipelineBuilder, Portal, QuerySink, + Statement, +}; +use std::net::SocketAddr; +use tokio::task::LocalSet; + +#[derive(Parser, Debug)] +#[clap(author, version, about, long_about = None)] +struct Args { + /// Use an ephemeral database + #[clap(short = 'e', long = "ephemeral", conflicts_with_all = &["dsn", "unix", "tcp", "username", "password", "database"])] + ephemeral: bool, + + #[clap(short = 'D', long = "dsn", value_parser, conflicts_with_all = &["unix", "tcp", "username", "password", "database"])] + dsn: Option, + + /// Network socket address and port + #[clap(short = 't', long = "tcp", value_parser, conflicts_with = "unix")] + tcp: Option, + + /// Unix socket path + #[clap(short = 'u', long = "unix", value_parser, conflicts_with = "tcp")] + unix: Option, + + /// Username to use for the connection + #[clap( + short = 'U', + long = "username", + value_parser, + default_value = "postgres" + )] + username: String, + + /// Username to use for the connection + #[clap(short = 'P', long = "password", value_parser, default_value = "")] + password: String, + + /// Database to use for the connection + #[clap( + short = 'd', + long = "database", + value_parser, + default_value = "postgres" + )] + database: String, + + /// Use extended query syntax + #[clap(short = 'x', long = "extended")] + extended: bool, + + /// SQL statements to run + #[clap( + name = "statements", + trailing_var_arg = true, + allow_hyphen_values = true, + help = "Zero or more SQL statements to run (defaults to 'select 1')" + )] + statements: Option>, +} + +fn address(address: &ListenAddress) -> ResolvedTarget { + match address { + ListenAddress::Tcp(addr) => ResolvedTarget::SocketAddr(*addr), + #[cfg(unix)] + ListenAddress::Unix(path) => ResolvedTarget::UnixSocketAddr( + std::os::unix::net::SocketAddr::from_pathname(path).unwrap(), + ), + } +} + +#[tokio::main] +async fn main() -> Result<(), Box> { + tracing_subscriber::fmt::init(); + let mut args = Args::parse(); + eprintln!("{args:?}"); + + let mut socket_address: Option = None; + + let _ephemeral = if args.ephemeral { + let process = setup_postgres(AuthType::Trust, Mode::Unix)?; + let Some(process) = process else { + eprintln!("Failed to start ephemeral database"); + return Err("Failed to start ephemeral database".into()); + }; + socket_address = Some(address(&process.socket_address)); + args.username = DEFAULT_USERNAME.to_string(); + args.password = DEFAULT_PASSWORD.to_string(); + args.database = DEFAULT_DATABASE.to_string(); + Some(process) + } else { + None + }; + + if let Some(dsn) = args.dsn { + // TODO + let mut conn = parse_postgres_dsn_env(&dsn, ())?; + #[allow(deprecated)] + let home = std::env::home_dir().unwrap(); + conn.password + .resolve(Some(&home), &conn.hosts, &conn.database, &conn.database)?; + args.database = conn.database; + args.username = conn.user; + args.password = conn.password.password().unwrap_or_default().to_string(); + if let Some(host) = conn.hosts.first() { + socket_address = host.target_name()?.to_addrs_sync()?.into_iter().next(); + } + } + + let socket_address = socket_address.unwrap_or_else(|| match (args.tcp, args.unix) { + (Some(addr), None) => ResolvedTarget::SocketAddr(addr), + (None, Some(path)) => ResolvedTarget::UnixSocketAddr( + std::os::unix::net::SocketAddr::from_pathname(path).unwrap(), + ), + _ => panic!("Must specify either a TCP address or a Unix socket path"), + }); + + eprintln!("Connecting to {socket_address:?}"); + + let credentials = Credentials { + username: args.username, + password: args.password, + database: args.database, + server_settings: Default::default(), + }; + + let statements = args + .statements + .unwrap_or_else(|| vec!["select 1;".to_string()]); + let socket_address = Target::new_resolved(socket_address); + + let local = LocalSet::new(); + local + .run_until(run_queries( + socket_address, + credentials, + statements, + args.extended, + )) + .await?; + + Ok(()) +} + +fn logging_sink() -> impl QuerySink { + ( + |rows: RowDescription<'_>| { + eprintln!("\nFields:"); + for field in rows.fields() { + eprint!(" {:?}", field.name()); + } + eprintln!(); + let guard = scopeguard::guard((), |_| { + eprintln!("Done"); + }); + move |row: DataRow<'_>| { + let _ = &guard; + eprintln!("Row:"); + for field in row.values() { + eprint!(" {:?}", field); + } + eprintln!(); + } + }, + |_: CopyOutResponse<'_>| { + eprintln!("\nCopy:"); + let guard = scopeguard::guard((), |_| { + eprintln!("Done"); + }); + move |data: CopyData<'_>| { + let _ = &guard; + eprintln!("Chunk:"); + for line in hexdump::hexdump_iter(data.data().as_ref()) { + eprintln!("{line}"); + } + } + }, + |error: ErrorResponse<'_>| { + eprintln!("\nError:\n {:?}", error); + }, + ) +} + +fn logging_sink_execute() -> impl ExecuteSink { + ( + || { + eprintln!(); + let guard = scopeguard::guard((), |_| { + eprintln!("Done"); + }); + move |row: DataRow<'_>| { + let _ = &guard; + eprintln!("Row:"); + for field in row.values() { + eprint!(" {:?}", field); + } + eprintln!(); + } + }, + |_: CopyOutResponse<'_>| { + eprintln!("\nCopy:"); + let guard = scopeguard::guard((), |_| { + eprintln!("Done"); + }); + move |data: CopyData<'_>| { + let _ = &guard; + eprintln!("Chunk:"); + for line in hexdump::hexdump_iter(data.data().as_ref()) { + eprintln!("{line}"); + } + } + }, + |error: ErrorResponse<'_>| { + eprintln!("\nError:\n {:?}", error); + }, + ) +} + +async fn run_queries( + target: Target, + credentials: Credentials, + statements: Vec, + extended: bool, +) -> Result<(), Box> { + let connector = Connector::new(target)?; + let (conn, task) = Client::new(credentials, connector); + tokio::task::spawn_local(task); + conn.ready().await?; + + eprintln!("Statements: {statements:?}"); + + for statement in statements { + if extended { + let conn = conn.clone(); + tokio::task::spawn_local(async move { + let pipeline = PipelineBuilder::default() + .parse(Statement::default(), &statement, &[], ()) + .describe_statement(Statement::default(), ()) + .bind( + Portal::default(), + Statement::default(), + &[], + &[Format::text()], + (), + ) + .describe_portal(Portal::default(), ()) + .execute( + Portal::default(), + MaxRows::Unlimited, + logging_sink_execute(), + ) + .build(); + conn.pipeline_sync(pipeline).await + }) + .await??; + } else { + tokio::task::spawn_local(conn.query(&statement, logging_sink())).await??; + } + } + + Ok(()) +} diff --git a/gel-pg-protocol/src/conn/conn.rs b/gel-pg-protocol/src/conn/conn.rs new file mode 100644 index 00000000..c37da39a --- /dev/null +++ b/gel-pg-protocol/src/conn/conn.rs @@ -0,0 +1,953 @@ +use super::{ + flow::{MessageHandler, MessageResult, Pipeline, QuerySink}, + raw_conn::RawClient, + Credentials, PGConnectionError, +}; +use crate::{ + connection::flow::{QueryMessageHandler, SyncMessageHandler}, + handshake::ConnectionSslRequirement, +}; +use futures::{future::Either, FutureExt}; +use gel_db_protocol::prelude::*; +use gel_pg_protocol::protocol::*; +use gel_stream::{Connector, Stream}; +use std::{ + cell::RefCell, + future::ready, + pin::Pin, + sync::Arc, + task::{ready, Poll}, +}; +use std::{ + collections::VecDeque, + future::{poll_fn, Future}, + rc::Rc, +}; +use tokio::io::ReadBuf; +use tracing::{error, trace, warn, Level}; + +#[derive(Debug, thiserror::Error)] +pub enum PGConnError { + #[error("Invalid state")] + InvalidState, + #[error("Postgres error: {0}")] + PgError(#[from] crate::errors::PgServerError), + #[error("Connection failed: {0}")] + Connection(#[from] PGConnectionError), + #[error("I/O error: {0}")] + Io(#[from] std::io::Error), + /// If an operation in a pipeline group fails, all operations up to + /// the next sync are skipped. + #[error("Operation skipped because of previous pipeline failure: {0}")] + Skipped(crate::errors::PgServerError), + #[error("Connection was closed")] + Closed, +} + +/// A client for a PostgreSQL connection. +/// +/// ``` +/// # use pgrust::connection::*; +/// # use gel_stream::{Target, Connector}; +/// # _ = async { +/// # let credentials = Credentials::default(); +/// # let connector = Connector::new(Target::new_tcp(("localhost", 1234))).unwrap(); +/// let (client, task) = Client::new(credentials, connector); +/// ::tokio::task::spawn_local(task); +/// +/// // Run a basic query +/// client.query("SELECT 1", ()).await?; +/// +/// // Run a pipelined extended query +/// client.pipeline_sync(PipelineBuilder::default() +/// .parse(Statement("stmt1"), "SELECT 1", &[], ()) +/// .bind(Portal("portal1"), Statement("stmt1"), &[], &[Format::text()], ()) +/// .execute(Portal("portal1"), MaxRows::Unlimited, ()) +/// .build()).await?; +/// # Ok::<(), PGConnError>(()) +/// # } +/// ``` +pub struct Client { + conn: Rc, +} + +impl Clone for Client { + fn clone(&self) -> Self { + Self { + conn: self.conn.clone(), + } + } +} + +impl Client { + pub fn new( + credentials: Credentials, + connector: Connector, + ) -> (Self, impl Future>) { + let conn = Rc::new(PGConn::new_connection(async move { + let ssl_mode = ConnectionSslRequirement::Optional; + let raw = RawClient::connect(credentials, ssl_mode, connector).await?; + Ok(raw) + })); + let task = conn.clone().task(); + (Self { conn }, task) + } + + /// Create a new PostgreSQL client and a background task. + pub fn new_raw(stm: RawClient) -> (Self, impl Future>) { + let conn = Rc::new(PGConn::new_raw(stm)); + let task = conn.clone().task(); + (Self { conn }, task) + } + + pub async fn ready(&self) -> Result<(), PGConnError> { + self.conn.ready().await + } + + /// Performs a bare `Query` operation. The sink handles the following messages: + /// + /// - `RowDescription` + /// - `DataRow` + /// - `CopyOutResponse` + /// - `CopyData` + /// - `CopyDone` + /// - `EmptyQueryResponse` + /// - `ErrorResponse` + /// + /// `CopyInResponse` is not currently supported and will result in a `CopyFail` being + /// sent to the server. + /// + /// Cancellation safety: if the future is dropped after the first time it is polled, the operation will + /// continue to callany callbacks and run to completion. If it has not been polled, the operation will + /// not be submitted. + pub fn query( + &self, + query: &str, + f: impl QuerySink + 'static, + ) -> impl Future> { + match self.conn.clone().query(query, f) { + Ok(f) => Either::Left(f), + Err(e) => Either::Right(ready(Err(e))), + } + } + + /// Performs a set of pipelined steps as a `Sync` group. + /// + /// Cancellation safety: if the future is dropped after the first time it is polled, the operation will + /// continue to callany callbacks and run to completion. If it has not been polled, the operation will + /// not be submitted. + pub fn pipeline_sync( + &self, + pipeline: Pipeline, + ) -> impl Future> { + match self.conn.clone().pipeline_sync(pipeline) { + Ok(f) => Either::Left(f), + Err(e) => Either::Right(ready(Err(e))), + } + } +} + +#[derive(derive_more::Debug)] +#[allow(clippy::type_complexity)] +enum ConnState { + #[debug("Connecting(..)")] + #[allow(clippy::type_complexity)] + Connecting(Pin>>>), + #[debug("Ready(..)")] + Ready { + stream: Pin>, + handlers: VecDeque<( + Box, + Option>, + )>, + }, + Error(PGConnError), + Closed, +} + +struct PGConn { + state: RefCell, + queue: RefCell>>, + ready_lock: Arc>, +} + +impl PGConn { + pub fn new_connection( + future: impl Future> + 'static, + ) -> Self { + Self { + state: ConnState::Connecting(future.boxed_local()).into(), + queue: Default::default(), + ready_lock: Default::default(), + } + } + + pub fn new_raw(stm: RawClient) -> Self { + let (stream, _params) = stm.into_parts(); + Self { + state: ConnState::Ready { + stream, + handlers: Default::default(), + } + .into(), + queue: Default::default(), + ready_lock: Default::default(), + } + } + + fn check_error(&self) -> Result<(), PGConnError> { + let state = &mut *self.state.borrow_mut(); + match state { + ConnState::Error(..) => { + let ConnState::Error(e) = std::mem::replace(state, ConnState::Closed) else { + unreachable!(); + }; + error!("Connection failed: {e:?}"); + Err(e) + } + ConnState::Closed => Err(PGConnError::Closed), + _ => Ok(()), + } + } + + #[inline(always)] + async fn ready(&self) -> Result<(), PGConnError> { + let _ = self.ready_lock.lock().await; + self.check_error() + } + + fn with_stream(&self, f: F) -> Result + where + F: FnOnce(Pin<&mut dyn Stream>) -> T, + { + match &mut *self.state.borrow_mut() { + ConnState::Ready { ref mut stream, .. } => Ok(f(stream.as_mut())), + _ => Err(PGConnError::InvalidState), + } + } + + fn write( + self: Rc, + message_handlers: Vec>, + buf: Vec, + ) -> Result, PGConnError> { + let (tx, rx) = tokio::sync::oneshot::channel(); + + self.clone().queue.borrow_mut().submit(async move { + // If the future was dropped before the first poll, we don't submit the operation + if tx.is_closed() { + return Ok(()); + } + + // Once we're polled the first time, we can add the handlers + match &mut *self.state.borrow_mut() { + ConnState::Ready { handlers, .. } => { + let mut handlers_iter = message_handlers.into_iter(); + let mut tx = Some(tx); + while let Some(handler) = handlers_iter.next() { + if handlers_iter.len() == 0 { + handlers.push_back((handler, tx.take())); + } else { + handlers.push_back((handler, None)); + } + } + } + x => { + warn!("Connection state was not ready: {x:?}"); + return Err(PGConnError::InvalidState); + } + } + + if tracing::enabled!(Level::TRACE) { + trace!("Write:"); + for s in hexdump::hexdump_iter(&buf) { + trace!("{}", s); + } + } + + let mut buf = &buf[..]; + + loop { + let n = poll_fn(|cx| { + self.with_stream(|stm| { + let n = match ready!(stm.poll_write(cx, buf)) { + Ok(n) => n, + Err(e) => return Poll::Ready(Err(PGConnError::Io(e))), + }; + Poll::Ready(Ok(n)) + })? + }) + .await?; + if n == buf.len() { + break; + } + buf = &buf[n..]; + } + + Ok(()) + }); + + Ok(rx) + } + + fn process_message(&self, message: Option) -> Result<(), PGConnError> { + let state = &mut *self.state.borrow_mut(); + match state { + ConnState::Ready { handlers, .. } => { + let message = message.ok_or(PGConnError::InvalidState)?; + if NotificationResponse::new(&message.as_ref()).is_ok() { + warn!("Notification: {:?}", message); + return Ok(()); + } + if ParameterStatus::new(&message.as_ref()).is_ok() { + warn!("ParameterStatus: {:?}", message); + return Ok(()); + } + if let Some((handler, _tx)) = handlers.front_mut() { + match handler.handle(message) { + MessageResult::SkipUntilSync => { + let mut found_sync = false; + let name = handler.name(); + while let Some((handler, _)) = handlers.front() { + if handler.is_sync() { + found_sync = true; + break; + } + trace!("skipping {}", handler.name()); + handlers.pop_front(); + } + if !found_sync { + warn!("Unexpected state in {name}: No sync handler found"); + } + } + MessageResult::Continue => {} + MessageResult::Done => { + handlers.pop_front(); + } + MessageResult::Unknown => { + // TODO: Should the be exposed to the API consumer? + warn!( + "Unknown message in {} ({:?})", + handler.name(), + message.mtype() as char + ); + } + MessageResult::UnexpectedState { complaint } => { + // TODO: Should the be exposed to the API consumer? + warn!( + "Unexpected state in {} while handling message ({:?}): {complaint}", + handler.name(), + message.mtype() as char + ); + } + }; + }; + } + ConnState::Connecting(..) => { + return Err(PGConnError::InvalidState); + } + ConnState::Error(..) | ConnState::Closed => self.check_error()?, + } + + Ok(()) + } + + pub fn task(self: Rc) -> impl Future> { + let ready_lock = self.ready_lock.clone().try_lock_owned().unwrap(); + async move { + poll_fn(|cx| { + let mut state = self.state.borrow_mut(); + match &mut *state { + ConnState::Connecting(fut) => match fut.poll_unpin(cx) { + Poll::Ready(result) => { + let raw = match result { + Ok(raw) => raw, + Err(e) => { + let error = PGConnError::Connection(e); + *state = ConnState::Error(error); + return Poll::Ready(Ok::<_, PGConnError>(())); + } + }; + let (stream, _params) = raw.into_parts(); + *state = ConnState::Ready { + stream, + handlers: Default::default(), + }; + Poll::Ready(Ok::<_, PGConnError>(())) + } + Poll::Pending => Poll::Pending, + }, + ConnState::Ready { .. } => Poll::Ready(Ok(())), + ConnState::Error(..) | ConnState::Closed => Poll::Ready(self.check_error()), + } + }) + .await?; + + drop(ready_lock); + + let mut buffer = StructBuffer::>::default(); + loop { + let mut read_buffer = [0; 1024]; + let n = poll_fn(|cx| { + // Poll the queue before we poll the read stream. Note that we toss + // the result here. Either we'll make progress or there's nothing to + // do. + while self.queue.borrow_mut().poll_next_unpin(cx).is_ready() {} + + self.with_stream(|stm| { + let mut buf = ReadBuf::new(&mut read_buffer); + let res = ready!(stm.poll_read(cx, &mut buf)); + Poll::Ready(res.map(|_| buf.filled().len())).map_err(PGConnError::Io) + })? + }) + .await?; + + if tracing::enabled!(Level::TRACE) { + trace!("Read:"); + for s in hexdump::hexdump_iter(&read_buffer[..n]) { + trace!("{}", s); + } + } + + buffer.push_fallible(&read_buffer[..n], |message| { + if let Ok(message) = &message { + if tracing::enabled!(Level::TRACE) { + trace!("Message ({:?})", message.mtype() as char); + for s in hexdump::hexdump_iter(message.as_ref()) { + trace!("{}", s); + } + } + }; + self.process_message(Some(message.map_err(PGConnectionError::ParseError)?)) + })?; + + if n == 0 { + break; + } + } + Ok(()) + } + } + + pub fn query( + self: Rc, + query: &str, + f: impl QuerySink + 'static, + ) -> Result>, PGConnError> { + trace!("Query task started: {query}"); + let message = QueryBuilder { query }.to_vec(); + let rx = self.write( + vec![Box::new(QueryMessageHandler { + sink: f, + data: None, + copy: None, + })], + message, + )?; + Ok(async { + _ = rx.await; + Ok(()) + }) + } + + pub fn pipeline_sync( + self: Rc, + pipeline: Pipeline, + ) -> Result>, PGConnError> { + trace!("Pipeline task started"); + let Pipeline { + mut messages, + mut handlers, + } = pipeline; + handlers.push(Box::new(SyncMessageHandler)); + messages.extend_from_slice(&SyncBuilder::default().to_vec()); + + let rx = self.write(handlers, messages)?; + Ok(async { + _ = rx.await; + Ok(()) + }) + } +} + +#[cfg(test)] +mod tests { + use hex_literal::hex; + use std::{fmt::Write, time::Duration}; + use tokio::{ + io::{AsyncReadExt, AsyncWriteExt}, + task::LocalSet, + time::timeout, + }; + + use crate::connection::{ + flow::{CopyDataSink, DataSink, DoneHandling}, + raw_conn::ConnectionParams, + }; + + use super::*; + + impl QuerySink for Rc> { + type Output = Self; + type CopyOutput = Self; + fn rows(&mut self, rows: RowDescription) -> Self { + write!(self.borrow_mut(), "[table=[").unwrap(); + for field in rows.fields() { + write!(self.borrow_mut(), "{},", field.name().to_string_lossy()).unwrap(); + } + write!(self.borrow_mut(), "]").unwrap(); + self.clone() + } + fn copy(&mut self, copy: CopyOutResponse) -> Self { + write!( + self.borrow_mut(), + "[copy={:?} {:?}", + copy.format(), + copy.format_codes() + ) + .unwrap(); + self.clone() + } + fn error(&mut self, error: ErrorResponse) { + for field in error.fields() { + if field.etype() as char == 'C' { + write!( + self.borrow_mut(), + "[error {}]", + field.value().to_string_lossy() + ) + .unwrap(); + return; + } + } + write!(self.borrow_mut(), "[error ??? {:?}]", error).unwrap(); + } + } + + impl DataSink for Rc> { + fn row(&mut self, row: DataRow) { + write!(self.borrow_mut(), "[").unwrap(); + for value in row.values() { + write!(self.borrow_mut(), "{},", value.to_string_lossy()).unwrap(); + } + write!(self.borrow_mut(), "]").unwrap(); + } + fn done(&mut self, result: Result) -> DoneHandling { + match result { + Ok(complete) => { + write!( + self.borrow_mut(), + " done={}]", + complete.tag().to_string_lossy() + ) + .unwrap(); + } + Err(error) => { + for field in error.fields() { + if field.etype() as char == 'C' { + write!( + self.borrow_mut(), + "[error {}]]", + field.value().to_string_lossy() + ) + .unwrap(); + return DoneHandling::Handled; + } + } + write!(self.borrow_mut(), "[error ??? {:?}]]", error).unwrap(); + } + } + DoneHandling::Handled + } + } + + impl CopyDataSink for Rc> { + fn data(&mut self, data: CopyData) { + write!( + self.borrow_mut(), + "[{}]", + String::from_utf8_lossy(data.data().as_ref()) + ) + .unwrap(); + } + fn done(&mut self, result: Result) -> DoneHandling { + match result { + Ok(complete) => { + write!( + self.borrow_mut(), + " done={}]", + complete.tag().to_string_lossy() + ) + .unwrap(); + } + Err(error) => { + for field in error.fields() { + if field.etype() as char == 'C' { + write!( + self.borrow_mut(), + "[error {}]]", + field.value().to_string_lossy() + ) + .unwrap(); + return DoneHandling::Handled; + } + } + write!(self.borrow_mut(), "[error ??? {:?}]]", error).unwrap(); + } + } + DoneHandling::Handled + } + } + + async fn read_expect(stream: &mut S, expected: &[u8]) { + let mut buf = vec![0u8; expected.len()]; + stream.read_exact(&mut buf).await.unwrap(); + assert_eq!(buf, expected); + } + + /// Perform a test using captured binary protocol data from a real server. + async fn run_expect( + query_task: impl FnOnce(Client, Rc>) -> F + 'static, + expect: &'static [(&[u8], &[u8], &str)], + ) { + let f = async move { + let (mut s1, s2) = tokio::io::duplex(1024 * 1024); + + let (client, task) = Client::new_raw(RawClient::new(s2, ConnectionParams::default())); + let task_handle = tokio::task::spawn_local(task); + + let handle = tokio::task::spawn_local(async move { + let log = Rc::new(RefCell::new(String::new())); + query_task(client, log.clone()).await; + Rc::try_unwrap(log).unwrap().into_inner() + }); + + let mut log_expect = String::new(); + for (read, write, expect) in expect { + // Query[text=""] + eprintln!("read {read:?}"); + read_expect(&mut s1, read).await; + eprintln!("write {write:?}"); + s1.write_all(write).await.unwrap(); + log_expect.push_str(expect); + } + + let log = handle.await.unwrap(); + + assert_eq!(log, log_expect); + + // EOF to trigger the task to exit + drop(s1); + + task_handle.await.unwrap().unwrap(); + }; + + let local = LocalSet::new(); + let task = local.spawn_local(f); + + timeout(Duration::from_secs(1), local).await.unwrap(); + + // Ensure we detect panics inside the task + task.await.unwrap(); + } + + #[test_log::test(tokio::test)] + async fn query_select_1() { + run_expect( + |client, log| async move { + client.query("SELECT 1", log.clone()).await.unwrap(); + }, + &[( + &hex!("51000000 0d53454c 45435420 3100"), + // T, D, C, Z + &hex!("54000000 2100013f 636f6c75 6d6e3f00 00000000 00000000 00170004 ffffffff 00004400 00000b00 01000000 01314300 00000d53 454c4543 54203100 5a000000 0549"), + "[table=[?column?,][1,] done=SELECT 1]", + )], + ) + .await; + } + + #[test_log::test(tokio::test)] + async fn query_select_1_limit_0() { + run_expect( + |client, log| async move { + client.query("SELECT 1 LIMIT 0", log.clone()).await.unwrap(); + }, + &[( + &hex!("51000000 1553454c 45435420 31204c49 4d495420 3000"), + // T, C, Z + &hex!("54000000 2100013f 636f6c75 6d6e3f00 00000000 00000000 00170004 ffffffff 00004300 00000d53 454c4543 54203000 5a000000 0549"), + "[table=[?column?,] done=SELECT 0]", + )], + ) + .await; + } + + #[test_log::test(tokio::test)] + async fn query_copy_1() { + run_expect( + |client, log| async move { + client.query("copy (select 1) to stdout;", log.clone()).await.unwrap(); + }, + &[( + &hex!("51000000 1f636f70 79202873 656c6563 74203129 20746f20 7374646f 75743b00"), + // H, d, c, C, Z + &hex!("48000000 09000001 00006400 00000631 0a630000 00044300 00000b43 4f505920 31005a00 00000549"), + "[copy=0 [0][1\n] done=COPY 1]", + )], + ) + .await; + } + + #[test_log::test(tokio::test)] + async fn query_copy_1_limit_0() { + run_expect( + |client, log| async move { + client.query("copy (select 1 limit 0) to stdout;", log.clone()).await.unwrap(); + }, + &[( + &hex!("51000000 27636f70 79202873 656c6563 74203120 6c696d69 74203029 20746f20 7374646f 75743b00"), + // H, c, C, Z + &hex!("48000000 09000001 00006300 00000443 0000000b 434f5059 2030005a 00000005 49"), + "[copy=0 [0] done=COPY 0]", + )], + ) + .await; + } + + #[test_log::test(tokio::test)] + async fn query_copy_with_error_rows() { + run_expect( + |client, log| async move { + client.query("copy (select case when id = 2 then id/(id-2) else id end from (select generate_series(1,2) as id)) to stdout;", log.clone()).await.unwrap(); + }, + &[( + &hex!(""" + 51000000 72636f70 79202873 656c6563 + 74206361 73652077 68656e20 6964203d + 20322074 68656e20 69642f28 69642d32 + 2920656c 73652069 6420656e 64206672 + 6f6d2028 73656c65 63742067 656e6572 + 6174655f 73657269 65732831 2c322920 + 61732069 64292920 746f2073 74646f75 + 743b00 + """), + // H, d, E, Z + &hex!(""" + 48000000 09000001 00006400 00000631 + 0a450000 00415345 52524f52 00564552 + 524f5200 43323230 3132004d 64697669 + 73696f6e 20627920 7a65726f 0046696e + 742e6300 4c383431 0052696e 74346469 + 7600005a 00000005 49 + """), + "[copy=0 [0][1\n][error 22012]]", + )], + ) + .await; + } + + #[test_log::test(tokio::test)] + async fn query_error() { + run_expect( + |client, log| async move { + client.query("do $$begin raise exception 'hi'; end$$;", log.clone()).await.unwrap(); + }, + &[( + &hex!("51000000 2c646f20 24246265 67696e20 72616973 65206578 63657074 696f6e20 27686927 3b20656e 6424243b 00"), + // E, Z + &hex!(""" + 45000000 75534552 524f5200 56455252 + 4f520043 50303030 31004d68 69005750 + 4c2f7067 53514c20 66756e63 74696f6e + 20696e6c 696e655f 636f6465 5f626c6f + 636b206c 696e6520 31206174 20524149 + 53450046 706c5f65 7865632e 63004c33 + 39313100 52657865 635f7374 6d745f72 + 61697365 00005a00 00000549 + """), + "[error P0001]", + )], + ) + .await; + } + + #[test_log::test(tokio::test)] + async fn query_empty_do() { + run_expect( + |client, log| async move { + client + .query("do $$begin end$$;", log.clone()) + .await + .unwrap(); + }, + &[( + &hex!("51000000 16646f20 24246265 67696e20 656e6424 243b00"), + // C, Z + &hex!(""" + 43000000 07444f00 5a000000 0549 + """), + "", + )], + ) + .await; + } + + #[test_log::test(tokio::test)] + async fn query_error_with_rows() { + run_expect( + |client, log| async move { + client.query("select case when id = 2 then id/(id-2) else 1 end from (select 1 as id union all select 2 as id);", log.clone()).await.unwrap(); + }, + &[( + &hex!(""" + 51000000 6673656c 65637420 63617365 + 20776865 6e206964 203d2032 20746865 + 6e206964 2f286964 2d322920 656c7365 + 20312065 6e642066 726f6d20 2873656c + 65637420 31206173 20696420 756e696f + 6e20616c 6c207365 6c656374 20322061 + 73206964 293b00 + """), + // T, D, E, Z + &hex!(""" + 54000000 1d000163 61736500 00000000 + 00000000 00170004 ffffffff 00004400 + 00000b00 01000000 01314500 00004153 + 4552524f 52005645 52524f52 00433232 + 30313200 4d646976 6973696f 6e206279 + 207a6572 6f004669 6e742e63 004c3834 + 31005269 6e743464 69760000 5a000000 + 0549 + """), + "[table=[case,][1,][error 22012]]", + )], + ) + .await; + } + + #[test_log::test(tokio::test)] + async fn query_second_errors() { + run_expect( + |client, log| async move { + client + .query("select; select 1/0;", log.clone()) + .await + .unwrap(); + }, + &[( + &hex!("51000000 1873656c 6563743b 2073656c 65637420 312f303b 00"), + // T, D, C, E, Z + &hex!(""" + 54000000 06000044 00000006 00004300 + 00000d53 454c4543 54203100 45000000 + 41534552 524f5200 56455252 4f520043 + 32323031 32004d64 69766973 696f6e20 + 6279207a 65726f00 46696e74 2e63004c + 38343100 52696e74 34646976 00005a00 + 00000549 + """), + "[table=[][] done=SELECT 1][error 22012]", + )], + ) + .await; + } + + #[test_log::test(tokio::test)] + async fn query_notification() { + run_expect( + |client, log| async move { + client + .query("listen a; select pg_notify('a','b')", log.clone()) + .await + .unwrap(); + }, + &[( + &hex!( + " + 51000000 286c6973 74656e20 613b2073 + 656c6563 74207067 5f6e6f74 69667928 + 2761272c 27622729 00 + " + ), + // C, T, D, C, A, Z + &hex!( + " + 43000000 0b4c4953 54454e00 54000000 + 22000170 675f6e6f 74696679 00000000 + 00000000 0008e600 04ffffff ff000044 + 0000000a 00010000 00004300 00000d53 + 454c4543 54203100 41000000 0c002cba + 5f610062 005a0000 000549 + " + ), + "[table=[pg_notify,][,] done=SELECT 1]", + )], + ) + .await; + } + + #[test_log::test(tokio::test)] + async fn query_two_empty() { + run_expect( + |client, log| async move { + client.query("", log.clone()).await.unwrap(); + client.query("", log.clone()).await.unwrap(); + }, + &[ + ( + &hex!("51000000 0500"), + // I, Z + &hex!("49000000 045a0000 000549"), + "", + ), + ( + &hex!("51000000 0500"), + // I, Z + &hex!("49000000 045a0000 000549"), + "", + ), + ], + ) + .await; + } + + #[test_log::test(tokio::test)] + async fn query_two_error() { + run_expect( + |client, log| async move { + client.query(".", log.clone()).await.unwrap(); + client.query(".", log.clone()).await.unwrap(); + }, + &[ + ( + &hex!("51000000 062e00"), + // E, Z + &hex!(""" + 45000000 59534552 524f5200 56455252 + 4f520043 34323630 31004d73 796e7461 + 78206572 726f7220 6174206f 72206e65 + 61722022 2e220050 31004673 63616e2e + 6c004c31 32343400 52736361 6e6e6572 + 5f797965 72726f72 00005a00 00000549 + """), + "[error 42601]", + ), + ( + &hex!("51000000 062e00"), + // E, Z + &hex!(""" + 45000000 59534552 524f5200 56455252 + 4f520043 34323630 31004d73 796e7461 + 78206572 726f7220 6174206f 72206e65 + 61722022 2e220050 31004673 63616e2e + 6c004c31 32343400 52736361 6e6e6572 + 5f797965 72726f72 00005a00 00000549 + """), + "[error 42601]", + ), + ], + ) + .await; + } +} diff --git a/gel-pg-protocol/src/conn/flow.rs b/gel-pg-protocol/src/conn/flow.rs new file mode 100644 index 00000000..df255363 --- /dev/null +++ b/gel-pg-protocol/src/conn/flow.rs @@ -0,0 +1,1224 @@ +//! Postgres flow notes: +//! +//! +//! +//! +//! +//! Extended query messages Parse, Bind, Describe, Execute, Close put the server +//! into a "skip-til-sync" mode when erroring. All messages other than Terminate (including +//! those not part of the extended query protocol) are skipped until an explicit Sync message is received. +//! +//! Sync closes _implicit_ but not _explicit_ transactions. +//! +//! Both Query and Execute may return COPY responses rather than rows. In the case of Query, +//! RowDescription + DataRow is replaced by CopyOutResponse + CopyData + CopyDone. In the case +//! of Execute, describing the portal will return NoData, but Execute will return CopyOutResponse + +//! CopyData + CopyDone. + +use std::{cell::RefCell, num::NonZeroU32, rc::Rc}; + +use gel_db_protocol::{match_message, Encoded}; +use gel_pg_protocol::protocol::*; + +#[derive(Debug, Clone, Copy)] +pub enum Param<'a> { + Null, + Text(&'a str), + Binary(&'a [u8]), +} + +#[derive(Debug, Clone, Copy, bytemuck::Pod, bytemuck::Zeroable)] +#[repr(transparent)] +pub struct Oid(u32); + +impl Oid { + pub fn unspecified() -> Self { + Self(0) + } + + pub fn from(oid: NonZeroU32) -> Self { + Self(oid.get()) + } +} + +#[derive(Debug, Clone, Copy, bytemuck::Pod, bytemuck::Zeroable)] +#[repr(transparent)] +pub struct Format(i16); + +impl Format { + pub fn text() -> Self { + Self(0) + } + + pub fn binary() -> Self { + Self(1) + } +} + +#[derive(Debug, Clone, Copy)] +#[repr(i32)] +pub enum MaxRows { + Unlimited, + Limited(NonZeroU32), +} + +#[derive(Debug, Clone, Copy, Default)] +pub struct Portal<'a>(pub &'a str); + +#[derive(Debug, Clone, Copy, Default)] +pub struct Statement<'a>(pub &'a str); + +pub trait Flow { + fn to_vec(&self) -> Vec; +} + +/// Performs a prepared statement parse operation. +/// +/// Handles: +/// - `ParseComplete` +/// - `ErrorResponse` +#[derive(Debug, Clone, Copy)] +struct ParseFlow<'a> { + pub name: Statement<'a>, + pub query: &'a str, + pub param_types: &'a [Oid], +} + +/// Performs a prepared statement bind operation. +/// +/// Handles: +/// - `BindComplete` +/// - `ErrorResponse` +#[derive(Debug, Clone, Copy)] +struct BindFlow<'a> { + pub portal: Portal<'a>, + pub statement: Statement<'a>, + pub params: &'a [Param<'a>], + pub result_format_codes: &'a [Format], +} + +/// Performs a prepared statement execute operation. +/// +/// Handles: +/// - `CommandComplete` +/// - `DataRow` +/// - `PortalSuspended` +/// - `CopyOutResponse` +/// - `CopyData` +/// - `CopyDone` +/// - `ErrorResponse` +#[derive(Debug, Clone, Copy)] +struct ExecuteFlow<'a> { + pub portal: Portal<'a>, + pub max_rows: MaxRows, +} + +/// Performs a portal describe operation. +/// +/// Handles: +/// - `RowDescription` +/// - `NoData` +/// - `ErrorResponse` +#[derive(Debug, Clone, Copy)] +struct DescribePortalFlow<'a> { + pub name: Portal<'a>, +} + +/// Performs a statement describe operation. +/// +/// Handles: +/// - `RowDescription` +/// - `NoData` +/// - `ParameterDescription` +/// - `ErrorResponse` +#[derive(Debug, Clone, Copy)] +struct DescribeStatementFlow<'a> { + pub name: Statement<'a>, +} + +/// Performs a portal close operation. +/// +/// Handles: +/// - `CloseComplete` +/// - `ErrorResponse` +#[derive(Debug, Clone, Copy)] +struct ClosePortalFlow<'a> { + pub name: Portal<'a>, +} + +/// Performs a statement close operation. +/// +/// Handles: +/// - `CloseComplete` +/// - `ErrorResponse` +#[derive(Debug, Clone, Copy)] +struct CloseStatementFlow<'a> { + pub name: Statement<'a>, +} + +/// Performs a query operation. +/// +/// Handles: +/// - `EmptyQueryResponse`: If no queries were specified in the text +/// - `CommandComplete`: For each fully-completed query +/// - `RowDescription`: For each query that returns data +/// - `DataRow`: For each row returned by a query +/// - `CopyOutResponse`: For each query that returns copy data +/// - `CopyData`: For each chunk of copy data returned by a query +/// - `CopyDone`: For each query that returns copy data +/// - `ErrorResponse`: For the first failed query +#[derive(Debug, Clone, Copy)] +struct QueryFlow<'a> { + pub query: &'a str, +} + +impl Flow for ParseFlow<'_> { + fn to_vec(&self) -> Vec { + let param_types = bytemuck::cast_slice(self.param_types); + ParseBuilder { + statement: self.name.0, + query: self.query, + param_types, + } + .to_vec() + } +} + +impl Flow for BindFlow<'_> { + fn to_vec(&self) -> Vec { + let mut format_codes = Vec::with_capacity(self.params.len()); + let mut values = Vec::with_capacity(self.params.len()); + + for param in self.params { + match param { + Param::Null => { + format_codes.push(0); + values.push(Encoded::Null); + } + Param::Text(value) => { + format_codes.push(0); + values.push(Encoded::Value(value.as_bytes())); + } + Param::Binary(value) => { + format_codes.push(1); + values.push(Encoded::Value(value)); + } + } + } + + let result_format_codes = bytemuck::cast_slice(self.result_format_codes); + + BindBuilder { + portal: self.portal.0, + statement: self.statement.0, + format_codes: &format_codes, + values: &values, + result_format_codes, + } + .to_vec() + } +} + +impl Flow for ExecuteFlow<'_> { + fn to_vec(&self) -> Vec { + let max_rows = match self.max_rows { + MaxRows::Unlimited => 0, + MaxRows::Limited(n) => n.get() as i32, + }; + ExecuteBuilder { + portal: self.portal.0, + max_rows, + } + .to_vec() + } +} + +impl Flow for DescribePortalFlow<'_> { + fn to_vec(&self) -> Vec { + DescribeBuilder { + name: self.name.0, + dtype: DescribeType::Portal, + } + .to_vec() + } +} + +impl Flow for DescribeStatementFlow<'_> { + fn to_vec(&self) -> Vec { + DescribeBuilder { + name: self.name.0, + dtype: DescribeType::Statement, + } + .to_vec() + } +} + +impl Flow for ClosePortalFlow<'_> { + fn to_vec(&self) -> Vec { + CloseBuilder { + name: self.name.0, + ctype: CloseType::Portal, + } + .to_vec() + } +} + +impl Flow for CloseStatementFlow<'_> { + fn to_vec(&self) -> Vec { + CloseBuilder { + name: self.name.0, + ctype: CloseType::Statement, + } + .to_vec() + } +} + +impl Flow for QueryFlow<'_> { + fn to_vec(&self) -> Vec { + QueryBuilder { query: self.query }.to_vec() + } +} + +pub(crate) enum MessageResult { + Continue, + Done, + SkipUntilSync, + Unknown, + UnexpectedState { complaint: &'static str }, +} + +pub(crate) trait MessageHandler { + fn handle(&mut self, message: Message) -> MessageResult; + fn name(&self) -> &'static str; + fn is_sync(&self) -> bool { + false + } +} + +pub(crate) struct SyncMessageHandler; + +impl MessageHandler for SyncMessageHandler { + fn handle(&mut self, message: Message) -> MessageResult { + if ReadyForQuery::new(&message.as_ref()).is_ok() { + return MessageResult::Done; + } + MessageResult::Unknown + } + fn name(&self) -> &'static str { + "Sync" + } + fn is_sync(&self) -> bool { + true + } +} + +impl MessageHandler for (&'static str, F) +where + F: for<'a> FnMut(Message<'a>) -> MessageResult, +{ + fn handle(&mut self, message: Message) -> MessageResult { + (self.1)(message) + } + fn name(&self) -> &'static str { + self.0 + } +} + +pub trait FlowWithSink { + fn visit_flow(&self, f: impl FnMut(&dyn Flow)); + fn make_handler(self) -> Box; +} + +pub trait SimpleFlowSink { + fn handle(&mut self, result: Result<(), ErrorResponse>); +} + +impl SimpleFlowSink for () { + fn handle(&mut self, _: Result<(), ErrorResponse>) {} +} + +impl FnMut(Result<(), ErrorResponse>)> SimpleFlowSink for F { + fn handle(&mut self, result: Result<(), ErrorResponse>) { + (self)(result) + } +} + +impl FlowWithSink for (ParseFlow<'_>, S) { + fn visit_flow(&self, mut f: impl FnMut(&dyn Flow)) { + f(&self.0); + } + fn make_handler(mut self) -> Box { + Box::new(("Parse", move |message: Message<'_>| { + if ParseComplete::new(&message.as_ref()).is_ok() { + self.1.handle(Ok(())); + return MessageResult::Done; + } + if let Ok(msg) = ErrorResponse::new(&message.as_ref()) { + self.1.handle(Err(msg)); + return MessageResult::SkipUntilSync; + } + MessageResult::Unknown + })) + } +} + +impl FlowWithSink for (BindFlow<'_>, S) { + fn visit_flow(&self, mut f: impl FnMut(&dyn Flow)) { + f(&self.0); + } + fn make_handler(mut self) -> Box { + Box::new(("Bind", move |message: Message<'_>| { + if BindComplete::new(&message.as_ref()).is_ok() { + self.1.handle(Ok(())); + return MessageResult::Done; + } + if let Ok(msg) = ErrorResponse::new(&message.as_ref()) { + self.1.handle(Err(msg)); + return MessageResult::SkipUntilSync; + } + MessageResult::Unknown + })) + } +} + +impl FlowWithSink for (ClosePortalFlow<'_>, S) { + fn visit_flow(&self, mut f: impl FnMut(&dyn Flow)) { + f(&self.0); + } + fn make_handler(mut self) -> Box { + Box::new(("ClosePortal", move |message: Message<'_>| { + if CloseComplete::new(&message.as_ref()).is_ok() { + self.1.handle(Ok(())); + return MessageResult::Done; + } + if let Ok(msg) = ErrorResponse::new(&message.as_ref()) { + self.1.handle(Err(msg)); + return MessageResult::SkipUntilSync; + } + MessageResult::Unknown + })) + } +} + +impl FlowWithSink for (CloseStatementFlow<'_>, S) { + fn visit_flow(&self, mut f: impl FnMut(&dyn Flow)) { + f(&self.0); + } + fn make_handler(mut self) -> Box { + Box::new(("CloseStatement", move |message: Message<'_>| { + if CloseComplete::new(&message.as_ref()).is_ok() { + self.1.handle(Ok(())); + return MessageResult::Done; + } + if let Ok(msg) = ErrorResponse::new(&message.as_ref()) { + self.1.handle(Err(msg)); + return MessageResult::Done; + } + MessageResult::Unknown + })) + } +} + +impl FlowWithSink for (ExecuteFlow<'_>, S) { + fn visit_flow(&self, mut f: impl FnMut(&dyn Flow)) { + f(&self.0); + } + fn make_handler(self) -> Box { + Box::new(ExecuteMessageHandler { + sink: self.1, + data: None, + copy: None, + }) + } +} + +impl FlowWithSink for (QueryFlow<'_>, S) { + fn visit_flow(&self, mut f: impl FnMut(&dyn Flow)) { + f(&self.0); + } + fn make_handler(self) -> Box { + Box::new(QueryMessageHandler { + sink: self.1, + data: None, + copy: None, + }) + } +} + +impl FlowWithSink for (DescribePortalFlow<'_>, S) { + fn visit_flow(&self, mut f: impl FnMut(&dyn Flow)) { + f(&self.0); + } + fn make_handler(self) -> Box { + Box::new(DescribeMessageHandler { sink: self.1 }) + } +} + +impl FlowWithSink for (DescribeStatementFlow<'_>, S) { + fn visit_flow(&self, mut f: impl FnMut(&dyn Flow)) { + f(&self.0); + } + fn make_handler(self) -> Box { + Box::new(DescribeMessageHandler { sink: self.1 }) + } +} + +pub trait DescribeSink { + fn params(&mut self, params: ParameterDescription); + fn rows(&mut self, rows: RowDescription); + fn error(&mut self, error: ErrorResponse); +} + +impl DescribeSink for () { + fn params(&mut self, _: ParameterDescription) {} + fn rows(&mut self, _: RowDescription) {} + fn error(&mut self, _: ErrorResponse) {} +} + +impl DescribeSink for F +where + F: for<'a> FnMut(RowDescription<'a>), +{ + fn rows(&mut self, rows: RowDescription) { + (self)(rows) + } + fn params(&mut self, _params: ParameterDescription) {} + fn error(&mut self, _error: ErrorResponse) {} +} + +impl DescribeSink for (F1, F2) +where + F1: for<'a> FnMut(ParameterDescription<'a>), + F2: for<'a> FnMut(RowDescription<'a>), +{ + fn params(&mut self, params: ParameterDescription) { + (self.0)(params) + } + fn rows(&mut self, rows: RowDescription) { + (self.1)(rows) + } + fn error(&mut self, _error: ErrorResponse) {} +} + +struct DescribeMessageHandler { + sink: S, +} + +impl MessageHandler for DescribeMessageHandler { + fn name(&self) -> &'static str { + "Describe" + } + fn handle(&mut self, message: Message) -> MessageResult { + match_message!(Ok(message), Backend { + (ParameterDescription as params) => { + self.sink.params(params); + return MessageResult::Continue; + }, + (RowDescription as rows) => { + self.sink.rows(rows); + return MessageResult::Done; + }, + (NoData) => { + return MessageResult::Done; + }, + (ErrorResponse as err) => { + self.sink.error(err); + return MessageResult::SkipUntilSync; + }, + _unknown => { + return MessageResult::Unknown; + } + }) + } +} + +pub trait ExecuteSink { + type Output: ExecuteDataSink; + type CopyOutput: CopyDataSink; + + fn rows(&mut self) -> Self::Output; + fn copy(&mut self, copy: CopyOutResponse) -> Self::CopyOutput; + fn complete(&mut self, _complete: ExecuteCompletion) {} + fn notice(&mut self, _: NoticeResponse) {} + fn error(&mut self, error: ErrorResponse); +} + +pub enum ExecuteCompletion<'a> { + PortalSuspended(PortalSuspended<'a>), + CommandComplete(CommandComplete<'a>), +} + +impl ExecuteSink for () { + type Output = (); + type CopyOutput = (); + fn rows(&mut self) {} + fn copy(&mut self, _: CopyOutResponse) {} + fn error(&mut self, _: ErrorResponse) {} +} + +impl ExecuteSink for (F1, F2) +where + F1: for<'a> FnMut() -> S, + F2: for<'a> FnMut(ErrorResponse<'a>), + S: ExecuteDataSink, +{ + type Output = S; + type CopyOutput = (); + fn rows(&mut self) -> S { + (self.0)() + } + fn copy(&mut self, _: CopyOutResponse) {} + fn error(&mut self, error: ErrorResponse) { + (self.1)(error) + } +} + +impl ExecuteSink for (F1, F2, F3) +where + F1: for<'a> FnMut() -> S, + F2: for<'a> FnMut(CopyOutResponse<'a>) -> T, + F3: for<'a> FnMut(ErrorResponse<'a>), + S: ExecuteDataSink, + T: CopyDataSink, +{ + type Output = S; + type CopyOutput = T; + fn rows(&mut self) -> S { + (self.0)() + } + fn copy(&mut self, copy: CopyOutResponse) -> T { + (self.1)(copy) + } + fn error(&mut self, error: ErrorResponse) { + (self.2)(error) + } +} + +pub trait ExecuteDataSink { + /// Sink a row of data. + fn row(&mut self, values: DataRow); + /// Handle the completion of a command. If unimplemented, will be redirected to the parent. + #[must_use] + fn done(&mut self, _result: Result) -> DoneHandling { + DoneHandling::RedirectToParent + } +} + +impl ExecuteDataSink for () { + fn row(&mut self, _: DataRow) {} +} + +impl ExecuteDataSink for F +where + F: for<'a> Fn(DataRow<'a>), +{ + fn row(&mut self, values: DataRow) { + (self)(values) + } +} + +/// A sink capable of handling standard query and COPY (out direction) messages. +pub trait QuerySink { + type Output: DataSink; + type CopyOutput: CopyDataSink; + + fn rows(&mut self, rows: RowDescription) -> Self::Output; + fn copy(&mut self, copy: CopyOutResponse) -> Self::CopyOutput; + fn complete(&mut self, _complete: CommandComplete) {} + fn notice(&mut self, _: NoticeResponse) {} + fn error(&mut self, error: ErrorResponse); +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum DoneHandling { + Handled, + RedirectToParent, +} + +pub trait DataSink { + /// Sink a row of data. + fn row(&mut self, values: DataRow); + /// Handle the completion of a command. If unimplemented, will be redirected to the parent. + #[must_use] + fn done(&mut self, _result: Result) -> DoneHandling { + DoneHandling::RedirectToParent + } +} + +pub trait CopyDataSink { + /// Sink a chunk of COPY data. + fn data(&mut self, values: CopyData); + /// Handle the completion of a COPY operation. If unimplemented, will be redirected to the parent. + #[must_use] + fn done(&mut self, _result: Result) -> DoneHandling { + DoneHandling::RedirectToParent + } +} + +impl QuerySink for Box +where + Q: QuerySink + 'static, +{ + type Output = Box; + type CopyOutput = Box; + fn rows(&mut self, rows: RowDescription) -> Self::Output { + Box::new(self.as_mut().rows(rows)) + } + fn copy(&mut self, copy: CopyOutResponse) -> Self::CopyOutput { + Box::new(self.as_mut().copy(copy)) + } + fn complete(&mut self, _complete: CommandComplete) { + self.as_mut().complete(_complete) + } + fn error(&mut self, error: ErrorResponse) { + self.as_mut().error(error) + } +} + +impl QuerySink for () { + type Output = (); + type CopyOutput = (); + fn rows(&mut self, _: RowDescription) {} + fn copy(&mut self, _: CopyOutResponse) {} + fn error(&mut self, _: ErrorResponse) {} +} + +impl QuerySink for (F1, F2) +where + F1: for<'a> FnMut(RowDescription<'a>) -> S, + F2: for<'a> FnMut(ErrorResponse<'a>), + S: DataSink, +{ + type Output = S; + type CopyOutput = (); + fn rows(&mut self, rows: RowDescription) -> S { + (self.0)(rows) + } + fn copy(&mut self, _: CopyOutResponse) {} + fn error(&mut self, error: ErrorResponse) { + (self.1)(error) + } +} + +impl QuerySink for (F1, F2, F3) +where + F1: for<'a> FnMut(RowDescription<'a>) -> S, + F2: for<'a> FnMut(CopyOutResponse<'a>) -> T, + F3: for<'a> FnMut(ErrorResponse<'a>), + S: DataSink, + T: CopyDataSink, +{ + type Output = S; + type CopyOutput = T; + fn rows(&mut self, rows: RowDescription) -> S { + (self.0)(rows) + } + fn copy(&mut self, copy: CopyOutResponse) -> T { + (self.1)(copy) + } + fn error(&mut self, error: ErrorResponse) { + (self.2)(error) + } +} + +impl DataSink for () { + fn row(&mut self, _: DataRow) {} +} + +impl DataSink for F +where + F: for<'a> Fn(DataRow<'a>), +{ + fn row(&mut self, values: DataRow) { + (self)(values) + } +} + +impl DataSink for Box { + fn row(&mut self, values: DataRow) { + self.as_mut().row(values) + } + fn done(&mut self, result: Result) -> DoneHandling { + self.as_mut().done(result) + } +} + +impl CopyDataSink for () { + fn data(&mut self, _: CopyData) {} +} + +impl CopyDataSink for F +where + F: for<'a> FnMut(CopyData<'a>), +{ + fn data(&mut self, values: CopyData) { + (self)(values) + } +} + +impl CopyDataSink for Box { + fn data(&mut self, values: CopyData) { + self.as_mut().data(values) + } + fn done(&mut self, result: Result) -> DoneHandling { + self.as_mut().done(result) + } +} + +pub(crate) struct ExecuteMessageHandler { + pub sink: Q, + pub data: Option, + pub copy: Option, +} + +impl MessageHandler for ExecuteMessageHandler { + fn name(&self) -> &'static str { + "Execute" + } + fn handle(&mut self, message: Message) -> MessageResult { + match_message!(Ok(message), Backend { + (CopyOutResponse as copy) => { + let sink = std::mem::replace(&mut self.copy, Some(self.sink.copy(copy))); + if sink.is_some() { + return MessageResult::UnexpectedState { complaint: "copy sink exists" }; + } + }, + (CopyData as data) => { + if let Some(sink) = &mut self.copy { + sink.data(data); + } else { + return MessageResult::UnexpectedState { complaint: "copy sink does not exist" }; + } + }, + (CopyDone) => { + if self.copy.is_none() { + return MessageResult::UnexpectedState { complaint: "copy sink does not exist" }; + } + }, + (DataRow as row) => { + if self.data.is_none() { + self.data = Some(self.sink.rows()); + } + let Some(sink) = &mut self.data else { + unreachable!() + }; + sink.row(row) + }, + (PortalSuspended as complete) => { + if let Some(mut sink) = std::mem::take(&mut self.data) { + if sink.done(Ok(ExecuteCompletion::PortalSuspended(complete))) == DoneHandling::RedirectToParent { + self.sink.complete(ExecuteCompletion::PortalSuspended(complete)); + } + } else { + return MessageResult::UnexpectedState { complaint: "data sink does not exist" }; + } + return MessageResult::Done; + }, + (CommandComplete as complete) => { + if let Some(mut sink) = std::mem::take(&mut self.copy) { + // If COPY has started, route this to the COPY sink. + if sink.done(Ok(complete)) == DoneHandling::RedirectToParent { + self.sink.complete(ExecuteCompletion::CommandComplete(complete)); + } + } else if let Some(mut sink) = std::mem::take(&mut self.data) { + // If data has started, route this to the data sink. + if sink.done(Ok(ExecuteCompletion::CommandComplete(complete))) == DoneHandling::RedirectToParent { + self.sink.complete(ExecuteCompletion::CommandComplete(complete)); + } + } else { + // Otherwise, create a new data sink and route to there. + if self.sink.rows().done(Ok(ExecuteCompletion::CommandComplete(complete))) == DoneHandling::RedirectToParent { + self.sink.complete(ExecuteCompletion::CommandComplete(complete)); + } + } + return MessageResult::Done; + }, + (EmptyQueryResponse) => { + // TODO: This should be exposed to the sink + return MessageResult::Done; + }, + + (ErrorResponse as err) => { + if let Some(mut sink) = std::mem::take(&mut self.copy) { + // If COPY has started, route this to the COPY sink. + if sink.done(Err(err)) == DoneHandling::RedirectToParent { + self.sink.error(err); + } + } else if let Some(mut sink) = std::mem::take(&mut self.data) { + // If data has started, route this to the data sink. + if sink.done(Err(err)) == DoneHandling::RedirectToParent { + self.sink.error(err); + } + } else { + // Otherwise, create a new data sink and route to there. + if self.sink.rows().done(Err(err)) == DoneHandling::RedirectToParent { + self.sink.error(err); + } + } + + return MessageResult::SkipUntilSync; + }, + (NoticeResponse as notice) => { + self.sink.notice(notice); + }, + + _unknown => { + return MessageResult::Unknown; + } + }); + MessageResult::Continue + } +} + +pub(crate) struct QueryMessageHandler { + pub sink: Q, + pub data: Option, + pub copy: Option, +} + +impl MessageHandler for QueryMessageHandler { + fn name(&self) -> &'static str { + "Query" + } + fn handle(&mut self, message: Message) -> MessageResult { + match_message!(Ok(message), Backend { + (CopyOutResponse as copy) => { + let sink = std::mem::replace(&mut self.copy, Some(self.sink.copy(copy))); + if sink.is_some() { + return MessageResult::UnexpectedState { complaint: "copy sink exists" }; + } + }, + (CopyData as data) => { + if let Some(sink) = &mut self.copy { + sink.data(data); + } else { + return MessageResult::UnexpectedState { complaint: "copy sink does not exist" }; + } + }, + (CopyDone) => { + if self.copy.is_none() { + return MessageResult::UnexpectedState { complaint: "copy sink does not exist" }; + } + }, + + (RowDescription as row) => { + let sink = std::mem::replace(&mut self.data, Some(self.sink.rows(row))); + if sink.is_some() { + return MessageResult::UnexpectedState { complaint: "data sink exists" }; + } + }, + (DataRow as row) => { + if let Some(sink) = &mut self.data { + sink.row(row) + } else { + return MessageResult::UnexpectedState { complaint: "data sink does not exist" }; + } + }, + (CommandComplete as complete) => { + let sink = std::mem::take(&mut self.data); + if let Some(mut sink) = sink { + if sink.done(Ok(complete)) == DoneHandling::RedirectToParent { + self.sink.complete(complete); + } + } else { + let sink = std::mem::take(&mut self.copy); + if let Some(mut sink) = sink { + if sink.done(Ok(complete)) == DoneHandling::RedirectToParent { + self.sink.complete(complete); + } + } else { + self.sink.complete(complete); + } + } + }, + + (EmptyQueryResponse) => { + // Equivalent to CommandComplete, but no data was provided + let sink = std::mem::take(&mut self.data); + if sink.is_some() { + return MessageResult::UnexpectedState { complaint: "data sink exists" }; + } else { + let sink = std::mem::take(&mut self.copy); + if sink.is_some() { + return MessageResult::UnexpectedState { complaint: "copy sink exists" }; + } + } + }, + + (ErrorResponse as err) => { + // Depending on the state of the sink, we direct the error to + // the appropriate handler. + if let Some(mut sink) = std::mem::take(&mut self.data) { + if sink.done(Err(err)) == DoneHandling::RedirectToParent { + self.sink.error(err); + } + } else if let Some(mut sink) = std::mem::take(&mut self.copy) { + if sink.done(Err(err)) == DoneHandling::RedirectToParent { + self.sink.error(err); + } + } else { + // Top level errors must complete this operation + self.sink.error(err); + } + }, + (NoticeResponse as notice) => { + self.sink.notice(notice); + }, + + (ReadyForQuery) => { + // All operations are complete at this point. + if std::mem::take(&mut self.data).is_some() || std::mem::take(&mut self.copy).is_some() { + return MessageResult::UnexpectedState { complaint: "sink exists" }; + } + return MessageResult::Done; + }, + + _unknown => { + return MessageResult::Unknown; + } + }); + MessageResult::Continue + } +} + +#[derive(Default)] +pub struct PipelineBuilder { + handlers: Vec>, + messages: Vec, +} + +impl PipelineBuilder { + fn push_flow_with_sink(mut self, flow: impl FlowWithSink) -> Self { + flow.visit_flow(|flow| self.messages.extend_from_slice(&flow.to_vec())); + self.handlers.push(flow.make_handler()); + self + } + + /// Add a bind flow to the pipeline. + pub fn bind( + self, + portal: Portal, + statement: Statement, + params: &[Param], + result_format_codes: &[Format], + handler: impl SimpleFlowSink + 'static, + ) -> Self { + self.push_flow_with_sink(( + BindFlow { + portal, + statement, + params, + result_format_codes, + }, + handler, + )) + } + + /// Add a parse flow to the pipeline. + pub fn parse( + self, + name: Statement, + query: &str, + param_types: &[Oid], + handler: impl SimpleFlowSink + 'static, + ) -> Self { + self.push_flow_with_sink(( + ParseFlow { + name, + query, + param_types, + }, + handler, + )) + } + + /// Add an execute flow to the pipeline. + /// + /// Note that this may be a COPY statement. In that case, the description of the portal + /// will not show any data returned, and this will use the `CopySink` of the provided + /// sink. In addition, COPY operations do not respect the `max_rows` parameter. + pub fn execute( + self, + portal: Portal, + max_rows: MaxRows, + handler: impl ExecuteSink + 'static, + ) -> Self { + self.push_flow_with_sink((ExecuteFlow { portal, max_rows }, handler)) + } + + /// Add a close portal flow to the pipeline. + pub fn close_portal(self, name: Portal, handler: impl SimpleFlowSink + 'static) -> Self { + self.push_flow_with_sink((ClosePortalFlow { name }, handler)) + } + + /// Add a close statement flow to the pipeline. + pub fn close_statement(self, name: Statement, handler: impl SimpleFlowSink + 'static) -> Self { + self.push_flow_with_sink((CloseStatementFlow { name }, handler)) + } + + /// Add a describe portal flow to the pipeline. Note that this will describe + /// both parameters and rows. + pub fn describe_portal(self, name: Portal, handler: impl DescribeSink + 'static) -> Self { + self.push_flow_with_sink((DescribePortalFlow { name }, handler)) + } + + /// Add a describe statement flow to the pipeline. Note that this will describe + /// only the rows of the portal. + pub fn describe_statement(self, name: Statement, handler: impl DescribeSink + 'static) -> Self { + self.push_flow_with_sink((DescribeStatementFlow { name }, handler)) + } + + /// Add a query flow to the pipeline. + /// + /// Note that if a query fails, the pipeline will continue executing until it + /// completes or a non-query pipeline element fails. If a previous non-query + /// element of this pipeline failed, the query will not be executed. + pub fn query(self, query: &str, handler: impl QuerySink + 'static) -> Self { + self.push_flow_with_sink((QueryFlow { query }, handler)) + } + + pub fn build(self) -> Pipeline { + Pipeline { + handlers: self.handlers, + messages: self.messages, + } + } +} + +pub struct Pipeline { + pub(crate) handlers: Vec>, + pub(crate) messages: Vec, +} + +#[derive(Default)] +/// Accumulate raw messages from a flow. Useful mainly for testing. +pub struct FlowAccumulator { + data: Vec, + messages: Vec, +} + +impl FlowAccumulator { + pub fn push(&mut self, message: impl AsRef<[u8]>) { + self.messages.push(self.data.len()); + self.data.extend_from_slice(message.as_ref()); + } + + pub fn with_messages(&self, mut f: impl FnMut(Message)) { + for &offset in &self.messages { + // First get the message header + let message = Message::new(&self.data[offset..]).unwrap(); + let len = message.mlen(); + // Then resize the message to the correct length + let message = Message::new(&self.data[offset..offset + (len.0 as usize) + 1]).unwrap(); + f(message); + } + } +} + +impl QuerySink for Rc> { + type Output = Self; + type CopyOutput = Self; + fn rows(&mut self, message: RowDescription) -> Self { + self.borrow_mut().push(message); + self.clone() + } + fn copy(&mut self, message: CopyOutResponse) -> Self { + self.borrow_mut().push(message); + self.clone() + } + fn error(&mut self, message: ErrorResponse) { + self.borrow_mut().push(message); + } + fn complete(&mut self, complete: CommandComplete) { + self.borrow_mut().push(complete); + } + fn notice(&mut self, message: NoticeResponse) { + self.borrow_mut().push(message); + } +} + +impl ExecuteSink for Rc> { + type Output = Self; + type CopyOutput = Self; + + fn rows(&mut self) -> Self { + self.clone() + } + fn copy(&mut self, message: CopyOutResponse) -> Self { + self.borrow_mut().push(message); + self.clone() + } + fn error(&mut self, message: ErrorResponse) { + self.borrow_mut().push(message); + } + fn complete(&mut self, complete: ExecuteCompletion) { + match complete { + ExecuteCompletion::PortalSuspended(suspended) => self.borrow_mut().push(suspended), + ExecuteCompletion::CommandComplete(complete) => self.borrow_mut().push(complete), + } + } + fn notice(&mut self, message: NoticeResponse) { + self.borrow_mut().push(message); + } +} + +impl DataSink for Rc> { + fn row(&mut self, message: DataRow) { + self.borrow_mut().push(message); + } + fn done(&mut self, result: Result) -> DoneHandling { + match result { + Ok(complete) => self.borrow_mut().push(complete), + Err(err) => self.borrow_mut().push(err), + }; + DoneHandling::Handled + } +} + +impl ExecuteDataSink for Rc> { + fn row(&mut self, message: DataRow) { + self.borrow_mut().push(message); + } + fn done(&mut self, result: Result) -> DoneHandling { + match result { + Ok(ExecuteCompletion::PortalSuspended(suspended)) => self.borrow_mut().push(suspended), + Ok(ExecuteCompletion::CommandComplete(complete)) => self.borrow_mut().push(complete), + Err(err) => self.borrow_mut().push(err), + }; + DoneHandling::Handled + } +} + +impl CopyDataSink for Rc> { + fn data(&mut self, message: CopyData) { + self.borrow_mut().push(message); + } + fn done(&mut self, result: Result) -> DoneHandling { + match result { + Ok(complete) => self.borrow_mut().push(complete), + Err(err) => self.borrow_mut().push(err), + }; + DoneHandling::Handled + } +} + +impl SimpleFlowSink for Rc> { + fn handle(&mut self, result: Result<(), ErrorResponse>) { + match result { + Ok(()) => (), + Err(err) => self.borrow_mut().push(err), + } + } +} + +impl DescribeSink for Rc> { + fn params(&mut self, params: ParameterDescription) { + self.borrow_mut().push(params); + } + fn rows(&mut self, rows: RowDescription) { + self.borrow_mut().push(rows); + } + fn error(&mut self, error: ErrorResponse) { + self.borrow_mut().push(error); + } +} diff --git a/gel-pg-protocol/src/conn/mod.rs b/gel-pg-protocol/src/conn/mod.rs new file mode 100644 index 00000000..67595718 --- /dev/null +++ b/gel-pg-protocol/src/conn/mod.rs @@ -0,0 +1,94 @@ +use std::collections::HashMap; + +use crate::errors::{edgedb::EdbError, PgServerError}; +mod conn; +mod flow; +pub(crate) mod queue; +mod raw_conn; + +pub use conn::{Client, PGConnError}; +pub use flow::{ + CopyDataSink, DataSink, DoneHandling, ExecuteSink, FlowAccumulator, Format, MaxRows, Oid, + Param, Pipeline, PipelineBuilder, Portal, QuerySink, Statement, +}; +use gel_db_protocol::prelude::ParseError; +use gel_stream::ConnectionError; +pub use raw_conn::RawClient; + +macro_rules! __invalid_state { + ($error:literal) => {{ + eprintln!( + "Invalid connection state: {}\n{}", + $error, + ::std::backtrace::Backtrace::capture() + ); + #[allow(deprecated)] + $crate::connection::PGConnectionError::__InvalidState + }}; +} +pub(crate) use __invalid_state as invalid_state; + +#[derive(Debug, thiserror::Error)] +pub enum PGConnectionError { + /// Invalid state error, suggesting a logic error in code rather than a server or client failure. + /// Use the `invalid_state!` macro instead which will print a backtrace. + #[error("Invalid state")] + #[deprecated = "Use invalid_state!"] + __InvalidState, + + /// Error during connection setup. + #[error("Connection error: {0}")] + ConnectionError(#[from] ConnectionError), + + /// Error returned by the server. + #[error("Server error: {0}")] + ServerError(#[from] PgServerError), + + /// Error returned by the server. + #[error("Server error: {0}")] + EdbServerError(#[from] EdbError), + + /// The server sent something we didn't expect + #[error("Unexpected server response: {0}")] + UnexpectedResponse(String), + + /// Error related to SCRAM authentication. + #[error("SCRAM: {0}")] + Scram(#[from] gel_auth::scram::SCRAMError), + + /// I/O error encountered during connection operations. + #[error("I/O error: {0}")] + Io(#[from] std::io::Error), + + /// UTF-8 decoding error. + #[error("UTF8 error: {0}")] + Utf8Error(#[from] std::str::Utf8Error), + + /// SSL-related error. + #[error("SSL error: {0}")] + SslError(#[from] SslError), + + #[error("Protocol error: {0}")] + ParseError(#[from] ParseError), +} + +#[derive(Debug, thiserror::Error)] +pub enum SslError { + #[error("SSL is not supported by this client transport")] + SslUnsupportedByClient, + #[error("SSL was required by the client, but not offered by server (rejected SSL)")] + SslRequiredByClient, + #[error("OpenSSL error: {0}")] + OpenSslError(#[from] ::openssl::ssl::Error), + #[error("OpenSSL error: {0}")] + OpenSslErrorStack(#[from] ::openssl::error::ErrorStack), +} + +/// A sufficient set of required parameters to connect to a given transport. +#[derive(Clone, Default, derive_more::Debug)] +pub struct Credentials { + pub username: String, + pub password: String, + pub database: String, + pub server_settings: HashMap, +} diff --git a/gel-pg-protocol/src/conn/queue.rs b/gel-pg-protocol/src/conn/queue.rs new file mode 100644 index 00000000..211beed5 --- /dev/null +++ b/gel-pg-protocol/src/conn/queue.rs @@ -0,0 +1,166 @@ +use std::future::Future; +use std::ops::DerefMut; +use std::pin::Pin; +use std::task::{Context, Poll}; + +/// A queue of futures that can be polled in order. +/// +/// Only one future will be active at a time. If no futures are active, the +/// waker will be triggered when the next future is submitted to the queue. +pub struct FutureQueue { + queue: tokio::sync::mpsc::UnboundedReceiver>>>, + sender: tokio::sync::mpsc::UnboundedSender>>>, + current: Option>>>, +} + +#[cfg(test)] +#[derive(Clone)] +pub struct FutureQueueSender { + sender: tokio::sync::mpsc::UnboundedSender>>>, +} + +#[cfg(test)] +impl FutureQueueSender { + pub fn submit(&self, future: impl Future + 'static) { + // This will never fail because the receiver still exists + self.sender.send(Box::pin(future)).unwrap(); + } +} + +impl FutureQueue { + #[cfg(test)] + pub fn sender(&self) -> FutureQueueSender { + FutureQueueSender { + sender: self.sender.clone(), + } + } + + pub fn submit(&self, future: impl Future + 'static) { + // This will never fail because we hold both ends of the channel. + self.sender.send(Box::pin(future)).unwrap(); + } + + /// Poll the current future, or no current future, poll for the next item + /// from the queue (and then poll that future). + pub fn poll_next_unpin(&mut self, cx: &mut Context<'_>) -> Poll> { + loop { + if let Some(future) = self.current.as_mut() { + match future.as_mut().poll(cx) { + Poll::Ready(output) => { + self.current = None; + return Poll::Ready(Some(output)); + } + Poll::Pending => return Poll::Pending, + } + } + + // If there is no current future, try to receive the next one from the queue. + let next = match self.queue.poll_recv(cx) { + Poll::Ready(Some(next)) => next, + Poll::Ready(None) => return Poll::Ready(None), + Poll::Pending => return Poll::Pending, + }; + + // Note that we loop around to poll this future until we get a Pending + // result. + self.current = Some(next); + } + } +} + +impl Default for FutureQueue { + fn default() -> Self { + let (sender, receiver) = tokio::sync::mpsc::unbounded_channel(); + Self { + queue: receiver, + sender, + current: None, + } + } +} + +impl futures::Stream for FutureQueue { + type Item = T; + + fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + // We're Unpin + let this = self.deref_mut(); + this.poll_next_unpin(cx) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use futures::StreamExt; + use tokio::{ + task::LocalSet, + time::{sleep, Duration}, + }; + + #[tokio::test] + async fn test_basic_queue() { + LocalSet::new() + .run_until(async { + let mut queue = FutureQueue::default(); + let sender = queue.sender(); + + // Spawn a task that sends some futures + tokio::task::spawn_local(async move { + sleep(Duration::from_millis(10)).await; + sender.submit(async { 1 }); + sleep(Duration::from_millis(10)).await; + sender.submit(async { 2 }); + sleep(Duration::from_millis(10)).await; + sender.submit(async { 3 }); + }); + + // Collect results + let mut results = Vec::new(); + while let Some(value) = queue.next().await { + results.push(value); + if results.len() == 3 { + break; + } + } + + assert_eq!(results, vec![1, 2, 3]); + }) + .await; + } + + #[tokio::test] + async fn test_delayed_futures() { + LocalSet::new() + .run_until(async { + let mut queue = FutureQueue::default(); + let sender = queue.sender(); + + // Spawn task with delayed futures + tokio::task::spawn_local(async move { + sleep(Duration::from_millis(10)).await; + sender.submit(async { + sleep(Duration::from_millis(50)).await; + 1 + }); + sleep(Duration::from_millis(10)).await; + sender.submit(async { + sleep(Duration::from_millis(10)).await; + 2 + }); + }); + + // Even though second future completes first, results should be in order of sending + let mut results = Vec::new(); + while let Some(value) = queue.next().await { + results.push(value); + if results.len() == 2 { + break; + } + } + + assert_eq!(results, vec![1, 2]); + }) + .await; + } +} diff --git a/gel-pg-protocol/src/conn/raw_conn.rs b/gel-pg-protocol/src/conn/raw_conn.rs new file mode 100644 index 00000000..7f5c5901 --- /dev/null +++ b/gel-pg-protocol/src/conn/raw_conn.rs @@ -0,0 +1,232 @@ +use super::{invalid_state, Credentials, PGConnectionError}; +use crate::handshake::{ + client::{ + ConnectionDrive, ConnectionState, ConnectionStateSend, ConnectionStateType, + ConnectionStateUpdate, + }, + ConnectionSslRequirement, +}; +use gel_auth::AuthType; +use gel_db_protocol::StructBuffer; +use gel_pg_protocol::protocol::{FrontendBuilder, InitialBuilder, Message, SSLResponse}; +use gel_stream::{ConnectionError, Connector, Stream, StreamUpgrade}; +use std::collections::HashMap; +use std::pin::Pin; +use tokio::io::AsyncWriteExt; +use tracing::{trace, Level}; + +#[derive(Clone, Default, Debug)] +pub struct ConnectionParams { + pub ssl: bool, + pub params: HashMap, + pub cancellation_key: (i32, i32), + pub auth: AuthType, +} + +pub struct ConnectionDriver { + send_buffer: Vec, + upgrade: bool, + params: ConnectionParams, +} + +impl ConnectionStateSend for ConnectionDriver { + fn send_initial(&mut self, message: InitialBuilder) -> Result<(), std::io::Error> { + self.send_buffer.extend(message.to_vec()); + Ok(()) + } + fn send(&mut self, message: FrontendBuilder) -> Result<(), std::io::Error> { + self.send_buffer.extend(message.to_vec()); + Ok(()) + } + fn upgrade(&mut self) -> Result<(), std::io::Error> { + self.upgrade = true; + self.params.ssl = true; + Ok(()) + } +} + +impl ConnectionStateUpdate for ConnectionDriver { + fn state_changed(&mut self, state: ConnectionStateType) { + trace!("State: {state:?}"); + } + fn cancellation_key(&mut self, pid: i32, key: i32) { + self.params.cancellation_key = (pid, key); + } + fn parameter(&mut self, name: &str, value: &str) { + self.params.params.insert(name.to_owned(), value.to_owned()); + } + fn auth(&mut self, auth: AuthType) { + trace!("Auth: {auth:?}"); + self.params.auth = auth; + } +} + +impl ConnectionDriver { + pub fn new() -> Self { + Self { + send_buffer: Vec::new(), + upgrade: false, + params: ConnectionParams::default(), + } + } + + async fn drive_bytes( + &mut self, + state: &mut ConnectionState, + drive: &[u8], + message_buffer: &mut StructBuffer>, + stream: &mut S, + ) -> Result<(), PGConnectionError> { + message_buffer.push_fallible(drive, |msg| { + state.drive(ConnectionDrive::Message(msg), self) + })?; + loop { + if !self.send_buffer.is_empty() { + if tracing::enabled!(Level::TRACE) { + trace!("Write:"); + for s in hexdump::hexdump_iter(&self.send_buffer) { + trace!("{}", s); + } + } + stream.write_all(&self.send_buffer).await?; + self.send_buffer.clear(); + } + if self.upgrade { + self.upgrade = false; + stream + .secure_upgrade() + .await + .map_err(ConnectionError::from)?; + state.drive(ConnectionDrive::SslReady, self)?; + } else { + break; + } + } + Ok(()) + } + + async fn drive( + &mut self, + state: &mut ConnectionState, + drive: ConnectionDrive<'_>, + stream: &mut S, + ) -> Result<(), PGConnectionError> { + state.drive(drive, self)?; + loop { + if !self.send_buffer.is_empty() { + if tracing::enabled!(Level::TRACE) { + trace!("Write:"); + for s in hexdump::hexdump_iter(&self.send_buffer) { + trace!("{}", s); + } + } + stream.write_all(&self.send_buffer).await?; + self.send_buffer.clear(); + } + if self.upgrade { + self.upgrade = false; + stream + .secure_upgrade() + .await + .map_err(ConnectionError::from)?; + state.drive(ConnectionDrive::SslReady, self)?; + } else { + break; + } + } + Ok(()) + } +} + +/// A raw client connection stream to a Postgres server, fully authenticated and +/// ready to send queries. +/// +/// This can be connected to a remote server using `connect`, or can be created +/// with a pre-existing, pre-authenticated stream. +#[derive(derive_more::Debug)] +pub struct RawClient { + #[debug(skip)] + stream: Pin>, + params: ConnectionParams, +} + +impl RawClient { + /// Create a new `RawClient` from a given fully-authenticated stream. + #[inline] + pub fn new(stream: S, params: ConnectionParams) -> Self { + Self { + stream: Box::pin(stream), + params, + } + } + + /// Create a new `RawClient` from a given fully-authenticated and boxed stream. + #[inline] + pub fn new_boxed(stream: Box, params: ConnectionParams) -> Self { + Self { + stream: Box::into_pin(stream), + params, + } + } + + /// Attempt to connect to a Postgres server using a given connector and SSL requirement. + pub async fn connect( + credentials: Credentials, + ssl_mode: ConnectionSslRequirement, + connector: Connector, + ) -> Result { + let mut state = ConnectionState::new(credentials, ssl_mode); + let mut stream = connector.connect().await?; + + let mut update = ConnectionDriver::new(); + update + .drive(&mut state, ConnectionDrive::Initial, &mut stream) + .await?; + + let mut struct_buffer: StructBuffer> = + StructBuffer::>::default(); + + while !state.is_ready() { + let mut buffer = [0; 1024]; + let n = tokio::io::AsyncReadExt::read(&mut stream, &mut buffer).await?; + if n == 0 { + Err(std::io::Error::from(std::io::ErrorKind::UnexpectedEof))?; + } + if tracing::enabled!(Level::TRACE) { + trace!("Read:"); + let bytes: &[u8] = &buffer[..n]; + for s in hexdump::hexdump_iter(bytes) { + trace!("{}", s); + } + } + if state.read_ssl_response() { + let ssl_response = SSLResponse::new(&buffer)?; + update + .drive( + &mut state, + ConnectionDrive::SslResponse(ssl_response), + &mut stream, + ) + .await?; + continue; + } + + update + .drive_bytes(&mut state, &buffer[..n], &mut struct_buffer, &mut stream) + .await?; + } + + // This should not be possible -- we've fully upgraded the stream by now + let Ok(stream) = stream.into_boxed() else { + return Err(invalid_state!("Connection was not ready")); + }; + + Ok(RawClient::new_boxed(stream, update.params)) + } + + /// Consume the `RawClient` and return the underlying stream and connection parameters. + #[inline] + pub fn into_parts(self) -> (Pin>, ConnectionParams) { + (self.stream, self.params) + } +} From 0c4b08d93b9435313de094ab7eb934f8fb222f9b Mon Sep 17 00:00:00 2001 From: Matt Mastracci Date: Thu, 5 Jun 2025 14:54:56 -0600 Subject: [PATCH 2/2] . --- gel-pg-protocol/src/conn/flow.rs | 2 +- gel-pg-protocol/src/lib.rs | 1 + 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/gel-pg-protocol/src/conn/flow.rs b/gel-pg-protocol/src/conn/flow.rs index df255363..8ef0ccdd 100644 --- a/gel-pg-protocol/src/conn/flow.rs +++ b/gel-pg-protocol/src/conn/flow.rs @@ -18,7 +18,7 @@ use std::{cell::RefCell, num::NonZeroU32, rc::Rc}; use gel_db_protocol::{match_message, Encoded}; -use gel_pg_protocol::protocol::*; +use crate::protocol::*; #[derive(Debug, Clone, Copy)] pub enum Param<'a> { diff --git a/gel-pg-protocol/src/lib.rs b/gel-pg-protocol/src/lib.rs index 89942246..5e802b70 100644 --- a/gel-pg-protocol/src/lib.rs +++ b/gel-pg-protocol/src/lib.rs @@ -1,4 +1,5 @@ pub mod errors; pub mod protocol; +pub mod conn; pub use gel_db_protocol::prelude;