Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion crates/apollo_gateway/benches/utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ use apollo_class_manager_types::transaction_converter::TransactionConverter;
use apollo_class_manager_types::EmptyClassManagerClient;
use apollo_gateway::gateway::Gateway;
use apollo_gateway::state_reader_test_utils::local_test_state_reader_factory;
use apollo_gateway::stateless_transaction_validator::StatelessTransactionValidator;
use apollo_gateway_config::config::GatewayConfig;
use apollo_mempool_types::communication::MockMempoolClient;
use blockifier::context::ChainInfo;
Expand Down Expand Up @@ -82,19 +83,22 @@ impl BenchTestSetup {

let state_reader_factory = local_test_state_reader_factory(cairo_version, false);
let mut mempool_client = MockMempoolClient::new();
// TODO(noamsp): use MockTransactionConverter
let class_manager_client = Arc::new(EmptyClassManagerClient);
let transaction_converter = TransactionConverter::new(
class_manager_client.clone(),
config.gateway_config.chain_info.chain_id.clone(),
);
let stateless_tx_validator = Arc::new(StatelessTransactionValidator {
config: config.gateway_config.stateless_tx_validator_config.clone(),
});
mempool_client.expect_add_tx().returning(|_| Ok(()));

let gateway_business_logic = Gateway::new(
config.gateway_config,
Arc::new(state_reader_factory),
Arc::new(mempool_client),
Arc::new(transaction_converter),
stateless_tx_validator,
);

Self { gateway: gateway_business_logic, txs }
Expand Down
30 changes: 16 additions & 14 deletions crates/apollo_gateway/src/gateway.rs
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ pub mod gateway_test;
#[derive(Clone)]
pub struct Gateway {
pub config: Arc<GatewayConfig>,
pub stateless_tx_validator: Arc<StatelessTransactionValidator>,
pub stateless_tx_validator: Arc<dyn StatelessTransactionValidatorTrait>,
pub stateful_tx_validator_factory: Arc<dyn StatefulTransactionValidatorFactoryTrait>,
pub state_reader_factory: Arc<dyn StateReaderFactory>,
pub mempool_client: SharedMempoolClient,
Expand All @@ -71,12 +71,11 @@ impl Gateway {
state_reader_factory: Arc<dyn StateReaderFactory>,
mempool_client: SharedMempoolClient,
transaction_converter: Arc<dyn TransactionConverterTrait>,
stateless_tx_validator: Arc<dyn StatelessTransactionValidatorTrait>,
) -> Self {
Self {
config: Arc::new(config.clone()),
stateless_tx_validator: Arc::new(StatelessTransactionValidator {
config: config.stateless_tx_validator_config.clone(),
}),
stateless_tx_validator,
stateful_tx_validator_factory: Arc::new(StatefulTransactionValidatorFactory {
config: config.stateful_tx_validator_config.clone(),
chain_info: config.chain_info.clone(),
Expand Down Expand Up @@ -125,6 +124,9 @@ impl Gateway {
}
}

// Perform stateless validations.
self.stateless_tx_validator.validate(tx)?;

let tx_signature = tx.signature().clone();
let internal_tx = self
.transaction_converter
Expand All @@ -149,7 +151,6 @@ impl Gateway {

let blocking_task = ProcessTxBlockingTask::new(
self,
tx.clone(),
internal_tx,
executable_tx,
tokio::runtime::Handle::current(),
Expand Down Expand Up @@ -232,11 +233,9 @@ impl Gateway {
/// CPU-intensive transaction processing, spawned in a blocking thread to avoid blocking other tasks
/// from running.
struct ProcessTxBlockingTask {
stateless_tx_validator: Arc<dyn StatelessTransactionValidatorTrait>,
stateful_tx_validator_factory: Arc<dyn StatefulTransactionValidatorFactoryTrait>,
state_reader_factory: Arc<dyn StateReaderFactory>,
mempool_client: SharedMempoolClient,
tx: RpcTransaction,
internal_tx: InternalRpcTransaction,
executable_tx: AccountTransaction,
runtime: tokio::runtime::Handle,
Expand All @@ -245,27 +244,21 @@ struct ProcessTxBlockingTask {
impl ProcessTxBlockingTask {
pub fn new(
gateway: &Gateway,
tx: RpcTransaction,
internal_tx: InternalRpcTransaction,
executable_tx: AccountTransaction,
runtime: tokio::runtime::Handle,
) -> Self {
Self {
stateless_tx_validator: gateway.stateless_tx_validator.clone(),
stateful_tx_validator_factory: gateway.stateful_tx_validator_factory.clone(),
state_reader_factory: gateway.state_reader_factory.clone(),
mempool_client: gateway.mempool_client.clone(),
tx,
internal_tx,
executable_tx,
runtime,
}
}

fn process_tx(self) -> GatewayResult<AddTransactionArgs> {
// Perform stateless validations.
self.stateless_tx_validator.validate(&self.tx)?;

let mut stateful_transaction_validator = self
.stateful_tx_validator_factory
.instantiate_validator(self.state_reader_factory.as_ref())?;
Expand Down Expand Up @@ -296,8 +289,17 @@ pub fn create_gateway(
class_manager_client,
config.chain_info.chain_id.clone(),
));
let stateless_tx_validator = Arc::new(StatelessTransactionValidator {
config: config.stateless_tx_validator_config.clone(),
});

Gateway::new(config, state_reader_factory, mempool_client, transaction_converter)
Gateway::new(
config,
state_reader_factory,
mempool_client,
transaction_converter,
stateless_tx_validator,
)
}

#[async_trait]
Expand Down
58 changes: 14 additions & 44 deletions crates/apollo_gateway/src/gateway_test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -130,11 +130,13 @@ fn mock_dependencies() -> MockDependencies {
local_test_state_reader_factory(CairoVersion::Cairo1(RunnableCairo1::Casm), true);
let mock_mempool_client = MockMempoolClient::new();
let mock_transaction_converter = MockTransactionConverterTrait::new();
let mock_stateless_transaction_validator = mock_stateless_transaction_validator();
MockDependencies {
config,
state_reader_factory,
mock_mempool_client,
mock_transaction_converter,
mock_stateless_transaction_validator,
}
}

Expand All @@ -143,6 +145,7 @@ struct MockDependencies {
state_reader_factory: TestStateReaderFactory,
mock_mempool_client: MockMempoolClient,
mock_transaction_converter: MockTransactionConverterTrait,
mock_stateless_transaction_validator: MockStatelessTransactionValidatorTrait,
}

impl MockDependencies {
Expand All @@ -153,6 +156,7 @@ impl MockDependencies {
Arc::new(self.state_reader_factory),
Arc::new(self.mock_mempool_client),
Arc::new(self.mock_transaction_converter),
Arc::new(self.mock_stateless_transaction_validator),
)
}

Expand Down Expand Up @@ -309,28 +313,13 @@ async fn run_add_tx_and_extract_metrics(
AddTxResults { result, metric_handle_for_queries, metrics }
}

#[derive(Default)]
pub struct ProcessTxOverrides {
pub mock_stateful_transaction_validator_factory:
Option<MockStatefulTransactionValidatorFactoryTrait>,
pub mock_stateless_transaction_validator: Option<MockStatelessTransactionValidatorTrait>,
}

fn process_tx_task(overrides: ProcessTxOverrides) -> ProcessTxBlockingTask {
let mock_validator_factory = overrides
.mock_stateful_transaction_validator_factory
.unwrap_or_else(mock_stateful_transaction_validator_factory);

let mock_stateless_transaction_validator = overrides
.mock_stateless_transaction_validator
.unwrap_or_else(mock_stateless_transaction_validator);

fn process_tx_task(
stateful_transaction_validator_factory: MockStatefulTransactionValidatorFactoryTrait,
) -> ProcessTxBlockingTask {
ProcessTxBlockingTask {
stateless_tx_validator: Arc::new(mock_stateless_transaction_validator),
stateful_tx_validator_factory: Arc::new(mock_validator_factory),
stateful_tx_validator_factory: Arc::new(stateful_transaction_validator_factory),
state_reader_factory: Arc::new(MockStateReaderFactory::new()),
mempool_client: Arc::new(MockMempoolClient::new()),
tx: invoke_args().get_rpc_tx(),
internal_tx: invoke_args().get_internal_tx(),
executable_tx: executable_invoke_tx(invoke_args()),
runtime: tokio::runtime::Handle::current(),
Expand Down Expand Up @@ -577,13 +566,7 @@ async fn process_tx_returns_error_when_extract_state_nonce_and_run_validations_f
.expect_instantiate_validator()
.return_once(|_| Ok(Box::new(mock_stateful_transaction_validator)));

let overrides = ProcessTxOverrides {
mock_stateful_transaction_validator_factory: Some(
mock_stateful_transaction_validator_factory,
),
..Default::default()
};
let process_tx_task = process_tx_task(overrides);
let process_tx_task = process_tx_task(mock_stateful_transaction_validator_factory);

let result = tokio::task::spawn_blocking(move || process_tx_task.process_tx()).await.unwrap();

Expand All @@ -593,27 +576,20 @@ async fn process_tx_returns_error_when_extract_state_nonce_and_run_validations_f

#[rstest]
#[tokio::test]
async fn process_tx_returns_error_for_one_stateless_error_variant() {
async fn stateless_transaction_validator_error(mut mock_dependencies: MockDependencies) {
let arbitrary_validation_error = Err(StatelessTransactionValidatorError::SignatureTooLong {
signature_length: 5001,
max_signature_length: 4000,
});
let error_code =
StarknetErrorCode::UnknownErrorCode("StarknetErrorCode.SIGNATURE_TOO_LONG".into());

let mut mock_stateless_transaction_validator = MockStatelessTransactionValidatorTrait::new();
mock_stateless_transaction_validator
.expect_validate()
.return_once(|_| arbitrary_validation_error);

let overrides = ProcessTxOverrides {
mock_stateless_transaction_validator: Some(mock_stateless_transaction_validator),
..Default::default()
};
let task = process_tx_task(overrides);

let result: Result<AddTransactionArgs, StarknetError> =
tokio::task::spawn_blocking(move || task.process_tx()).await.unwrap();
mock_dependencies.mock_stateless_transaction_validator = mock_stateless_transaction_validator;
let gateway = mock_dependencies.gateway();
let result = gateway.add_tx(invoke_args().get_rpc_tx(), None).await;

assert!(result.is_err());
assert_eq!(result.unwrap_err().code, error_code);
Expand All @@ -633,13 +609,7 @@ async fn process_tx_returns_error_when_instantiating_validator_fails(
.expect_instantiate_validator()
.return_once(|_| Err(expected_error));

let overrides = ProcessTxOverrides {
mock_stateful_transaction_validator_factory: Some(
mock_stateful_transaction_validator_factory,
),
..Default::default()
};
let process_tx_task = process_tx_task(overrides);
let process_tx_task = process_tx_task(mock_stateful_transaction_validator_factory);

let result = tokio::task::spawn_blocking(move || process_tx_task.process_tx()).await.unwrap();

Expand Down
2 changes: 1 addition & 1 deletion crates/apollo_gateway/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ pub mod state_reader;
#[cfg(any(feature = "testing", test))]
pub mod state_reader_test_utils;
mod stateful_transaction_validator;
mod stateless_transaction_validator;
pub mod stateless_transaction_validator;
mod sync_state_reader;
#[cfg(test)]
mod sync_state_reader_test;
Expand Down