Skip to content
127 changes: 111 additions & 16 deletions datafusion/core/src/dataframe/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,10 @@ use std::sync::Arc;
use arrow::array::{Array, ArrayRef, Int64Array, StringArray};
use arrow::compute::{cast, concat};
use arrow::datatypes::{DataType, Field, Schema, SchemaRef};
use async_trait::async_trait;
use datafusion_catalog::Session;
use datafusion_common::config::{CsvOptions, JsonOptions};
use datafusion_common::DFSchemaRef;
use datafusion_common::{
exec_err, not_impl_err, plan_datafusion_err, plan_err, Column, DFSchema,
DataFusionError, ParamValues, ScalarValue, SchemaError, UnnestOptions,
Expand All @@ -63,15 +66,13 @@ use datafusion_expr::{
utils::COUNT_STAR_EXPANSION,
ExplainOption, SortExpr, TableProviderFilterPushDown, UNNAMED_TABLE,
};
use datafusion_expr::{Extension, InvariantLevel, UserDefinedLogicalNode};
use datafusion_functions::core::coalesce;
use datafusion_functions_aggregate::expr_fn::{
avg, count, max, median, min, stddev, sum,
};

use async_trait::async_trait;
use datafusion_catalog::Session;
use datafusion_sql::TableReference;

use std::hash::{Hash, Hasher};
/// Contains options that control how data is
/// written out from a DataFrame
pub struct DataFrameWriteOptions {
Expand Down Expand Up @@ -231,7 +232,67 @@ pub struct DataFrame {
// via anything other than a `project` call should set this to true.
projection_requires_validation: bool,
}
#[derive(Debug)]
struct CacheNode {
input: Arc<LogicalPlan>,
}

impl UserDefinedLogicalNode for CacheNode {
fn name(&self) -> &str {
"CacheNode"
}

fn inputs(&self) -> Vec<&LogicalPlan> {
vec![&self.input]
}

fn schema(&self) -> &DFSchemaRef {
self.input.schema()
}

fn expressions(&self) -> Vec<Expr> {
vec![]
}

fn fmt_for_explain(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
write!(f, "CacheNode")
}

fn as_any(&self) -> &(dyn Any + 'static) {
self
}

fn check_invariants(&self, _level: InvariantLevel) -> Result<(), DataFusionError> {
Ok(())
}

fn with_exprs_and_inputs(
&self,
_exprs: Vec<Expr>,
_inputs: Vec<LogicalPlan>,
) -> Result<Arc<dyn UserDefinedLogicalNode>, DataFusionError> {
Ok(Arc::new(CacheNode {
input: self.input.clone(),
}))
}

fn dyn_hash(&self, state: &mut dyn Hasher) {
let s: &'static str = std::any::type_name::<Self>();
let mut hasher: &mut dyn Hasher = state;
s.hash(&mut hasher);
}

fn dyn_eq(&self, _other: &dyn UserDefinedLogicalNode) -> bool {
true
}

fn dyn_ord(&self, _other: &dyn UserDefinedLogicalNode) -> Option<std::cmp::Ordering> {
Some(std::cmp::Ordering::Equal)
}
}

#[derive(Debug)]
pub struct LocalCacheOption(pub bool);
impl DataFrame {
/// Create a new `DataFrame ` based on an existing `LogicalPlan`
///
Expand Down Expand Up @@ -2205,28 +2266,62 @@ impl DataFrame {
})
}

/// Cache DataFrame as a memory table.
/// Cache `DataFrame` as an in-memory table (eager materialization).
///
/// ```
/// This eagerly executes the current plan, collects all partitions, and
/// registers a `MemTable` backed by the collected batches. This matches the
/// existing behavior and is suitable for single-process execution.
///
/// ```no_run
/// # use datafusion::prelude::*;
/// # use datafusion::error::Result;
/// # #[tokio::main]
/// # async fn main() -> Result<()> {
/// let ctx = SessionContext::new();
/// let df = ctx.read_csv("tests/data/example.csv", CsvReadOptions::new()).await?;
/// let df = df.cache().await?;
/// let df = ctx
/// .read_csv("tests/data/example.csv", CsvReadOptions::new())
/// .await?;
/// let df = df.cache().await?; // materialize into memory
/// # Ok(())
/// # }
/// ```
///
/// NOTE: This eager caching materializes the entire dataset in the local
/// process. It is **not** suitable for distributed environments (e.g.
/// Ballista). See issue #17297 for a follow-up design to make caching
/// configurable and/or lazy via a logical plan node.
pub async fn cache(self) -> Result<DataFrame> {
let context = SessionContext::new_with_state((*self.session_state).clone());
// The schema is consistent with the output
let plan = self.clone().create_physical_plan().await?;
let schema = plan.schema();
let task_ctx = Arc::new(self.task_ctx());
let partitions = collect_partitioned(plan, task_ctx).await?;
let mem_table = MemTable::try_new(schema, partitions)?;
context.read_table(Arc::new(mem_table))
// Access enable_local_cache via extension; default to true
let enable_cache = self
.session_state
.config()
.get_extension::<LocalCacheOption>()
.map(|c| c.0)
.unwrap_or(true);

if enable_cache {
// Eager MemTable logic (current behavior)
let context = SessionContext::new_with_state((*self.session_state).clone());
let plan = self.clone().create_physical_plan().await?;
let schema = plan.schema();
let task_ctx = Arc::new(self.task_ctx());
let partitions = collect_partitioned(plan, task_ctx).await?;
let mem_table = MemTable::try_new(schema, partitions)?;
context.read_table(Arc::new(mem_table))
} else {
// Lazy cache: wrap logical plan in an extension
let plan = self.logical_plan().clone();
let cache_plan = LogicalPlan::Extension(Extension {
node: Arc::new(CacheNode {
input: Arc::new(plan),
}),
});

Ok(DataFrame::new(
*self.session_state.clone(), // unbox Box<SessionState>
cache_plan, // pass LogicalPlan directly
))
}
}

/// Apply an alias to the DataFrame.
Expand Down
10 changes: 6 additions & 4 deletions datafusion/core/tests/parquet/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -474,10 +474,12 @@ fn make_uint_batches(start: u8, end: u8) -> RecordBatch {
Field::new("u32", DataType::UInt32, true),
Field::new("u64", DataType::UInt64, true),
]));
let v8: Vec<u8> = (start..end).collect();
let v16: Vec<u16> = (start as _..end as _).collect();
let v32: Vec<u32> = (start as _..end as _).collect();
let v64: Vec<u64> = (start as _..end as _).collect();

let v8: Vec<u8> = (start..end).collect::<Vec<u8>>();
let v16: Vec<u16> = ((start as u16)..(end as u16)).collect::<Vec<u16>>();
let v32: Vec<u32> = ((start as u32)..(end as u32)).collect::<Vec<u32>>();
let v64: Vec<u64> = ((start as u64)..(end as u64)).collect::<Vec<u64>>();

RecordBatch::try_new(
schema,
vec![
Expand Down