diff --git a/crates/apollo_gateway/benches/utils.rs b/crates/apollo_gateway/benches/utils.rs index 385c41c17f7..27a333f3096 100644 --- a/crates/apollo_gateway/benches/utils.rs +++ b/crates/apollo_gateway/benches/utils.rs @@ -4,6 +4,7 @@ use apollo_class_manager_types::transaction_converter::TransactionConverter; use apollo_class_manager_types::EmptyClassManagerClient; use apollo_gateway::gateway::Gateway; use apollo_gateway::state_reader_test_utils::local_test_state_reader_factory; +use apollo_gateway::stateless_transaction_validator::StatelessTransactionValidator; use apollo_gateway_config::config::GatewayConfig; use apollo_mempool_types::communication::MockMempoolClient; use blockifier::context::ChainInfo; @@ -82,12 +83,14 @@ impl BenchTestSetup { let state_reader_factory = local_test_state_reader_factory(cairo_version, false); let mut mempool_client = MockMempoolClient::new(); - // TODO(noamsp): use MockTransactionConverter let class_manager_client = Arc::new(EmptyClassManagerClient); let transaction_converter = TransactionConverter::new( class_manager_client.clone(), config.gateway_config.chain_info.chain_id.clone(), ); + let stateless_tx_validator = Arc::new(StatelessTransactionValidator { + config: config.gateway_config.stateless_tx_validator_config.clone(), + }); mempool_client.expect_add_tx().returning(|_| Ok(())); let gateway_business_logic = Gateway::new( @@ -95,6 +98,7 @@ impl BenchTestSetup { Arc::new(state_reader_factory), Arc::new(mempool_client), Arc::new(transaction_converter), + stateless_tx_validator, ); Self { gateway: gateway_business_logic, txs } diff --git a/crates/apollo_gateway/src/gateway.rs b/crates/apollo_gateway/src/gateway.rs index 2401ad68061..caf510680c7 100644 --- a/crates/apollo_gateway/src/gateway.rs +++ b/crates/apollo_gateway/src/gateway.rs @@ -58,7 +58,7 @@ pub mod gateway_test; #[derive(Clone)] pub struct Gateway { pub config: Arc, - pub stateless_tx_validator: Arc, + pub stateless_tx_validator: Arc, pub stateful_tx_validator_factory: Arc, pub state_reader_factory: Arc, pub mempool_client: SharedMempoolClient, @@ -71,12 +71,11 @@ impl Gateway { state_reader_factory: Arc, mempool_client: SharedMempoolClient, transaction_converter: Arc, + stateless_tx_validator: Arc, ) -> Self { Self { config: Arc::new(config.clone()), - stateless_tx_validator: Arc::new(StatelessTransactionValidator { - config: config.stateless_tx_validator_config.clone(), - }), + stateless_tx_validator, stateful_tx_validator_factory: Arc::new(StatefulTransactionValidatorFactory { config: config.stateful_tx_validator_config.clone(), chain_info: config.chain_info.clone(), @@ -125,6 +124,9 @@ impl Gateway { } } + // Perform stateless validations. + self.stateless_tx_validator.validate(tx)?; + let tx_signature = tx.signature().clone(); let internal_tx = self .transaction_converter @@ -149,7 +151,6 @@ impl Gateway { let blocking_task = ProcessTxBlockingTask::new( self, - tx.clone(), internal_tx, executable_tx, tokio::runtime::Handle::current(), @@ -232,11 +233,9 @@ impl Gateway { /// CPU-intensive transaction processing, spawned in a blocking thread to avoid blocking other tasks /// from running. struct ProcessTxBlockingTask { - stateless_tx_validator: Arc, stateful_tx_validator_factory: Arc, state_reader_factory: Arc, mempool_client: SharedMempoolClient, - tx: RpcTransaction, internal_tx: InternalRpcTransaction, executable_tx: AccountTransaction, runtime: tokio::runtime::Handle, @@ -245,17 +244,14 @@ struct ProcessTxBlockingTask { impl ProcessTxBlockingTask { pub fn new( gateway: &Gateway, - tx: RpcTransaction, internal_tx: InternalRpcTransaction, executable_tx: AccountTransaction, runtime: tokio::runtime::Handle, ) -> Self { Self { - stateless_tx_validator: gateway.stateless_tx_validator.clone(), stateful_tx_validator_factory: gateway.stateful_tx_validator_factory.clone(), state_reader_factory: gateway.state_reader_factory.clone(), mempool_client: gateway.mempool_client.clone(), - tx, internal_tx, executable_tx, runtime, @@ -263,9 +259,6 @@ impl ProcessTxBlockingTask { } fn process_tx(self) -> GatewayResult { - // Perform stateless validations. - self.stateless_tx_validator.validate(&self.tx)?; - let mut stateful_transaction_validator = self .stateful_tx_validator_factory .instantiate_validator(self.state_reader_factory.as_ref())?; @@ -296,8 +289,17 @@ pub fn create_gateway( class_manager_client, config.chain_info.chain_id.clone(), )); + let stateless_tx_validator = Arc::new(StatelessTransactionValidator { + config: config.stateless_tx_validator_config.clone(), + }); - Gateway::new(config, state_reader_factory, mempool_client, transaction_converter) + Gateway::new( + config, + state_reader_factory, + mempool_client, + transaction_converter, + stateless_tx_validator, + ) } #[async_trait] diff --git a/crates/apollo_gateway/src/gateway_test.rs b/crates/apollo_gateway/src/gateway_test.rs index afc60893c7c..4301badcc1e 100644 --- a/crates/apollo_gateway/src/gateway_test.rs +++ b/crates/apollo_gateway/src/gateway_test.rs @@ -130,11 +130,13 @@ fn mock_dependencies() -> MockDependencies { local_test_state_reader_factory(CairoVersion::Cairo1(RunnableCairo1::Casm), true); let mock_mempool_client = MockMempoolClient::new(); let mock_transaction_converter = MockTransactionConverterTrait::new(); + let mock_stateless_transaction_validator = mock_stateless_transaction_validator(); MockDependencies { config, state_reader_factory, mock_mempool_client, mock_transaction_converter, + mock_stateless_transaction_validator, } } @@ -143,6 +145,7 @@ struct MockDependencies { state_reader_factory: TestStateReaderFactory, mock_mempool_client: MockMempoolClient, mock_transaction_converter: MockTransactionConverterTrait, + mock_stateless_transaction_validator: MockStatelessTransactionValidatorTrait, } impl MockDependencies { @@ -153,6 +156,7 @@ impl MockDependencies { Arc::new(self.state_reader_factory), Arc::new(self.mock_mempool_client), Arc::new(self.mock_transaction_converter), + Arc::new(self.mock_stateless_transaction_validator), ) } @@ -309,28 +313,13 @@ async fn run_add_tx_and_extract_metrics( AddTxResults { result, metric_handle_for_queries, metrics } } -#[derive(Default)] -pub struct ProcessTxOverrides { - pub mock_stateful_transaction_validator_factory: - Option, - pub mock_stateless_transaction_validator: Option, -} - -fn process_tx_task(overrides: ProcessTxOverrides) -> ProcessTxBlockingTask { - let mock_validator_factory = overrides - .mock_stateful_transaction_validator_factory - .unwrap_or_else(mock_stateful_transaction_validator_factory); - - let mock_stateless_transaction_validator = overrides - .mock_stateless_transaction_validator - .unwrap_or_else(mock_stateless_transaction_validator); - +fn process_tx_task( + stateful_transaction_validator_factory: MockStatefulTransactionValidatorFactoryTrait, +) -> ProcessTxBlockingTask { ProcessTxBlockingTask { - stateless_tx_validator: Arc::new(mock_stateless_transaction_validator), - stateful_tx_validator_factory: Arc::new(mock_validator_factory), + stateful_tx_validator_factory: Arc::new(stateful_transaction_validator_factory), state_reader_factory: Arc::new(MockStateReaderFactory::new()), mempool_client: Arc::new(MockMempoolClient::new()), - tx: invoke_args().get_rpc_tx(), internal_tx: invoke_args().get_internal_tx(), executable_tx: executable_invoke_tx(invoke_args()), runtime: tokio::runtime::Handle::current(), @@ -577,13 +566,7 @@ async fn process_tx_returns_error_when_extract_state_nonce_and_run_validations_f .expect_instantiate_validator() .return_once(|_| Ok(Box::new(mock_stateful_transaction_validator))); - let overrides = ProcessTxOverrides { - mock_stateful_transaction_validator_factory: Some( - mock_stateful_transaction_validator_factory, - ), - ..Default::default() - }; - let process_tx_task = process_tx_task(overrides); + let process_tx_task = process_tx_task(mock_stateful_transaction_validator_factory); let result = tokio::task::spawn_blocking(move || process_tx_task.process_tx()).await.unwrap(); @@ -593,27 +576,20 @@ async fn process_tx_returns_error_when_extract_state_nonce_and_run_validations_f #[rstest] #[tokio::test] -async fn process_tx_returns_error_for_one_stateless_error_variant() { +async fn stateless_transaction_validator_error(mut mock_dependencies: MockDependencies) { let arbitrary_validation_error = Err(StatelessTransactionValidatorError::SignatureTooLong { signature_length: 5001, max_signature_length: 4000, }); let error_code = StarknetErrorCode::UnknownErrorCode("StarknetErrorCode.SIGNATURE_TOO_LONG".into()); - let mut mock_stateless_transaction_validator = MockStatelessTransactionValidatorTrait::new(); mock_stateless_transaction_validator .expect_validate() .return_once(|_| arbitrary_validation_error); - - let overrides = ProcessTxOverrides { - mock_stateless_transaction_validator: Some(mock_stateless_transaction_validator), - ..Default::default() - }; - let task = process_tx_task(overrides); - - let result: Result = - tokio::task::spawn_blocking(move || task.process_tx()).await.unwrap(); + mock_dependencies.mock_stateless_transaction_validator = mock_stateless_transaction_validator; + let gateway = mock_dependencies.gateway(); + let result = gateway.add_tx(invoke_args().get_rpc_tx(), None).await; assert!(result.is_err()); assert_eq!(result.unwrap_err().code, error_code); @@ -633,13 +609,7 @@ async fn process_tx_returns_error_when_instantiating_validator_fails( .expect_instantiate_validator() .return_once(|_| Err(expected_error)); - let overrides = ProcessTxOverrides { - mock_stateful_transaction_validator_factory: Some( - mock_stateful_transaction_validator_factory, - ), - ..Default::default() - }; - let process_tx_task = process_tx_task(overrides); + let process_tx_task = process_tx_task(mock_stateful_transaction_validator_factory); let result = tokio::task::spawn_blocking(move || process_tx_task.process_tx()).await.unwrap(); diff --git a/crates/apollo_gateway/src/lib.rs b/crates/apollo_gateway/src/lib.rs index e565ce7e8ad..9b9122eac2f 100644 --- a/crates/apollo_gateway/src/lib.rs +++ b/crates/apollo_gateway/src/lib.rs @@ -10,7 +10,7 @@ pub mod state_reader; #[cfg(any(feature = "testing", test))] pub mod state_reader_test_utils; mod stateful_transaction_validator; -mod stateless_transaction_validator; +pub mod stateless_transaction_validator; mod sync_state_reader; #[cfg(test)] mod sync_state_reader_test;