Skip to content

Commit 9d8f420

Browse files
committed
apollo_gateway: extract stateless validator from blocking task
1 parent 150836a commit 9d8f420

File tree

4 files changed

+36
-60
lines changed

4 files changed

+36
-60
lines changed

crates/apollo_gateway/benches/utils.rs

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ use apollo_class_manager_types::transaction_converter::TransactionConverter;
44
use apollo_class_manager_types::EmptyClassManagerClient;
55
use apollo_gateway::gateway::Gateway;
66
use apollo_gateway::state_reader_test_utils::local_test_state_reader_factory;
7+
use apollo_gateway::stateless_transaction_validator::StatelessTransactionValidator;
78
use apollo_gateway_config::config::GatewayConfig;
89
use apollo_mempool_types::communication::MockMempoolClient;
910
use blockifier::context::ChainInfo;
@@ -82,19 +83,22 @@ impl BenchTestSetup {
8283

8384
let state_reader_factory = local_test_state_reader_factory(cairo_version, false);
8485
let mut mempool_client = MockMempoolClient::new();
85-
// TODO(noamsp): use MockTransactionConverter
8686
let class_manager_client = Arc::new(EmptyClassManagerClient);
8787
let transaction_converter = TransactionConverter::new(
8888
class_manager_client.clone(),
8989
config.gateway_config.chain_info.chain_id.clone(),
9090
);
91+
let stateless_tx_validator = Arc::new(StatelessTransactionValidator {
92+
config: config.gateway_config.stateless_tx_validator_config.clone(),
93+
});
9194
mempool_client.expect_add_tx().returning(|_| Ok(()));
9295

9396
let gateway_business_logic = Gateway::new(
9497
config.gateway_config,
9598
Arc::new(state_reader_factory),
9699
Arc::new(mempool_client),
97100
Arc::new(transaction_converter),
101+
stateless_tx_validator,
98102
);
99103

100104
Self { gateway: gateway_business_logic, txs }

crates/apollo_gateway/src/gateway.rs

Lines changed: 16 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ pub mod gateway_test;
5858
#[derive(Clone)]
5959
pub struct Gateway {
6060
pub config: Arc<GatewayConfig>,
61-
pub stateless_tx_validator: Arc<StatelessTransactionValidator>,
61+
pub stateless_tx_validator: Arc<dyn StatelessTransactionValidatorTrait>,
6262
pub stateful_tx_validator_factory: Arc<dyn StatefulTransactionValidatorFactoryTrait>,
6363
pub state_reader_factory: Arc<dyn StateReaderFactory>,
6464
pub mempool_client: SharedMempoolClient,
@@ -71,12 +71,11 @@ impl Gateway {
7171
state_reader_factory: Arc<dyn StateReaderFactory>,
7272
mempool_client: SharedMempoolClient,
7373
transaction_converter: Arc<dyn TransactionConverterTrait>,
74+
stateless_tx_validator: Arc<dyn StatelessTransactionValidatorTrait>,
7475
) -> Self {
7576
Self {
7677
config: Arc::new(config.clone()),
77-
stateless_tx_validator: Arc::new(StatelessTransactionValidator {
78-
config: config.stateless_tx_validator_config.clone(),
79-
}),
78+
stateless_tx_validator,
8079
stateful_tx_validator_factory: Arc::new(StatefulTransactionValidatorFactory {
8180
config: config.stateful_tx_validator_config.clone(),
8281
chain_info: config.chain_info.clone(),
@@ -125,6 +124,9 @@ impl Gateway {
125124
}
126125
}
127126

127+
// Perform stateless validations.
128+
self.stateless_tx_validator.validate(&tx)?;
129+
128130
let tx_signature = tx.signature().clone();
129131
let internal_tx = self
130132
.transaction_converter
@@ -149,7 +151,6 @@ impl Gateway {
149151

150152
let blocking_task = ProcessTxBlockingTask::new(
151153
self,
152-
tx.clone(),
153154
internal_tx,
154155
executable_tx,
155156
tokio::runtime::Handle::current(),
@@ -232,11 +233,9 @@ impl Gateway {
232233
/// CPU-intensive transaction processing, spawned in a blocking thread to avoid blocking other tasks
233234
/// from running.
234235
struct ProcessTxBlockingTask {
235-
stateless_tx_validator: Arc<dyn StatelessTransactionValidatorTrait>,
236236
stateful_tx_validator_factory: Arc<dyn StatefulTransactionValidatorFactoryTrait>,
237237
state_reader_factory: Arc<dyn StateReaderFactory>,
238238
mempool_client: SharedMempoolClient,
239-
tx: RpcTransaction,
240239
internal_tx: InternalRpcTransaction,
241240
executable_tx: AccountTransaction,
242241
runtime: tokio::runtime::Handle,
@@ -245,27 +244,21 @@ struct ProcessTxBlockingTask {
245244
impl ProcessTxBlockingTask {
246245
pub fn new(
247246
gateway: &Gateway,
248-
tx: RpcTransaction,
249247
internal_tx: InternalRpcTransaction,
250248
executable_tx: AccountTransaction,
251249
runtime: tokio::runtime::Handle,
252250
) -> Self {
253251
Self {
254-
stateless_tx_validator: gateway.stateless_tx_validator.clone(),
255252
stateful_tx_validator_factory: gateway.stateful_tx_validator_factory.clone(),
256253
state_reader_factory: gateway.state_reader_factory.clone(),
257254
mempool_client: gateway.mempool_client.clone(),
258-
tx,
259255
internal_tx,
260256
executable_tx,
261257
runtime,
262258
}
263259
}
264260

265261
fn process_tx(self) -> GatewayResult<AddTransactionArgs> {
266-
// Perform stateless validations.
267-
self.stateless_tx_validator.validate(&self.tx)?;
268-
269262
let mut stateful_transaction_validator = self
270263
.stateful_tx_validator_factory
271264
.instantiate_validator(self.state_reader_factory.as_ref())?;
@@ -296,8 +289,17 @@ pub fn create_gateway(
296289
class_manager_client,
297290
config.chain_info.chain_id.clone(),
298291
));
292+
let stateless_tx_validator = Arc::new(StatelessTransactionValidator {
293+
config: config.stateless_tx_validator_config.clone(),
294+
});
299295

300-
Gateway::new(config, state_reader_factory, mempool_client, transaction_converter)
296+
Gateway::new(
297+
config,
298+
state_reader_factory,
299+
mempool_client,
300+
transaction_converter,
301+
stateless_tx_validator,
302+
)
301303
}
302304

303305
#[async_trait]

crates/apollo_gateway/src/gateway_test.rs

Lines changed: 14 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -130,11 +130,13 @@ fn mock_dependencies() -> MockDependencies {
130130
local_test_state_reader_factory(CairoVersion::Cairo1(RunnableCairo1::Casm), true);
131131
let mock_mempool_client = MockMempoolClient::new();
132132
let mock_transaction_converter = MockTransactionConverterTrait::new();
133+
let mock_stateless_transaction_validator = mock_stateless_transaction_validator();
133134
MockDependencies {
134135
config,
135136
state_reader_factory,
136137
mock_mempool_client,
137138
mock_transaction_converter,
139+
mock_stateless_transaction_validator,
138140
}
139141
}
140142

@@ -143,6 +145,7 @@ struct MockDependencies {
143145
state_reader_factory: TestStateReaderFactory,
144146
mock_mempool_client: MockMempoolClient,
145147
mock_transaction_converter: MockTransactionConverterTrait,
148+
mock_stateless_transaction_validator: MockStatelessTransactionValidatorTrait,
146149
}
147150

148151
impl MockDependencies {
@@ -153,6 +156,7 @@ impl MockDependencies {
153156
Arc::new(self.state_reader_factory),
154157
Arc::new(self.mock_mempool_client),
155158
Arc::new(self.mock_transaction_converter),
159+
Arc::new(self.mock_stateless_transaction_validator),
156160
)
157161
}
158162

@@ -313,28 +317,13 @@ async fn run_add_tx_and_extract_metrics(
313317
AddTxResults { result, metric_handle_for_queries, metrics }
314318
}
315319

316-
#[derive(Default)]
317-
pub struct ProcessTxOverrides {
318-
pub mock_stateful_transaction_validator_factory:
319-
Option<MockStatefulTransactionValidatorFactoryTrait>,
320-
pub mock_stateless_transaction_validator: Option<MockStatelessTransactionValidatorTrait>,
321-
}
322-
323-
fn process_tx_task(overrides: ProcessTxOverrides) -> ProcessTxBlockingTask {
324-
let mock_validator_factory = overrides
325-
.mock_stateful_transaction_validator_factory
326-
.unwrap_or_else(mock_stateful_transaction_validator_factory);
327-
328-
let mock_stateless_transaction_validator = overrides
329-
.mock_stateless_transaction_validator
330-
.unwrap_or_else(mock_stateless_transaction_validator);
331-
320+
fn process_tx_task(
321+
stateful_transaction_validator_factory: MockStatefulTransactionValidatorFactoryTrait,
322+
) -> ProcessTxBlockingTask {
332323
ProcessTxBlockingTask {
333-
stateless_tx_validator: Arc::new(mock_stateless_transaction_validator),
334-
stateful_tx_validator_factory: Arc::new(mock_validator_factory),
324+
stateful_tx_validator_factory: Arc::new(stateful_transaction_validator_factory),
335325
state_reader_factory: Arc::new(MockStateReaderFactory::new()),
336326
mempool_client: Arc::new(MockMempoolClient::new()),
337-
tx: invoke_args().get_rpc_tx(),
338327
internal_tx: invoke_args().get_internal_tx(),
339328
executable_tx: executable_invoke_tx(invoke_args()),
340329
runtime: tokio::runtime::Handle::current(),
@@ -581,13 +570,7 @@ async fn process_tx_returns_error_when_extract_state_nonce_and_run_validations_f
581570
.expect_instantiate_validator()
582571
.return_once(|_| Ok(Box::new(mock_stateful_transaction_validator)));
583572

584-
let overrides = ProcessTxOverrides {
585-
mock_stateful_transaction_validator_factory: Some(
586-
mock_stateful_transaction_validator_factory,
587-
),
588-
..Default::default()
589-
};
590-
let process_tx_task = process_tx_task(overrides);
573+
let process_tx_task = process_tx_task(mock_stateful_transaction_validator_factory);
591574

592575
let result = tokio::task::spawn_blocking(move || process_tx_task.process_tx()).await.unwrap();
593576

@@ -597,27 +580,20 @@ async fn process_tx_returns_error_when_extract_state_nonce_and_run_validations_f
597580

598581
#[rstest]
599582
#[tokio::test]
600-
async fn process_tx_returns_error_for_one_stateless_error_variant() {
583+
async fn stateless_transaction_validator_error(mut mock_dependencies: MockDependencies) {
601584
let arbitrary_validation_error = Err(StatelessTransactionValidatorError::SignatureTooLong {
602585
signature_length: 5001,
603586
max_signature_length: 4000,
604587
});
605588
let error_code =
606589
StarknetErrorCode::UnknownErrorCode("StarknetErrorCode.SIGNATURE_TOO_LONG".into());
607-
608590
let mut mock_stateless_transaction_validator = MockStatelessTransactionValidatorTrait::new();
609591
mock_stateless_transaction_validator
610592
.expect_validate()
611593
.return_once(|_| arbitrary_validation_error);
612-
613-
let overrides = ProcessTxOverrides {
614-
mock_stateless_transaction_validator: Some(mock_stateless_transaction_validator),
615-
..Default::default()
616-
};
617-
let task = process_tx_task(overrides);
618-
619-
let result: Result<AddTransactionArgs, StarknetError> =
620-
tokio::task::spawn_blocking(move || task.process_tx()).await.unwrap();
594+
mock_dependencies.mock_stateless_transaction_validator = mock_stateless_transaction_validator;
595+
let gateway = mock_dependencies.gateway();
596+
let result = gateway.add_tx(invoke_args().get_rpc_tx(), None).await;
621597

622598
assert!(result.is_err());
623599
assert_eq!(result.unwrap_err().code, error_code);
@@ -637,13 +613,7 @@ async fn process_tx_returns_error_when_instantiating_validator_fails(
637613
.expect_instantiate_validator()
638614
.return_once(|_| Err(expected_error));
639615

640-
let overrides = ProcessTxOverrides {
641-
mock_stateful_transaction_validator_factory: Some(
642-
mock_stateful_transaction_validator_factory,
643-
),
644-
..Default::default()
645-
};
646-
let process_tx_task = process_tx_task(overrides);
616+
let process_tx_task = process_tx_task(mock_stateful_transaction_validator_factory);
647617

648618
let result = tokio::task::spawn_blocking(move || process_tx_task.process_tx()).await.unwrap();
649619

crates/apollo_gateway/src/lib.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ pub mod state_reader;
1010
#[cfg(any(feature = "testing", test))]
1111
pub mod state_reader_test_utils;
1212
mod stateful_transaction_validator;
13-
mod stateless_transaction_validator;
13+
pub mod stateless_transaction_validator;
1414
mod sync_state_reader;
1515
#[cfg(test)]
1616
mod sync_state_reader_test;

0 commit comments

Comments
 (0)