From 52d1a1e795bdf93991fe0da03cc5feea18cd1720 Mon Sep 17 00:00:00 2001 From: Adam Reeve Date: Thu, 28 Aug 2025 16:17:01 +1200 Subject: [PATCH] Make EncryptionFactory async --- Cargo.lock | 1 + .../examples/parquet_encrypted_with_kms.rs | 6 +- datafusion/core/tests/parquet/encryption.rs | 19 +++-- .../datasource-parquet/src/file_format.rs | 34 +++++---- datafusion/datasource-parquet/src/opener.rs | 69 +++++++++++++++---- datafusion/execution/Cargo.toml | 1 + .../execution/src/parquet_encryption.rs | 6 +- 7 files changed, 96 insertions(+), 40 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 95fca3c32235..2f4811a16f6f 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2209,6 +2209,7 @@ name = "datafusion-execution" version = "49.0.2" dependencies = [ "arrow", + "async-trait", "chrono", "dashmap", "datafusion-common", diff --git a/datafusion-examples/examples/parquet_encrypted_with_kms.rs b/datafusion-examples/examples/parquet_encrypted_with_kms.rs index d30608ce7a1c..19b0e8d0b199 100644 --- a/datafusion-examples/examples/parquet_encrypted_with_kms.rs +++ b/datafusion-examples/examples/parquet_encrypted_with_kms.rs @@ -17,6 +17,7 @@ use arrow::array::{ArrayRef, Int32Array, RecordBatch, StringArray}; use arrow_schema::SchemaRef; +use async_trait::async_trait; use base64::Engine; use datafusion::common::extensions_options; use datafusion::config::{EncryptionFactoryOptions, TableParquetOptions}; @@ -211,6 +212,7 @@ struct TestEncryptionFactory {} /// `EncryptionFactory` is a DataFusion trait for types that generate /// file encryption and decryption properties. +#[async_trait] impl EncryptionFactory for TestEncryptionFactory { /// Generate file encryption properties to use when writing a Parquet file. /// The `schema` is provided so that it may be used to dynamically configure @@ -219,7 +221,7 @@ impl EncryptionFactory for TestEncryptionFactory { /// but other implementations may want to use this to compute an /// AAD prefix for the file, or to allow use of external key material /// (where key metadata is stored in a JSON file alongside Parquet files). - fn get_file_encryption_properties( + async fn get_file_encryption_properties( &self, options: &EncryptionFactoryOptions, schema: &SchemaRef, @@ -262,7 +264,7 @@ impl EncryptionFactory for TestEncryptionFactory { /// Generate file decryption properties to use when reading a Parquet file. /// Rather than provide the AES keys directly for decryption, we set a `KeyRetriever` /// that can determine the keys using the encryption metadata. - fn get_file_decryption_properties( + async fn get_file_decryption_properties( &self, _options: &EncryptionFactoryOptions, _file_path: &Path, diff --git a/datafusion/core/tests/parquet/encryption.rs b/datafusion/core/tests/parquet/encryption.rs index c32310752cc5..034e99b4408f 100644 --- a/datafusion/core/tests/parquet/encryption.rs +++ b/datafusion/core/tests/parquet/encryption.rs @@ -20,6 +20,7 @@ use arrow::array::{ArrayRef, Int32Array, StringArray}; use arrow::record_batch::RecordBatch; use arrow_schema::{DataType, SchemaRef}; +use async_trait::async_trait; use datafusion::dataframe::DataFrameWriteOptions; use datafusion::datasource::listing::ListingOptions; use datafusion::prelude::{ParquetReadOptions, SessionContext}; @@ -175,7 +176,9 @@ async fn round_trip_parquet_with_encryption_factory() { // Crypto factory should have generated one key per partition file assert_eq!(encryption_factory.encryption_keys.lock().unwrap().len(), 3); - verify_table_encrypted(tmpdir.path(), &encryption_factory).unwrap(); + verify_table_encrypted(tmpdir.path(), &encryption_factory) + .await + .unwrap(); // Registering table without decryption properties should fail let table_path = format!("file://{}/", tmpdir.path().to_str().unwrap()); @@ -255,7 +258,7 @@ async fn round_trip_parquet_with_encryption_factory() { assert_batches_sorted_eq!(expected, &table); } -fn verify_table_encrypted( +async fn verify_table_encrypted( table_path: &Path, encryption_factory: &Arc, ) -> datafusion_common::Result<()> { @@ -267,7 +270,7 @@ fn verify_table_encrypted( if path.is_dir() { directories.push(path); } else { - verify_file_encrypted(&path, encryption_factory)?; + verify_file_encrypted(&path, encryption_factory).await?; files_visited += 1; } } @@ -276,7 +279,7 @@ fn verify_table_encrypted( Ok(()) } -fn verify_file_encrypted( +async fn verify_file_encrypted( file_path: &Path, encryption_factory: &Arc, ) -> datafusion_common::Result<()> { @@ -296,7 +299,8 @@ fn verify_file_encrypted( let object_path = object_store::path::Path::from(file_path_str); let decryption_properties = encryption_factory - .get_file_decryption_properties(&options, &object_path)? + .get_file_decryption_properties(&options, &object_path) + .await? .unwrap(); let reader_options = @@ -325,8 +329,9 @@ struct MockEncryptionFactory { pub counter: AtomicU8, } +#[async_trait] impl EncryptionFactory for MockEncryptionFactory { - fn get_file_encryption_properties( + async fn get_file_encryption_properties( &self, config: &EncryptionFactoryOptions, _schema: &SchemaRef, @@ -344,7 +349,7 @@ impl EncryptionFactory for MockEncryptionFactory { Ok(Some(encryption_properties)) } - fn get_file_decryption_properties( + async fn get_file_decryption_properties( &self, config: &EncryptionFactoryOptions, file_path: &object_store::path::Path, diff --git a/datafusion/datasource-parquet/src/file_format.rs b/datafusion/datasource-parquet/src/file_format.rs index a2621d385458..1fcc1721017c 100644 --- a/datafusion/datasource-parquet/src/file_format.rs +++ b/datafusion/datasource-parquet/src/file_format.rs @@ -302,7 +302,7 @@ fn clear_metadata( } #[cfg(feature = "parquet_encryption")] -fn get_file_decryption_properties( +async fn get_file_decryption_properties( state: &dyn Session, options: &TableParquetOptions, file_path: &Path, @@ -314,10 +314,12 @@ fn get_file_decryption_properties( Some(factory_id) => { let factory = state.runtime_env().parquet_encryption_factory(factory_id)?; - factory.get_file_decryption_properties( - &options.crypto.factory_options, - file_path, - )? + factory + .get_file_decryption_properties( + &options.crypto.factory_options, + file_path, + ) + .await? } None => None, }, @@ -326,7 +328,7 @@ fn get_file_decryption_properties( } #[cfg(not(feature = "parquet_encryption"))] -fn get_file_decryption_properties( +async fn get_file_decryption_properties( _state: &dyn Session, _options: &TableParquetOptions, _file_path: &Path, @@ -379,7 +381,8 @@ impl FileFormat for ParquetFormat { state, &self.options, &object.location, - )?; + ) + .await?; let result = DFParquetMetadata::new(store.as_ref(), object) .with_metadata_size_hint(self.metadata_size_hint()) .with_decryption_properties(file_decryption_properties.as_ref()) @@ -437,7 +440,8 @@ impl FileFormat for ParquetFormat { object: &ObjectMeta, ) -> Result { let file_decryption_properties = - get_file_decryption_properties(state, &self.options, &object.location)?; + get_file_decryption_properties(state, &self.options, &object.location) + .await?; let file_metadata_cache = state.runtime_env().cache_manager.get_file_metadata_cache(); DFParquetMetadata::new(store, object) @@ -1119,7 +1123,7 @@ impl ParquetSink { /// Create writer properties based upon configuration settings, /// including partitioning and the inclusion of arrow schema metadata. - fn create_writer_props( + async fn create_writer_props( &self, runtime: &Arc, path: &Path, @@ -1147,7 +1151,8 @@ impl ParquetSink { &parquet_opts, schema, path, - )?; + ) + .await?; Ok(builder.build()) } @@ -1188,7 +1193,7 @@ impl ParquetSink { } #[cfg(feature = "parquet_encryption")] -fn set_writer_encryption_properties( +async fn set_writer_encryption_properties( builder: WriterPropertiesBuilder, runtime: &Arc, parquet_opts: &TableParquetOptions, @@ -1208,7 +1213,8 @@ fn set_writer_encryption_properties( &parquet_opts.crypto.factory_options, schema, path, - )?; + ) + .await?; if let Some(file_encryption_properties) = file_encryption_properties { return Ok( builder.with_file_encryption_properties(file_encryption_properties) @@ -1219,7 +1225,7 @@ fn set_writer_encryption_properties( } #[cfg(not(feature = "parquet_encryption"))] -fn set_writer_encryption_properties( +async fn set_writer_encryption_properties( builder: WriterPropertiesBuilder, _runtime: &Arc, _parquet_opts: &TableParquetOptions, @@ -1269,7 +1275,7 @@ impl FileSink for ParquetSink { }; while let Some((path, mut rx)) = file_stream_rx.recv().await { - let parquet_props = self.create_writer_props(&runtime, &path)?; + let parquet_props = self.create_writer_props(&runtime, &path).await?; if !allow_single_file_parallelism { let mut writer = self .create_async_arrow_writer( diff --git a/datafusion/datasource-parquet/src/opener.rs b/datafusion/datasource-parquet/src/opener.rs index c96d73242e8e..c078c2ef44c0 100644 --- a/datafusion/datasource-parquet/src/opener.rs +++ b/datafusion/datasource-parquet/src/opener.rs @@ -112,7 +112,7 @@ impl FileOpener for ParquetOpener { fn open(&self, file_meta: FileMeta, file: PartitionedFile) -> Result { let file_range = file_meta.range.clone(); let extensions = file_meta.extensions.clone(); - let file_location = file_meta.location(); + let file_location = file_meta.location().clone(); let file_name = file_location.to_string(); let file_metrics = ParquetFileMetrics::new(self.partition_index, &file_name, &self.metrics); @@ -152,16 +152,18 @@ impl FileOpener for ParquetOpener { let mut predicate_file_schema = Arc::clone(&self.logical_file_schema); let mut enable_page_index = self.enable_page_index; - let file_decryption_properties = - self.get_file_decryption_properties(file_location)?; - - // For now, page index does not work with encrypted files. See: - // https://github.com/apache/arrow-rs/issues/7629 - if file_decryption_properties.is_some() { - enable_page_index = false; - } + let encryption_context = self.get_encryption_context(); Ok(Box::pin(async move { + let file_decryption_properties = encryption_context + .get_file_decryption_properties(&file_location) + .await?; + // For now, page index does not work with encrypted files. See: + // https://github.com/apache/arrow-rs/issues/7629 + if file_decryption_properties.is_some() { + enable_page_index = false; + } + // Prune this file using the file level statistics and partition values. // Since dynamic filters may have been updated since planning it is possible that we are able // to prune files now that we couldn't prune at planning time. @@ -508,9 +510,30 @@ where } } +#[derive(Default)] +struct EncryptionContext { + #[cfg(feature = "parquet_encryption")] + file_decryption_properties: Option>, + #[cfg(feature = "parquet_encryption")] + encryption_factory: Option<(Arc, EncryptionFactoryOptions)>, +} + #[cfg(feature = "parquet_encryption")] -impl ParquetOpener { - fn get_file_decryption_properties( +impl EncryptionContext { + fn new( + file_decryption_properties: Option>, + encryption_factory: Option<( + Arc, + EncryptionFactoryOptions, + )>, + ) -> Self { + Self { + file_decryption_properties, + encryption_factory, + } + } + + async fn get_file_decryption_properties( &self, file_location: &object_store::path::Path, ) -> Result>> { @@ -520,7 +543,8 @@ impl ParquetOpener { } None => match &self.encryption_factory { Some((encryption_factory, encryption_config)) => Ok(encryption_factory - .get_file_decryption_properties(encryption_config, file_location)? + .get_file_decryption_properties(encryption_config, file_location) + .await? .map(Arc::new)), None => Ok(None), }, @@ -529,12 +553,27 @@ impl ParquetOpener { } #[cfg(not(feature = "parquet_encryption"))] -impl ParquetOpener { - fn get_file_decryption_properties( +impl EncryptionContext { + async fn get_file_decryption_properties( &self, _file_location: &object_store::path::Path, ) -> Result>> { - Ok(self.file_decryption_properties.clone()) + Ok(None) + } +} + +impl ParquetOpener { + #[cfg(feature = "parquet_encryption")] + fn get_encryption_context(&self) -> EncryptionContext { + EncryptionContext::new( + self.file_decryption_properties.clone(), + self.encryption_factory.clone(), + ) + } + + #[cfg(not(feature = "parquet_encryption"))] + fn get_encryption_context(&self) -> EncryptionContext { + EncryptionContext::default() } } diff --git a/datafusion/execution/Cargo.toml b/datafusion/execution/Cargo.toml index f6d02615e39a..afe9039f8bae 100644 --- a/datafusion/execution/Cargo.toml +++ b/datafusion/execution/Cargo.toml @@ -44,6 +44,7 @@ parquet_encryption = [ [dependencies] arrow = { workspace = true } +async-trait = { workspace = true } dashmap = { workspace = true } datafusion-common = { workspace = true, default-features = true } datafusion-expr = { workspace = true } diff --git a/datafusion/execution/src/parquet_encryption.rs b/datafusion/execution/src/parquet_encryption.rs index 13a18390d02a..c06764a0eb55 100644 --- a/datafusion/execution/src/parquet_encryption.rs +++ b/datafusion/execution/src/parquet_encryption.rs @@ -16,6 +16,7 @@ // under the License. use arrow::datatypes::SchemaRef; +use async_trait::async_trait; use dashmap::DashMap; use datafusion_common::config::EncryptionFactoryOptions; use datafusion_common::error::Result; @@ -32,9 +33,10 @@ use std::sync::Arc; /// For example usage, see the [`parquet_encrypted_with_kms` example]. /// /// [`parquet_encrypted_with_kms` example]: https://github.com/apache/datafusion/blob/main/datafusion-examples/examples/parquet_encrypted_with_kms.rs +#[async_trait] pub trait EncryptionFactory: Send + Sync + std::fmt::Debug + 'static { /// Generate file encryption properties to use when writing a Parquet file. - fn get_file_encryption_properties( + async fn get_file_encryption_properties( &self, config: &EncryptionFactoryOptions, schema: &SchemaRef, @@ -42,7 +44,7 @@ pub trait EncryptionFactory: Send + Sync + std::fmt::Debug + 'static { ) -> Result>; /// Generate file decryption properties to use when reading a Parquet file. - fn get_file_decryption_properties( + async fn get_file_decryption_properties( &self, config: &EncryptionFactoryOptions, file_path: &Path,