Skip to content

Commit 45634d9

Browse files
authored
feat: add websocket backend for TransactorClient (#3)
1 parent 982d2b7 commit 45634d9

File tree

16 files changed

+1121
-100
lines changed

16 files changed

+1121
-100
lines changed

Cargo.toml

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,10 +10,8 @@ reqwest = { version = "0.12.15", default-features = false, features = [
1010
"json",
1111
"rustls-tls",
1212
] }
13-
reqwest-middleware = { version = "0.4.2", features = ["json", "rustls-tls"] }
14-
reqwest-retry = { version = "0.7.0" }
15-
reqwest-ratelimit = "0.4.1"
1613
governor = { version = "0.10.0", features = ["std"] }
14+
reqwest-websocket = { version = "0.5.0", features = ["json"] }
1715
serde = "1.0.219"
1816
serde_json = "1.0.140"
1917
thiserror = "2.0.12"
@@ -27,6 +25,9 @@ config = "0.15.11"
2725
secrecy = { version = "0.10.3", features = ["serde"] }
2826
serde_with = "3.12.0"
2927
rand = "0.9.1"
28+
futures = "0.3.31"
29+
tokio_with_wasm = { version = "0.8.6", features = ["rt", "sync", "macros"] }
30+
tokio-stream = { version = "0.1.17", features = ["sync"] }
3031

3132
actix-web = { version = "4.10.2", optional = true, features = ["rustls"] }
3233
rdkafka = { version = "0.38.0", optional = true, features = [
@@ -38,6 +39,14 @@ num-traits = "0.2.19"
3839
itoa = "1.0.15"
3940
ryu = "1.0.20"
4041

42+
# Middleware
43+
reqwest-middleware = { version = "0.4.2", features = ["json", "rustls-tls"] }
44+
reqwest-retry = { version = "0.7.0", optional = true }
45+
reqwest-ratelimit = { version = "0.4.1", optional = true }
46+
47+
[target.'cfg(target_family = "wasm")'.dependencies]
48+
wasmtimer = { version = "0.4.1" }
49+
4150
[dev-dependencies]
4251
anyhow = "1.0.98"
4352
tokio = { version = "1", features = ["full"] }
@@ -46,4 +55,7 @@ tokio = { version = "1", features = ["full"] }
4655
default = ["reqwest_middleware"]
4756
actix = ["dep:actix-web"]
4857
kafka = ["dep:rdkafka"]
49-
reqwest_middleware = []
58+
reqwest_middleware = ["dep:reqwest-retry", "dep:reqwest-ratelimit"]
59+
60+
[lints.clippy]
61+
result_large_err = "allow"

src/lib.rs

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,13 +32,21 @@ pub enum Error {
3232
#[error(transparent)]
3333
Reqwest(#[from] reqwest::Error),
3434

35+
#[error(transparent)]
36+
Ws(#[from] reqwest_websocket::Error),
37+
3538
#[error(transparent)]
3639
ReqwestMiddleware(#[from] reqwest_middleware::Error),
3740

3841
#[cfg(feature = "kafka")]
3942
#[error(transparent)]
4043
Kafka(#[from] rdkafka::error::KafkaError),
4144

45+
#[error("Subscription task panicked")]
46+
SubscriptionFailed,
47+
#[error("Subscription task lagged and was forcibly disconnected")]
48+
SubscriptionLagged,
49+
4250
#[error(transparent)]
4351
Url(#[from] url::ParseError),
4452

src/services/event.rs

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
use serde_json::Value;
2+
3+
pub trait Class {
4+
const CLASS: &'static str;
5+
}
6+
7+
pub trait Event: Class {
8+
fn matches(value: &Value) -> bool {
9+
value.get("_class").and_then(|v| v.as_str()) == Some(Self::CLASS)
10+
}
11+
}

src/services/mod.rs

Lines changed: 80 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,10 @@
1515

1616
pub mod account;
1717
pub mod core;
18+
pub mod event;
1819
pub mod jwt;
1920
pub mod kvs;
21+
mod rpc;
2022
pub mod transactor;
2123

2224
pub use reqwest_middleware::{ClientWithMiddleware as HttpClient, RequestBuilder};
@@ -27,20 +29,28 @@ use std::time::Duration;
2729
use reqwest::{self, Response, Url};
2830
use reqwest::{StatusCode, header::HeaderValue};
2931
use reqwest_middleware::ClientBuilder;
30-
use reqwest_retry::{
31-
RetryTransientMiddleware, Retryable, RetryableStrategy, default_on_request_failure,
32-
policies::ExponentialBackoff,
33-
};
3432
use secrecy::SecretString;
3533
use serde::{Deserialize, Serialize, de::DeserializeOwned};
3634
use serde_json::{self as json, Value};
3735
use tracing::*;
3836

3937
use crate::services::core::{AccountUuid, WorkspaceUuid};
38+
use crate::services::transactor::backend::http::HttpBackend;
39+
use crate::services::transactor::backend::ws::{WsBackend, WsBackendOpts};
40+
use crate::{Error, Result, config::Config};
41+
use account::AccountClient;
42+
use jwt::Claims;
43+
use kvs::KvsClient;
44+
use transactor::TransactorClient;
45+
4046
#[cfg(feature = "kafka")]
4147
use crate::services::transactor::kafka;
42-
use crate::{Error, Result, config::Config};
43-
use {account::AccountClient, jwt::Claims, kvs::KvsClient, transactor::TransactorClient};
48+
49+
#[cfg(feature = "reqwest_middleware")]
50+
use reqwest_retry::{
51+
RetryTransientMiddleware, Retryable, RetryableStrategy, default_on_request_failure,
52+
policies::ExponentialBackoff,
53+
};
4454

4555
pub trait RequestBuilderExt {
4656
fn send_ext(self) -> impl Future<Output = Result<Response>>;
@@ -54,11 +64,12 @@ pub trait BasePathProvider {
5464
fn provide_base_path(&self) -> &Url;
5565
}
5666

57-
pub trait ForceHttpScheme {
67+
pub trait ForceScheme {
5868
fn force_http_scheme(self) -> Url;
69+
fn force_ws_scheme(self) -> Url;
5970
}
6071

61-
impl ForceHttpScheme for Url {
72+
impl ForceScheme for Url {
6273
fn force_http_scheme(mut self) -> Url {
6374
match self.scheme() {
6475
"ws" => {
@@ -74,6 +85,24 @@ impl ForceHttpScheme for Url {
7485

7586
self
7687
}
88+
89+
fn force_ws_scheme(mut self) -> Url {
90+
match self.scheme() {
91+
"http" => {
92+
self.set_scheme("ws").unwrap();
93+
}
94+
95+
"https" => {
96+
self.set_scheme("wss").unwrap();
97+
}
98+
99+
"ws" | "wss" => {}
100+
101+
_ => panic!(),
102+
};
103+
104+
self
105+
}
77106
}
78107

79108
impl RequestBuilderExt for RequestBuilder {
@@ -116,13 +145,13 @@ fn from_value<T: DeserializeOwned>(value: Value) -> Result<T> {
116145
pub trait JsonClient {
117146
fn get<U: TokenProvider, R: DeserializeOwned>(
118147
&self,
119-
user: U,
148+
user: &U,
120149
url: Url,
121150
) -> impl Future<Output = Result<R>>;
122151

123152
fn post<U: TokenProvider, Q: Serialize, R: DeserializeOwned>(
124153
&self,
125-
user: U,
154+
user: &U,
126155
url: Url,
127156
body: &Q,
128157
) -> impl Future<Output = Result<R>>;
@@ -134,7 +163,7 @@ impl JsonClient for HttpClient {
134163
skip(self, user, url),
135164
fields(%url, method = "get", type = "json")
136165
)]
137-
async fn get<U: TokenProvider, R: DeserializeOwned>(&self, user: U, url: Url) -> Result<R> {
166+
async fn get<U: TokenProvider, R: DeserializeOwned>(&self, user: &U, url: Url) -> Result<R> {
138167
trace!("request");
139168

140169
let mut request = self.get(url.clone());
@@ -148,7 +177,7 @@ impl JsonClient for HttpClient {
148177

149178
async fn post<U: TokenProvider, Q: Serialize, R: DeserializeOwned>(
150179
&self,
151-
user: U,
180+
user: &U,
152181
url: Url,
153182
body: &Q,
154183
) -> Result<R> {
@@ -170,7 +199,7 @@ impl JsonClient for HttpClient {
170199
}
171200
}
172201

173-
#[derive(Deserialize, Debug, Clone, strum::Display)]
202+
#[derive(Serialize, Deserialize, Debug, Clone, strum::Display)]
174203
#[serde(rename_all = "UPPERCASE")]
175204
pub enum Severity {
176205
Ok,
@@ -179,11 +208,11 @@ pub enum Severity {
179208
Error,
180209
}
181210

182-
#[derive(Deserialize, Debug, Clone, thiserror::Error)]
211+
#[derive(Serialize, Deserialize, Debug, Clone, thiserror::Error)]
183212
pub struct Status {
184213
pub severity: Severity,
185214
pub code: String,
186-
pub params: HashMap<String, String>,
215+
pub params: HashMap<String, Value>,
187216
}
188217

189218
impl std::fmt::Display for Status {
@@ -431,7 +460,11 @@ impl ServiceFactory {
431460
)
432461
}
433462

434-
pub fn new_transactor_client(&self, base: Url, claims: &Claims) -> Result<TransactorClient> {
463+
pub fn new_transactor_client(
464+
&self,
465+
base: Url,
466+
claims: &Claims,
467+
) -> Result<TransactorClient<HttpBackend>> {
435468
TransactorClient::new(
436469
self.transactor_http.clone(),
437470
base,
@@ -445,15 +478,45 @@ impl ServiceFactory {
445478
)
446479
}
447480

481+
pub async fn new_transactor_client_ws(
482+
&self,
483+
base: Url,
484+
claims: &Claims,
485+
opts: WsBackendOpts,
486+
) -> Result<TransactorClient<WsBackend>> {
487+
TransactorClient::new_ws(
488+
base,
489+
claims.workspace()?,
490+
claims.encode(
491+
self.config
492+
.token_secret
493+
.as_ref()
494+
.ok_or(Error::Other("NoSecret"))?,
495+
)?,
496+
opts,
497+
)
498+
.await
499+
}
500+
448501
pub fn new_transactor_client_from_token(
449502
&self,
450503
base: Url,
451504
workspace: WorkspaceUuid,
452505
token: impl Into<SecretString>,
453-
) -> Result<TransactorClient> {
506+
) -> Result<TransactorClient<HttpBackend>> {
454507
TransactorClient::new(self.transactor_http.clone(), base, workspace, token)
455508
}
456509

510+
pub async fn new_transactor_client_ws_from_token(
511+
&self,
512+
base: Url,
513+
workspace: WorkspaceUuid,
514+
token: impl Into<SecretString>,
515+
opts: WsBackendOpts,
516+
) -> Result<TransactorClient<WsBackend>> {
517+
TransactorClient::new_ws(base, workspace, token, opts).await
518+
}
519+
457520
#[cfg(feature = "kafka")]
458521
pub fn new_kafka_publisher(&self, topic: &str) -> Result<kafka::KafkaProducer> {
459522
kafka::KafkaProducer::new(&self.config, topic)

src/services/rpc/mod.rs

Lines changed: 105 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,105 @@
1+
pub mod util;
2+
3+
use crate::services::Status;
4+
use crate::services::core::Account;
5+
use serde::{Deserialize, Serialize};
6+
7+
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Eq, Hash)]
8+
#[serde(untagged, rename_all = "camelCase")]
9+
pub enum ReqId {
10+
Str(String),
11+
Num(i32),
12+
}
13+
14+
impl From<String> for ReqId {
15+
fn from(s: String) -> Self {
16+
ReqId::Str(s)
17+
}
18+
}
19+
20+
impl From<i32> for ReqId {
21+
fn from(i: i32) -> Self {
22+
ReqId::Num(i)
23+
}
24+
}
25+
26+
#[derive(Serialize, Deserialize, Debug, Clone)]
27+
#[serde(rename_all = "camelCase")]
28+
pub struct RateLimitInfo {
29+
pub remaining: u32,
30+
pub limit: u32,
31+
pub current: u32,
32+
pub reset: f64,
33+
#[serde(skip_serializing_if = "Option::is_none")]
34+
pub retry_after: Option<u32>,
35+
}
36+
37+
#[derive(Serialize, Deserialize, Debug, Clone)]
38+
#[serde(rename_all = "camelCase")]
39+
pub struct Chunk {
40+
pub index: u32,
41+
pub r#final: bool,
42+
}
43+
44+
#[derive(Serialize, Deserialize, Default, Debug, Clone)]
45+
#[serde(rename_all = "camelCase")]
46+
pub struct Response<R> {
47+
#[serde(skip_serializing_if = "Option::is_none")]
48+
pub result: Option<R>,
49+
#[serde(skip_serializing_if = "Option::is_none")]
50+
pub id: Option<ReqId>,
51+
#[serde(skip_serializing_if = "Option::is_none")]
52+
pub error: Option<Status>,
53+
#[serde(skip_serializing_if = "Option::is_none")]
54+
pub terminate: Option<bool>,
55+
#[serde(skip_serializing_if = "Option::is_none")]
56+
pub rate_limit: Option<RateLimitInfo>,
57+
#[serde(skip_serializing_if = "Option::is_none")]
58+
pub chunk: Option<Chunk>,
59+
#[serde(skip_serializing_if = "Option::is_none")]
60+
pub time: Option<f64>,
61+
#[serde(skip_serializing_if = "Option::is_none")]
62+
pub bfst: Option<f64>,
63+
#[serde(skip_serializing_if = "Option::is_none")]
64+
pub queue: Option<u32>,
65+
}
66+
67+
#[derive(Serialize, Deserialize, Debug, Clone)]
68+
#[serde(rename_all = "camelCase")]
69+
pub struct Request<P> {
70+
#[serde(skip_serializing_if = "Option::is_none")]
71+
pub id: Option<ReqId>,
72+
pub method: String,
73+
pub params: Vec<P>,
74+
#[serde(skip_serializing_if = "Option::is_none")]
75+
pub time: Option<f64>,
76+
}
77+
78+
#[derive(Serialize, Deserialize, Debug, Clone)]
79+
#[serde(rename_all = "camelCase")]
80+
pub struct HelloRequest {
81+
#[serde(flatten)]
82+
pub request: Request<()>,
83+
#[serde(skip_serializing_if = "Option::is_none")]
84+
pub binary: Option<bool>,
85+
#[serde(skip_serializing_if = "Option::is_none")]
86+
pub compression: Option<bool>,
87+
}
88+
89+
#[derive(Serialize, Deserialize, Debug, Clone)]
90+
#[serde(rename_all = "camelCase")]
91+
pub struct HelloResponse {
92+
#[serde(flatten)]
93+
pub response: Response<String>,
94+
pub binary: bool,
95+
#[serde(skip_serializing_if = "Option::is_none")]
96+
pub reconnect: Option<bool>,
97+
pub server_version: String,
98+
#[serde(skip_serializing_if = "Option::is_none")]
99+
pub last_tx: Option<String>,
100+
#[serde(skip_serializing_if = "Option::is_none")]
101+
pub last_hash: Option<String>,
102+
pub account: Account,
103+
#[serde(skip_serializing_if = "Option::is_none")]
104+
pub use_compression: Option<bool>,
105+
}

0 commit comments

Comments
 (0)