diff --git a/Cargo.toml b/Cargo.toml index db50f63..ba9de80 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -10,10 +10,8 @@ reqwest = { version = "0.12.15", default-features = false, features = [ "json", "rustls-tls", ] } -reqwest-middleware = { version = "0.4.2", features = ["json", "rustls-tls"] } -reqwest-retry = { version = "0.7.0" } -reqwest-ratelimit = "0.4.1" governor = { version = "0.10.0", features = ["std"] } +reqwest-websocket = { version = "0.5.0", features = ["json"] } serde = "1.0.219" serde_json = "1.0.140" thiserror = "2.0.12" @@ -27,6 +25,9 @@ config = "0.15.11" secrecy = { version = "0.10.3", features = ["serde"] } serde_with = "3.12.0" rand = "0.9.1" +futures = "0.3.31" +tokio_with_wasm = { version = "0.8.6", features = ["rt", "sync", "macros"] } +tokio-stream = { version = "0.1.17", features = ["sync"] } actix-web = { version = "4.10.2", optional = true, features = ["rustls"] } rdkafka = { version = "0.38.0", optional = true, features = [ @@ -38,6 +39,14 @@ num-traits = "0.2.19" itoa = "1.0.15" ryu = "1.0.20" +# Middleware +reqwest-middleware = { version = "0.4.2", features = ["json", "rustls-tls"] } +reqwest-retry = { version = "0.7.0", optional = true } +reqwest-ratelimit = { version = "0.4.1", optional = true } + +[target.'cfg(target_family = "wasm")'.dependencies] +wasmtimer = { version = "0.4.1" } + [dev-dependencies] anyhow = "1.0.98" tokio = { version = "1", features = ["full"] } @@ -46,4 +55,7 @@ tokio = { version = "1", features = ["full"] } default = ["reqwest_middleware"] actix = ["dep:actix-web"] kafka = ["dep:rdkafka"] -reqwest_middleware = [] +reqwest_middleware = ["dep:reqwest-retry", "dep:reqwest-ratelimit"] + +[lints.clippy] +result_large_err = "allow" \ No newline at end of file diff --git a/src/lib.rs b/src/lib.rs index 660a98b..f3c79d5 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -32,6 +32,9 @@ pub enum Error { #[error(transparent)] Reqwest(#[from] reqwest::Error), + #[error(transparent)] + Ws(#[from] reqwest_websocket::Error), + #[error(transparent)] ReqwestMiddleware(#[from] reqwest_middleware::Error), @@ -39,6 +42,11 @@ pub enum Error { #[error(transparent)] Kafka(#[from] rdkafka::error::KafkaError), + #[error("Subscription task panicked")] + SubscriptionFailed, + #[error("Subscription task lagged and was forcibly disconnected")] + SubscriptionLagged, + #[error(transparent)] Url(#[from] url::ParseError), diff --git a/src/services/event.rs b/src/services/event.rs new file mode 100644 index 0000000..b89c072 --- /dev/null +++ b/src/services/event.rs @@ -0,0 +1,11 @@ +use serde_json::Value; + +pub trait Class { + const CLASS: &'static str; +} + +pub trait Event: Class { + fn matches(value: &Value) -> bool { + value.get("_class").and_then(|v| v.as_str()) == Some(Self::CLASS) + } +} diff --git a/src/services/mod.rs b/src/services/mod.rs index 648b63f..65a2f88 100644 --- a/src/services/mod.rs +++ b/src/services/mod.rs @@ -15,8 +15,10 @@ pub mod account; pub mod core; +pub mod event; pub mod jwt; pub mod kvs; +mod rpc; pub mod transactor; pub use reqwest_middleware::{ClientWithMiddleware as HttpClient, RequestBuilder}; @@ -27,20 +29,28 @@ use std::time::Duration; use reqwest::{self, Response, Url}; use reqwest::{StatusCode, header::HeaderValue}; use reqwest_middleware::ClientBuilder; -use reqwest_retry::{ - RetryTransientMiddleware, Retryable, RetryableStrategy, default_on_request_failure, - policies::ExponentialBackoff, -}; use secrecy::SecretString; use serde::{Deserialize, Serialize, de::DeserializeOwned}; use serde_json::{self as json, Value}; use tracing::*; use crate::services::core::{AccountUuid, WorkspaceUuid}; +use crate::services::transactor::backend::http::HttpBackend; +use crate::services::transactor::backend::ws::{WsBackend, WsBackendOpts}; +use crate::{Error, Result, config::Config}; +use account::AccountClient; +use jwt::Claims; +use kvs::KvsClient; +use transactor::TransactorClient; + #[cfg(feature = "kafka")] use crate::services::transactor::kafka; -use crate::{Error, Result, config::Config}; -use {account::AccountClient, jwt::Claims, kvs::KvsClient, transactor::TransactorClient}; + +#[cfg(feature = "reqwest_middleware")] +use reqwest_retry::{ + RetryTransientMiddleware, Retryable, RetryableStrategy, default_on_request_failure, + policies::ExponentialBackoff, +}; pub trait RequestBuilderExt { fn send_ext(self) -> impl Future>; @@ -54,11 +64,12 @@ pub trait BasePathProvider { fn provide_base_path(&self) -> &Url; } -pub trait ForceHttpScheme { +pub trait ForceScheme { fn force_http_scheme(self) -> Url; + fn force_ws_scheme(self) -> Url; } -impl ForceHttpScheme for Url { +impl ForceScheme for Url { fn force_http_scheme(mut self) -> Url { match self.scheme() { "ws" => { @@ -74,6 +85,24 @@ impl ForceHttpScheme for Url { self } + + fn force_ws_scheme(mut self) -> Url { + match self.scheme() { + "http" => { + self.set_scheme("ws").unwrap(); + } + + "https" => { + self.set_scheme("wss").unwrap(); + } + + "ws" | "wss" => {} + + _ => panic!(), + }; + + self + } } impl RequestBuilderExt for RequestBuilder { @@ -116,13 +145,13 @@ fn from_value(value: Value) -> Result { pub trait JsonClient { fn get( &self, - user: U, + user: &U, url: Url, ) -> impl Future>; fn post( &self, - user: U, + user: &U, url: Url, body: &Q, ) -> impl Future>; @@ -134,7 +163,7 @@ impl JsonClient for HttpClient { skip(self, user, url), fields(%url, method = "get", type = "json") )] - async fn get(&self, user: U, url: Url) -> Result { + async fn get(&self, user: &U, url: Url) -> Result { trace!("request"); let mut request = self.get(url.clone()); @@ -148,7 +177,7 @@ impl JsonClient for HttpClient { async fn post( &self, - user: U, + user: &U, url: Url, body: &Q, ) -> Result { @@ -170,7 +199,7 @@ impl JsonClient for HttpClient { } } -#[derive(Deserialize, Debug, Clone, strum::Display)] +#[derive(Serialize, Deserialize, Debug, Clone, strum::Display)] #[serde(rename_all = "UPPERCASE")] pub enum Severity { Ok, @@ -179,11 +208,11 @@ pub enum Severity { Error, } -#[derive(Deserialize, Debug, Clone, thiserror::Error)] +#[derive(Serialize, Deserialize, Debug, Clone, thiserror::Error)] pub struct Status { pub severity: Severity, pub code: String, - pub params: HashMap, + pub params: HashMap, } impl std::fmt::Display for Status { @@ -431,7 +460,11 @@ impl ServiceFactory { ) } - pub fn new_transactor_client(&self, base: Url, claims: &Claims) -> Result { + pub fn new_transactor_client( + &self, + base: Url, + claims: &Claims, + ) -> Result> { TransactorClient::new( self.transactor_http.clone(), base, @@ -445,15 +478,45 @@ impl ServiceFactory { ) } + pub async fn new_transactor_client_ws( + &self, + base: Url, + claims: &Claims, + opts: WsBackendOpts, + ) -> Result> { + TransactorClient::new_ws( + base, + claims.workspace()?, + claims.encode( + self.config + .token_secret + .as_ref() + .ok_or(Error::Other("NoSecret"))?, + )?, + opts, + ) + .await + } + pub fn new_transactor_client_from_token( &self, base: Url, workspace: WorkspaceUuid, token: impl Into, - ) -> Result { + ) -> Result> { TransactorClient::new(self.transactor_http.clone(), base, workspace, token) } + pub async fn new_transactor_client_ws_from_token( + &self, + base: Url, + workspace: WorkspaceUuid, + token: impl Into, + opts: WsBackendOpts, + ) -> Result> { + TransactorClient::new_ws(base, workspace, token, opts).await + } + #[cfg(feature = "kafka")] pub fn new_kafka_publisher(&self, topic: &str) -> Result { kafka::KafkaProducer::new(&self.config, topic) diff --git a/src/services/rpc/mod.rs b/src/services/rpc/mod.rs new file mode 100644 index 0000000..d03b8a4 --- /dev/null +++ b/src/services/rpc/mod.rs @@ -0,0 +1,105 @@ +pub mod util; + +use crate::services::Status; +use crate::services::core::Account; +use serde::{Deserialize, Serialize}; + +#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Eq, Hash)] +#[serde(untagged, rename_all = "camelCase")] +pub enum ReqId { + Str(String), + Num(i32), +} + +impl From for ReqId { + fn from(s: String) -> Self { + ReqId::Str(s) + } +} + +impl From for ReqId { + fn from(i: i32) -> Self { + ReqId::Num(i) + } +} + +#[derive(Serialize, Deserialize, Debug, Clone)] +#[serde(rename_all = "camelCase")] +pub struct RateLimitInfo { + pub remaining: u32, + pub limit: u32, + pub current: u32, + pub reset: f64, + #[serde(skip_serializing_if = "Option::is_none")] + pub retry_after: Option, +} + +#[derive(Serialize, Deserialize, Debug, Clone)] +#[serde(rename_all = "camelCase")] +pub struct Chunk { + pub index: u32, + pub r#final: bool, +} + +#[derive(Serialize, Deserialize, Default, Debug, Clone)] +#[serde(rename_all = "camelCase")] +pub struct Response { + #[serde(skip_serializing_if = "Option::is_none")] + pub result: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub id: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub error: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub terminate: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub rate_limit: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub chunk: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub time: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub bfst: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub queue: Option, +} + +#[derive(Serialize, Deserialize, Debug, Clone)] +#[serde(rename_all = "camelCase")] +pub struct Request

{ + #[serde(skip_serializing_if = "Option::is_none")] + pub id: Option, + pub method: String, + pub params: Vec

, + #[serde(skip_serializing_if = "Option::is_none")] + pub time: Option, +} + +#[derive(Serialize, Deserialize, Debug, Clone)] +#[serde(rename_all = "camelCase")] +pub struct HelloRequest { + #[serde(flatten)] + pub request: Request<()>, + #[serde(skip_serializing_if = "Option::is_none")] + pub binary: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub compression: Option, +} + +#[derive(Serialize, Deserialize, Debug, Clone)] +#[serde(rename_all = "camelCase")] +pub struct HelloResponse { + #[serde(flatten)] + pub response: Response, + pub binary: bool, + #[serde(skip_serializing_if = "Option::is_none")] + pub reconnect: Option, + pub server_version: String, + #[serde(skip_serializing_if = "Option::is_none")] + pub last_tx: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub last_hash: Option, + pub account: Account, + #[serde(skip_serializing_if = "Option::is_none")] + pub use_compression: Option, +} diff --git a/src/services/rpc/util.rs b/src/services/rpc/util.rs new file mode 100644 index 0000000..3cdd852 --- /dev/null +++ b/src/services/rpc/util.rs @@ -0,0 +1,42 @@ +use crate::services::Status; +use crate::services::rpc::{Chunk, RateLimitInfo, ReqId, Response}; +use serde::{Deserialize, Serialize}; + +#[derive(Serialize, Deserialize, Default, Debug, Clone)] +#[serde(rename_all = "camelCase")] +pub struct OkResponse { + #[serde(skip_serializing_if = "Option::is_none")] + pub result: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub id: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub terminate: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub rate_limit: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub chunk: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub time: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub bfst: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub queue: Option, +} + +impl Response { + pub fn into_result(self) -> Result, Status> { + match self.error { + Some(e) => Err(e), + None => Ok(OkResponse { + result: self.result, + id: self.id, + terminate: self.terminate, + rate_limit: self.rate_limit, + chunk: self.chunk, + time: self.time, + bfst: self.bfst, + queue: self.queue, + }), + } + } +} diff --git a/src/services/transactor/backend/http.rs b/src/services/transactor/backend/http.rs new file mode 100644 index 0000000..95d0fa1 --- /dev/null +++ b/src/services/transactor/backend/http.rs @@ -0,0 +1,118 @@ +use crate::Result; +use crate::services::core::WorkspaceUuid; +use crate::services::transactor::backend::Backend; +use crate::services::transactor::methods::Method; +use crate::services::{JsonClient, TokenProvider}; +use reqwest_middleware::ClientWithMiddleware; +use secrecy::{ExposeSecret, SecretString}; +use serde::Serialize; +use serde::de::DeserializeOwned; +use serde_json::Value; +use std::sync::Arc; +use url::Url; + +pub type HttpClient = ClientWithMiddleware; + +struct HttpBackendInner { + workspace: WorkspaceUuid, + base: Url, + client: HttpClient, + token: SecretString, +} + +#[derive(Clone)] +pub struct HttpBackend { + inner: Arc, +} + +impl HttpBackend { + pub fn new( + client: HttpClient, + base: Url, + workspace: WorkspaceUuid, + token: impl Into, + ) -> Self { + Self { + inner: Arc::new(HttpBackendInner { + workspace, + base, + client, + token: token.into(), + }), + } + } + + pub(crate) async fn post_path( + &self, + path: &str, + body: &Q, + ) -> Result { + let url = self.base().join(path)?; + ::post(&self.inner.client, self, url, body).await + } +} + +impl JsonClient for HttpBackend { + fn get( + &self, + user: &U, + url: Url, + ) -> impl Future> { + JsonClient::get(&self.inner.client, user, url) + } + + fn post( + &self, + user: &U, + url: Url, + body: &Q, + ) -> impl Future> { + JsonClient::post(&self.inner.client, user, url, body) + } +} + +impl TokenProvider for HttpBackend { + fn provide_token(&self) -> Option<&str> { + Some(self.inner.token.expose_secret()) + } +} + +impl TokenProvider for &'_ HttpBackend { + fn provide_token(&self) -> Option<&str> { + Some(self.inner.token.expose_secret()) + } +} + +impl super::Backend for HttpBackend { + async fn get( + &self, + method: Method, + params: impl IntoIterator, + ) -> Result { + let mut url = self.base().join(&format!("/api/v1/{}", method.kebab()))?; + let mut qp = url.query_pairs_mut(); + for (name, value) in params { + qp.append_pair(&name, &value.to_string()); + } + drop(qp); + + ::get(&self.inner.client, self, url).await + } + + async fn post( + &self, + method: Method, + body: &Q, + ) -> Result { + self.post_path(&format!("/api/v1/{}", method.kebab()), body) + .await + } + + fn base(&self) -> &Url { + &self.inner.base + } + + fn workspace(&self) -> WorkspaceUuid { + self.inner.workspace + } +} diff --git a/src/services/transactor/backend/mod.rs b/src/services/transactor/backend/mod.rs new file mode 100644 index 0000000..19ae68d --- /dev/null +++ b/src/services/transactor/backend/mod.rs @@ -0,0 +1,30 @@ +use crate::Result; +use crate::services::TokenProvider; +use crate::services::core::WorkspaceUuid; +use crate::services::transactor::methods::Method; +use serde::Serialize; +use serde::de::DeserializeOwned; +use serde_json::Value; +use url::Url; + +pub mod http; +pub mod ws; + +#[allow(async_fn_in_trait)] +pub trait Backend: Clone + TokenProvider + 'static { + async fn get( + &self, + method: Method, + params: impl IntoIterator, + ) -> Result; + + async fn post( + &self, + method: Method, + body: &Q, + ) -> Result; + + fn base(&self) -> &Url; + + fn workspace(&self) -> WorkspaceUuid; +} diff --git a/src/services/transactor/backend/ws.rs b/src/services/transactor/backend/ws.rs new file mode 100644 index 0000000..1c0fec5 --- /dev/null +++ b/src/services/transactor/backend/ws.rs @@ -0,0 +1,437 @@ +use crate::services::core::WorkspaceUuid; +use crate::services::rpc::util::OkResponse; +use crate::services::rpc::{HelloRequest, HelloResponse, ReqId, Request, Response}; +use crate::services::transactor::backend::Backend; +use crate::services::transactor::methods::Method; +use crate::services::{Status, TokenProvider}; +use crate::{Error, Result}; +use bytes::Bytes; +use futures::stream::{SplitSink, SplitStream}; +use futures::{SinkExt, StreamExt}; +use reqwest::Client; +use reqwest_websocket::{Message, RequestBuilderExt, WebSocket}; +use secrecy::{ExposeSecret, SecretString}; +use serde::Serialize; +use serde::de::DeserializeOwned; +use serde_json::Value; +use std::collections::HashMap; +use std::fmt::Debug; +use std::sync::Arc; +use std::sync::atomic::{AtomicI32, Ordering}; +use std::time::Duration; +use tokio::sync::mpsc::{self, UnboundedSender}; +use tokio::sync::{broadcast, oneshot}; +use tokio::task::JoinHandle; +use tokio_with_wasm::alias as tokio; +use tracing::{error, trace, warn}; +use url::Url; +#[cfg(target_family = "wasm")] +pub use wasmtimer::{std::Instant, tokio::sleep, tokio::timeout}; +#[cfg(not(target_family = "wasm"))] +use {std::time::Instant, tokio::time::sleep, tokio::time::timeout}; + +const PING: &str = "ping"; +const PONG: &str = "pong!"; + +enum Command { + Call { + payload: Value, + reply_tx: oneshot::Sender, Status>>, + }, + Ping { + reply_tx: oneshot::Sender>, + }, + // TODO: Manual close + #[allow(dead_code)] + Close, +} + +async fn socket_task( + mut write: SplitSink, + mut read: SplitStream, + mut cmd_rx: mpsc::UnboundedReceiver, + opts: WsBackendOpts, + hello_tx: oneshot::Sender>, + tx_broadcast: broadcast::Sender, +) -> Result<()> { + let mut pending_ping = None; + let mut pending = + HashMap::, Status>>>::new(); + let mut binary_mode = opts.binary; + let mut use_compression = opts.compression; + let next_id = AtomicI32::new(1); + + let hello = HelloRequest { + request: Request { + id: Some(ReqId::Num(-1)), + method: Method::Hello.camel().to_string(), + params: Vec::new(), + time: None, + }, + binary: Some(binary_mode), + compression: Some(use_compression), + }; + trace!(target: "ws", ?hello, "sending HELLO"); + write.send(encode_message(&hello, binary_mode)?).await?; + + let mut hello_tx = Some(hello_tx); + loop { + tokio::select! { + Some(cmd) = cmd_rx.recv() => match cmd { + Command::Call { mut payload, reply_tx } => { + let id = next_id.fetch_add(1, Ordering::Relaxed); + payload["id"] = Value::Number(id.into()); + + pending.insert(id.into(), reply_tx); + write.send(encode_message(&payload, binary_mode)?).await?; + }, + Command::Ping { reply_tx } => { + let payload = Request { + id: None, + method: Method::Ping.camel().to_string(), + params: Vec::<()>::new(), + time: None, + }; + + write.send(encode_message(&payload, binary_mode)?).await?; + pending_ping = Some(reply_tx); + }, + Command::Close => break, + }, + + Some(message) = read.next() => { + trace!(target: "ws", ?message, "Got message"); + + let response: Response; + let payload: Bytes; + match message? { + Message::Text(resp) => { + // Ping responses don't follow the same structure + if resp == PONG { + response = Response { + result: Some(Value::String(PONG.to_string())), + ..Default::default() + } + } else { + response = serde_json::from_str(&resp)?; + } + + payload = resp.into(); + }, + Message::Binary(resp) => { + if resp == PONG.as_bytes() { + response = Response { + result: Some(Value::String(PONG.to_string())), + ..Default::default() + } + } else { + response = serde_json::from_slice(&resp)?; + } + + payload = resp; + }, + Message::Ping(request) => { + response = Response { + result: Some(Value::String(PONG.to_string())), + ..Default::default() + }; + + payload = request; + }, + Message::Close { .. } => break, + _ => continue, + } + + if response.result.as_ref().is_some_and(|v| v == PING) { + trace!(target: "ws", "Received ping, replying..."); + if binary_mode { + write.send(Message::Binary(PONG.into())).await?; + } else { + write.send(Message::Text(PONG.into())).await?; + } + continue; + } + + if response.result.as_ref().is_some_and(|v| v == PONG) { + if let Some(pong_tx) = pending_ping.take() { + let _ = pong_tx.send(Ok(())); + } + + continue; + } + + if matches!(response.id, Some(ReqId::Num(-1))) { + if response.result.is_none() && response.error.is_some() { + let result = response.into_result(); + error!(target: "ws", ?result); + continue; + } + + if response.result.is_some_and(|result| result == "hello") { + // Just ignore any extra HELLOs + let Some(hello_tx) = hello_tx.take() else { + continue; + }; + + let hello = serde_json::from_slice::(&payload)?; + binary_mode = hello.binary; + + // TODO: compression support + #[allow(unused_assignments)] + { + use_compression = hello.use_compression.unwrap_or(false); + } + + let _ = hello_tx.send(Ok(())); + continue; + } + + continue; + } + + trace!(target: "ws", ?response, "Full response"); + if let Some(id) = &response.id + && let Some(tx) = pending.remove(id) { + let _ = tx.send(response.into_result()).ok(); + continue; + } + + + if let Some(result) = response.result { + match serde_json::from_value::>(result) { + Ok(tx_array) => { + for tx in tx_array { + let _ = tx_broadcast.send(tx); + } + } + Err(e) => { + warn!(target: "ws", "Failed to deserialize transaction array: {}", e); + } + } + } + } + } + } + + Ok(()) +} + +async fn ping_task(cmd_tx: UnboundedSender) -> Result<()> { + const PING_TIMEOUT: Duration = Duration::from_secs(10); + const HANG_TIMEOUT: Duration = Duration::from_secs(60 * 5); + + let mut last_ping_response = None; + + loop { + sleep(PING_TIMEOUT).await; + + let Some(ping_response_time) = last_ping_response.take() else { + trace!(target: "ws", "Pinging server"); + + let (tx, rx) = oneshot::channel(); + cmd_tx.send(Command::Ping { reply_tx: tx }).ok(); + + let Ok(res) = rx.await else { + error!("Ping channel closed unexpectedly, closing socket"); + break; + }; + + res?; + last_ping_response = Some(Instant::now()); + continue; + }; + + if ping_response_time.elapsed() > HANG_TIMEOUT { + error!("No ping response from server, closing socket"); + break; + } + + last_ping_response = None; + } + + Ok(()) +} + +#[derive(Copy, Clone, PartialEq, Eq, Debug)] +pub struct WsBackendOpts { + pub binary: bool, + pub compression: bool, + /// How long to wait for the server's HELLO response before timing out + pub hello_timeout: Duration, +} + +impl Default for WsBackendOpts { + fn default() -> Self { + Self { + binary: false, + compression: false, + hello_timeout: Duration::from_secs(10), + } + } +} + +struct WsBackendInner { + workspace: WorkspaceUuid, + token: SecretString, + + cmd_tx: UnboundedSender, + base: Url, + tx_broadcast: broadcast::Sender, + _handle: JoinHandle<()>, +} + +#[derive(Clone)] +pub struct WsBackend { + inner: Arc, +} + +impl WsBackend { + pub(in crate::services::transactor) async fn connect( + base: Url, + workspace: WorkspaceUuid, + token: impl Into, + opts: WsBackendOpts, + ) -> Result { + let token = token.into(); + + let url = base.join(token.expose_secret())?; + let resp = Client::default() + .get(url) + .bearer_auth(token.expose_secret()) + .upgrade() + .send() + .await?; + let ws = resp.into_websocket().await?; + + let (write, read) = ws.split(); + let (hello_tx, hello_rx) = oneshot::channel(); + + let (tx_broadcast, _) = broadcast::channel::(128); + + let tx_broadcast_clone = tx_broadcast.clone(); + let (cmd_tx, cmd_rx) = mpsc::unbounded_channel::(); + let socket_handle = async move { + if let Err(e) = + socket_task(write, read, cmd_rx, opts, hello_tx, tx_broadcast_clone).await + { + warn!(target:"ws", ?e, "socket task crashed"); + } + }; + + let cmd_tx2 = cmd_tx.clone(); + let ping_handle = async move { + if let Err(e) = ping_task(cmd_tx2).await { + warn!(target:"ws", ?e, "ping task ended"); + } + }; + + let handle = tokio::task::spawn(async move { + tokio::select! { + _ = socket_handle => {}, + _ = ping_handle => {}, + } + }); + + match timeout(opts.hello_timeout, hello_rx).await { + Ok(Ok(Ok(()))) => {} + Ok(Ok(Err(e))) => return Err(e), + Err(_) => return Err(Error::Other("timed out waiting for HELLO")), + _ => return Err(Error::Other("HELLO channel closed unexpectedly")), + } + + Ok(Self { + inner: Arc::new(WsBackendInner { + workspace, + base, + cmd_tx, + tx_broadcast, + _handle: handle, + token, + }), + }) + } + + pub(in crate::services::transactor) fn tx_stream( + &self, + ) -> tokio_stream::wrappers::BroadcastStream { + self.inner.tx_broadcast.subscribe().into() + } +} + +fn encode_message(value: &Q, binary_mode: bool) -> Result { + if binary_mode { + Ok(Message::Binary(serde_json::to_vec(value)?.into())) + } else { + Ok(Message::Text(serde_json::to_string(value)?)) + } +} + +impl TokenProvider for WsBackend { + fn provide_token(&self) -> Option<&str> { + Some(self.inner.token.expose_secret()) + } +} + +impl Backend for WsBackend { + async fn get( + &self, + method: Method, + params: impl IntoIterator, + ) -> Result { + let param_values = params.into_iter().map(|(_k, v)| v).collect::>(); + + let payload = Request { + id: None, + method: method.camel().to_string(), + params: param_values, + time: None, + }; + + send_and_wait(&self.inner.cmd_tx, payload).await + } + + async fn post( + &self, + method: Method, + body: &Q, + ) -> Result { + let payload = Request { + id: None, + method: method.camel().to_string(), + params: vec![serde_json::to_value(body)?], + time: None, + }; + + send_and_wait(&self.inner.cmd_tx, payload).await + } + + fn base(&self) -> &Url { + &self.inner.base + } + + fn workspace(&self) -> WorkspaceUuid { + self.inner.workspace + } +} + +async fn send_and_wait( + cmd_tx: &UnboundedSender, + payload: Request, +) -> Result { + let payload = serde_json::to_value(&payload)?; + trace!(target: "ws", %payload, "Sending message"); + + let (reply_tx, reply_rx) = oneshot::channel(); + cmd_tx.send(Command::Call { payload, reply_tx }).ok(); + + let Ok(reply) = reply_rx.await else { + return Err(Error::Other("connection closed before reply")); + }; + + let reply = reply?; + let Some(result) = reply.result else { + return Err(Error::Other("server didn't return a result")); + }; + + serde_json::from_value(result).map_err(|e| e.into()) +} diff --git a/src/services/transactor/comm/mod.rs b/src/services/transactor/comm/mod.rs index 66f94a9..c98bebd 100644 --- a/src/services/transactor/comm/mod.rs +++ b/src/services/transactor/comm/mod.rs @@ -17,12 +17,11 @@ use serde::{Deserialize, Serialize, de::DeserializeOwned}; use serde_json::{self as json, Value}; use super::tx::{Doc, Obj, Tx, TxDomainEvent}; +use crate::Result; +use crate::services::JsonClient; use crate::services::core::Ref; +use crate::services::transactor::backend::http::HttpBackend; use crate::services::transactor::document::generate_object_id; -use crate::{ - Result, - services::{HttpClient, JsonClient}, -}; mod message; pub use message::*; @@ -132,11 +131,11 @@ pub trait EventClient { } } -impl EventClient for super::TransactorClient { +impl EventClient for super::TransactorClient { async fn request_raw(&self, envelope: &T) -> Result { - let path = format!("/api/v1/event/{}", self.workspace); - let url = self.base.join(&path)?; + let path = format!("/api/v1/event/{}", self.workspace()); + let url = self.base().join(&path)?; - ::post(&self.http, self, url, envelope).await + ::post(self.backend(), &self, url, envelope).await } } diff --git a/src/services/transactor/document.rs b/src/services/transactor/document.rs index 11589f9..f83049b 100644 --- a/src/services/transactor/document.rs +++ b/src/services/transactor/document.rs @@ -26,11 +26,10 @@ use super::{ }; use crate::services::core::ser::Data; -use crate::services::core::{Account, PersonId, Ref, Timestamp}; -use crate::{ - Error, Result, - services::{HttpClient, JsonClient}, -}; +use crate::services::core::{Account, FindResult, PersonId, Ref, Timestamp}; +use crate::services::transactor::backend::Backend; +use crate::services::transactor::methods::Method; +use crate::{Error, Result}; static COUNT: AtomicUsize = AtomicUsize::new(0); static RANDOM: LazyLock = LazyLock::new(|| { @@ -202,14 +201,6 @@ impl FindOptionsBuilder { } } -#[derive(Debug, Serialize, Deserialize)] -#[serde(rename_all = "camelCase")] -pub struct FindResult { - pub data_type: String, - pub total: i64, - pub value: Vec, -} - pub trait DocumentClient { fn get_account(&self) -> impl Future>; @@ -228,12 +219,9 @@ pub trait DocumentClient { ) -> impl Future>>; } -impl DocumentClient for super::TransactorClient { +impl DocumentClient for super::TransactorClient { async fn get_account(&self) -> Result { - let path = format!("/api/v1/account/{}", self.workspace); - let url = self.base.join(&path)?; - - ::get(&self.http, self, url).await + self.get(Method::Account, []).await } async fn find_all( @@ -242,9 +230,6 @@ impl DocumentClient for super::TransactorClient { query: Q, options: &FindOptions, ) -> Result> { - let path = format!("/api/v1/find-all/{}", self.workspace); - let mut url = self.base.join(&path)?; - let query = json::to_value(query)?; if !query.is_object() { @@ -253,13 +238,16 @@ impl DocumentClient for super::TransactorClient { let query = query.as_object().unwrap(); - url.query_pairs_mut() - .append_pair("class", class) - .append_pair("query", &json::to_string(&query)?) - .append_pair("options", &json::to_string(&options)?); - - let mut result: FindResult = - ::get(&self.http, self, url).await?; + let mut result: FindResult = self + .get( + Method::FindAll, + [ + (String::from("class"), class.into()), + (String::from("query"), json::to_value(query)?), + (String::from("options"), json::to_value(options)?), + ], + ) + .await?; // TODO? /* api-client/src/rest.ts @@ -294,7 +282,6 @@ impl DocumentClient for super::TransactorClient { } let result = FindResult { - data_type: result.data_type, total: result.total, value: { let mut value = Vec::new(); @@ -305,6 +292,20 @@ impl DocumentClient for super::TransactorClient { value }, + lookup_map: match result.lookup_map { + Some(lookup_map) => { + let new_map = lookup_map + .into_iter() + .map(|(k, v)| match json::from_value(v) { + Ok(val) => Ok((k, val)), + Err(e) => Err(e.into()), + }) + .collect::>()?; + + Some(new_map) + } + None => None, + }, }; Ok(result) diff --git a/src/services/transactor/methods.rs b/src/services/transactor/methods.rs new file mode 100644 index 0000000..fb7d2fc --- /dev/null +++ b/src/services/transactor/methods.rs @@ -0,0 +1,30 @@ +macro_rules! api_methods { + ($($Variant:ident: $kebab:literal, $camel:literal),+ $(,)?) => { + #[derive(Debug, Clone, Copy)] + pub enum Method { $($Variant),+ } + + impl Method { + pub const fn kebab(self) -> &'static str { + match self { + $( Self::$Variant => $kebab ),+ + } + } + + pub const fn camel(self) -> &'static str { + match self { + $( Self::$Variant => $camel ),+ + } + } + } + }; +} + +api_methods!( + Account: "account", "account", + FindAll: "find-all", "findAll", + EnsurePerson: "ensure-person", "ensurePerson", + Tx: "tx", "tx", + Event: "event", "event", + Ping: "ping", "ping", + Hello: "hello", "hello", +); diff --git a/src/services/transactor/mod.rs b/src/services/transactor/mod.rs index 2774e53..7119edb 100644 --- a/src/services/transactor/mod.rs +++ b/src/services/transactor/mod.rs @@ -12,23 +12,27 @@ // limitations under the License. // +use crate::Result; +use crate::services::core::WorkspaceUuid; +use crate::services::transactor::backend::Backend; +use crate::services::transactor::backend::http::{HttpBackend, HttpClient}; +use crate::services::transactor::backend::ws::{WsBackend, WsBackendOpts}; +use crate::services::transactor::methods::Method; +use crate::services::{ForceScheme, JsonClient}; use secrecy::{ExposeSecret, SecretString}; use serde::{Serialize, de::DeserializeOwned}; use serde_json::Value; -use tracing::*; +use subscription::SubscribedQuery; use url::Url; +pub mod backend; pub mod comm; pub mod document; +pub mod methods; pub mod person; +mod subscription; pub mod tx; -use crate::services::core::WorkspaceUuid; -use crate::{ - Result, - services::{ForceHttpScheme, HttpClient, JsonClient}, -}; - pub trait Transaction { fn to_value(self) -> crate::Result; } @@ -36,7 +40,6 @@ pub trait Transaction { pub trait TransactionValue { fn matches(&self, class: Option<&str>, domain: Option<&str>) -> bool; } - impl TransactionValue for Value { fn matches(&self, class: Option<&str>, domain: Option<&str>) -> bool { let class = class.is_none() || self["_class"].as_str() == class; @@ -47,28 +50,55 @@ impl TransactionValue for Value { } #[derive(Clone)] -pub struct TransactorClient { - pub workspace: WorkspaceUuid, - pub base: Url, - token: SecretString, - http: HttpClient, +pub struct TransactorClient { + backend: B, } -impl PartialEq for TransactorClient { +impl PartialEq for TransactorClient { fn eq(&self, other: &Self) -> bool { - self.workspace == other.workspace - && self.token.expose_secret() == other.token.expose_secret() - && self.base == other.base + self.backend.workspace() == other.backend.workspace() + && self.backend.provide_token() == other.backend.provide_token() + && self.base() == other.base() } } -impl super::TokenProvider for &TransactorClient { +impl super::TokenProvider for &TransactorClient { fn provide_token(&self) -> Option<&str> { - Some(self.token.expose_secret()) + self.backend.provide_token() + } +} + +impl TransactorClient { + pub fn base(&self) -> &Url { + self.backend.base() + } + + pub fn workspace(&self) -> WorkspaceUuid { + self.backend.workspace() + } + + pub async fn get( + &self, + method: Method, + params: impl IntoIterator + Send, + ) -> Result { + self.backend.get(method, params).await + } + + pub async fn post( + &self, + method: Method, + body: &Q, + ) -> Result { + self.backend.post(method, body).await + } + + pub(in crate::services::transactor) fn backend(&self) -> &B { + &self.backend } } -impl TransactorClient { +impl TransactorClient { pub fn new( http: HttpClient, base: Url, @@ -77,18 +107,17 @@ impl TransactorClient { ) -> Result { let base = base.force_http_scheme(); Ok(Self { - workspace, - http, - base, - token: token.into(), + backend: HttpBackend::new(http, base, workspace, token), }) } +} +impl TransactorClient { pub async fn tx_raw(&self, tx: T) -> Result { - let path = format!("/api/v1/tx/{}", self.workspace); - let url = self.base.join(&path)?; + let path = format!("/api/v1/tx/{}", self.workspace()); + let url = self.base().join(&path)?; - ::post(&self.http, self, url, &tx).await + ::post(self.backend(), &self, url, &tx).await } pub async fn tx(&self, tx: T) -> Result { @@ -96,6 +125,37 @@ impl TransactorClient { } } +impl TransactorClient { + pub async fn new_ws( + base: Url, + workspace: WorkspaceUuid, + token: impl Into, + opts: WsBackendOpts, + ) -> Result { + let base = base.force_ws_scheme(); + let token = token.into(); + let backend = WsBackend::connect(base, workspace, token.expose_secret(), opts).await?; + + Ok(Self { backend }) + } + + pub async fn subscribe( + &self, + ) -> SubscribedQuery { + SubscribedQuery::new(self.clone()) + } +} + +impl TransactorClient { + pub async fn tx_raw(&self, tx: T) -> Result { + self.backend.post(Method::Tx, &tx).await + } + + pub async fn tx(&self, tx: T) -> Result { + self.tx_raw(tx.to_value()?).await + } +} + #[cfg(feature = "kafka")] pub mod kafka { use super::*; @@ -108,6 +168,7 @@ pub mod kafka { }; use serde_json::{self as json, Value}; use std::time::Duration; + use tracing::{debug, warn}; use uuid::Uuid; pub struct KafkaProducer { diff --git a/src/services/transactor/person.rs b/src/services/transactor/person.rs index 788a373..c64abee 100644 --- a/src/services/transactor/person.rs +++ b/src/services/transactor/person.rs @@ -16,10 +16,9 @@ use serde::{Deserialize, Serialize}; use crate::services::core::{PersonId, PersonUuid}; -use crate::{ - Result, - services::{HttpClient, JsonClient, core::SocialIdType}, -}; +use crate::services::transactor::backend::Backend; +use crate::services::transactor::methods::Method; +use crate::{Result, services::core::SocialIdType}; #[derive(Serialize, Debug, derive_builder::Builder)] #[serde(rename_all = "camelCase")] @@ -49,11 +48,8 @@ pub trait EnsurePerson { ) -> impl Future>; } -impl EnsurePerson for super::TransactorClient { +impl EnsurePerson for super::TransactorClient { async fn ensure_person(&self, request: &EnsurePersonRequest) -> Result { - let path = format!("/api/v1/ensure-person/{}", self.workspace); - let url = self.base.join(&path)?; - - ::post(&self.http, self, url, request).await + self.post(Method::EnsurePerson, request).await } } diff --git a/src/services/transactor/subscription.rs b/src/services/transactor/subscription.rs new file mode 100644 index 0000000..34e0f78 --- /dev/null +++ b/src/services/transactor/subscription.rs @@ -0,0 +1,54 @@ +use crate::services::event::Event; +use crate::services::transactor::TransactorClient; +use crate::services::transactor::backend::ws::WsBackend; +use crate::{Error, Result}; +use futures::{Stream, TryStreamExt}; +use serde::de::DeserializeOwned; +use serde_json::Value; +use std::marker::PhantomData; +use std::pin::Pin; +use std::task::{Context, Poll}; +use tokio_stream::wrappers::BroadcastStream; +use tokio_stream::wrappers::errors::BroadcastStreamRecvError; + +pub struct SubscribedQuery { + tx_rx: BroadcastStream, + _phantom: PhantomData, +} + +impl SubscribedQuery { + pub fn new(client: TransactorClient) -> Self { + let tx_rx = client.backend().tx_stream(); + + Self { + tx_rx, + _phantom: PhantomData, + } + } +} + +impl Stream for SubscribedQuery { + type Item = Result; + + fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + loop { + match self.tx_rx.try_poll_next_unpin(cx) { + Poll::Ready(Some(Ok(value))) => { + if T::matches(&value) { + let event = serde_json::from_value(value).map_err(Error::from); + return Poll::Ready(Some(event)); + } + + continue; + } + Poll::Ready(Some(Err(BroadcastStreamRecvError::Lagged(_)))) => { + return Poll::Ready(Some(Err(Error::SubscriptionLagged))); + } + Poll::Ready(None) => return Poll::Ready(None), + Poll::Pending => { + return Poll::Pending; + } + } + } + } +} diff --git a/src/services/transactor/tx.rs b/src/services/transactor/tx.rs index 2eb24e9..e282d39 100644 --- a/src/services/transactor/tx.rs +++ b/src/services/transactor/tx.rs @@ -15,15 +15,19 @@ use crate::services::core::ser::Data; use crate::services::core::{PersonId, Ref, Timestamp}; -use serde::{Deserialize, Serialize}; +use crate::services::event::{Class, Event}; +use serde::de::DeserializeOwned; +use serde::{Deserialize, Deserializer, Serialize}; +use serde_json::Value; +use std::fmt::Debug; -#[derive(Deserialize, Serialize, Debug, Clone)] +#[derive(Deserialize, Serialize, Debug, Clone, Default)] pub struct Obj { #[serde(rename = "_class")] pub class: Ref, } -#[derive(Serialize, Deserialize, Debug, Clone)] +#[derive(Serialize, Deserialize, Debug, Clone, Default)] #[serde(rename_all = "camelCase")] pub struct Doc { #[serde(flatten)] @@ -48,7 +52,7 @@ pub struct Doc { pub created_on: Option, } -#[derive(Serialize, Deserialize, Debug)] +#[derive(Serialize, Deserialize, Clone, Debug)] #[serde(rename_all = "camelCase")] pub struct Tx { #[serde(flatten)] @@ -56,7 +60,7 @@ pub struct Tx { pub object_space: Ref, } -#[derive(Serialize, Deserialize, Debug)] +#[derive(Serialize, Deserialize, Clone, Debug)] #[serde(rename_all = "camelCase")] pub struct TxCUD { #[serde(flatten)] @@ -82,6 +86,37 @@ pub struct TxCreateDoc { pub attributes: Data, } +impl Class for TxCreateDoc { + const CLASS: &'static str = crate::services::core::class::TxCreateDoc; +} + +impl Event for TxCreateDoc { + fn matches(value: &Value) -> bool { + if value.get("_class").and_then(|v| v.as_str()) != Some(Self::CLASS) { + return false; + } + value.get("objectClass").and_then(|v| v.as_str()) == Some(T::CLASS) + } +} + +impl<'de, T> Deserialize<'de> for TxCreateDoc +where + T: Serialize + DeserializeOwned, +{ + fn deserialize(deserializer: D) -> Result + where + D: Deserializer<'de>, + { + let value = Value::deserialize(deserializer)?; + + let txcud = serde_json::from_value(value.clone()).map_err(serde::de::Error::custom)?; + + let attributes = serde_json::from_value(value).map_err(serde::de::Error::custom)?; + + Ok(TxCreateDoc { txcud, attributes }) + } +} + #[derive(Serialize, Deserialize, Debug)] #[serde(rename_all = "camelCase")] pub struct TxRemoveDoc { @@ -89,13 +124,32 @@ pub struct TxRemoveDoc { pub txcud: TxCUD, } +impl Class for TxRemoveDoc { + const CLASS: &'static str = crate::services::core::class::TxRemoveDoc; +} + +impl Event for TxRemoveDoc {} + pub type OperationDomain = String; #[derive(Serialize, Deserialize, Debug)] -pub struct TxDomainEvent { +pub struct TxDomainEvent { #[serde(flatten)] pub tx: Tx, pub domain: OperationDomain, pub event: T, } + +impl Class for TxDomainEvent { + const CLASS: &'static str = crate::services::core::class::TxDomainEvent; +} + +impl Event for TxDomainEvent { + fn matches(value: &Value) -> bool { + if value.get("_class").and_then(|v| v.as_str()) != Some(Self::CLASS) { + return false; + } + value.get("objectClass").and_then(|v| v.as_str()) == Some(T::CLASS) + } +}