Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 22 additions & 9 deletions datafusion-examples/examples/async_udf.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,15 +20,14 @@ use arrow::compute::kernels::cmp::eq;
use arrow_schema::{DataType, Field, Schema};
use async_trait::async_trait;
use datafusion::common::error::Result;
use datafusion::common::internal_err;
use datafusion::common::types::{logical_int64, logical_string};
use datafusion::common::utils::take_function_args;
use datafusion::common::{internal_err, not_impl_err};
use datafusion::config::ConfigOptions;
use datafusion::logical_expr::async_udf::{
AsyncScalarFunctionArgs, AsyncScalarUDF, AsyncScalarUDFImpl,
};
use datafusion::logical_expr::async_udf::{AsyncScalarUDF, AsyncScalarUDFImpl};
use datafusion::logical_expr::{
ColumnarValue, Signature, TypeSignature, TypeSignatureClass, Volatility,
ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, TypeSignature,
TypeSignatureClass, Volatility,
};
use datafusion::logical_expr_common::signature::Coercion;
use datafusion::physical_expr_common::datum::apply_cmp;
Expand Down Expand Up @@ -153,7 +152,7 @@ impl AsyncUpper {
}

#[async_trait]
impl AsyncScalarUDFImpl for AsyncUpper {
impl ScalarUDFImpl for AsyncUpper {
fn as_any(&self) -> &dyn Any {
self
}
Expand All @@ -170,13 +169,20 @@ impl AsyncScalarUDFImpl for AsyncUpper {
Ok(DataType::Utf8)
}

fn invoke_with_args(&self, _args: ScalarFunctionArgs) -> Result<ColumnarValue> {
not_impl_err!("AsyncUpper can only be called from async contexts")
}
}

#[async_trait]
impl AsyncScalarUDFImpl for AsyncUpper {
fn ideal_batch_size(&self) -> Option<usize> {
Some(10)
}

async fn invoke_async_with_args(
&self,
args: AsyncScalarFunctionArgs,
args: ScalarFunctionArgs,
_option: &ConfigOptions,
) -> Result<ArrayRef> {
trace!("Invoking async_upper with args: {:?}", args);
Expand Down Expand Up @@ -226,7 +232,7 @@ impl AsyncEqual {
}

#[async_trait]
impl AsyncScalarUDFImpl for AsyncEqual {
impl ScalarUDFImpl for AsyncEqual {
fn as_any(&self) -> &dyn Any {
self
}
Expand All @@ -243,9 +249,16 @@ impl AsyncScalarUDFImpl for AsyncEqual {
Ok(DataType::Boolean)
}

fn invoke_with_args(&self, _args: ScalarFunctionArgs) -> Result<ColumnarValue> {
not_impl_err!("AsyncEqual can only be called from async contexts")
}
}

#[async_trait]
impl AsyncScalarUDFImpl for AsyncEqual {
async fn invoke_async_with_args(
&self,
args: AsyncScalarFunctionArgs,
args: ScalarFunctionArgs,
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is interesting here that invoke_async_with_args has a copy of the config_options` 🤔 -- I think that is soemthing that @Omega359 has tried to get into normal scalar functions for a while

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It makes sense to me. We allow the custom config. Allowing access to the config option can make UDF flexible.

_option: &ConfigOptions,
) -> Result<ArrayRef> {
let [arg1, arg2] = take_function_args(self.name(), &args.args)?;
Expand Down
10 changes: 7 additions & 3 deletions datafusion/core/src/physical_planner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -779,9 +779,11 @@ impl DefaultPhysicalPlanner {
let runtime_expr =
self.create_physical_expr(predicate, input_dfschema, session_state)?;

let input_schema = input.schema();
let filter = match self.try_plan_async_exprs(
input.schema().fields().len(),
input_schema.fields().len(),
PlannedExprResult::Expr(vec![runtime_expr]),
input_schema.as_arrow(),
)? {
PlanAsyncExpr::Sync(PlannedExprResult::Expr(runtime_expr)) => {
FilterExec::try_new(Arc::clone(&runtime_expr[0]), physical_input)?
Expand Down Expand Up @@ -2082,6 +2084,7 @@ impl DefaultPhysicalPlanner {
match self.try_plan_async_exprs(
num_input_columns,
PlannedExprResult::ExprWithName(physical_exprs),
input_physical_schema.as_ref(),
)? {
PlanAsyncExpr::Sync(PlannedExprResult::ExprWithName(physical_exprs)) => Ok(
Arc::new(ProjectionExec::try_new(physical_exprs, input_exec)?),
Expand All @@ -2104,18 +2107,19 @@ impl DefaultPhysicalPlanner {
&self,
num_input_columns: usize,
physical_expr: PlannedExprResult,
schema: &Schema,
) -> Result<PlanAsyncExpr> {
let mut async_map = AsyncMapper::new(num_input_columns);
match &physical_expr {
PlannedExprResult::ExprWithName(exprs) => {
exprs
.iter()
.try_for_each(|(expr, _)| async_map.find_references(expr))?;
.try_for_each(|(expr, _)| async_map.find_references(expr, schema))?;
}
PlannedExprResult::Expr(exprs) => {
exprs
.iter()
.try_for_each(|expr| async_map.find_references(expr))?;
.try_for_each(|expr| async_map.find_references(expr, schema))?;
}
}

Expand Down
42 changes: 4 additions & 38 deletions datafusion/expr/src/async_udf.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

use crate::{ReturnFieldArgs, ScalarFunctionArgs, ScalarUDF, ScalarUDFImpl};
use arrow::array::ArrayRef;
use arrow::datatypes::{DataType, Field, FieldRef, SchemaRef};
use arrow::datatypes::{DataType, FieldRef};
use async_trait::async_trait;
use datafusion_common::config::ConfigOptions;
use datafusion_common::error::Result;
Expand All @@ -35,34 +35,7 @@ use std::sync::Arc;
///
/// The name is chosen to mirror ScalarUDFImpl
#[async_trait]
pub trait AsyncScalarUDFImpl: Debug + Send + Sync {
/// the function cast as any
fn as_any(&self) -> &dyn Any;
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The point of this PR is to remove this duplication from AsyncScalarUDFImpl and instead use ScalarUDFImpl instead


/// The name of the function
fn name(&self) -> &str;

/// The signature of the function
fn signature(&self) -> &Signature;

/// The return type of the function
fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType>;

/// What type will be returned by this function, given the arguments?
///
/// By default, this function calls [`Self::return_type`] with the
/// types of each argument.
fn return_field_from_args(&self, args: ReturnFieldArgs) -> Result<FieldRef> {
let data_types = args
.arg_fields
.iter()
.map(|f| f.data_type())
.cloned()
.collect::<Vec<_>>();
let return_type = self.return_type(&data_types)?;
Ok(Arc::new(Field::new(self.name(), return_type, true)))
}

pub trait AsyncScalarUDFImpl: ScalarUDFImpl {
/// The ideal batch size for this function.
///
/// This is used to determine what size of data to be evaluated at once.
Expand All @@ -74,7 +47,7 @@ pub trait AsyncScalarUDFImpl: Debug + Send + Sync {
/// Invoke the function asynchronously with the async arguments
async fn invoke_async_with_args(
&self,
args: AsyncScalarFunctionArgs,
args: ScalarFunctionArgs,
option: &ConfigOptions,
) -> Result<ArrayRef>;
}
Expand Down Expand Up @@ -107,7 +80,7 @@ impl AsyncScalarUDF {
/// Invoke the function asynchronously with the async arguments
pub async fn invoke_async_with_args(
&self,
args: AsyncScalarFunctionArgs,
args: ScalarFunctionArgs,
option: &ConfigOptions,
) -> Result<ArrayRef> {
self.inner.invoke_async_with_args(args, option).await
Expand Down Expand Up @@ -145,10 +118,3 @@ impl Display for AsyncScalarUDF {
write!(f, "AsyncScalarUDF: {}", self.inner.name())
}
}

#[derive(Debug)]
pub struct AsyncScalarFunctionArgs {
pub args: Vec<ColumnarValue>,
pub number_rows: usize,
pub schema: SchemaRef,
}
34 changes: 26 additions & 8 deletions datafusion/physical-expr/src/async_scalar_function.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,12 @@

use crate::ScalarFunctionExpr;
use arrow::array::{make_array, MutableArrayData, RecordBatch};
use arrow::datatypes::{DataType, Field, Schema};
use arrow::datatypes::{DataType, Field, FieldRef, Schema};
use datafusion_common::config::ConfigOptions;
use datafusion_common::Result;
use datafusion_common::{internal_err, not_impl_err};
use datafusion_expr::async_udf::{AsyncScalarFunctionArgs, AsyncScalarUDF};
use datafusion_expr::async_udf::AsyncScalarUDF;
use datafusion_expr::ScalarFunctionArgs;
use datafusion_expr_common::columnar_value::ColumnarValue;
use datafusion_physical_expr_common::physical_expr::PhysicalExpr;
use std::any::Any;
Expand All @@ -36,6 +37,8 @@ pub struct AsyncFuncExpr {
pub name: String,
/// The actual function (always `ScalarFunctionExpr`)
pub func: Arc<dyn PhysicalExpr>,
/// The field that this function will return
return_field: FieldRef,
}

impl Display for AsyncFuncExpr {
Expand All @@ -59,17 +62,23 @@ impl Hash for AsyncFuncExpr {

impl AsyncFuncExpr {
/// create a new AsyncFuncExpr
pub fn try_new(name: impl Into<String>, func: Arc<dyn PhysicalExpr>) -> Result<Self> {
pub fn try_new(
name: impl Into<String>,
func: Arc<dyn PhysicalExpr>,
schema: &Schema,
) -> Result<Self> {
let Some(_) = func.as_any().downcast_ref::<ScalarFunctionExpr>() else {
return internal_err!(
"unexpected function type, expected ScalarFunctionExpr, got: {:?}",
func
);
};

let return_field = func.return_field(schema)?;
Ok(Self {
name: name.into(),
func,
return_field,
})
}

Expand Down Expand Up @@ -128,6 +137,12 @@ impl AsyncFuncExpr {
);
};

let arg_fields = scalar_function_expr
.args()
.iter()
.map(|e| e.return_field(batch.schema_ref()))
.collect::<Result<Vec<_>>>()?;

let mut result_batches = vec![];
if let Some(ideal_batch_size) = self.ideal_batch_size()? {
let mut remainder = batch.clone();
Expand All @@ -148,10 +163,11 @@ impl AsyncFuncExpr {
result_batches.push(
async_udf
.invoke_async_with_args(
AsyncScalarFunctionArgs {
args: args.to_vec(),
ScalarFunctionArgs {
args,
arg_fields: arg_fields.clone(),
number_rows: current_batch.num_rows(),
schema: current_batch.schema(),
return_field: Arc::clone(&self.return_field),
},
option,
)
Expand All @@ -168,10 +184,11 @@ impl AsyncFuncExpr {
result_batches.push(
async_udf
.invoke_async_with_args(
AsyncScalarFunctionArgs {
ScalarFunctionArgs {
args: args.to_vec(),
arg_fields,
number_rows: batch.num_rows(),
schema: batch.schema(),
return_field: Arc::clone(&self.return_field),
},
option,
)
Expand Down Expand Up @@ -223,6 +240,7 @@ impl PhysicalExpr for AsyncFuncExpr {
Ok(Arc::new(AsyncFuncExpr {
name: self.name.clone(),
func: new_func,
return_field: Arc::clone(&self.return_field),
}))
}

Expand Down
2 changes: 2 additions & 0 deletions datafusion/physical-plan/src/async_func.rs
Original file line number Diff line number Diff line change
Expand Up @@ -245,6 +245,7 @@ impl AsyncMapper {
pub fn find_references(
&mut self,
physical_expr: &Arc<dyn PhysicalExpr>,
schema: &Schema,
) -> Result<()> {
// recursively look for references to async functions
physical_expr.apply(|expr| {
Expand All @@ -256,6 +257,7 @@ impl AsyncMapper {
self.async_exprs.push(Arc::new(AsyncFuncExpr::try_new(
next_name,
Arc::clone(expr),
schema,
)?));
}
}
Expand Down
Loading