From 8686ee83567cee6c4322cd5d415113ce09b60305 Mon Sep 17 00:00:00 2001 From: Serial <69764315+Serial-ATA@users.noreply.github.com> Date: Wed, 11 Jun 2025 21:31:31 -0400 Subject: [PATCH 1/8] feat(transactor): add websocket backend --- Cargo.toml | 16 +- src/lib.rs | 3 + src/services/mod.rs | 1 + src/services/rpc/mod.rs | 105 ++++++++ src/services/rpc/util.rs | 42 +++ src/services/transactor/backend/http.rs | 117 ++++++++ src/services/transactor/backend/mod.rs | 26 ++ src/services/transactor/backend/ws.rs | 343 ++++++++++++++++++++++++ src/services/transactor/document.rs | 64 +++-- src/services/transactor/methods.rs | 30 +++ src/services/transactor/mod.rs | 115 ++++++-- src/services/transactor/person.rs | 14 +- 12 files changed, 808 insertions(+), 68 deletions(-) create mode 100644 src/services/rpc/mod.rs create mode 100644 src/services/rpc/util.rs create mode 100644 src/services/transactor/backend/http.rs create mode 100644 src/services/transactor/backend/mod.rs create mode 100644 src/services/transactor/backend/ws.rs create mode 100644 src/services/transactor/methods.rs diff --git a/Cargo.toml b/Cargo.toml index db50f63..e32c0cf 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -10,10 +10,8 @@ reqwest = { version = "0.12.15", default-features = false, features = [ "json", "rustls-tls", ] } -reqwest-middleware = { version = "0.4.2", features = ["json", "rustls-tls"] } -reqwest-retry = { version = "0.7.0" } -reqwest-ratelimit = "0.4.1" governor = { version = "0.10.0", features = ["std"] } +reqwest-websocket = { version = "0.5.0", features = ["json"] } serde = "1.0.219" serde_json = "1.0.140" thiserror = "2.0.12" @@ -27,6 +25,8 @@ config = "0.15.11" secrecy = { version = "0.10.3", features = ["serde"] } serde_with = "3.12.0" rand = "0.9.1" +futures = "0.3.31" +tokio_with_wasm = { version = "0.8.6", features = ["rt", "sync", "macros"] } actix-web = { version = "4.10.2", optional = true, features = ["rustls"] } rdkafka = { version = "0.38.0", optional = true, features = [ @@ -38,6 +38,14 @@ num-traits = "0.2.19" itoa = "1.0.15" ryu = "1.0.20" +# Middleware +reqwest-middleware = { version = "0.4.2", features = ["json", "rustls-tls"] } +reqwest-retry = { version = "0.7.0", optional = true } +reqwest-ratelimit = { version = "0.4.1", optional = true } + +[target.'cfg(target_family = "wasm")'.dependencies] +wasmtimer = { version = "0.4.1" } + [dev-dependencies] anyhow = "1.0.98" tokio = { version = "1", features = ["full"] } @@ -46,4 +54,4 @@ tokio = { version = "1", features = ["full"] } default = ["reqwest_middleware"] actix = ["dep:actix-web"] kafka = ["dep:rdkafka"] -reqwest_middleware = [] +reqwest_middleware = ["dep:reqwest-retry", "dep:reqwest-ratelimit"] diff --git a/src/lib.rs b/src/lib.rs index 660a98b..19f4952 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -32,6 +32,9 @@ pub enum Error { #[error(transparent)] Reqwest(#[from] reqwest::Error), + #[error(transparent)] + Ws(#[from] reqwest_websocket::Error), + #[error(transparent)] ReqwestMiddleware(#[from] reqwest_middleware::Error), diff --git a/src/services/mod.rs b/src/services/mod.rs index 648b63f..a71a989 100644 --- a/src/services/mod.rs +++ b/src/services/mod.rs @@ -17,6 +17,7 @@ pub mod account; pub mod core; pub mod jwt; pub mod kvs; +mod rpc; pub mod transactor; pub use reqwest_middleware::{ClientWithMiddleware as HttpClient, RequestBuilder}; diff --git a/src/services/rpc/mod.rs b/src/services/rpc/mod.rs new file mode 100644 index 0000000..d03b8a4 --- /dev/null +++ b/src/services/rpc/mod.rs @@ -0,0 +1,105 @@ +pub mod util; + +use crate::services::Status; +use crate::services::core::Account; +use serde::{Deserialize, Serialize}; + +#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Eq, Hash)] +#[serde(untagged, rename_all = "camelCase")] +pub enum ReqId { + Str(String), + Num(i32), +} + +impl From for ReqId { + fn from(s: String) -> Self { + ReqId::Str(s) + } +} + +impl From for ReqId { + fn from(i: i32) -> Self { + ReqId::Num(i) + } +} + +#[derive(Serialize, Deserialize, Debug, Clone)] +#[serde(rename_all = "camelCase")] +pub struct RateLimitInfo { + pub remaining: u32, + pub limit: u32, + pub current: u32, + pub reset: f64, + #[serde(skip_serializing_if = "Option::is_none")] + pub retry_after: Option, +} + +#[derive(Serialize, Deserialize, Debug, Clone)] +#[serde(rename_all = "camelCase")] +pub struct Chunk { + pub index: u32, + pub r#final: bool, +} + +#[derive(Serialize, Deserialize, Default, Debug, Clone)] +#[serde(rename_all = "camelCase")] +pub struct Response { + #[serde(skip_serializing_if = "Option::is_none")] + pub result: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub id: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub error: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub terminate: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub rate_limit: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub chunk: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub time: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub bfst: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub queue: Option, +} + +#[derive(Serialize, Deserialize, Debug, Clone)] +#[serde(rename_all = "camelCase")] +pub struct Request

{ + #[serde(skip_serializing_if = "Option::is_none")] + pub id: Option, + pub method: String, + pub params: Vec

, + #[serde(skip_serializing_if = "Option::is_none")] + pub time: Option, +} + +#[derive(Serialize, Deserialize, Debug, Clone)] +#[serde(rename_all = "camelCase")] +pub struct HelloRequest { + #[serde(flatten)] + pub request: Request<()>, + #[serde(skip_serializing_if = "Option::is_none")] + pub binary: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub compression: Option, +} + +#[derive(Serialize, Deserialize, Debug, Clone)] +#[serde(rename_all = "camelCase")] +pub struct HelloResponse { + #[serde(flatten)] + pub response: Response, + pub binary: bool, + #[serde(skip_serializing_if = "Option::is_none")] + pub reconnect: Option, + pub server_version: String, + #[serde(skip_serializing_if = "Option::is_none")] + pub last_tx: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub last_hash: Option, + pub account: Account, + #[serde(skip_serializing_if = "Option::is_none")] + pub use_compression: Option, +} diff --git a/src/services/rpc/util.rs b/src/services/rpc/util.rs new file mode 100644 index 0000000..3cdd852 --- /dev/null +++ b/src/services/rpc/util.rs @@ -0,0 +1,42 @@ +use crate::services::Status; +use crate::services::rpc::{Chunk, RateLimitInfo, ReqId, Response}; +use serde::{Deserialize, Serialize}; + +#[derive(Serialize, Deserialize, Default, Debug, Clone)] +#[serde(rename_all = "camelCase")] +pub struct OkResponse { + #[serde(skip_serializing_if = "Option::is_none")] + pub result: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub id: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub terminate: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub rate_limit: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub chunk: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub time: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub bfst: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub queue: Option, +} + +impl Response { + pub fn into_result(self) -> Result, Status> { + match self.error { + Some(e) => Err(e), + None => Ok(OkResponse { + result: self.result, + id: self.id, + terminate: self.terminate, + rate_limit: self.rate_limit, + chunk: self.chunk, + time: self.time, + bfst: self.bfst, + queue: self.queue, + }), + } + } +} diff --git a/src/services/transactor/backend/http.rs b/src/services/transactor/backend/http.rs new file mode 100644 index 0000000..5931bc0 --- /dev/null +++ b/src/services/transactor/backend/http.rs @@ -0,0 +1,117 @@ +use crate::Result; +use crate::services::core::WorkspaceUuid; +use crate::services::transactor::methods::Method; +use crate::services::{JsonClient, TokenProvider}; +use reqwest_middleware::ClientWithMiddleware; +use secrecy::{ExposeSecret, SecretString}; +use serde::Serialize; +use serde::de::DeserializeOwned; +use serde_json::Value; +use std::sync::Arc; +use url::Url; +use crate::services::transactor::backend::Backend; + +pub type HttpClient = ClientWithMiddleware; + +struct HttpBackendInner { + workspace: WorkspaceUuid, + base: Url, + client: HttpClient, + token: SecretString, +} + +#[derive(Clone)] +pub struct HttpBackend { + inner: Arc, +} + +impl HttpBackend { + pub fn new( + client: HttpClient, + base: Url, + workspace: WorkspaceUuid, + token: impl Into, + ) -> Self { + Self { + inner: Arc::new(HttpBackendInner { + workspace, + base, + client, + token: token.into(), + }), + } + } + + pub(crate) async fn post_path( + &self, + path: &str, + body: &Q, + ) -> Result { + let url = self.base().join(path)?; + ::post(&self.inner.client, self, url, body).await + } +} + +impl JsonClient for HttpBackend { + fn get( + &self, + user: &U, + url: Url, + ) -> impl Future> { + JsonClient::get(&self.inner.client, user, url) + } + + fn post( + &self, + user: &U, + url: Url, + body: &Q, + ) -> impl Future> { + JsonClient::post(&self.inner.client, user, url, body) + } +} + +impl TokenProvider for HttpBackend { + fn provide_token(&self) -> Option<&str> { + Some(self.inner.token.expose_secret()) + } +} + +impl TokenProvider for &'_ HttpBackend { + fn provide_token(&self) -> Option<&str> { + Some(self.inner.token.expose_secret()) + } +} + +impl super::Backend for HttpBackend { + async fn get( + &self, + method: Method, + params: impl IntoIterator, + ) -> Result { + let mut url = self.base().join(&format!("/api/v1/{}", method.kebab()))?; + let mut qp = url.query_pairs_mut(); + for (name, value) in params { + qp.append_pair(&name, &value.to_string()); + } + drop(qp); + + ::get(&self.inner.client, self, url).await + } + + async fn post( + &self, + method: Method, + body: &Q, + ) -> Result { + self.post_path(&format!("/api/v1/{}", method.kebab()), body).await + } + + fn base(&self) -> &Url { + &self.inner.base + } + + fn workspace(&self) -> WorkspaceUuid { + self.inner.workspace + } +} diff --git a/src/services/transactor/backend/mod.rs b/src/services/transactor/backend/mod.rs new file mode 100644 index 0000000..9680a31 --- /dev/null +++ b/src/services/transactor/backend/mod.rs @@ -0,0 +1,26 @@ +use crate::Result; +use crate::services::transactor::methods::Method; +use serde::Serialize; +use serde::de::DeserializeOwned; +use serde_json::Value; +use url::Url; + +pub mod http; +pub mod ws; + +#[allow(async_fn_in_trait)] +pub trait Backend { + async fn get( + &self, + method: Method, + params: impl IntoIterator, + ) -> Result; + + async fn post( + &self, + method: Method, + body: &Q, + ) -> Result; + + fn base(&self) -> &Url; +} diff --git a/src/services/transactor/backend/ws.rs b/src/services/transactor/backend/ws.rs new file mode 100644 index 0000000..b7ddccb --- /dev/null +++ b/src/services/transactor/backend/ws.rs @@ -0,0 +1,343 @@ +use crate::services::Status; +use crate::services::rpc::util::OkResponse; +use crate::services::rpc::{HelloRequest, HelloResponse, ReqId, Request, Response}; +use crate::services::transactor::backend::Backend; +use crate::services::transactor::methods::Method; +use crate::{Error, Result}; +use bytes::Bytes; +use futures::stream::{SplitSink, SplitStream}; +use futures::{SinkExt, StreamExt}; +use reqwest::Client; +use reqwest_websocket::{Message, RequestBuilderExt, WebSocket}; +use serde::Serialize; +use serde::de::DeserializeOwned; +use serde_json::Value; +use std::collections::HashMap; +use std::fmt::Debug; +use std::sync::atomic::{AtomicI32, Ordering}; +use std::time::Duration; +use tokio::sync::mpsc::{self, UnboundedSender}; +use tokio::sync::oneshot; +use tokio::task::JoinHandle; +use tokio_with_wasm::alias as tokio; +use tracing::{error, trace, warn}; +use url::Url; +#[cfg(target_family = "wasm")] +pub use wasmtimer::{std::Instant, tokio::sleep, tokio::timeout}; +#[cfg(not(target_family = "wasm"))] +use {std::time::Instant, tokio::time::sleep, tokio::time::timeout}; + +const PONG: &str = "pong!"; + +enum Command { + Call { + payload: Value, + reply_tx: oneshot::Sender, Status>>, + }, + Close, +} + +async fn socket_task( + mut write: SplitSink, + mut read: SplitStream, + mut cmd_rx: mpsc::UnboundedReceiver, + opts: WsBackendOpts, + hello_tx: oneshot::Sender>, +) -> Result<()> { + let mut pending = + HashMap::, Status>>>::new(); + let mut binary_mode = opts.binary; + let mut use_compression = opts.compression; + let next_id = AtomicI32::new(1); + + let hello = HelloRequest { + request: Request { + id: Some(ReqId::Num(-1)), + method: Method::Hello.camel().to_string(), + params: Vec::new(), + time: None, + }, + binary: Some(binary_mode), + compression: Some(use_compression), + }; + trace!(target: "ws", ?hello, "sending HELLO"); + write.send(encode_message(&hello, binary_mode)?).await?; + + let mut hello_tx = Some(hello_tx); + loop { + tokio::select! { + Some(cmd) = cmd_rx.recv() => match cmd { + Command::Call { mut payload, reply_tx } => { + let id = next_id.fetch_add(1, Ordering::Relaxed); + payload["id"] = Value::Number(id.into()); + + pending.insert(id.into(), reply_tx); + write.send(encode_message(&payload, binary_mode)?).await?; + } + Command::Close => break, + }, + + Some(message) = read.next() => { + trace!(target: "ws", ?message, "Got message"); + + let response: Response; + let payload: Bytes; + match message? { + Message::Text(resp) => { + // Ping responses don't follow the same structure + if resp == PONG { + response = Response { + result: Some(Value::String(PONG.to_string())), + ..Default::default() + } + } else { + response = serde_json::from_str(&resp)?; + } + + payload = resp.into(); + }, + Message::Binary(resp) => { + if resp == PONG.as_bytes() { + response = Response { + result: Some(Value::String(PONG.to_string())), + ..Default::default() + } + } else { + response = serde_json::from_slice(&resp)?; + } + + payload = resp; + }, + Message::Ping(payload) => { + trace!(target: "ws", ?payload, "Received ping, replying..."); + write.send(encode_message(&Method::Ping.camel(), binary_mode)?).await?; + continue; + }, + Message::Close { .. } => break, + _ => continue, + } + + if response.result.as_ref().is_some_and(|v| v == "ping") { + trace!(target: "ws", ?payload, "Received ping, replying..."); + write.send(encode_message(&Method::Ping.camel(), binary_mode)?).await?; + continue; + } + + if matches!(response.id, Some(ReqId::Num(-1))) { + if response.result.is_none() && response.error.is_some() { + let result = response.into_result(); + error!(target: "ws", ?result); + continue; + } + + if response.result.is_some_and(|result| result == "hello") { + // Just ignore any extra HELLOs + let Some(hello_tx) = hello_tx.take() else { + continue; + }; + + let hello = serde_json::from_slice::(&payload)?; + binary_mode = hello.binary; + use_compression = hello.use_compression.unwrap_or(false); + let _ = hello_tx.send(Ok(())); + continue; + } + + continue; + } + + trace!(target: "ws", ?response, "Full response"); + if let Some(id) = &response.id { + if let Some(tx) = pending.remove(id) { + let _ = tx.send(response.into_result()).ok(); + continue; + } + } + } + } + } + + Ok(()) +} + +async fn ping_task(cmd_tx: UnboundedSender) -> Result<()> { + const PING_TIMEOUT: Duration = Duration::from_secs(10); + const HANG_TIMEOUT: Duration = Duration::from_secs(60 * 5); + + let mut last_ping_response = None; + + loop { + sleep(PING_TIMEOUT).await; + + let Some(ping_response_time) = last_ping_response.take() else { + trace!(target: "ws", "Pinging server"); + + let payload = Request { + id: None, + method: Method::Ping.camel().to_string(), + params: Vec::<()>::new(), + time: None, + }; + + let _response: Value = send_and_wait(&cmd_tx, payload).await?; + last_ping_response = Some(Instant::now()); + continue; + }; + + if ping_response_time.elapsed() > HANG_TIMEOUT { + error!("No ping response from server, closing socket"); + } + + last_ping_response = None; + } +} + +#[derive(Copy, Clone, PartialEq, Eq, Debug)] +pub struct WsBackendOpts { + pub binary: bool, + pub compression: bool, + /// How long to wait for the server's HELLO response before timing out + pub hello_timeout: Duration, +} + +impl Default for WsBackendOpts { + fn default() -> Self { + Self { + binary: false, + compression: false, + hello_timeout: Duration::from_secs(10), + } + } +} + +pub struct WsBackend { + cmd_tx: UnboundedSender, + base: Url, + _handle: JoinHandle<()>, +} + +impl WsBackend { + pub(in crate::services::transactor) async fn connect( + base: Url, + token: &str, + opts: WsBackendOpts, + ) -> Result { + let url = base.join(token)?; + let resp = Client::default() + .get(url) + .bearer_auth(token) + .upgrade() + .send() + .await?; + let ws = resp.into_websocket().await?; + + let (write, read) = ws.split(); + let (hello_tx, hello_rx) = oneshot::channel(); + + let (cmd_tx, cmd_rx) = mpsc::unbounded_channel::(); + let socket_handle = async move { + if let Err(e) = socket_task(write, read, cmd_rx, opts, hello_tx).await { + warn!(target:"ws", ?e, "socket task crashed"); + } + }; + + let cmd_tx2 = cmd_tx.clone(); + let ping_handle = async move { + if let Err(e) = ping_task(cmd_tx2).await { + warn!(target:"ws", ?e, "ping task ended"); + } + }; + + let handle = tokio::task::spawn(async move { + tokio::select! { + _ = socket_handle => {}, + _ = ping_handle => {}, + } + }); + + match timeout(opts.hello_timeout, hello_rx).await { + Ok(Ok(Ok(()))) => {} + Ok(Ok(Err(e))) => return Err(e), + Err(_) => return Err(Error::Other("timed out waiting for HELLO")), + _ => return Err(Error::Other("HELLO channel closed unexpectedly")), + } + + Ok(Self { + base, + cmd_tx, + _handle: handle, + }) + } +} + +fn encode_message(value: &Q, binary_mode: bool) -> Result { + if binary_mode { + Ok(Message::Binary(serde_json::to_vec(value)?.into())) + } else { + Ok(Message::Text(serde_json::to_string(value)?)) + } +} + +impl Backend for WsBackend { + async fn get( + &self, + method: Method, + params: impl IntoIterator, + ) -> Result { + let param_values = params.into_iter().map(|(_k, v)| v).collect::>(); + + let payload = Request { + id: None, + method: method.camel().to_string(), + params: param_values, + time: None, + }; + + send_and_wait(&self.cmd_tx, payload).await + } + + async fn post( + &self, + method: Method, + body: &Q, + ) -> Result { + let Value::Object(body_json) = serde_json::to_value(body)? else { + return Err(Error::Other("Expected a JSON object")); + }; + + let payload = Request { + id: None, + method: method.camel().to_string(), + params: body_json.values().collect(), + time: None, + }; + + send_and_wait(&self.cmd_tx, payload).await + } + + fn base(&self) -> &Url { + &self.base + } +} + +async fn send_and_wait( + cmd_tx: &UnboundedSender, + payload: Request, +) -> Result { + let payload = serde_json::to_value(&payload)?; + trace!(target: "ws", %payload, "Sending message"); + + let (reply_tx, reply_rx) = oneshot::channel(); + cmd_tx.send(Command::Call { payload, reply_tx }).ok(); + + let Ok(reply) = reply_rx.await else { + return Err(Error::Other("connection closed before reply")); + }; + + let reply = reply?; + let Some(result) = reply.result else { + return Err(Error::Other("server didn't return a result")); + }; + + serde_json::from_value(result).map_err(|e| e.into()) +} diff --git a/src/services/transactor/document.rs b/src/services/transactor/document.rs index 11589f9..85bfe1e 100644 --- a/src/services/transactor/document.rs +++ b/src/services/transactor/document.rs @@ -27,10 +27,9 @@ use super::{ use crate::services::core::ser::Data; use crate::services::core::{Account, PersonId, Ref, Timestamp}; -use crate::{ - Error, Result, - services::{HttpClient, JsonClient}, -}; +use crate::services::transactor::backend::Backend; +use crate::services::transactor::methods::Method; +use crate::{Error, Result}; static COUNT: AtomicUsize = AtomicUsize::new(0); static RANDOM: LazyLock = LazyLock::new(|| { @@ -202,14 +201,6 @@ impl FindOptionsBuilder { } } -#[derive(Debug, Serialize, Deserialize)] -#[serde(rename_all = "camelCase")] -pub struct FindResult { - pub data_type: String, - pub total: i64, - pub value: Vec, -} - pub trait DocumentClient { fn get_account(&self) -> impl Future>; @@ -228,12 +219,9 @@ pub trait DocumentClient { ) -> impl Future>>; } -impl DocumentClient for super::TransactorClient { +impl DocumentClient for super::TransactorClient { async fn get_account(&self) -> Result { - let path = format!("/api/v1/account/{}", self.workspace); - let url = self.base.join(&path)?; - - ::get(&self.http, self, url).await + self.get(Method::Account, []).await } async fn find_all( @@ -242,9 +230,6 @@ impl DocumentClient for super::TransactorClient { query: Q, options: &FindOptions, ) -> Result> { - let path = format!("/api/v1/find-all/{}", self.workspace); - let mut url = self.base.join(&path)?; - let query = json::to_value(query)?; if !query.is_object() { @@ -253,13 +238,16 @@ impl DocumentClient for super::TransactorClient { let query = query.as_object().unwrap(); - url.query_pairs_mut() - .append_pair("class", class) - .append_pair("query", &json::to_string(&query)?) - .append_pair("options", &json::to_string(&options)?); - - let mut result: FindResult = - ::get(&self.http, self, url).await?; + let mut result: FindResult = self + .get( + Method::FindAll, + [ + ("class", class.into()), + ("query", json::to_value(&query)?), + ("options", json::to_value(&options)?), + ], + ) + .await?; // TODO? /* api-client/src/rest.ts @@ -294,7 +282,6 @@ impl DocumentClient for super::TransactorClient { } let result = FindResult { - data_type: result.data_type, total: result.total, value: { let mut value = Vec::new(); @@ -305,6 +292,20 @@ impl DocumentClient for super::TransactorClient { value }, + lookup_map: match result.lookup_map { + Some(lookup_map) => { + let new_map = lookup_map + .into_iter() + .map(|(k, v)| match json::from_value(v) { + Ok(val) => Ok((k, val)), + Err(e) => Err(e.into()), + }) + .collect::>()?; + + Some(new_map) + } + None => None, + }, }; Ok(result) @@ -330,4 +331,11 @@ impl DocumentClient for super::TransactorClient { .into_iter() .next()) } + + async fn tx(&self, tx: T) -> Result + where + T: Transaction, + { + self.post(Method::Tx, &tx.transaction()).await + } } diff --git a/src/services/transactor/methods.rs b/src/services/transactor/methods.rs new file mode 100644 index 0000000..fb7d2fc --- /dev/null +++ b/src/services/transactor/methods.rs @@ -0,0 +1,30 @@ +macro_rules! api_methods { + ($($Variant:ident: $kebab:literal, $camel:literal),+ $(,)?) => { + #[derive(Debug, Clone, Copy)] + pub enum Method { $($Variant),+ } + + impl Method { + pub const fn kebab(self) -> &'static str { + match self { + $( Self::$Variant => $kebab ),+ + } + } + + pub const fn camel(self) -> &'static str { + match self { + $( Self::$Variant => $camel ),+ + } + } + } + }; +} + +api_methods!( + Account: "account", "account", + FindAll: "find-all", "findAll", + EnsurePerson: "ensure-person", "ensurePerson", + Tx: "tx", "tx", + Event: "event", "event", + Ping: "ping", "ping", + Hello: "hello", "hello", +); diff --git a/src/services/transactor/mod.rs b/src/services/transactor/mod.rs index 2774e53..7119edb 100644 --- a/src/services/transactor/mod.rs +++ b/src/services/transactor/mod.rs @@ -12,23 +12,27 @@ // limitations under the License. // +use crate::Result; +use crate::services::core::WorkspaceUuid; +use crate::services::transactor::backend::Backend; +use crate::services::transactor::backend::http::{HttpBackend, HttpClient}; +use crate::services::transactor::backend::ws::{WsBackend, WsBackendOpts}; +use crate::services::transactor::methods::Method; +use crate::services::{ForceScheme, JsonClient}; use secrecy::{ExposeSecret, SecretString}; use serde::{Serialize, de::DeserializeOwned}; use serde_json::Value; -use tracing::*; +use subscription::SubscribedQuery; use url::Url; +pub mod backend; pub mod comm; pub mod document; +pub mod methods; pub mod person; +mod subscription; pub mod tx; -use crate::services::core::WorkspaceUuid; -use crate::{ - Result, - services::{ForceHttpScheme, HttpClient, JsonClient}, -}; - pub trait Transaction { fn to_value(self) -> crate::Result; } @@ -36,7 +40,6 @@ pub trait Transaction { pub trait TransactionValue { fn matches(&self, class: Option<&str>, domain: Option<&str>) -> bool; } - impl TransactionValue for Value { fn matches(&self, class: Option<&str>, domain: Option<&str>) -> bool { let class = class.is_none() || self["_class"].as_str() == class; @@ -47,28 +50,55 @@ impl TransactionValue for Value { } #[derive(Clone)] -pub struct TransactorClient { - pub workspace: WorkspaceUuid, - pub base: Url, - token: SecretString, - http: HttpClient, +pub struct TransactorClient { + backend: B, } -impl PartialEq for TransactorClient { +impl PartialEq for TransactorClient { fn eq(&self, other: &Self) -> bool { - self.workspace == other.workspace - && self.token.expose_secret() == other.token.expose_secret() - && self.base == other.base + self.backend.workspace() == other.backend.workspace() + && self.backend.provide_token() == other.backend.provide_token() + && self.base() == other.base() } } -impl super::TokenProvider for &TransactorClient { +impl super::TokenProvider for &TransactorClient { fn provide_token(&self) -> Option<&str> { - Some(self.token.expose_secret()) + self.backend.provide_token() + } +} + +impl TransactorClient { + pub fn base(&self) -> &Url { + self.backend.base() + } + + pub fn workspace(&self) -> WorkspaceUuid { + self.backend.workspace() + } + + pub async fn get( + &self, + method: Method, + params: impl IntoIterator + Send, + ) -> Result { + self.backend.get(method, params).await + } + + pub async fn post( + &self, + method: Method, + body: &Q, + ) -> Result { + self.backend.post(method, body).await + } + + pub(in crate::services::transactor) fn backend(&self) -> &B { + &self.backend } } -impl TransactorClient { +impl TransactorClient { pub fn new( http: HttpClient, base: Url, @@ -77,18 +107,17 @@ impl TransactorClient { ) -> Result { let base = base.force_http_scheme(); Ok(Self { - workspace, - http, - base, - token: token.into(), + backend: HttpBackend::new(http, base, workspace, token), }) } +} +impl TransactorClient { pub async fn tx_raw(&self, tx: T) -> Result { - let path = format!("/api/v1/tx/{}", self.workspace); - let url = self.base.join(&path)?; + let path = format!("/api/v1/tx/{}", self.workspace()); + let url = self.base().join(&path)?; - ::post(&self.http, self, url, &tx).await + ::post(self.backend(), &self, url, &tx).await } pub async fn tx(&self, tx: T) -> Result { @@ -96,6 +125,37 @@ impl TransactorClient { } } +impl TransactorClient { + pub async fn new_ws( + base: Url, + workspace: WorkspaceUuid, + token: impl Into, + opts: WsBackendOpts, + ) -> Result { + let base = base.force_ws_scheme(); + let token = token.into(); + let backend = WsBackend::connect(base, workspace, token.expose_secret(), opts).await?; + + Ok(Self { backend }) + } + + pub async fn subscribe( + &self, + ) -> SubscribedQuery { + SubscribedQuery::new(self.clone()) + } +} + +impl TransactorClient { + pub async fn tx_raw(&self, tx: T) -> Result { + self.backend.post(Method::Tx, &tx).await + } + + pub async fn tx(&self, tx: T) -> Result { + self.tx_raw(tx.to_value()?).await + } +} + #[cfg(feature = "kafka")] pub mod kafka { use super::*; @@ -108,6 +168,7 @@ pub mod kafka { }; use serde_json::{self as json, Value}; use std::time::Duration; + use tracing::{debug, warn}; use uuid::Uuid; pub struct KafkaProducer { diff --git a/src/services/transactor/person.rs b/src/services/transactor/person.rs index 788a373..c64abee 100644 --- a/src/services/transactor/person.rs +++ b/src/services/transactor/person.rs @@ -16,10 +16,9 @@ use serde::{Deserialize, Serialize}; use crate::services::core::{PersonId, PersonUuid}; -use crate::{ - Result, - services::{HttpClient, JsonClient, core::SocialIdType}, -}; +use crate::services::transactor::backend::Backend; +use crate::services::transactor::methods::Method; +use crate::{Result, services::core::SocialIdType}; #[derive(Serialize, Debug, derive_builder::Builder)] #[serde(rename_all = "camelCase")] @@ -49,11 +48,8 @@ pub trait EnsurePerson { ) -> impl Future>; } -impl EnsurePerson for super::TransactorClient { +impl EnsurePerson for super::TransactorClient { async fn ensure_person(&self, request: &EnsurePersonRequest) -> Result { - let path = format!("/api/v1/ensure-person/{}", self.workspace); - let url = self.base.join(&path)?; - - ::post(&self.http, self, url, request).await + self.post(Method::EnsurePerson, request).await } } From 4cfd14dccc31e9f4f2134358239146b192e5a296 Mon Sep 17 00:00:00 2001 From: Serial <69764315+Serial-ATA@users.noreply.github.com> Date: Tue, 8 Jul 2025 08:42:10 -0400 Subject: [PATCH 2/8] feat(ws): event subscriptions --- src/lib.rs | 3 + src/services/transactor/backend/mod.rs | 6 +- src/services/transactor/backend/ws.rs | 86 ++++++++++++--- src/services/transactor/subscription.rs | 139 ++++++++++++++++++++++++ 4 files changed, 218 insertions(+), 16 deletions(-) create mode 100644 src/services/transactor/subscription.rs diff --git a/src/lib.rs b/src/lib.rs index 19f4952..dadfb96 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -41,6 +41,9 @@ pub enum Error { #[cfg(feature = "kafka")] #[error(transparent)] Kafka(#[from] rdkafka::error::KafkaError), + + #[error("Subscription task panicked")] + SubscriptionFailed, #[error(transparent)] Url(#[from] url::ParseError), diff --git a/src/services/transactor/backend/mod.rs b/src/services/transactor/backend/mod.rs index 9680a31..2d30ee9 100644 --- a/src/services/transactor/backend/mod.rs +++ b/src/services/transactor/backend/mod.rs @@ -1,4 +1,6 @@ use crate::Result; +use crate::services::TokenProvider; +use crate::services::core::WorkspaceUuid; use crate::services::transactor::methods::Method; use serde::Serialize; use serde::de::DeserializeOwned; @@ -9,7 +11,7 @@ pub mod http; pub mod ws; #[allow(async_fn_in_trait)] -pub trait Backend { +pub trait Backend: Clone + TokenProvider { async fn get( &self, method: Method, @@ -23,4 +25,6 @@ pub trait Backend { ) -> Result; fn base(&self) -> &Url; + + fn workspace(&self) -> WorkspaceUuid; } diff --git a/src/services/transactor/backend/ws.rs b/src/services/transactor/backend/ws.rs index b7ddccb..d49d0d0 100644 --- a/src/services/transactor/backend/ws.rs +++ b/src/services/transactor/backend/ws.rs @@ -1,23 +1,27 @@ -use crate::services::Status; +use crate::services::core::WorkspaceUuid; +use crate::services::core::tx::Tx; use crate::services::rpc::util::OkResponse; use crate::services::rpc::{HelloRequest, HelloResponse, ReqId, Request, Response}; use crate::services::transactor::backend::Backend; use crate::services::transactor::methods::Method; +use crate::services::{Status, TokenProvider}; use crate::{Error, Result}; use bytes::Bytes; use futures::stream::{SplitSink, SplitStream}; use futures::{SinkExt, StreamExt}; use reqwest::Client; use reqwest_websocket::{Message, RequestBuilderExt, WebSocket}; +use secrecy::{ExposeSecret, SecretString}; use serde::Serialize; use serde::de::DeserializeOwned; use serde_json::Value; use std::collections::HashMap; use std::fmt::Debug; +use std::sync::Arc; use std::sync::atomic::{AtomicI32, Ordering}; use std::time::Duration; use tokio::sync::mpsc::{self, UnboundedSender}; -use tokio::sync::oneshot; +use tokio::sync::{broadcast, oneshot}; use tokio::task::JoinHandle; use tokio_with_wasm::alias as tokio; use tracing::{error, trace, warn}; @@ -43,6 +47,7 @@ async fn socket_task( mut cmd_rx: mpsc::UnboundedReceiver, opts: WsBackendOpts, hello_tx: oneshot::Sender>, + tx_broadcast: broadcast::Sender, ) -> Result<()> { let mut pending = HashMap::, Status>>>::new(); @@ -129,7 +134,7 @@ async fn socket_task( error!(target: "ws", ?result); continue; } - + if response.result.is_some_and(|result| result == "hello") { // Just ignore any extra HELLOs let Some(hello_tx) = hello_tx.take() else { @@ -142,7 +147,7 @@ async fn socket_task( let _ = hello_tx.send(Ok(())); continue; } - + continue; } @@ -153,6 +158,21 @@ async fn socket_task( continue; } } + + if response.id.is_none() { + if let Some(result) = response.result { + match serde_json::from_value::>(result) { + Ok(tx_array) => { + for tx in tx_array { + let _ = tx_broadcast.send(tx); + } + } + Err(e) => { + warn!(target: "ws", "Failed to deserialize transaction array: {}", e); + } + } + } + } } } } @@ -210,22 +230,34 @@ impl Default for WsBackendOpts { } } -pub struct WsBackend { +struct WsBackendInner { + workspace: WorkspaceUuid, + token: SecretString, + cmd_tx: UnboundedSender, base: Url, + tx_broadcast: broadcast::Sender, _handle: JoinHandle<()>, } +#[derive(Clone)] +pub struct WsBackend { + inner: Arc, +} + impl WsBackend { pub(in crate::services::transactor) async fn connect( base: Url, - token: &str, + workspace: WorkspaceUuid, + token: impl Into, opts: WsBackendOpts, ) -> Result { - let url = base.join(token)?; + let token = token.into(); + + let url = base.join(token.expose_secret())?; let resp = Client::default() .get(url) - .bearer_auth(token) + .bearer_auth(token.expose_secret()) .upgrade() .send() .await?; @@ -234,9 +266,14 @@ impl WsBackend { let (write, read) = ws.split(); let (hello_tx, hello_rx) = oneshot::channel(); + let (tx_broadcast, _) = broadcast::channel::(128); + + let tx_broadcast_clone = tx_broadcast.clone(); let (cmd_tx, cmd_rx) = mpsc::unbounded_channel::(); let socket_handle = async move { - if let Err(e) = socket_task(write, read, cmd_rx, opts, hello_tx).await { + if let Err(e) = + socket_task(write, read, cmd_rx, opts, hello_tx, tx_broadcast_clone).await + { warn!(target:"ws", ?e, "socket task crashed"); } }; @@ -263,11 +300,20 @@ impl WsBackend { } Ok(Self { - base, - cmd_tx, - _handle: handle, + inner: Arc::new(WsBackendInner { + workspace, + base, + cmd_tx, + tx_broadcast, + _handle: handle, + token, + }), }) } + + pub(in crate::services::transactor) fn tx_stream(&self) -> broadcast::Receiver { + self.inner.tx_broadcast.subscribe() + } } fn encode_message(value: &Q, binary_mode: bool) -> Result { @@ -278,6 +324,12 @@ fn encode_message(value: &Q, binary_mode: bool) -> Result } } +impl TokenProvider for WsBackend { + fn provide_token(&self) -> Option<&str> { + Some(self.inner.token.expose_secret()) + } +} + impl Backend for WsBackend { async fn get( &self, @@ -293,7 +345,7 @@ impl Backend for WsBackend { time: None, }; - send_and_wait(&self.cmd_tx, payload).await + send_and_wait(&self.inner.cmd_tx, payload).await } async fn post( @@ -312,11 +364,15 @@ impl Backend for WsBackend { time: None, }; - send_and_wait(&self.cmd_tx, payload).await + send_and_wait(&self.inner.cmd_tx, payload).await } fn base(&self) -> &Url { - &self.base + &self.inner.base + } + + fn workspace(&self) -> WorkspaceUuid { + self.inner.workspace } } diff --git a/src/services/transactor/subscription.rs b/src/services/transactor/subscription.rs new file mode 100644 index 0000000..bc3930c --- /dev/null +++ b/src/services/transactor/subscription.rs @@ -0,0 +1,139 @@ +use crate::{Error, Result}; +use crate::services::core::FindResult; +use crate::services::core::tx::Tx; +use crate::services::transactor::TransactorClient; +use crate::services::transactor::backend::ws::WsBackend; +use crate::services::transactor::document::FindOptions; +use crate::services::transactor::methods::Method; +use futures::Stream; +use serde::Serialize; +use serde::de::DeserializeOwned; +use std::collections::VecDeque; +use std::pin::Pin; +use std::task::{Context, Poll}; +use tokio_with_wasm::alias::sync::broadcast::Receiver; +use tokio_with_wasm::alias::sync::broadcast::error::TryRecvError; +use tokio_with_wasm::alias::task::{self, JoinHandle}; + +enum SubscriptionState { + Initial, + Fetching(JoinHandle>>), + Draining, + Waiting, +} + +pub struct SubscribedQuery { + class: String, + query: Q, + options: FindOptions, + client: TransactorClient, + + state: SubscriptionState, + items: VecDeque, + tx_rx: Receiver, +} + +impl SubscribedQuery { + pub fn new( + client: TransactorClient, + class: &str, + query: Q, + options: FindOptions, + ) -> Self { + let tx_rx = client.backend().tx_stream(); + + Self { + client, + class: class.to_string(), + query, + options, + state: SubscriptionState::Initial, + items: VecDeque::new(), + tx_rx, + } + } +} + +impl< + Q: Serialize + Clone + Unpin + Send + Sync + 'static, + T: DeserializeOwned + Send + Unpin + 'static, +> Stream for SubscribedQuery +{ + type Item = Result; + + fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + loop { + match self.state { + SubscriptionState::Initial => { + let client = self.client.clone(); + let class = self.class.clone(); + let query = self.query.clone(); + let options = self.options.clone(); + + let handle = task::spawn(async move { + client + .get( + Method::FindAll, + vec![ + ("class", class.into()), + ("query", serde_json::to_value(query)?), + ("options", serde_json::to_value(options)?), + ], + ) + .await + }); + + self.state = SubscriptionState::Fetching(handle); + } + SubscriptionState::Fetching(ref mut handle) => { + match Pin::new(handle).poll(cx) { + Poll::Ready(Ok(Ok(find_result))) => { + self.items = find_result.value.into(); + self.state = SubscriptionState::Draining; + continue; + } + Poll::Ready(Ok(Err(e))) => { + self.state = SubscriptionState::Waiting; + return Poll::Ready(Some(Err(e))); + } + // Task panic + Poll::Ready(Err(_join_err)) => { + self.state = SubscriptionState::Waiting; + return Poll::Ready(Some(Err(Error::SubscriptionFailed))); + } + Poll::Pending => { + return Poll::Pending; + } + } + } + SubscriptionState::Draining => { + let Some(item) = self.items.pop_front() else { + self.state = SubscriptionState::Waiting; + continue; + }; + + return Poll::Ready(Some(Ok(item))); + } + SubscriptionState::Waiting => match self.tx_rx.try_recv() { + Ok(tx) => { + if tx.parent.obj.class != self.class { + continue; + } + + self.state = SubscriptionState::Initial; + } + Err(TryRecvError::Lagged(_)) => { + self.state = SubscriptionState::Initial; + continue; + } + Err(TryRecvError::Closed) => { + return Poll::Ready(None); + } + Err(TryRecvError::Empty) => { + return Poll::Pending; + } + }, + } + } + } +} From b563af1e04877aa9ccad8a0da19127207ab9faa0 Mon Sep 17 00:00:00 2001 From: Serial <69764315+Serial-ATA@users.noreply.github.com> Date: Tue, 8 Jul 2025 09:48:30 -0400 Subject: [PATCH 3/8] fix: errors from rebase --- src/lib.rs | 2 +- src/services/mod.rs | 95 ++++++++++++++++++++----- src/services/transactor/backend/mod.rs | 4 +- src/services/transactor/backend/ws.rs | 19 +++-- src/services/transactor/comm/mod.rs | 15 ++-- src/services/transactor/document.rs | 15 ++-- src/services/transactor/subscription.rs | 21 +++--- src/services/transactor/tx.rs | 68 ++++++++++++++++-- 8 files changed, 177 insertions(+), 62 deletions(-) diff --git a/src/lib.rs b/src/lib.rs index dadfb96..85af7ba 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -41,7 +41,7 @@ pub enum Error { #[cfg(feature = "kafka")] #[error(transparent)] Kafka(#[from] rdkafka::error::KafkaError), - + #[error("Subscription task panicked")] SubscriptionFailed, diff --git a/src/services/mod.rs b/src/services/mod.rs index a71a989..97b307d 100644 --- a/src/services/mod.rs +++ b/src/services/mod.rs @@ -28,20 +28,28 @@ use std::time::Duration; use reqwest::{self, Response, Url}; use reqwest::{StatusCode, header::HeaderValue}; use reqwest_middleware::ClientBuilder; -use reqwest_retry::{ - RetryTransientMiddleware, Retryable, RetryableStrategy, default_on_request_failure, - policies::ExponentialBackoff, -}; use secrecy::SecretString; use serde::{Deserialize, Serialize, de::DeserializeOwned}; use serde_json::{self as json, Value}; use tracing::*; use crate::services::core::{AccountUuid, WorkspaceUuid}; +use crate::services::transactor::backend::http::HttpBackend; +use crate::services::transactor::backend::ws::{WsBackend, WsBackendOpts}; +use crate::{Error, Result, config::Config}; +use account::AccountClient; +use jwt::Claims; +use kvs::KvsClient; +use transactor::TransactorClient; + #[cfg(feature = "kafka")] use crate::services::transactor::kafka; -use crate::{Error, Result, config::Config}; -use {account::AccountClient, jwt::Claims, kvs::KvsClient, transactor::TransactorClient}; + +#[cfg(feature = "reqwest_middleware")] +use reqwest_retry::{ + RetryTransientMiddleware, Retryable, RetryableStrategy, default_on_request_failure, + policies::ExponentialBackoff, +}; pub trait RequestBuilderExt { fn send_ext(self) -> impl Future>; @@ -55,11 +63,12 @@ pub trait BasePathProvider { fn provide_base_path(&self) -> &Url; } -pub trait ForceHttpScheme { +pub trait ForceScheme { fn force_http_scheme(self) -> Url; + fn force_ws_scheme(self) -> Url; } -impl ForceHttpScheme for Url { +impl ForceScheme for Url { fn force_http_scheme(mut self) -> Url { match self.scheme() { "ws" => { @@ -75,6 +84,24 @@ impl ForceHttpScheme for Url { self } + + fn force_ws_scheme(mut self) -> Url { + match self.scheme() { + "http" => { + self.set_scheme("ws").unwrap(); + } + + "https" => { + self.set_scheme("wss").unwrap(); + } + + "ws" | "wss" => {} + + _ => panic!(), + }; + + self + } } impl RequestBuilderExt for RequestBuilder { @@ -117,13 +144,13 @@ fn from_value(value: Value) -> Result { pub trait JsonClient { fn get( &self, - user: U, + user: &U, url: Url, ) -> impl Future>; fn post( &self, - user: U, + user: &U, url: Url, body: &Q, ) -> impl Future>; @@ -135,7 +162,7 @@ impl JsonClient for HttpClient { skip(self, user, url), fields(%url, method = "get", type = "json") )] - async fn get(&self, user: U, url: Url) -> Result { + async fn get(&self, user: &U, url: Url) -> Result { trace!("request"); let mut request = self.get(url.clone()); @@ -149,7 +176,7 @@ impl JsonClient for HttpClient { async fn post( &self, - user: U, + user: &U, url: Url, body: &Q, ) -> Result { @@ -171,7 +198,7 @@ impl JsonClient for HttpClient { } } -#[derive(Deserialize, Debug, Clone, strum::Display)] +#[derive(Serialize, Deserialize, Debug, Clone, strum::Display)] #[serde(rename_all = "UPPERCASE")] pub enum Severity { Ok, @@ -180,11 +207,11 @@ pub enum Severity { Error, } -#[derive(Deserialize, Debug, Clone, thiserror::Error)] +#[derive(Serialize, Deserialize, Debug, Clone, thiserror::Error)] pub struct Status { pub severity: Severity, pub code: String, - pub params: HashMap, + pub params: HashMap, } impl std::fmt::Display for Status { @@ -432,7 +459,11 @@ impl ServiceFactory { ) } - pub fn new_transactor_client(&self, base: Url, claims: &Claims) -> Result { + pub fn new_transactor_client( + &self, + base: Url, + claims: &Claims, + ) -> Result> { TransactorClient::new( self.transactor_http.clone(), base, @@ -446,15 +477,45 @@ impl ServiceFactory { ) } + pub async fn new_transactor_client_ws( + &self, + base: Url, + claims: &Claims, + opts: WsBackendOpts, + ) -> Result> { + TransactorClient::new_ws( + base, + claims.workspace()?, + claims.encode( + self.config + .token_secret + .as_ref() + .ok_or(Error::Other("NoSecret"))?, + )?, + opts, + ) + .await + } + pub fn new_transactor_client_from_token( &self, base: Url, workspace: WorkspaceUuid, token: impl Into, - ) -> Result { + ) -> Result> { TransactorClient::new(self.transactor_http.clone(), base, workspace, token) } + pub async fn new_transactor_client_ws_from_token( + &self, + base: Url, + workspace: WorkspaceUuid, + token: impl Into, + opts: WsBackendOpts, + ) -> Result> { + TransactorClient::new_ws(base, workspace, token, opts).await + } + #[cfg(feature = "kafka")] pub fn new_kafka_publisher(&self, topic: &str) -> Result { kafka::KafkaProducer::new(&self.config, topic) diff --git a/src/services/transactor/backend/mod.rs b/src/services/transactor/backend/mod.rs index 2d30ee9..19ae68d 100644 --- a/src/services/transactor/backend/mod.rs +++ b/src/services/transactor/backend/mod.rs @@ -11,11 +11,11 @@ pub mod http; pub mod ws; #[allow(async_fn_in_trait)] -pub trait Backend: Clone + TokenProvider { +pub trait Backend: Clone + TokenProvider + 'static { async fn get( &self, method: Method, - params: impl IntoIterator, + params: impl IntoIterator, ) -> Result; async fn post( diff --git a/src/services/transactor/backend/ws.rs b/src/services/transactor/backend/ws.rs index d49d0d0..546d93a 100644 --- a/src/services/transactor/backend/ws.rs +++ b/src/services/transactor/backend/ws.rs @@ -1,9 +1,9 @@ use crate::services::core::WorkspaceUuid; -use crate::services::core::tx::Tx; use crate::services::rpc::util::OkResponse; use crate::services::rpc::{HelloRequest, HelloResponse, ReqId, Request, Response}; use crate::services::transactor::backend::Backend; use crate::services::transactor::methods::Method; +use crate::services::transactor::tx::Tx; use crate::services::{Status, TokenProvider}; use crate::{Error, Result}; use bytes::Bytes; @@ -38,6 +38,8 @@ enum Command { payload: Value, reply_tx: oneshot::Sender, Status>>, }, + // TODO: Manual close + #[allow(dead_code)] Close, } @@ -143,7 +145,13 @@ async fn socket_task( let hello = serde_json::from_slice::(&payload)?; binary_mode = hello.binary; - use_compression = hello.use_compression.unwrap_or(false); + + // TODO: compression support + #[allow(unused_assignments)] + { + use_compression = hello.use_compression.unwrap_or(false); + } + let _ = hello_tx.send(Ok(())); continue; } @@ -152,12 +160,11 @@ async fn socket_task( } trace!(target: "ws", ?response, "Full response"); - if let Some(id) = &response.id { - if let Some(tx) = pending.remove(id) { + if let Some(id) = &response.id + && let Some(tx) = pending.remove(id) { let _ = tx.send(response.into_result()).ok(); continue; } - } if response.id.is_none() { if let Some(result) = response.result { @@ -334,7 +341,7 @@ impl Backend for WsBackend { async fn get( &self, method: Method, - params: impl IntoIterator, + params: impl IntoIterator, ) -> Result { let param_values = params.into_iter().map(|(_k, v)| v).collect::>(); diff --git a/src/services/transactor/comm/mod.rs b/src/services/transactor/comm/mod.rs index 66f94a9..7242a41 100644 --- a/src/services/transactor/comm/mod.rs +++ b/src/services/transactor/comm/mod.rs @@ -16,13 +16,12 @@ use serde::{Deserialize, Serialize, de::DeserializeOwned}; use serde_json::{self as json, Value}; +use crate::Result; use super::tx::{Doc, Obj, Tx, TxDomainEvent}; use crate::services::core::Ref; +use crate::services::JsonClient; +use crate::services::transactor::backend::http::HttpBackend; use crate::services::transactor::document::generate_object_id; -use crate::{ - Result, - services::{HttpClient, JsonClient}, -}; mod message; pub use message::*; @@ -132,11 +131,11 @@ pub trait EventClient { } } -impl EventClient for super::TransactorClient { +impl EventClient for super::TransactorClient { async fn request_raw(&self, envelope: &T) -> Result { - let path = format!("/api/v1/event/{}", self.workspace); - let url = self.base.join(&path)?; + let path = format!("/api/v1/event/{}", self.workspace()); + let url = self.base().join(&path)?; - ::post(&self.http, self, url, envelope).await + ::post(self.backend(), &self, url, envelope).await } } diff --git a/src/services/transactor/document.rs b/src/services/transactor/document.rs index 85bfe1e..f83049b 100644 --- a/src/services/transactor/document.rs +++ b/src/services/transactor/document.rs @@ -26,7 +26,7 @@ use super::{ }; use crate::services::core::ser::Data; -use crate::services::core::{Account, PersonId, Ref, Timestamp}; +use crate::services::core::{Account, FindResult, PersonId, Ref, Timestamp}; use crate::services::transactor::backend::Backend; use crate::services::transactor::methods::Method; use crate::{Error, Result}; @@ -242,9 +242,9 @@ impl DocumentClient for super::TransactorClient { .get( Method::FindAll, [ - ("class", class.into()), - ("query", json::to_value(&query)?), - ("options", json::to_value(&options)?), + (String::from("class"), class.into()), + (String::from("query"), json::to_value(query)?), + (String::from("options"), json::to_value(options)?), ], ) .await?; @@ -331,11 +331,4 @@ impl DocumentClient for super::TransactorClient { .into_iter() .next()) } - - async fn tx(&self, tx: T) -> Result - where - T: Transaction, - { - self.post(Method::Tx, &tx.transaction()).await - } } diff --git a/src/services/transactor/subscription.rs b/src/services/transactor/subscription.rs index bc3930c..4ec941a 100644 --- a/src/services/transactor/subscription.rs +++ b/src/services/transactor/subscription.rs @@ -1,13 +1,14 @@ -use crate::{Error, Result}; use crate::services::core::FindResult; -use crate::services::core::tx::Tx; use crate::services::transactor::TransactorClient; use crate::services::transactor::backend::ws::WsBackend; use crate::services::transactor::document::FindOptions; use crate::services::transactor::methods::Method; +use crate::services::transactor::tx::Tx; +use crate::{Error, Result}; use futures::Stream; use serde::Serialize; use serde::de::DeserializeOwned; +use serde_json::Value; use std::collections::VecDeque; use std::pin::Pin; use std::task::{Context, Poll}; @@ -66,18 +67,18 @@ impl< match self.state { SubscriptionState::Initial => { let client = self.client.clone(); - let class = self.class.clone(); - let query = self.query.clone(); - let options = self.options.clone(); + let class = Value::from(self.class.clone()); + let query = serde_json::to_value(self.query.clone())?; + let options = serde_json::to_value(self.options.clone())?; let handle = task::spawn(async move { client .get( Method::FindAll, - vec![ - ("class", class.into()), - ("query", serde_json::to_value(query)?), - ("options", serde_json::to_value(options)?), + [ + (String::from("class"), class), + (String::from("query"), query), + (String::from("options"), options), ], ) .await @@ -116,7 +117,7 @@ impl< } SubscriptionState::Waiting => match self.tx_rx.try_recv() { Ok(tx) => { - if tx.parent.obj.class != self.class { + if tx.doc.obj.class != self.class { continue; } diff --git a/src/services/transactor/tx.rs b/src/services/transactor/tx.rs index 2eb24e9..ebc2e6e 100644 --- a/src/services/transactor/tx.rs +++ b/src/services/transactor/tx.rs @@ -15,15 +15,19 @@ use crate::services::core::ser::Data; use crate::services::core::{PersonId, Ref, Timestamp}; -use serde::{Deserialize, Serialize}; +use crate::services::event::{Class, Event}; +use serde::de::DeserializeOwned; +use serde::{Deserialize, Deserializer, Serialize}; +use serde_json::Value; +use std::fmt::Debug; -#[derive(Deserialize, Serialize, Debug, Clone)] +#[derive(Deserialize, Serialize, Debug, Clone, Default)] pub struct Obj { #[serde(rename = "_class")] pub class: Ref, } -#[derive(Serialize, Deserialize, Debug, Clone)] +#[derive(Serialize, Deserialize, Debug, Clone, Default)] #[serde(rename_all = "camelCase")] pub struct Doc { #[serde(flatten)] @@ -48,7 +52,7 @@ pub struct Doc { pub created_on: Option, } -#[derive(Serialize, Deserialize, Debug)] +#[derive(Serialize, Deserialize, Clone, Debug)] #[serde(rename_all = "camelCase")] pub struct Tx { #[serde(flatten)] @@ -56,7 +60,7 @@ pub struct Tx { pub object_space: Ref, } -#[derive(Serialize, Deserialize, Debug)] +#[derive(Serialize, Deserialize, Clone, Debug)] #[serde(rename_all = "camelCase")] pub struct TxCUD { #[serde(flatten)] @@ -74,7 +78,7 @@ pub struct TxCUD { pub collection: Option, } -#[derive(Serialize, Debug)] +#[derive(Serialize, Deserialize, Debug)] #[serde(rename_all = "camelCase")] pub struct TxCreateDoc { #[serde(flatten)] @@ -82,6 +86,37 @@ pub struct TxCreateDoc { pub attributes: Data, } +impl Class for TxCreateDoc { + const CLASS: &'static str = crate::services::core::class::TxCreateDoc; +} + +impl Event for TxCreateDoc { + fn matches(value: &Value) -> bool { + if value.get("_class").and_then(|v| v.as_str()) != Some(Self::CLASS) { + return false; + } + value.get("objectClass").and_then(|v| v.as_str()) == Some(T::CLASS) + } +} + +impl<'de, T> Deserialize<'de> for TxCreateDoc +where + T: Serialize + DeserializeOwned, +{ + fn deserialize(deserializer: D) -> Result + where + D: Deserializer<'de>, + { + let value = Value::deserialize(deserializer)?; + + let txcud = serde_json::from_value(value.clone()).map_err(serde::de::Error::custom)?; + + let attributes = serde_json::from_value(value).map_err(serde::de::Error::custom)?; + + Ok(TxCreateDoc { txcud, attributes }) + } +} + #[derive(Serialize, Deserialize, Debug)] #[serde(rename_all = "camelCase")] pub struct TxRemoveDoc { @@ -89,13 +124,32 @@ pub struct TxRemoveDoc { pub txcud: TxCUD, } +impl Class for TxRemoveDoc { + const CLASS: &'static str = crate::services::core::class::TxRemoveDoc; +} + +impl Event for TxRemoveDoc {} + pub type OperationDomain = String; #[derive(Serialize, Deserialize, Debug)] -pub struct TxDomainEvent { +pub struct TxDomainEvent { #[serde(flatten)] pub tx: Tx, pub domain: OperationDomain, pub event: T, } + +impl Class for TxDomainEvent { + const CLASS: &'static str = crate::services::core::class::TxDomainEvent; +} + +impl Event for TxDomainEvent { + fn matches(value: &Value) -> bool { + if value.get("_class").and_then(|v| v.as_str()) != Some(Self::CLASS) { + return false; + } + value.get("objectClass").and_then(|v| v.as_str()) == Some(T::CLASS) + } +} From 4e3eb342909afd0ae480d4056e0462cdfb0c348b Mon Sep 17 00:00:00 2001 From: Serial <69764315+Serial-ATA@users.noreply.github.com> Date: Wed, 16 Jul 2025 19:12:48 -0400 Subject: [PATCH 4/8] refactor: simplify event subscriptions --- Cargo.toml | 1 + src/lib.rs | 2 + src/services/event.rs | 11 ++ src/services/mod.rs | 1 + src/services/transactor/backend/http.rs | 5 +- src/services/transactor/backend/ws.rs | 31 +++--- src/services/transactor/mod.rs | 2 +- src/services/transactor/subscription.rs | 134 +++++------------------- src/services/transactor/tx.rs | 3 +- 9 files changed, 59 insertions(+), 131 deletions(-) create mode 100644 src/services/event.rs diff --git a/Cargo.toml b/Cargo.toml index e32c0cf..d89290e 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -27,6 +27,7 @@ serde_with = "3.12.0" rand = "0.9.1" futures = "0.3.31" tokio_with_wasm = { version = "0.8.6", features = ["rt", "sync", "macros"] } +tokio-stream = { version = "0.1.17", features = ["sync"] } actix-web = { version = "4.10.2", optional = true, features = ["rustls"] } rdkafka = { version = "0.38.0", optional = true, features = [ diff --git a/src/lib.rs b/src/lib.rs index 85af7ba..f3c79d5 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -44,6 +44,8 @@ pub enum Error { #[error("Subscription task panicked")] SubscriptionFailed, + #[error("Subscription task lagged and was forcibly disconnected")] + SubscriptionLagged, #[error(transparent)] Url(#[from] url::ParseError), diff --git a/src/services/event.rs b/src/services/event.rs new file mode 100644 index 0000000..b89c072 --- /dev/null +++ b/src/services/event.rs @@ -0,0 +1,11 @@ +use serde_json::Value; + +pub trait Class { + const CLASS: &'static str; +} + +pub trait Event: Class { + fn matches(value: &Value) -> bool { + value.get("_class").and_then(|v| v.as_str()) == Some(Self::CLASS) + } +} diff --git a/src/services/mod.rs b/src/services/mod.rs index 97b307d..65a2f88 100644 --- a/src/services/mod.rs +++ b/src/services/mod.rs @@ -15,6 +15,7 @@ pub mod account; pub mod core; +pub mod event; pub mod jwt; pub mod kvs; mod rpc; diff --git a/src/services/transactor/backend/http.rs b/src/services/transactor/backend/http.rs index 5931bc0..95d0fa1 100644 --- a/src/services/transactor/backend/http.rs +++ b/src/services/transactor/backend/http.rs @@ -1,5 +1,6 @@ use crate::Result; use crate::services::core::WorkspaceUuid; +use crate::services::transactor::backend::Backend; use crate::services::transactor::methods::Method; use crate::services::{JsonClient, TokenProvider}; use reqwest_middleware::ClientWithMiddleware; @@ -9,7 +10,6 @@ use serde::de::DeserializeOwned; use serde_json::Value; use std::sync::Arc; use url::Url; -use crate::services::transactor::backend::Backend; pub type HttpClient = ClientWithMiddleware; @@ -104,7 +104,8 @@ impl super::Backend for HttpBackend { method: Method, body: &Q, ) -> Result { - self.post_path(&format!("/api/v1/{}", method.kebab()), body).await + self.post_path(&format!("/api/v1/{}", method.kebab()), body) + .await } fn base(&self) -> &Url { diff --git a/src/services/transactor/backend/ws.rs b/src/services/transactor/backend/ws.rs index 546d93a..85e508e 100644 --- a/src/services/transactor/backend/ws.rs +++ b/src/services/transactor/backend/ws.rs @@ -3,7 +3,6 @@ use crate::services::rpc::util::OkResponse; use crate::services::rpc::{HelloRequest, HelloResponse, ReqId, Request, Response}; use crate::services::transactor::backend::Backend; use crate::services::transactor::methods::Method; -use crate::services::transactor::tx::Tx; use crate::services::{Status, TokenProvider}; use crate::{Error, Result}; use bytes::Bytes; @@ -49,7 +48,7 @@ async fn socket_task( mut cmd_rx: mpsc::UnboundedReceiver, opts: WsBackendOpts, hello_tx: oneshot::Sender>, - tx_broadcast: broadcast::Sender, + tx_broadcast: broadcast::Sender, ) -> Result<()> { let mut pending = HashMap::, Status>>>::new(); @@ -166,18 +165,16 @@ async fn socket_task( continue; } - if response.id.is_none() { - if let Some(result) = response.result { - match serde_json::from_value::>(result) { - Ok(tx_array) => { - for tx in tx_array { - let _ = tx_broadcast.send(tx); - } - } - Err(e) => { - warn!(target: "ws", "Failed to deserialize transaction array: {}", e); + if let Some(result) = response.result { + match serde_json::from_value::>(result) { + Ok(tx_array) => { + for tx in tx_array { + let _ = tx_broadcast.send(tx); } } + Err(e) => { + warn!(target: "ws", "Failed to deserialize transaction array: {}", e); + } } } } @@ -243,7 +240,7 @@ struct WsBackendInner { cmd_tx: UnboundedSender, base: Url, - tx_broadcast: broadcast::Sender, + tx_broadcast: broadcast::Sender, _handle: JoinHandle<()>, } @@ -273,7 +270,7 @@ impl WsBackend { let (write, read) = ws.split(); let (hello_tx, hello_rx) = oneshot::channel(); - let (tx_broadcast, _) = broadcast::channel::(128); + let (tx_broadcast, _) = broadcast::channel::(128); let tx_broadcast_clone = tx_broadcast.clone(); let (cmd_tx, cmd_rx) = mpsc::unbounded_channel::(); @@ -318,8 +315,10 @@ impl WsBackend { }) } - pub(in crate::services::transactor) fn tx_stream(&self) -> broadcast::Receiver { - self.inner.tx_broadcast.subscribe() + pub(in crate::services::transactor) fn tx_stream( + &self, + ) -> tokio_stream::wrappers::BroadcastStream { + self.inner.tx_broadcast.subscribe().into() } } diff --git a/src/services/transactor/mod.rs b/src/services/transactor/mod.rs index 7119edb..4f3cb79 100644 --- a/src/services/transactor/mod.rs +++ b/src/services/transactor/mod.rs @@ -13,12 +13,12 @@ // use crate::Result; +use crate::services::ForceScheme; use crate::services::core::WorkspaceUuid; use crate::services::transactor::backend::Backend; use crate::services::transactor::backend::http::{HttpBackend, HttpClient}; use crate::services::transactor::backend::ws::{WsBackend, WsBackendOpts}; use crate::services::transactor::methods::Method; -use crate::services::{ForceScheme, JsonClient}; use secrecy::{ExposeSecret, SecretString}; use serde::{Serialize, de::DeserializeOwned}; use serde_json::Value; diff --git a/src/services/transactor/subscription.rs b/src/services/transactor/subscription.rs index 4ec941a..34e0f78 100644 --- a/src/services/transactor/subscription.rs +++ b/src/services/transactor/subscription.rs @@ -1,139 +1,53 @@ -use crate::services::core::FindResult; +use crate::services::event::Event; use crate::services::transactor::TransactorClient; use crate::services::transactor::backend::ws::WsBackend; -use crate::services::transactor::document::FindOptions; -use crate::services::transactor::methods::Method; -use crate::services::transactor::tx::Tx; use crate::{Error, Result}; -use futures::Stream; -use serde::Serialize; +use futures::{Stream, TryStreamExt}; use serde::de::DeserializeOwned; use serde_json::Value; -use std::collections::VecDeque; +use std::marker::PhantomData; use std::pin::Pin; use std::task::{Context, Poll}; -use tokio_with_wasm::alias::sync::broadcast::Receiver; -use tokio_with_wasm::alias::sync::broadcast::error::TryRecvError; -use tokio_with_wasm::alias::task::{self, JoinHandle}; +use tokio_stream::wrappers::BroadcastStream; +use tokio_stream::wrappers::errors::BroadcastStreamRecvError; -enum SubscriptionState { - Initial, - Fetching(JoinHandle>>), - Draining, - Waiting, +pub struct SubscribedQuery { + tx_rx: BroadcastStream, + _phantom: PhantomData, } -pub struct SubscribedQuery { - class: String, - query: Q, - options: FindOptions, - client: TransactorClient, - - state: SubscriptionState, - items: VecDeque, - tx_rx: Receiver, -} - -impl SubscribedQuery { - pub fn new( - client: TransactorClient, - class: &str, - query: Q, - options: FindOptions, - ) -> Self { +impl SubscribedQuery { + pub fn new(client: TransactorClient) -> Self { let tx_rx = client.backend().tx_stream(); Self { - client, - class: class.to_string(), - query, - options, - state: SubscriptionState::Initial, - items: VecDeque::new(), tx_rx, + _phantom: PhantomData, } } } -impl< - Q: Serialize + Clone + Unpin + Send + Sync + 'static, - T: DeserializeOwned + Send + Unpin + 'static, -> Stream for SubscribedQuery -{ +impl Stream for SubscribedQuery { type Item = Result; fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { loop { - match self.state { - SubscriptionState::Initial => { - let client = self.client.clone(); - let class = Value::from(self.class.clone()); - let query = serde_json::to_value(self.query.clone())?; - let options = serde_json::to_value(self.options.clone())?; - - let handle = task::spawn(async move { - client - .get( - Method::FindAll, - [ - (String::from("class"), class), - (String::from("query"), query), - (String::from("options"), options), - ], - ) - .await - }); + match self.tx_rx.try_poll_next_unpin(cx) { + Poll::Ready(Some(Ok(value))) => { + if T::matches(&value) { + let event = serde_json::from_value(value).map_err(Error::from); + return Poll::Ready(Some(event)); + } - self.state = SubscriptionState::Fetching(handle); + continue; } - SubscriptionState::Fetching(ref mut handle) => { - match Pin::new(handle).poll(cx) { - Poll::Ready(Ok(Ok(find_result))) => { - self.items = find_result.value.into(); - self.state = SubscriptionState::Draining; - continue; - } - Poll::Ready(Ok(Err(e))) => { - self.state = SubscriptionState::Waiting; - return Poll::Ready(Some(Err(e))); - } - // Task panic - Poll::Ready(Err(_join_err)) => { - self.state = SubscriptionState::Waiting; - return Poll::Ready(Some(Err(Error::SubscriptionFailed))); - } - Poll::Pending => { - return Poll::Pending; - } - } + Poll::Ready(Some(Err(BroadcastStreamRecvError::Lagged(_)))) => { + return Poll::Ready(Some(Err(Error::SubscriptionLagged))); } - SubscriptionState::Draining => { - let Some(item) = self.items.pop_front() else { - self.state = SubscriptionState::Waiting; - continue; - }; - - return Poll::Ready(Some(Ok(item))); + Poll::Ready(None) => return Poll::Ready(None), + Poll::Pending => { + return Poll::Pending; } - SubscriptionState::Waiting => match self.tx_rx.try_recv() { - Ok(tx) => { - if tx.doc.obj.class != self.class { - continue; - } - - self.state = SubscriptionState::Initial; - } - Err(TryRecvError::Lagged(_)) => { - self.state = SubscriptionState::Initial; - continue; - } - Err(TryRecvError::Closed) => { - return Poll::Ready(None); - } - Err(TryRecvError::Empty) => { - return Poll::Pending; - } - }, } } } diff --git a/src/services/transactor/tx.rs b/src/services/transactor/tx.rs index ebc2e6e..425668c 100644 --- a/src/services/transactor/tx.rs +++ b/src/services/transactor/tx.rs @@ -15,7 +15,6 @@ use crate::services::core::ser::Data; use crate::services::core::{PersonId, Ref, Timestamp}; -use crate::services::event::{Class, Event}; use serde::de::DeserializeOwned; use serde::{Deserialize, Deserializer, Serialize}; use serde_json::Value; @@ -78,7 +77,7 @@ pub struct TxCUD { pub collection: Option, } -#[derive(Serialize, Deserialize, Debug)] +#[derive(Serialize, Debug)] #[serde(rename_all = "camelCase")] pub struct TxCreateDoc { #[serde(flatten)] From b0076b1ba5c5da990c92f71aaac4a40d2368b8c3 Mon Sep 17 00:00:00 2001 From: Serial <69764315+Serial-ATA@users.noreply.github.com> Date: Sat, 19 Jul 2025 09:10:42 -0400 Subject: [PATCH 5/8] chore(clippy): suppress result_large_err --- Cargo.toml | 3 +++ 1 file changed, 3 insertions(+) diff --git a/Cargo.toml b/Cargo.toml index d89290e..ba9de80 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -56,3 +56,6 @@ default = ["reqwest_middleware"] actix = ["dep:actix-web"] kafka = ["dep:rdkafka"] reqwest_middleware = ["dep:reqwest-retry", "dep:reqwest-ratelimit"] + +[lints.clippy] +result_large_err = "allow" \ No newline at end of file From cd1b6888826a8b52548b44c3e58f3a4f618cde81 Mon Sep 17 00:00:00 2001 From: Serial <69764315+Serial-ATA@users.noreply.github.com> Date: Sat, 19 Jul 2025 09:57:37 -0400 Subject: [PATCH 6/8] fix: errors from rebase --- src/services/transactor/mod.rs | 2 +- src/services/transactor/tx.rs | 1 + 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/src/services/transactor/mod.rs b/src/services/transactor/mod.rs index 4f3cb79..7119edb 100644 --- a/src/services/transactor/mod.rs +++ b/src/services/transactor/mod.rs @@ -13,12 +13,12 @@ // use crate::Result; -use crate::services::ForceScheme; use crate::services::core::WorkspaceUuid; use crate::services::transactor::backend::Backend; use crate::services::transactor::backend::http::{HttpBackend, HttpClient}; use crate::services::transactor::backend::ws::{WsBackend, WsBackendOpts}; use crate::services::transactor::methods::Method; +use crate::services::{ForceScheme, JsonClient}; use secrecy::{ExposeSecret, SecretString}; use serde::{Serialize, de::DeserializeOwned}; use serde_json::Value; diff --git a/src/services/transactor/tx.rs b/src/services/transactor/tx.rs index 425668c..e282d39 100644 --- a/src/services/transactor/tx.rs +++ b/src/services/transactor/tx.rs @@ -15,6 +15,7 @@ use crate::services::core::ser::Data; use crate::services::core::{PersonId, Ref, Timestamp}; +use crate::services::event::{Class, Event}; use serde::de::DeserializeOwned; use serde::{Deserialize, Deserializer, Serialize}; use serde_json::Value; From 9fef0765b845334e189154fad4e4f6ae4a688b41 Mon Sep 17 00:00:00 2001 From: Serial <69764315+Serial-ATA@users.noreply.github.com> Date: Sat, 19 Jul 2025 11:01:40 -0400 Subject: [PATCH 7/8] chore: fmt --- src/services/transactor/comm/mod.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/services/transactor/comm/mod.rs b/src/services/transactor/comm/mod.rs index 7242a41..c98bebd 100644 --- a/src/services/transactor/comm/mod.rs +++ b/src/services/transactor/comm/mod.rs @@ -16,10 +16,10 @@ use serde::{Deserialize, Serialize, de::DeserializeOwned}; use serde_json::{self as json, Value}; -use crate::Result; use super::tx::{Doc, Obj, Tx, TxDomainEvent}; -use crate::services::core::Ref; +use crate::Result; use crate::services::JsonClient; +use crate::services::core::Ref; use crate::services::transactor::backend::http::HttpBackend; use crate::services::transactor::document::generate_object_id; From 6f182b3aad47fdf2aa2b09319f5e8aea0aafde7c Mon Sep 17 00:00:00 2001 From: Serial <69764315+Serial-ATA@users.noreply.github.com> Date: Sat, 19 Jul 2025 11:14:24 -0400 Subject: [PATCH 8/8] wip --- src/services/transactor/backend/ws.rs | 70 +++++++++++++++++++-------- 1 file changed, 51 insertions(+), 19 deletions(-) diff --git a/src/services/transactor/backend/ws.rs b/src/services/transactor/backend/ws.rs index 85e508e..1c0fec5 100644 --- a/src/services/transactor/backend/ws.rs +++ b/src/services/transactor/backend/ws.rs @@ -30,6 +30,7 @@ pub use wasmtimer::{std::Instant, tokio::sleep, tokio::timeout}; #[cfg(not(target_family = "wasm"))] use {std::time::Instant, tokio::time::sleep, tokio::time::timeout}; +const PING: &str = "ping"; const PONG: &str = "pong!"; enum Command { @@ -37,6 +38,9 @@ enum Command { payload: Value, reply_tx: oneshot::Sender, Status>>, }, + Ping { + reply_tx: oneshot::Sender>, + }, // TODO: Manual close #[allow(dead_code)] Close, @@ -50,6 +54,7 @@ async fn socket_task( hello_tx: oneshot::Sender>, tx_broadcast: broadcast::Sender, ) -> Result<()> { + let mut pending_ping = None; let mut pending = HashMap::, Status>>>::new(); let mut binary_mode = opts.binary; @@ -79,7 +84,18 @@ async fn socket_task( pending.insert(id.into(), reply_tx); write.send(encode_message(&payload, binary_mode)?).await?; - } + }, + Command::Ping { reply_tx } => { + let payload = Request { + id: None, + method: Method::Ping.camel().to_string(), + params: Vec::<()>::new(), + time: None, + }; + + write.send(encode_message(&payload, binary_mode)?).await?; + pending_ping = Some(reply_tx); + }, Command::Close => break, }, @@ -114,18 +130,33 @@ async fn socket_task( payload = resp; }, - Message::Ping(payload) => { - trace!(target: "ws", ?payload, "Received ping, replying..."); - write.send(encode_message(&Method::Ping.camel(), binary_mode)?).await?; - continue; + Message::Ping(request) => { + response = Response { + result: Some(Value::String(PONG.to_string())), + ..Default::default() + }; + + payload = request; }, Message::Close { .. } => break, _ => continue, } - if response.result.as_ref().is_some_and(|v| v == "ping") { - trace!(target: "ws", ?payload, "Received ping, replying..."); - write.send(encode_message(&Method::Ping.camel(), binary_mode)?).await?; + if response.result.as_ref().is_some_and(|v| v == PING) { + trace!(target: "ws", "Received ping, replying..."); + if binary_mode { + write.send(Message::Binary(PONG.into())).await?; + } else { + write.send(Message::Text(PONG.into())).await?; + } + continue; + } + + if response.result.as_ref().is_some_and(|v| v == PONG) { + if let Some(pong_tx) = pending_ping.take() { + let _ = pong_tx.send(Ok(())); + } + continue; } @@ -165,6 +196,7 @@ async fn socket_task( continue; } + if let Some(result) = response.result { match serde_json::from_value::>(result) { Ok(tx_array) => { @@ -196,24 +228,28 @@ async fn ping_task(cmd_tx: UnboundedSender) -> Result<()> { let Some(ping_response_time) = last_ping_response.take() else { trace!(target: "ws", "Pinging server"); - let payload = Request { - id: None, - method: Method::Ping.camel().to_string(), - params: Vec::<()>::new(), - time: None, + let (tx, rx) = oneshot::channel(); + cmd_tx.send(Command::Ping { reply_tx: tx }).ok(); + + let Ok(res) = rx.await else { + error!("Ping channel closed unexpectedly, closing socket"); + break; }; - let _response: Value = send_and_wait(&cmd_tx, payload).await?; + res?; last_ping_response = Some(Instant::now()); continue; }; if ping_response_time.elapsed() > HANG_TIMEOUT { error!("No ping response from server, closing socket"); + break; } last_ping_response = None; } + + Ok(()) } #[derive(Copy, Clone, PartialEq, Eq, Debug)] @@ -359,14 +395,10 @@ impl Backend for WsBackend { method: Method, body: &Q, ) -> Result { - let Value::Object(body_json) = serde_json::to_value(body)? else { - return Err(Error::Other("Expected a JSON object")); - }; - let payload = Request { id: None, method: method.camel().to_string(), - params: body_json.values().collect(), + params: vec![serde_json::to_value(body)?], time: None, };