Skip to content

Commit b405380

Browse files
authored
Simplify AsyncScalarUdfImpl so it extends ScalarUdfImpl (#16523)
* Simplify AsyncScalarUdfImpl so it extends ScalarUdfImpl * Update one example * Update one example * prettier
1 parent 0e48627 commit b405380

File tree

6 files changed

+98
-73
lines changed

6 files changed

+98
-73
lines changed

datafusion-examples/examples/async_udf.rs

Lines changed: 22 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -20,15 +20,14 @@ use arrow::compute::kernels::cmp::eq;
2020
use arrow_schema::{DataType, Field, Schema};
2121
use async_trait::async_trait;
2222
use datafusion::common::error::Result;
23-
use datafusion::common::internal_err;
2423
use datafusion::common::types::{logical_int64, logical_string};
2524
use datafusion::common::utils::take_function_args;
25+
use datafusion::common::{internal_err, not_impl_err};
2626
use datafusion::config::ConfigOptions;
27-
use datafusion::logical_expr::async_udf::{
28-
AsyncScalarFunctionArgs, AsyncScalarUDF, AsyncScalarUDFImpl,
29-
};
27+
use datafusion::logical_expr::async_udf::{AsyncScalarUDF, AsyncScalarUDFImpl};
3028
use datafusion::logical_expr::{
31-
ColumnarValue, Signature, TypeSignature, TypeSignatureClass, Volatility,
29+
ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, TypeSignature,
30+
TypeSignatureClass, Volatility,
3231
};
3332
use datafusion::logical_expr_common::signature::Coercion;
3433
use datafusion::physical_expr_common::datum::apply_cmp;
@@ -153,7 +152,7 @@ impl AsyncUpper {
153152
}
154153

155154
#[async_trait]
156-
impl AsyncScalarUDFImpl for AsyncUpper {
155+
impl ScalarUDFImpl for AsyncUpper {
157156
fn as_any(&self) -> &dyn Any {
158157
self
159158
}
@@ -170,13 +169,20 @@ impl AsyncScalarUDFImpl for AsyncUpper {
170169
Ok(DataType::Utf8)
171170
}
172171

172+
fn invoke_with_args(&self, _args: ScalarFunctionArgs) -> Result<ColumnarValue> {
173+
not_impl_err!("AsyncUpper can only be called from async contexts")
174+
}
175+
}
176+
177+
#[async_trait]
178+
impl AsyncScalarUDFImpl for AsyncUpper {
173179
fn ideal_batch_size(&self) -> Option<usize> {
174180
Some(10)
175181
}
176182

177183
async fn invoke_async_with_args(
178184
&self,
179-
args: AsyncScalarFunctionArgs,
185+
args: ScalarFunctionArgs,
180186
_option: &ConfigOptions,
181187
) -> Result<ArrayRef> {
182188
trace!("Invoking async_upper with args: {:?}", args);
@@ -226,7 +232,7 @@ impl AsyncEqual {
226232
}
227233

228234
#[async_trait]
229-
impl AsyncScalarUDFImpl for AsyncEqual {
235+
impl ScalarUDFImpl for AsyncEqual {
230236
fn as_any(&self) -> &dyn Any {
231237
self
232238
}
@@ -243,9 +249,16 @@ impl AsyncScalarUDFImpl for AsyncEqual {
243249
Ok(DataType::Boolean)
244250
}
245251

252+
fn invoke_with_args(&self, _args: ScalarFunctionArgs) -> Result<ColumnarValue> {
253+
not_impl_err!("AsyncEqual can only be called from async contexts")
254+
}
255+
}
256+
257+
#[async_trait]
258+
impl AsyncScalarUDFImpl for AsyncEqual {
246259
async fn invoke_async_with_args(
247260
&self,
248-
args: AsyncScalarFunctionArgs,
261+
args: ScalarFunctionArgs,
249262
_option: &ConfigOptions,
250263
) -> Result<ArrayRef> {
251264
let [arg1, arg2] = take_function_args(self.name(), &args.args)?;

datafusion/core/src/physical_planner.rs

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -779,9 +779,11 @@ impl DefaultPhysicalPlanner {
779779
let runtime_expr =
780780
self.create_physical_expr(predicate, input_dfschema, session_state)?;
781781

782+
let input_schema = input.schema();
782783
let filter = match self.try_plan_async_exprs(
783-
input.schema().fields().len(),
784+
input_schema.fields().len(),
784785
PlannedExprResult::Expr(vec![runtime_expr]),
786+
input_schema.as_arrow(),
785787
)? {
786788
PlanAsyncExpr::Sync(PlannedExprResult::Expr(runtime_expr)) => {
787789
FilterExec::try_new(Arc::clone(&runtime_expr[0]), physical_input)?
@@ -2082,6 +2084,7 @@ impl DefaultPhysicalPlanner {
20822084
match self.try_plan_async_exprs(
20832085
num_input_columns,
20842086
PlannedExprResult::ExprWithName(physical_exprs),
2087+
input_physical_schema.as_ref(),
20852088
)? {
20862089
PlanAsyncExpr::Sync(PlannedExprResult::ExprWithName(physical_exprs)) => Ok(
20872090
Arc::new(ProjectionExec::try_new(physical_exprs, input_exec)?),
@@ -2104,18 +2107,19 @@ impl DefaultPhysicalPlanner {
21042107
&self,
21052108
num_input_columns: usize,
21062109
physical_expr: PlannedExprResult,
2110+
schema: &Schema,
21072111
) -> Result<PlanAsyncExpr> {
21082112
let mut async_map = AsyncMapper::new(num_input_columns);
21092113
match &physical_expr {
21102114
PlannedExprResult::ExprWithName(exprs) => {
21112115
exprs
21122116
.iter()
2113-
.try_for_each(|(expr, _)| async_map.find_references(expr))?;
2117+
.try_for_each(|(expr, _)| async_map.find_references(expr, schema))?;
21142118
}
21152119
PlannedExprResult::Expr(exprs) => {
21162120
exprs
21172121
.iter()
2118-
.try_for_each(|expr| async_map.find_references(expr))?;
2122+
.try_for_each(|expr| async_map.find_references(expr, schema))?;
21192123
}
21202124
}
21212125

datafusion/expr/src/async_udf.rs

Lines changed: 4 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717

1818
use crate::{ReturnFieldArgs, ScalarFunctionArgs, ScalarUDF, ScalarUDFImpl};
1919
use arrow::array::ArrayRef;
20-
use arrow::datatypes::{DataType, Field, FieldRef, SchemaRef};
20+
use arrow::datatypes::{DataType, FieldRef};
2121
use async_trait::async_trait;
2222
use datafusion_common::config::ConfigOptions;
2323
use datafusion_common::error::Result;
@@ -35,34 +35,7 @@ use std::sync::Arc;
3535
///
3636
/// The name is chosen to mirror ScalarUDFImpl
3737
#[async_trait]
38-
pub trait AsyncScalarUDFImpl: Debug + Send + Sync {
39-
/// the function cast as any
40-
fn as_any(&self) -> &dyn Any;
41-
42-
/// The name of the function
43-
fn name(&self) -> &str;
44-
45-
/// The signature of the function
46-
fn signature(&self) -> &Signature;
47-
48-
/// The return type of the function
49-
fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType>;
50-
51-
/// What type will be returned by this function, given the arguments?
52-
///
53-
/// By default, this function calls [`Self::return_type`] with the
54-
/// types of each argument.
55-
fn return_field_from_args(&self, args: ReturnFieldArgs) -> Result<FieldRef> {
56-
let data_types = args
57-
.arg_fields
58-
.iter()
59-
.map(|f| f.data_type())
60-
.cloned()
61-
.collect::<Vec<_>>();
62-
let return_type = self.return_type(&data_types)?;
63-
Ok(Arc::new(Field::new(self.name(), return_type, true)))
64-
}
65-
38+
pub trait AsyncScalarUDFImpl: ScalarUDFImpl {
6639
/// The ideal batch size for this function.
6740
///
6841
/// This is used to determine what size of data to be evaluated at once.
@@ -74,7 +47,7 @@ pub trait AsyncScalarUDFImpl: Debug + Send + Sync {
7447
/// Invoke the function asynchronously with the async arguments
7548
async fn invoke_async_with_args(
7649
&self,
77-
args: AsyncScalarFunctionArgs,
50+
args: ScalarFunctionArgs,
7851
option: &ConfigOptions,
7952
) -> Result<ArrayRef>;
8053
}
@@ -107,7 +80,7 @@ impl AsyncScalarUDF {
10780
/// Invoke the function asynchronously with the async arguments
10881
pub async fn invoke_async_with_args(
10982
&self,
110-
args: AsyncScalarFunctionArgs,
83+
args: ScalarFunctionArgs,
11184
option: &ConfigOptions,
11285
) -> Result<ArrayRef> {
11386
self.inner.invoke_async_with_args(args, option).await
@@ -145,10 +118,3 @@ impl Display for AsyncScalarUDF {
145118
write!(f, "AsyncScalarUDF: {}", self.inner.name())
146119
}
147120
}
148-
149-
#[derive(Debug)]
150-
pub struct AsyncScalarFunctionArgs {
151-
pub args: Vec<ColumnarValue>,
152-
pub number_rows: usize,
153-
pub schema: SchemaRef,
154-
}

datafusion/physical-expr/src/async_scalar_function.rs

Lines changed: 26 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -17,11 +17,12 @@
1717

1818
use crate::ScalarFunctionExpr;
1919
use arrow::array::{make_array, MutableArrayData, RecordBatch};
20-
use arrow::datatypes::{DataType, Field, Schema};
20+
use arrow::datatypes::{DataType, Field, FieldRef, Schema};
2121
use datafusion_common::config::ConfigOptions;
2222
use datafusion_common::Result;
2323
use datafusion_common::{internal_err, not_impl_err};
24-
use datafusion_expr::async_udf::{AsyncScalarFunctionArgs, AsyncScalarUDF};
24+
use datafusion_expr::async_udf::AsyncScalarUDF;
25+
use datafusion_expr::ScalarFunctionArgs;
2526
use datafusion_expr_common::columnar_value::ColumnarValue;
2627
use datafusion_physical_expr_common::physical_expr::PhysicalExpr;
2728
use std::any::Any;
@@ -36,6 +37,8 @@ pub struct AsyncFuncExpr {
3637
pub name: String,
3738
/// The actual function (always `ScalarFunctionExpr`)
3839
pub func: Arc<dyn PhysicalExpr>,
40+
/// The field that this function will return
41+
return_field: FieldRef,
3942
}
4043

4144
impl Display for AsyncFuncExpr {
@@ -59,17 +62,23 @@ impl Hash for AsyncFuncExpr {
5962

6063
impl AsyncFuncExpr {
6164
/// create a new AsyncFuncExpr
62-
pub fn try_new(name: impl Into<String>, func: Arc<dyn PhysicalExpr>) -> Result<Self> {
65+
pub fn try_new(
66+
name: impl Into<String>,
67+
func: Arc<dyn PhysicalExpr>,
68+
schema: &Schema,
69+
) -> Result<Self> {
6370
let Some(_) = func.as_any().downcast_ref::<ScalarFunctionExpr>() else {
6471
return internal_err!(
6572
"unexpected function type, expected ScalarFunctionExpr, got: {:?}",
6673
func
6774
);
6875
};
6976

77+
let return_field = func.return_field(schema)?;
7078
Ok(Self {
7179
name: name.into(),
7280
func,
81+
return_field,
7382
})
7483
}
7584

@@ -128,6 +137,12 @@ impl AsyncFuncExpr {
128137
);
129138
};
130139

140+
let arg_fields = scalar_function_expr
141+
.args()
142+
.iter()
143+
.map(|e| e.return_field(batch.schema_ref()))
144+
.collect::<Result<Vec<_>>>()?;
145+
131146
let mut result_batches = vec![];
132147
if let Some(ideal_batch_size) = self.ideal_batch_size()? {
133148
let mut remainder = batch.clone();
@@ -148,10 +163,11 @@ impl AsyncFuncExpr {
148163
result_batches.push(
149164
async_udf
150165
.invoke_async_with_args(
151-
AsyncScalarFunctionArgs {
152-
args: args.to_vec(),
166+
ScalarFunctionArgs {
167+
args,
168+
arg_fields: arg_fields.clone(),
153169
number_rows: current_batch.num_rows(),
154-
schema: current_batch.schema(),
170+
return_field: Arc::clone(&self.return_field),
155171
},
156172
option,
157173
)
@@ -168,10 +184,11 @@ impl AsyncFuncExpr {
168184
result_batches.push(
169185
async_udf
170186
.invoke_async_with_args(
171-
AsyncScalarFunctionArgs {
187+
ScalarFunctionArgs {
172188
args: args.to_vec(),
189+
arg_fields,
173190
number_rows: batch.num_rows(),
174-
schema: batch.schema(),
191+
return_field: Arc::clone(&self.return_field),
175192
},
176193
option,
177194
)
@@ -223,6 +240,7 @@ impl PhysicalExpr for AsyncFuncExpr {
223240
Ok(Arc::new(AsyncFuncExpr {
224241
name: self.name.clone(),
225242
func: new_func,
243+
return_field: Arc::clone(&self.return_field),
226244
}))
227245
}
228246

datafusion/physical-plan/src/async_func.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -245,6 +245,7 @@ impl AsyncMapper {
245245
pub fn find_references(
246246
&mut self,
247247
physical_expr: &Arc<dyn PhysicalExpr>,
248+
schema: &Schema,
248249
) -> Result<()> {
249250
// recursively look for references to async functions
250251
physical_expr.apply(|expr| {
@@ -256,6 +257,7 @@ impl AsyncMapper {
256257
self.async_exprs.push(Arc::new(AsyncFuncExpr::try_new(
257258
next_name,
258259
Arc::clone(expr),
260+
schema,
259261
)?));
260262
}
261263
}

0 commit comments

Comments
 (0)