Skip to content

Commit 0346b5e

Browse files
committed
simple example
1 parent 8b83572 commit 0346b5e

File tree

1 file changed

+255
-0
lines changed

1 file changed

+255
-0
lines changed
Lines changed: 255 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,255 @@
1+
// Licensed to the Apache Software Foundation (ASF) under one
2+
// or more contributor license agreements. See the NOTICE file
3+
// distributed with this work for additional information
4+
// regarding copyright ownership. The ASF licenses this file
5+
// to you under the Apache License, Version 2.0 (the
6+
// "License"); you may not use this file except in compliance
7+
// with the License. You may obtain a copy of the License at
8+
//
9+
// http://www.apache.org/licenses/LICENSE-2.0
10+
//
11+
// Unless required by applicable law or agreed to in writing,
12+
// software distributed under the License is distributed on an
13+
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
// KIND, either express or implied. See the License for the
15+
// specific language governing permissions and limitations
16+
// under the License.
17+
18+
use arrow::array::{ArrayIter, ArrayRef, AsArray, Int64Array, RecordBatch, StringArray};
19+
use arrow::compute::kernels::cmp::eq;
20+
use arrow_schema::{DataType, Field, Schema};
21+
use async_trait::async_trait;
22+
use datafusion::common::error::Result;
23+
use datafusion::common::internal_err;
24+
use datafusion::common::types::{logical_int64, logical_string};
25+
use datafusion::common::utils::take_function_args;
26+
use datafusion::config::ConfigOptions;
27+
use datafusion::execution::{FunctionRegistry, SessionStateBuilder};
28+
use datafusion::logical_expr::async_udf::{
29+
AsyncScalarFunctionArgs, AsyncScalarUDF, AsyncScalarUDFImpl,
30+
};
31+
use datafusion::logical_expr::{
32+
ColumnarValue, Signature, TypeSignature, TypeSignatureClass, Volatility,
33+
};
34+
use datafusion::logical_expr_common::signature::Coercion;
35+
use datafusion::physical_expr_common::datum::apply_cmp;
36+
use datafusion::prelude::SessionContext;
37+
use log::trace;
38+
use std::any::Any;
39+
use std::sync::Arc;
40+
41+
#[tokio::main]
42+
async fn main() -> Result<()> {
43+
let mut state = SessionStateBuilder::new().build();
44+
45+
let async_upper = AsyncUpper::new();
46+
let udf = AsyncScalarUDF::new(Arc::new(async_upper));
47+
state.register_udf(udf.into_scalar_udf())?;
48+
let async_equal = AsyncEqual::new();
49+
let udf = AsyncScalarUDF::new(Arc::new(async_equal));
50+
state.register_udf(udf.into_scalar_udf())?;
51+
let ctx = SessionContext::new_with_state(state);
52+
ctx.register_batch("animal", animal()?)?;
53+
54+
// use Async UDF in the projection
55+
// +---------------+----------------------------------------------------------------------------------------+
56+
// | plan_type | plan |
57+
// +---------------+----------------------------------------------------------------------------------------+
58+
// | logical_plan | Projection: async_equal(a.id, Int64(1)) |
59+
// | | SubqueryAlias: a |
60+
// | | TableScan: animal projection=[id] |
61+
// | physical_plan | ProjectionExec: expr=[__async_fn_0@1 as async_equal(a.id,Int64(1))] |
62+
// | | AsyncFuncExec: async_expr=[async_expr(name=__async_fn_0, expr=async_equal(id@0, 1))] |
63+
// | | CoalesceBatchesExec: target_batch_size=8192 |
64+
// | | DataSourceExec: partitions=1, partition_sizes=[1] |
65+
// | | |
66+
// +---------------+----------------------------------------------------------------------------------------+
67+
ctx.sql("explain select async_equal(a.id, 1) from animal a")
68+
.await?
69+
.show()
70+
.await?;
71+
72+
// +----------------------------+
73+
// | async_equal(a.id,Int64(1)) |
74+
// +----------------------------+
75+
// | true |
76+
// | false |
77+
// | false |
78+
// | false |
79+
// | false |
80+
// +----------------------------+
81+
ctx.sql("select async_equal(a.id, 1) from animal a")
82+
.await?
83+
.show()
84+
.await?;
85+
86+
// +---------------+--------------------------------------------------------------------------------------------+
87+
// | plan_type | plan |
88+
// +---------------+--------------------------------------------------------------------------------------------+
89+
// | logical_plan | SubqueryAlias: a |
90+
// | | Filter: async_equal(animal.id, Int64(1)) |
91+
// | | TableScan: animal projection=[id, name] |
92+
// | physical_plan | CoalesceBatchesExec: target_batch_size=8192 |
93+
// | | FilterExec: __async_fn_0@2, projection=[id@0, name@1] |
94+
// | | RepartitionExec: partitioning=RoundRobinBatch(12), input_partitions=1 |
95+
// | | AsyncFuncExec: async_expr=[async_expr(name=__async_fn_0, expr=async_equal(id@0, 1))] |
96+
// | | CoalesceBatchesExec: target_batch_size=8192 |
97+
// | | DataSourceExec: partitions=1, partition_sizes=[1] |
98+
// | | |
99+
// +---------------+--------------------------------------------------------------------------------------------+
100+
ctx.sql("explain select * from animal a where async_equal(a.id, 1)")
101+
.await?
102+
.show()
103+
.await?;
104+
105+
// +----+------+
106+
// | id | name |
107+
// +----+------+
108+
// | 1 | cat |
109+
// +----+------+
110+
ctx.sql("select * from animal a where async_equal(a.id, 1)")
111+
.await?
112+
.show()
113+
.await?;
114+
115+
Ok(())
116+
}
117+
118+
fn animal() -> Result<RecordBatch> {
119+
let schema = Arc::new(Schema::new(vec![
120+
Field::new("id", DataType::Int64, false),
121+
Field::new("name", DataType::Utf8, false),
122+
]));
123+
124+
let id_array = Arc::new(Int64Array::from(vec![1, 2, 3, 4, 5]));
125+
let name_array = Arc::new(StringArray::from(vec![
126+
"cat", "dog", "fish", "bird", "snake",
127+
]));
128+
129+
Ok(RecordBatch::try_new(schema, vec![id_array, name_array])?)
130+
}
131+
132+
#[derive(Debug)]
133+
pub struct AsyncUpper {
134+
signature: Signature,
135+
}
136+
137+
impl Default for AsyncUpper {
138+
fn default() -> Self {
139+
Self::new()
140+
}
141+
}
142+
143+
impl AsyncUpper {
144+
pub fn new() -> Self {
145+
Self {
146+
signature: Signature::new(
147+
TypeSignature::Coercible(vec![Coercion::Exact {
148+
desired_type: TypeSignatureClass::Native(logical_string()),
149+
}]),
150+
Volatility::Volatile,
151+
),
152+
}
153+
}
154+
}
155+
156+
#[async_trait]
157+
impl AsyncScalarUDFImpl for AsyncUpper {
158+
fn as_any(&self) -> &dyn Any {
159+
self
160+
}
161+
162+
fn name(&self) -> &str {
163+
"async_upper"
164+
}
165+
166+
fn signature(&self) -> &Signature {
167+
&self.signature
168+
}
169+
170+
fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
171+
Ok(DataType::Utf8)
172+
}
173+
174+
fn ideal_batch_size(&self) -> Option<usize> {
175+
Some(10)
176+
}
177+
178+
async fn invoke_async_with_args(
179+
&self,
180+
args: AsyncScalarFunctionArgs,
181+
_option: &ConfigOptions,
182+
) -> Result<ArrayRef> {
183+
trace!("Invoking async_upper with args: {:?}", args);
184+
let value = &args.args[0];
185+
let result = match value {
186+
ColumnarValue::Array(array) => {
187+
let string_array = array.as_string::<i32>();
188+
let iter = ArrayIter::new(string_array);
189+
let result = iter
190+
.map(|string| string.map(|s| s.to_uppercase()))
191+
.collect::<StringArray>();
192+
Arc::new(result) as ArrayRef
193+
}
194+
_ => return internal_err!("Expected a string argument, got {:?}", value),
195+
};
196+
Ok(result)
197+
}
198+
}
199+
200+
#[derive(Debug)]
201+
struct AsyncEqual {
202+
signature: Signature,
203+
}
204+
205+
impl Default for AsyncEqual {
206+
fn default() -> Self {
207+
Self::new()
208+
}
209+
}
210+
211+
impl AsyncEqual {
212+
pub fn new() -> Self {
213+
Self {
214+
signature: Signature::new(
215+
TypeSignature::Coercible(vec![
216+
Coercion::Exact {
217+
desired_type: TypeSignatureClass::Native(logical_int64()),
218+
},
219+
Coercion::Exact {
220+
desired_type: TypeSignatureClass::Native(logical_int64()),
221+
},
222+
]),
223+
Volatility::Volatile,
224+
),
225+
}
226+
}
227+
}
228+
229+
#[async_trait]
230+
impl AsyncScalarUDFImpl for AsyncEqual {
231+
fn as_any(&self) -> &dyn Any {
232+
self
233+
}
234+
235+
fn name(&self) -> &str {
236+
"async_equal"
237+
}
238+
239+
fn signature(&self) -> &Signature {
240+
&self.signature
241+
}
242+
243+
fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
244+
Ok(DataType::Boolean)
245+
}
246+
247+
async fn invoke_async_with_args(
248+
&self,
249+
args: AsyncScalarFunctionArgs,
250+
_option: &ConfigOptions,
251+
) -> Result<ArrayRef> {
252+
let [arg1, arg2] = take_function_args(self.name(), &args.args)?;
253+
apply_cmp(&arg1, &arg2, eq)?.to_array(args.number_rows)
254+
}
255+
}

0 commit comments

Comments
 (0)