Skip to content

Commit db2e2ba

Browse files
committed
fix: handle rate limiting
Signed-off-by: Rahul Baradol <[email protected]>
1 parent 0735f38 commit db2e2ba

File tree

2 files changed

+157
-2
lines changed

2 files changed

+157
-2
lines changed

crates/ofrep/Cargo.toml

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ edition = "2024"
66
[dev-dependencies]
77
wiremock = "0.6.3"
88
test-log = { version = "0.2", features = ["trace"] }
9-
tokio = { version = "1.45", features = ["full"] }
9+
serial_test = "3.2.0"
1010

1111
[dependencies]
1212
async-trait = "0.1.88"
@@ -16,8 +16,10 @@ reqwest = { version = "0.12", default-features = false, features = [
1616
"stream",
1717
"rustls-tls",
1818
] }
19-
serde = "1.0.219"
2019
serde_json = "1.0.140"
2120
tracing = "0.1.41"
2221
thiserror = "2.0"
2322
anyhow = "1.0.98"
23+
chrono = "0.4"
24+
once_cell = "1.18"
25+
tokio = { version = "1.45", features = ["full"] }

crates/ofrep/src/resolver.rs

Lines changed: 153 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,22 @@
11
use async_trait::async_trait;
2+
use chrono::{DateTime, Duration, Utc};
3+
use once_cell::sync::Lazy;
24
use open_feature::provider::{FeatureProvider, ProviderMetadata, ResolutionDetails};
35
use open_feature::{
46
EvaluationContext, EvaluationContextFieldValue, EvaluationError, EvaluationErrorCode,
57
EvaluationResult, StructValue, Value,
68
};
79
use reqwest::Client;
810
use reqwest::StatusCode;
11+
use reqwest::header::RETRY_AFTER;
912
use std::any;
13+
use tokio::sync::Mutex;
1014
use tracing::{debug, error, instrument};
1115

1216
use crate::OfrepOptions;
1317

18+
static CURRENT_RETRY_AFTER: Lazy<Mutex<DateTime<Utc>>> = Lazy::new(|| Mutex::new(Utc::now()));
19+
1420
#[derive(Debug)]
1521
pub struct Resolver {
1622
base_url: String,
@@ -31,13 +37,51 @@ impl Resolver {
3137
}
3238
}
3339

40+
async fn parse_retry_after(retry_after: &str) -> DateTime<Utc> {
41+
let now = Utc::now();
42+
43+
if retry_after.trim().is_empty() {
44+
return now;
45+
}
46+
47+
if let Ok(seconds) = retry_after.trim().parse::<i64>() {
48+
return now + Duration::seconds(seconds);
49+
}
50+
51+
if let Ok(parsed_date) = retry_after.trim().parse::<DateTime<Utc>>() {
52+
return parsed_date.with_timezone(&Utc);
53+
}
54+
55+
debug!("Failed to parse Retry-After header : {}", retry_after);
56+
now
57+
}
58+
59+
async fn update_retry_after(new_retry_after: DateTime<Utc>) {
60+
let mut retry_after = CURRENT_RETRY_AFTER.lock().await;
61+
*retry_after = new_retry_after;
62+
}
63+
64+
async fn is_rate_limit_exceeded() -> bool {
65+
let retry_after = CURRENT_RETRY_AFTER.lock().await;
66+
Utc::now() < *retry_after
67+
}
68+
3469
#[instrument(skip(self, evaluation_context), fields(flag_key = %flag_key))]
3570
async fn resolve_value<T: std::fmt::Debug>(
3671
&self,
3772
flag_key: &str,
3873
evaluation_context: &EvaluationContext,
3974
convertor: fn(serde_json::Value) -> Option<T>,
4075
) -> EvaluationResult<ResolutionDetails<T>> {
76+
if Resolver::is_rate_limit_exceeded().await {
77+
return Err(EvaluationError {
78+
code: EvaluationErrorCode::General("Rate limit exceeded".to_string()),
79+
message: Some(
80+
"Rate limit exceeded. Please wait before making another request.".to_string(),
81+
),
82+
});
83+
}
84+
4185
debug!("Resolving {} flag", std::any::type_name::<T>());
4286
let payload = serde_json::json!({
4387
"context": context_to_json(evaluation_context)
@@ -87,6 +131,28 @@ impl Resolver {
87131
message: Some(format!("Flag: {flag_key} not found")),
88132
});
89133
}
134+
StatusCode::TOO_MANY_REQUESTS => {
135+
let header_retry_after: Option<&str> = response
136+
.headers()
137+
.get(RETRY_AFTER)
138+
.and_then(|value| value.to_str().ok());
139+
140+
if let Some(header_retry_after) = header_retry_after {
141+
let new_retry_after: DateTime<Utc> =
142+
Resolver::parse_retry_after(header_retry_after).await;
143+
Resolver::update_retry_after(new_retry_after).await;
144+
} else {
145+
debug!("Couldn't parse the retry-after header.");
146+
let mut retry_after = CURRENT_RETRY_AFTER.lock().await;
147+
*retry_after = Utc::now();
148+
}
149+
150+
let retry_after = CURRENT_RETRY_AFTER.lock().await;
151+
return Err(EvaluationError {
152+
code: EvaluationErrorCode::General("Rate limit exceeded".to_string()),
153+
message: Some(format!("Rate limit exceeded. Retry after {}", *retry_after)),
154+
});
155+
}
90156
_ => {
91157
let result = response.json::<serde_json::Value>().await.map_err(|e| {
92158
error!(error = %e, "Failed to parse {} response", any::type_name::<T>());
@@ -242,9 +308,15 @@ mod tests {
242308
use super::*;
243309
use serde_json::json;
244310
use test_log::test;
311+
use tokio::time::{Duration, sleep};
245312
use wiremock::matchers::{method, path};
246313
use wiremock::{Mock, MockServer, ResponseTemplate};
247314

315+
async fn reset_states() {
316+
let mut retry_after = CURRENT_RETRY_AFTER.lock().await;
317+
*retry_after = Utc::now();
318+
}
319+
248320
async fn setup_mock_server() -> (MockServer, Resolver) {
249321
let mock_server = MockServer::start().await;
250322
let options = OfrepOptions {
@@ -256,7 +328,9 @@ mod tests {
256328
}
257329

258330
#[test(tokio::test)]
331+
#[serial_test::serial]
259332
async fn test_resolve_bool_value() {
333+
reset_states().await;
260334
let (mock_server, resolver) = setup_mock_server().await;
261335

262336
Mock::given(method("POST"))
@@ -281,7 +355,9 @@ mod tests {
281355
}
282356

283357
#[test(tokio::test)]
358+
#[serial_test::serial]
284359
async fn test_resolve_string_value() {
360+
reset_states().await;
285361
let (mock_server, resolver) = setup_mock_server().await;
286362

287363
Mock::given(method("POST"))
@@ -306,7 +382,9 @@ mod tests {
306382
}
307383

308384
#[test(tokio::test)]
385+
#[serial_test::serial]
309386
async fn test_resolve_float_value() {
387+
reset_states().await;
310388
let (mock_server, resolver) = setup_mock_server().await;
311389

312390
Mock::given(method("POST"))
@@ -331,7 +409,9 @@ mod tests {
331409
}
332410

333411
#[test(tokio::test)]
412+
#[serial_test::serial]
334413
async fn test_resolve_int_value() {
414+
reset_states().await;
335415
let (mock_server, resolver) = setup_mock_server().await;
336416

337417
Mock::given(method("POST"))
@@ -356,7 +436,9 @@ mod tests {
356436
}
357437

358438
#[test(tokio::test)]
439+
#[serial_test::serial]
359440
async fn test_resolve_struct_value() {
441+
reset_states().await;
360442
let (mock_server, resolver) = setup_mock_server().await;
361443

362444
Mock::given(method("POST"))
@@ -401,7 +483,9 @@ mod tests {
401483
}
402484

403485
#[test(tokio::test)]
486+
#[serial_test::serial]
404487
async fn test_error_400() {
488+
reset_states().await;
405489
let (mock_server, resolver) = setup_mock_server().await;
406490

407491
Mock::given(method("POST"))
@@ -446,7 +530,9 @@ mod tests {
446530
}
447531

448532
#[test(tokio::test)]
533+
#[serial_test::serial]
449534
async fn test_error_401() {
535+
reset_states().await;
450536
let (mock_server, resolver) = setup_mock_server().await;
451537

452538
Mock::given(method("POST"))
@@ -492,7 +578,9 @@ mod tests {
492578
}
493579

494580
#[test(tokio::test)]
581+
#[serial_test::serial]
495582
async fn test_error_403() {
583+
reset_states().await;
496584
let (mock_server, resolver) = setup_mock_server().await;
497585

498586
Mock::given(method("POST"))
@@ -538,7 +626,9 @@ mod tests {
538626
}
539627

540628
#[test(tokio::test)]
629+
#[serial_test::serial]
541630
async fn test_error_404() {
631+
reset_states().await;
542632
let (mock_server, resolver) = setup_mock_server().await;
543633

544634
Mock::given(method("POST"))
@@ -596,4 +686,67 @@ mod tests {
596686
"Flag: test-flag not found"
597687
);
598688
}
689+
690+
#[test(tokio::test)]
691+
#[serial_test::serial]
692+
async fn test_error_429() {
693+
reset_states().await;
694+
let (mock_server, resolver) = setup_mock_server().await;
695+
696+
Mock::given(method("POST"))
697+
.and(path("/ofrep/v1/evaluate/flags/test-flag"))
698+
.respond_with(
699+
ResponseTemplate::new(429)
700+
.insert_header("Retry-After", "3")
701+
.set_body_json(json!({})),
702+
)
703+
.mount(&mock_server)
704+
.await;
705+
706+
let context = EvaluationContext::default();
707+
708+
let result_bool = resolver.resolve_bool_value("test-flag", &context).await;
709+
let result_bool_2 = resolver.resolve_bool_value("test-flag", &context).await;
710+
711+
assert!(result_bool.is_err());
712+
let result_bool_error = result_bool.unwrap_err();
713+
assert_eq!(
714+
result_bool_error.code,
715+
EvaluationErrorCode::General("Rate limit exceeded".to_string())
716+
);
717+
assert!(
718+
result_bool_error
719+
.message
720+
.unwrap()
721+
.starts_with("Rate limit exceeded. Retry after")
722+
);
723+
724+
assert!(result_bool_2.is_err());
725+
let result_bool_error_2 = result_bool_2.unwrap_err();
726+
assert_eq!(
727+
result_bool_error_2.code,
728+
EvaluationErrorCode::General("Rate limit exceeded".to_string())
729+
);
730+
assert_eq!(
731+
result_bool_error_2.message.unwrap(),
732+
"Rate limit exceeded. Please wait before making another request."
733+
);
734+
735+
sleep(Duration::from_secs(3)).await;
736+
737+
let result_bool_3 = resolver.resolve_bool_value("test-flag", &context).await;
738+
assert!(result_bool_3.is_err());
739+
740+
let result_bool_error_3 = result_bool_3.unwrap_err();
741+
assert_eq!(
742+
result_bool_error_3.code,
743+
EvaluationErrorCode::General("Rate limit exceeded".to_string())
744+
);
745+
assert!(
746+
result_bool_error_3
747+
.message
748+
.unwrap()
749+
.starts_with("Rate limit exceeded. Retry after")
750+
);
751+
}
599752
}

0 commit comments

Comments
 (0)