Skip to content

Commit a5ab298

Browse files
committed
Fix scalar raster broadcast in RasterExecutor
Fixes RasterExecutor::execute_raster_void() to broadcast scalar raster inputs across num_iterations when other arguments are arrays, and tightens coordinate UDF tests to cover scalar raster + array coordinates.
1 parent e51f5a5 commit a5ab298

File tree

5 files changed

+241
-13
lines changed

5 files changed

+241
-13
lines changed

rust/sedona-raster-functions/src/executor.rs

Lines changed: 15 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ impl<'a, 'b> RasterExecutor<'a, 'b> {
5959
/// 4. Calling the provided function with each raster
6060
pub fn execute_raster_void<F>(&self, mut func: F) -> Result<()>
6161
where
62-
F: FnMut(usize, Option<RasterRefImpl<'_>>) -> Result<()>,
62+
F: FnMut(usize, Option<&RasterRefImpl<'_>>) -> Result<()>,
6363
{
6464
if self.arg_types[0] != RASTER {
6565
return sedona_internal_err!("First argument must be a raster type");
@@ -87,22 +87,30 @@ impl<'a, 'b> RasterExecutor<'a, 'b> {
8787
continue;
8888
}
8989
let raster = raster_array.get(i)?;
90-
func(i, Some(raster))?;
90+
func(i, Some(&raster))?;
9191
}
9292

9393
Ok(())
9494
}
9595
ColumnarValue::Scalar(scalar_value) => match scalar_value {
9696
ScalarValue::Struct(arc_struct) => {
9797
let raster_array = RasterStructArray::new(arc_struct.as_ref());
98-
if raster_array.is_null(0) {
99-
func(0, None)
98+
let raster_opt = if raster_array.is_null(0) {
99+
None
100100
} else {
101-
let raster = raster_array.get(0)?;
102-
func(0, Some(raster))
101+
Some(raster_array.get(0)?)
102+
};
103+
for i in 0..self.num_iterations {
104+
func(i, raster_opt.as_ref())?;
103105
}
106+
Ok(())
107+
}
108+
ScalarValue::Null => {
109+
for i in 0..self.num_iterations {
110+
func(i, None)?;
111+
}
112+
Ok(())
104113
}
105-
ScalarValue::Null => func(0, None),
106114
_ => Err(DataFusionError::Internal(
107115
"Expected Struct scalar for raster".to_string(),
108116
)),

rust/sedona-raster-functions/src/rs_envelope.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,7 @@ impl SedonaScalarKernel for RsEnvelope {
8282
executor.execute_raster_void(|_i, raster_opt| {
8383
match raster_opt {
8484
Some(raster) => {
85-
create_envelope_wkb(&raster, &mut builder)?;
85+
create_envelope_wkb(raster, &mut builder)?;
8686
builder.append_value([]);
8787
}
8888
None => builder.append_null(),

rust/sedona-raster-functions/src/rs_geotransform.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -251,7 +251,7 @@ impl SedonaScalarKernel for RsGeoTransform {
251251
let metadata = raster.metadata();
252252
match self.param {
253253
GeoTransformParam::Rotation => {
254-
let rotation = rotation(&raster);
254+
let rotation = rotation(raster);
255255
builder.append_value(rotation);
256256
}
257257
GeoTransformParam::ScaleX => builder.append_value(metadata.scale_x()),

rust/sedona-raster-functions/src/rs_rastercoordinate.rs

Lines changed: 114 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -155,7 +155,7 @@ impl SedonaScalarKernel for RsCoordinateMapper {
155155

156156
match (raster_opt, x_opt, y_opt) {
157157
(Some(raster), Some(x), Some(y)) => {
158-
let (raster_x, raster_y) = to_raster_coordinate(&raster, x, y)?;
158+
let (raster_x, raster_y) = to_raster_coordinate(raster, x, y)?;
159159
match self.coord {
160160
Coord::X => builder.append_value(raster_x),
161161
Coord::Y => builder.append_value(raster_y),
@@ -216,7 +216,7 @@ impl SedonaScalarKernel for RsCoordinatePoint {
216216

217217
match (raster_opt, x_opt, y_opt) {
218218
(Some(raster), Some(world_x), Some(world_y)) => {
219-
let (raster_x, raster_y) = to_raster_coordinate(&raster, world_x, world_y)?;
219+
let (raster_x, raster_y) = to_raster_coordinate(raster, world_x, world_y)?;
220220
item[5..13].copy_from_slice(&(raster_x as f64).to_le_bytes());
221221
item[13..21].copy_from_slice(&(raster_y as f64).to_le_bytes());
222222
builder.append_value(item);
@@ -233,9 +233,13 @@ impl SedonaScalarKernel for RsCoordinatePoint {
233233
#[cfg(test)]
234234
mod tests {
235235
use super::*;
236+
use arrow_array::Array;
237+
use arrow_array::StructArray;
236238
use arrow_schema::DataType;
239+
use datafusion_common::ScalarValue;
237240
use datafusion_expr::ScalarUDF;
238241
use rstest::rstest;
242+
use sedona_raster::array::RasterStructArray;
239243
use sedona_schema::datatypes::{RASTER, WKB_GEOMETRY};
240244
use sedona_testing::compare::assert_array_equal;
241245
use sedona_testing::create::create_array;
@@ -368,4 +372,112 @@ mod tests {
368372
.unwrap();
369373
assert_array_equal(&result, &expected);
370374
}
375+
376+
#[rstest]
377+
fn udf_invoke_xy_with_scalar_raster_array_coords(#[values(Coord::Y, Coord::X)] coord: Coord) {
378+
let udf = match coord {
379+
Coord::X => rs_worldtorastercoordx_udf(),
380+
Coord::Y => rs_worldtorastercoordy_udf(),
381+
};
382+
let tester = ScalarUdfTester::new(
383+
udf.into(),
384+
vec![
385+
RASTER,
386+
SedonaType::Arrow(DataType::Float64),
387+
SedonaType::Arrow(DataType::Float64),
388+
],
389+
);
390+
391+
// Use raster 1 (invertible) as scalar.
392+
let rasters = generate_test_rasters(2, Some(0)).unwrap();
393+
let raster_struct = rasters.as_any().downcast_ref::<StructArray>().unwrap();
394+
let scalar_raster = ScalarValue::try_from_array(raster_struct, 1).unwrap();
395+
396+
let raster_ref = RasterStructArray::new(raster_struct).get(1).unwrap();
397+
398+
let world_x = Arc::new(arrow_array::Float64Array::from(vec![2.0, 2.5, 3.25]));
399+
let world_y = Arc::new(arrow_array::Float64Array::from(vec![3.0, 2.5, 1.75]));
400+
401+
let expected_coords: Vec<Option<i64>> = world_x
402+
.iter()
403+
.zip(world_y.iter())
404+
.map(|(x, y)| match (x, y) {
405+
(Some(x), Some(y)) => {
406+
let (rx, ry) = to_raster_coordinate(&raster_ref, x, y).unwrap();
407+
Some(match coord {
408+
Coord::X => rx,
409+
Coord::Y => ry,
410+
})
411+
}
412+
_ => None,
413+
})
414+
.collect();
415+
416+
let result = tester
417+
.invoke(vec![
418+
ColumnarValue::Scalar(scalar_raster),
419+
ColumnarValue::Array(world_x.clone()),
420+
ColumnarValue::Array(world_y.clone()),
421+
])
422+
.unwrap();
423+
424+
let array = match result {
425+
ColumnarValue::Array(array) => array,
426+
ColumnarValue::Scalar(_) => panic!("Expected array result"),
427+
};
428+
429+
let expected: Arc<dyn arrow_array::Array> =
430+
Arc::new(arrow_array::Int64Array::from(expected_coords));
431+
assert_array_equal(&array, &expected);
432+
}
433+
434+
#[test]
435+
fn udf_invoke_pt_with_scalar_raster_array_coords() {
436+
let udf = rs_worldtorastercoord_udf();
437+
let tester = ScalarUdfTester::new(
438+
udf.into(),
439+
vec![
440+
RASTER,
441+
SedonaType::Arrow(DataType::Float64),
442+
SedonaType::Arrow(DataType::Float64),
443+
],
444+
);
445+
446+
let rasters = generate_test_rasters(2, Some(0)).unwrap();
447+
let raster_struct = rasters.as_any().downcast_ref::<StructArray>().unwrap();
448+
let scalar_raster = ScalarValue::try_from_array(raster_struct, 1).unwrap();
449+
450+
let raster_ref = RasterStructArray::new(raster_struct).get(1).unwrap();
451+
452+
let world_x = Arc::new(arrow_array::Float64Array::from(vec![2.0, 2.5, 3.25]));
453+
let world_y = Arc::new(arrow_array::Float64Array::from(vec![3.0, 2.5, 1.75]));
454+
455+
let result = tester
456+
.invoke(vec![
457+
ColumnarValue::Scalar(scalar_raster),
458+
ColumnarValue::Array(world_x.clone()),
459+
ColumnarValue::Array(world_y.clone()),
460+
])
461+
.unwrap();
462+
463+
let array = match result {
464+
ColumnarValue::Array(array) => array,
465+
ColumnarValue::Scalar(_) => panic!("Expected array result"),
466+
};
467+
468+
let expected_wkt: Vec<Option<String>> = world_x
469+
.iter()
470+
.zip(world_y.iter())
471+
.map(|(x, y)| match (x, y) {
472+
(Some(x), Some(y)) => {
473+
let (rx, ry) = to_raster_coordinate(&raster_ref, x, y).unwrap();
474+
Some(format!("POINT ({} {})", rx, ry))
475+
}
476+
_ => None,
477+
})
478+
.collect();
479+
let expected_refs: Vec<Option<&str>> = expected_wkt.iter().map(|v| v.as_deref()).collect();
480+
let expected = create_array(&expected_refs, &WKB_GEOMETRY);
481+
assert_array_equal(&array, &expected);
482+
}
371483
}

rust/sedona-raster-functions/src/rs_worldcoordinate.rs

Lines changed: 110 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -153,7 +153,7 @@ impl SedonaScalarKernel for RsCoordinateMapper {
153153

154154
match (raster_opt, x_opt, y_opt) {
155155
(Some(raster), Some(x), Some(y)) => {
156-
let (world_x, world_y) = to_world_coordinate(&raster, x, y);
156+
let (world_x, world_y) = to_world_coordinate(raster, x, y);
157157
match self.coord {
158158
Coord::X => builder.append_value(world_x),
159159
Coord::Y => builder.append_value(world_y),
@@ -214,7 +214,7 @@ impl SedonaScalarKernel for RsCoordinatePoint {
214214

215215
match (raster_opt, x_opt, y_opt) {
216216
(Some(raster), Some(x), Some(y)) => {
217-
let (world_x, world_y) = to_world_coordinate(&raster, x, y);
217+
let (world_x, world_y) = to_world_coordinate(raster, x, y);
218218
item[5..13].copy_from_slice(&world_x.to_le_bytes());
219219
item[13..21].copy_from_slice(&world_y.to_le_bytes());
220220
builder.append_value(item);
@@ -232,8 +232,11 @@ impl SedonaScalarKernel for RsCoordinatePoint {
232232
mod tests {
233233
use super::*;
234234
use arrow_array::Array;
235+
use arrow_array::StructArray;
236+
use datafusion_common::ScalarValue;
235237
use datafusion_expr::ScalarUDF;
236238
use rstest::rstest;
239+
use sedona_raster::array::RasterStructArray;
237240
use sedona_schema::datatypes::{RASTER, WKB_GEOMETRY};
238241
use sedona_testing::compare::assert_array_equal;
239242
use sedona_testing::create::create_array;
@@ -357,4 +360,109 @@ mod tests {
357360
}
358361
}
359362
}
363+
364+
#[rstest]
365+
fn udf_invoke_xy_with_scalar_raster_array_coords(#[values(Coord::Y, Coord::X)] coord: Coord) {
366+
let udf = match coord {
367+
Coord::X => rs_rastertoworldcoordx_udf(),
368+
Coord::Y => rs_rastertoworldcoordy_udf(),
369+
};
370+
let tester = ScalarUdfTester::new(
371+
udf.into(),
372+
vec![
373+
RASTER,
374+
SedonaType::Arrow(DataType::Int32),
375+
SedonaType::Arrow(DataType::Int32),
376+
],
377+
);
378+
379+
// Use raster 0 as scalar; it has scale/skew of 0, so all pixels map to upper left.
380+
let rasters = generate_test_rasters(1, None).unwrap();
381+
let raster_struct = rasters.as_any().downcast_ref::<StructArray>().unwrap();
382+
let scalar_raster = ScalarValue::try_from_array(raster_struct, 0).unwrap();
383+
384+
let raster_ref = RasterStructArray::new(raster_struct).get(0).unwrap();
385+
386+
let x_vals = vec![0_i32, 1_i32, 2_i32];
387+
let y_vals = vec![0_i32, 1_i32, 2_i32];
388+
let x_coords: Arc<dyn Array> = Arc::new(arrow_array::Int32Array::from(x_vals.clone()));
389+
let y_coords: Arc<dyn Array> = Arc::new(arrow_array::Int32Array::from(y_vals.clone()));
390+
391+
let result = tester
392+
.invoke(vec![
393+
ColumnarValue::Scalar(scalar_raster),
394+
ColumnarValue::Array(x_coords),
395+
ColumnarValue::Array(y_coords),
396+
])
397+
.unwrap();
398+
399+
let array = match result {
400+
ColumnarValue::Array(array) => array,
401+
ColumnarValue::Scalar(_) => panic!("Expected array result"),
402+
};
403+
404+
let expected_values: Vec<Option<f64>> = x_vals
405+
.iter()
406+
.zip(y_vals.iter())
407+
.map(|(x, y)| {
408+
let (wx, wy) = to_world_coordinate(&raster_ref, *x as i64, *y as i64);
409+
Some(match coord {
410+
Coord::X => wx,
411+
Coord::Y => wy,
412+
})
413+
})
414+
.collect();
415+
let expected: Arc<dyn arrow_array::Array> =
416+
Arc::new(arrow_array::Float64Array::from(expected_values));
417+
assert_array_equal(&array, &expected);
418+
}
419+
420+
#[test]
421+
fn udf_invoke_pt_with_scalar_raster_array_coords() {
422+
let udf = rs_rastertoworldcoord_udf();
423+
let tester = ScalarUdfTester::new(
424+
udf.into(),
425+
vec![
426+
RASTER,
427+
SedonaType::Arrow(DataType::Int32),
428+
SedonaType::Arrow(DataType::Int32),
429+
],
430+
);
431+
432+
let rasters = generate_test_rasters(1, None).unwrap();
433+
let raster_struct = rasters.as_any().downcast_ref::<StructArray>().unwrap();
434+
let scalar_raster = ScalarValue::try_from_array(raster_struct, 0).unwrap();
435+
436+
let raster_ref = RasterStructArray::new(raster_struct).get(0).unwrap();
437+
438+
let x_vals = vec![0_i32, 1_i32, 2_i32];
439+
let y_vals = vec![0_i32, 1_i32, 2_i32];
440+
let x_coords: Arc<dyn Array> = Arc::new(arrow_array::Int32Array::from(x_vals.clone()));
441+
let y_coords: Arc<dyn Array> = Arc::new(arrow_array::Int32Array::from(y_vals.clone()));
442+
443+
let result = tester
444+
.invoke(vec![
445+
ColumnarValue::Scalar(scalar_raster),
446+
ColumnarValue::Array(x_coords),
447+
ColumnarValue::Array(y_coords),
448+
])
449+
.unwrap();
450+
451+
let array = match result {
452+
ColumnarValue::Array(array) => array,
453+
ColumnarValue::Scalar(_) => panic!("Expected array result"),
454+
};
455+
456+
let expected_wkt: Vec<Option<String>> = x_vals
457+
.iter()
458+
.zip(y_vals.iter())
459+
.map(|(x, y)| {
460+
let (wx, wy) = to_world_coordinate(&raster_ref, *x as i64, *y as i64);
461+
Some(format!("POINT ({} {})", wx, wy))
462+
})
463+
.collect();
464+
let expected_refs: Vec<Option<&str>> = expected_wkt.iter().map(|v| v.as_deref()).collect();
465+
let expected = create_array(&expected_refs, &WKB_GEOMETRY);
466+
assert_array_equal(&array, &expected);
467+
}
360468
}

0 commit comments

Comments
 (0)