Skip to content

Commit 5458f70

Browse files
committed
apollo_gateway: extract stateless validator from blocking task
1 parent 63c1a13 commit 5458f70

File tree

4 files changed

+36
-60
lines changed

4 files changed

+36
-60
lines changed

crates/apollo_gateway/benches/utils.rs

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ use apollo_class_manager_types::transaction_converter::TransactionConverter;
44
use apollo_class_manager_types::EmptyClassManagerClient;
55
use apollo_gateway::gateway::Gateway;
66
use apollo_gateway::state_reader_test_utils::local_test_state_reader_factory;
7+
use apollo_gateway::stateless_transaction_validator::StatelessTransactionValidator;
78
use apollo_gateway_config::config::GatewayConfig;
89
use apollo_mempool_types::communication::MockMempoolClient;
910
use blockifier::context::ChainInfo;
@@ -82,19 +83,22 @@ impl BenchTestSetup {
8283

8384
let state_reader_factory = local_test_state_reader_factory(cairo_version, false);
8485
let mut mempool_client = MockMempoolClient::new();
85-
// TODO(noamsp): use MockTransactionConverter
8686
let class_manager_client = Arc::new(EmptyClassManagerClient);
8787
let transaction_converter = TransactionConverter::new(
8888
class_manager_client.clone(),
8989
config.gateway_config.chain_info.chain_id.clone(),
9090
);
91+
let stateless_tx_validator = Arc::new(StatelessTransactionValidator {
92+
config: config.gateway_config.stateless_tx_validator_config.clone(),
93+
});
9194
mempool_client.expect_add_tx().returning(|_| Ok(()));
9295

9396
let gateway_business_logic = Gateway::new(
9497
config.gateway_config,
9598
Arc::new(state_reader_factory),
9699
Arc::new(mempool_client),
97100
Arc::new(transaction_converter),
101+
stateless_tx_validator,
98102
);
99103

100104
Self { gateway: gateway_business_logic, txs }

crates/apollo_gateway/src/gateway.rs

Lines changed: 16 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ pub mod gateway_test;
5858
#[derive(Clone)]
5959
pub struct Gateway {
6060
pub config: Arc<GatewayConfig>,
61-
pub stateless_tx_validator: Arc<StatelessTransactionValidator>,
61+
pub stateless_tx_validator: Arc<dyn StatelessTransactionValidatorTrait>,
6262
pub stateful_tx_validator_factory: Arc<dyn StatefulTransactionValidatorFactoryTrait>,
6363
pub state_reader_factory: Arc<dyn StateReaderFactory>,
6464
pub mempool_client: SharedMempoolClient,
@@ -71,12 +71,11 @@ impl Gateway {
7171
state_reader_factory: Arc<dyn StateReaderFactory>,
7272
mempool_client: SharedMempoolClient,
7373
transaction_converter: Arc<dyn TransactionConverterTrait>,
74+
stateless_tx_validator: Arc<dyn StatelessTransactionValidatorTrait>,
7475
) -> Self {
7576
Self {
7677
config: Arc::new(config.clone()),
77-
stateless_tx_validator: Arc::new(StatelessTransactionValidator {
78-
config: config.stateless_tx_validator_config.clone(),
79-
}),
78+
stateless_tx_validator,
8079
stateful_tx_validator_factory: Arc::new(StatefulTransactionValidatorFactory {
8180
config: config.stateful_tx_validator_config.clone(),
8281
chain_info: config.chain_info.clone(),
@@ -125,6 +124,9 @@ impl Gateway {
125124
}
126125
}
127126

127+
// Perform stateless validations.
128+
self.stateless_tx_validator.validate(&tx)?;
129+
128130
let tx_signature = tx.signature().clone();
129131
let internal_tx = self
130132
.transaction_converter
@@ -149,7 +151,6 @@ impl Gateway {
149151

150152
let blocking_task = ProcessTxBlockingTask::new(
151153
self,
152-
tx.clone(),
153154
internal_tx,
154155
executable_tx,
155156
tokio::runtime::Handle::current(),
@@ -232,11 +233,9 @@ impl Gateway {
232233
/// CPU-intensive transaction processing, spawned in a blocking thread to avoid blocking other tasks
233234
/// from running.
234235
struct ProcessTxBlockingTask {
235-
stateless_tx_validator: Arc<dyn StatelessTransactionValidatorTrait>,
236236
stateful_tx_validator_factory: Arc<dyn StatefulTransactionValidatorFactoryTrait>,
237237
state_reader_factory: Arc<dyn StateReaderFactory>,
238238
mempool_client: SharedMempoolClient,
239-
tx: RpcTransaction,
240239
internal_tx: InternalRpcTransaction,
241240
executable_tx: AccountTransaction,
242241
runtime: tokio::runtime::Handle,
@@ -245,27 +244,21 @@ struct ProcessTxBlockingTask {
245244
impl ProcessTxBlockingTask {
246245
pub fn new(
247246
gateway: &Gateway,
248-
tx: RpcTransaction,
249247
internal_tx: InternalRpcTransaction,
250248
executable_tx: AccountTransaction,
251249
runtime: tokio::runtime::Handle,
252250
) -> Self {
253251
Self {
254-
stateless_tx_validator: gateway.stateless_tx_validator.clone(),
255252
stateful_tx_validator_factory: gateway.stateful_tx_validator_factory.clone(),
256253
state_reader_factory: gateway.state_reader_factory.clone(),
257254
mempool_client: gateway.mempool_client.clone(),
258-
tx,
259255
internal_tx,
260256
executable_tx,
261257
runtime,
262258
}
263259
}
264260

265261
fn process_tx(self) -> GatewayResult<AddTransactionArgs> {
266-
// Perform stateless validations.
267-
self.stateless_tx_validator.validate(&self.tx)?;
268-
269262
let mut stateful_transaction_validator = self
270263
.stateful_tx_validator_factory
271264
.instantiate_validator(self.state_reader_factory.as_ref())?;
@@ -296,8 +289,17 @@ pub fn create_gateway(
296289
class_manager_client,
297290
config.chain_info.chain_id.clone(),
298291
));
292+
let stateless_tx_validator = Arc::new(StatelessTransactionValidator {
293+
config: config.stateless_tx_validator_config.clone(),
294+
});
299295

300-
Gateway::new(config, state_reader_factory, mempool_client, transaction_converter)
296+
Gateway::new(
297+
config,
298+
state_reader_factory,
299+
mempool_client,
300+
transaction_converter,
301+
stateless_tx_validator,
302+
)
301303
}
302304

303305
#[async_trait]

crates/apollo_gateway/src/gateway_test.rs

Lines changed: 14 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -130,11 +130,13 @@ fn mock_dependencies() -> MockDependencies {
130130
local_test_state_reader_factory(CairoVersion::Cairo1(RunnableCairo1::Casm), true);
131131
let mock_mempool_client = MockMempoolClient::new();
132132
let mock_transaction_converter = MockTransactionConverterTrait::new();
133+
let mock_stateless_transaction_validator = mock_stateless_transaction_validator();
133134
MockDependencies {
134135
config,
135136
state_reader_factory,
136137
mock_mempool_client,
137138
mock_transaction_converter,
139+
mock_stateless_transaction_validator,
138140
}
139141
}
140142

@@ -143,6 +145,7 @@ struct MockDependencies {
143145
state_reader_factory: TestStateReaderFactory,
144146
mock_mempool_client: MockMempoolClient,
145147
mock_transaction_converter: MockTransactionConverterTrait,
148+
mock_stateless_transaction_validator: MockStatelessTransactionValidatorTrait,
146149
}
147150

148151
impl MockDependencies {
@@ -153,6 +156,7 @@ impl MockDependencies {
153156
Arc::new(self.state_reader_factory),
154157
Arc::new(self.mock_mempool_client),
155158
Arc::new(self.mock_transaction_converter),
159+
Arc::new(self.mock_stateless_transaction_validator),
156160
)
157161
}
158162

@@ -309,28 +313,13 @@ async fn run_add_tx_and_extract_metrics(
309313
AddTxResults { result, metric_handle_for_queries, metrics }
310314
}
311315

312-
#[derive(Default)]
313-
pub struct ProcessTxOverrides {
314-
pub mock_stateful_transaction_validator_factory:
315-
Option<MockStatefulTransactionValidatorFactoryTrait>,
316-
pub mock_stateless_transaction_validator: Option<MockStatelessTransactionValidatorTrait>,
317-
}
318-
319-
fn process_tx_task(overrides: ProcessTxOverrides) -> ProcessTxBlockingTask {
320-
let mock_validator_factory = overrides
321-
.mock_stateful_transaction_validator_factory
322-
.unwrap_or_else(mock_stateful_transaction_validator_factory);
323-
324-
let mock_stateless_transaction_validator = overrides
325-
.mock_stateless_transaction_validator
326-
.unwrap_or_else(mock_stateless_transaction_validator);
327-
316+
fn process_tx_task(
317+
stateful_transaction_validator_factory: MockStatefulTransactionValidatorFactoryTrait,
318+
) -> ProcessTxBlockingTask {
328319
ProcessTxBlockingTask {
329-
stateless_tx_validator: Arc::new(mock_stateless_transaction_validator),
330-
stateful_tx_validator_factory: Arc::new(mock_validator_factory),
320+
stateful_tx_validator_factory: Arc::new(stateful_transaction_validator_factory),
331321
state_reader_factory: Arc::new(MockStateReaderFactory::new()),
332322
mempool_client: Arc::new(MockMempoolClient::new()),
333-
tx: invoke_args().get_rpc_tx(),
334323
internal_tx: invoke_args().get_internal_tx(),
335324
executable_tx: executable_invoke_tx(invoke_args()),
336325
runtime: tokio::runtime::Handle::current(),
@@ -577,13 +566,7 @@ async fn process_tx_returns_error_when_extract_state_nonce_and_run_validations_f
577566
.expect_instantiate_validator()
578567
.return_once(|_| Ok(Box::new(mock_stateful_transaction_validator)));
579568

580-
let overrides = ProcessTxOverrides {
581-
mock_stateful_transaction_validator_factory: Some(
582-
mock_stateful_transaction_validator_factory,
583-
),
584-
..Default::default()
585-
};
586-
let process_tx_task = process_tx_task(overrides);
569+
let process_tx_task = process_tx_task(mock_stateful_transaction_validator_factory);
587570

588571
let result = tokio::task::spawn_blocking(move || process_tx_task.process_tx()).await.unwrap();
589572

@@ -593,27 +576,20 @@ async fn process_tx_returns_error_when_extract_state_nonce_and_run_validations_f
593576

594577
#[rstest]
595578
#[tokio::test]
596-
async fn process_tx_returns_error_for_one_stateless_error_variant() {
579+
async fn stateless_transaction_validator_error(mut mock_dependencies: MockDependencies) {
597580
let arbitrary_validation_error = Err(StatelessTransactionValidatorError::SignatureTooLong {
598581
signature_length: 5001,
599582
max_signature_length: 4000,
600583
});
601584
let error_code =
602585
StarknetErrorCode::UnknownErrorCode("StarknetErrorCode.SIGNATURE_TOO_LONG".into());
603-
604586
let mut mock_stateless_transaction_validator = MockStatelessTransactionValidatorTrait::new();
605587
mock_stateless_transaction_validator
606588
.expect_validate()
607589
.return_once(|_| arbitrary_validation_error);
608-
609-
let overrides = ProcessTxOverrides {
610-
mock_stateless_transaction_validator: Some(mock_stateless_transaction_validator),
611-
..Default::default()
612-
};
613-
let task = process_tx_task(overrides);
614-
615-
let result: Result<AddTransactionArgs, StarknetError> =
616-
tokio::task::spawn_blocking(move || task.process_tx()).await.unwrap();
590+
mock_dependencies.mock_stateless_transaction_validator = mock_stateless_transaction_validator;
591+
let gateway = mock_dependencies.gateway();
592+
let result = gateway.add_tx(invoke_args().get_rpc_tx(), None).await;
617593

618594
assert!(result.is_err());
619595
assert_eq!(result.unwrap_err().code, error_code);
@@ -633,13 +609,7 @@ async fn process_tx_returns_error_when_instantiating_validator_fails(
633609
.expect_instantiate_validator()
634610
.return_once(|_| Err(expected_error));
635611

636-
let overrides = ProcessTxOverrides {
637-
mock_stateful_transaction_validator_factory: Some(
638-
mock_stateful_transaction_validator_factory,
639-
),
640-
..Default::default()
641-
};
642-
let process_tx_task = process_tx_task(overrides);
612+
let process_tx_task = process_tx_task(mock_stateful_transaction_validator_factory);
643613

644614
let result = tokio::task::spawn_blocking(move || process_tx_task.process_tx()).await.unwrap();
645615

crates/apollo_gateway/src/lib.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ pub mod state_reader;
1010
#[cfg(any(feature = "testing", test))]
1111
pub mod state_reader_test_utils;
1212
mod stateful_transaction_validator;
13-
mod stateless_transaction_validator;
13+
pub mod stateless_transaction_validator;
1414
mod sync_state_reader;
1515
#[cfg(test)]
1616
mod sync_state_reader_test;

0 commit comments

Comments
 (0)