17
17
18
18
use crate :: ScalarFunctionExpr ;
19
19
use arrow:: array:: { make_array, MutableArrayData , RecordBatch } ;
20
- use arrow:: datatypes:: { DataType , Field , Schema } ;
20
+ use arrow:: datatypes:: { DataType , Field , FieldRef , Schema } ;
21
21
use datafusion_common:: config:: ConfigOptions ;
22
22
use datafusion_common:: Result ;
23
23
use datafusion_common:: { internal_err, not_impl_err} ;
24
- use datafusion_expr:: async_udf:: { AsyncScalarFunctionArgs , AsyncScalarUDF } ;
24
+ use datafusion_expr:: async_udf:: AsyncScalarUDF ;
25
+ use datafusion_expr:: ScalarFunctionArgs ;
25
26
use datafusion_expr_common:: columnar_value:: ColumnarValue ;
26
27
use datafusion_physical_expr_common:: physical_expr:: PhysicalExpr ;
27
28
use std:: any:: Any ;
@@ -36,6 +37,8 @@ pub struct AsyncFuncExpr {
36
37
pub name : String ,
37
38
/// The actual function (always `ScalarFunctionExpr`)
38
39
pub func : Arc < dyn PhysicalExpr > ,
40
+ /// The field that this function will return
41
+ return_field : FieldRef ,
39
42
}
40
43
41
44
impl Display for AsyncFuncExpr {
@@ -59,17 +62,23 @@ impl Hash for AsyncFuncExpr {
59
62
60
63
impl AsyncFuncExpr {
61
64
/// create a new AsyncFuncExpr
62
- pub fn try_new ( name : impl Into < String > , func : Arc < dyn PhysicalExpr > ) -> Result < Self > {
65
+ pub fn try_new (
66
+ name : impl Into < String > ,
67
+ func : Arc < dyn PhysicalExpr > ,
68
+ schema : & Schema ,
69
+ ) -> Result < Self > {
63
70
let Some ( _) = func. as_any ( ) . downcast_ref :: < ScalarFunctionExpr > ( ) else {
64
71
return internal_err ! (
65
72
"unexpected function type, expected ScalarFunctionExpr, got: {:?}" ,
66
73
func
67
74
) ;
68
75
} ;
69
76
77
+ let return_field = func. return_field ( schema) ?;
70
78
Ok ( Self {
71
79
name : name. into ( ) ,
72
80
func,
81
+ return_field,
73
82
} )
74
83
}
75
84
@@ -128,6 +137,12 @@ impl AsyncFuncExpr {
128
137
) ;
129
138
} ;
130
139
140
+ let arg_fields = scalar_function_expr
141
+ . args ( )
142
+ . iter ( )
143
+ . map ( |e| e. return_field ( batch. schema_ref ( ) ) )
144
+ . collect :: < Result < Vec < _ > > > ( ) ?;
145
+
131
146
let mut result_batches = vec ! [ ] ;
132
147
if let Some ( ideal_batch_size) = self . ideal_batch_size ( ) ? {
133
148
let mut remainder = batch. clone ( ) ;
@@ -148,10 +163,11 @@ impl AsyncFuncExpr {
148
163
result_batches. push (
149
164
async_udf
150
165
. invoke_async_with_args (
151
- AsyncScalarFunctionArgs {
152
- args : args. to_vec ( ) ,
166
+ ScalarFunctionArgs {
167
+ args,
168
+ arg_fields : arg_fields. clone ( ) ,
153
169
number_rows : current_batch. num_rows ( ) ,
154
- schema : current_batch . schema ( ) ,
170
+ return_field : Arc :: clone ( & self . return_field ) ,
155
171
} ,
156
172
option,
157
173
)
@@ -168,10 +184,11 @@ impl AsyncFuncExpr {
168
184
result_batches. push (
169
185
async_udf
170
186
. invoke_async_with_args (
171
- AsyncScalarFunctionArgs {
187
+ ScalarFunctionArgs {
172
188
args : args. to_vec ( ) ,
189
+ arg_fields,
173
190
number_rows : batch. num_rows ( ) ,
174
- schema : batch . schema ( ) ,
191
+ return_field : Arc :: clone ( & self . return_field ) ,
175
192
} ,
176
193
option,
177
194
)
@@ -223,6 +240,7 @@ impl PhysicalExpr for AsyncFuncExpr {
223
240
Ok ( Arc :: new ( AsyncFuncExpr {
224
241
name : self . name . clone ( ) ,
225
242
func : new_func,
243
+ return_field : Arc :: clone ( & self . return_field ) ,
226
244
} ) )
227
245
}
228
246
0 commit comments