Skip to content

Commit 68aa763

Browse files
authored
Multispectral tools: convert data to float32 dtype before doing calculations (#755)
* convert to f4 dtype before doing calculations * add tests
1 parent b0c87a9 commit 68aa763

File tree

2 files changed

+138
-12
lines changed

2 files changed

+138
-12
lines changed

xrspatial/multispectral.py

Lines changed: 20 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -146,7 +146,9 @@ def arvi(nir_agg: xr.DataArray,
146146
cupy_func=_arvi_cupy,
147147
dask_cupy_func=_arvi_dask_cupy)
148148

149-
out = mapper(red_agg)(nir_agg.data, red_agg.data, blue_agg.data)
149+
out = mapper(red_agg)(
150+
nir_agg.data.astype('f4'), red_agg.data.astype('f4'), blue_agg.data.astype('f4')
151+
)
150152

151153
return DataArray(out,
152154
name=name,
@@ -311,8 +313,10 @@ def evi(nir_agg: xr.DataArray,
311313
cupy_func=_evi_cupy,
312314
dask_cupy_func=_evi_dask_cupy)
313315

314-
out = mapper(red_agg)(nir_agg.data, red_agg.data, blue_agg.data, c1, c2,
315-
soil_factor, gain)
316+
out = mapper(red_agg)(
317+
nir_agg.data.astype('f4'), red_agg.data.astype('f4'), blue_agg.data.astype('f4'),
318+
c1, c2, soil_factor, gain
319+
)
316320

317321
return DataArray(out,
318322
name=name,
@@ -431,7 +435,7 @@ def gci(nir_agg: xr.DataArray,
431435
cupy_func=_gci_cupy,
432436
dask_cupy_func=_gci_dask_cupy)
433437

434-
out = mapper(nir_agg)(nir_agg.data, green_agg.data)
438+
out = mapper(nir_agg)(nir_agg.data.astype('f4'), green_agg.data.astype('f4'))
435439

436440
return DataArray(out,
437441
name=name,
@@ -510,7 +514,7 @@ def nbr(nir_agg: xr.DataArray,
510514
dask_cupy_func=_run_normalized_ratio_dask_cupy,
511515
)
512516

513-
out = mapper(nir_agg)(nir_agg.data, swir2_agg.data)
517+
out = mapper(nir_agg)(nir_agg.data.astype('f4'), swir2_agg.data.astype('f4'))
514518

515519
return DataArray(out,
516520
name=name,
@@ -594,7 +598,7 @@ def nbr2(swir1_agg: xr.DataArray,
594598
dask_cupy_func=_run_normalized_ratio_dask_cupy,
595599
)
596600

597-
out = mapper(swir1_agg)(swir1_agg.data, swir2_agg.data)
601+
out = mapper(swir1_agg)(swir1_agg.data.astype('f4'), swir2_agg.data.astype('f4'))
598602

599603
return DataArray(out,
600604
name=name,
@@ -671,7 +675,7 @@ def ndvi(nir_agg: xr.DataArray,
671675
dask_cupy_func=_run_normalized_ratio_dask_cupy,
672676
)
673677

674-
out = mapper(nir_agg)(nir_agg.data, red_agg.data)
678+
out = mapper(nir_agg)(nir_agg.data.astype('f4'), red_agg.data.astype('f4'))
675679

676680
return DataArray(out,
677681
name=name,
@@ -753,7 +757,7 @@ def ndmi(nir_agg: xr.DataArray,
753757
dask_cupy_func=_run_normalized_ratio_dask_cupy,
754758
)
755759

756-
out = mapper(nir_agg)(nir_agg.data, swir1_agg.data)
760+
out = mapper(nir_agg)(nir_agg.data.astype('f4'), swir1_agg.data.astype('f4'))
757761

758762
return DataArray(out,
759763
name=name,
@@ -937,7 +941,7 @@ def savi(nir_agg: xr.DataArray,
937941
cupy_func=_savi_cupy,
938942
dask_cupy_func=_savi_dask_cupy)
939943

940-
out = mapper(red_agg)(nir_agg.data, red_agg.data, soil_factor)
944+
out = mapper(red_agg)(nir_agg.data.astype('f4'), red_agg.data.astype('f4'), soil_factor)
941945

942946
return DataArray(out,
943947
name=name,
@@ -1071,7 +1075,9 @@ def sipi(nir_agg: xr.DataArray,
10711075
cupy_func=_sipi_cupy,
10721076
dask_cupy_func=_sipi_dask_cupy)
10731077

1074-
out = mapper(red_agg)(nir_agg.data, red_agg.data, blue_agg.data)
1078+
out = mapper(red_agg)(
1079+
nir_agg.data.astype('f4'), red_agg.data.astype('f4'), blue_agg.data.astype('f4')
1080+
)
10751081

10761082
return DataArray(out,
10771083
name=name,
@@ -1238,7 +1244,9 @@ def ebbi(red_agg: xr.DataArray,
12381244
cupy_func=_ebbi_cupy,
12391245
dask_cupy_func=_ebbi_dask_cupy)
12401246

1241-
out = mapper(red_agg)(red_agg.data, swir_agg.data, tir_agg.data)
1247+
out = mapper(red_agg)(
1248+
red_agg.data.astype('f4'), swir_agg.data.astype('f4'), tir_agg.data.astype('f4')
1249+
)
12421250

12431251
return DataArray(out,
12441252
name=name,
@@ -1298,7 +1306,7 @@ def _normalize_data(agg, pixel_max, c, th):
12981306
dask_func=_normalize_data_dask,
12991307
cupy_func=_normalize_data_cupy,
13001308
dask_cupy_func=_normalize_data_dask_cupy)
1301-
out = mapper(agg)(agg.data, pixel_max, c, th)
1309+
out = mapper(agg)(agg.data.astype('f4'), pixel_max, c, th)
13021310
return out
13031311

13041312

xrspatial/tests/test_multispectral.py

Lines changed: 118 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -281,6 +281,60 @@ def qgis_ebbi():
281281
return result
282282

283283

284+
@pytest.fixture
285+
def data_uint_dtype_normalized_ratio(dtype):
286+
# test data for input data array of uint dtype
287+
# normalized ratio is applied with different bands for NBR, NBR2, NDVI, NDMI.
288+
band1 = xr.DataArray(np.array([[1, 1], [1, 1]], dtype=dtype))
289+
band2 = xr.DataArray(np.array([[0, 2], [1, 2]], dtype=dtype))
290+
result = np.array([[1, -0.33333334], [0, -0.33333334]], dtype=np.float32)
291+
return band1, band2, result
292+
293+
294+
@pytest.fixture
295+
def data_uint_dtype_arvi(dtype):
296+
nir = xr.DataArray(np.array([[1, 1], [1, 1]], dtype=dtype))
297+
red = xr.DataArray(np.array([[0, 1], [0, 2]], dtype=dtype))
298+
blue = xr.DataArray(np.array([[0, 2], [1, 2]], dtype=dtype))
299+
result = np.array([[1, 0.2], [1, -0.14285715]], dtype=np.float32)
300+
return nir, red, blue, result
301+
302+
303+
@pytest.fixture
304+
def data_uint_dtype_evi(dtype):
305+
nir = xr.DataArray(np.array([[1, 1], [1, 1]], dtype=dtype))
306+
red = xr.DataArray(np.array([[0, 1], [0, 2]], dtype=dtype))
307+
blue = xr.DataArray(np.array([[0, 2], [1, 2]], dtype=dtype))
308+
result = np.array([[1.25, 0.], [-0.45454547, 2.5]], dtype=np.float32)
309+
return nir, red, blue, result
310+
311+
312+
@pytest.fixture
313+
def data_uint_dtype_savi(dtype):
314+
nir = xr.DataArray(np.array([[1, 1], [1, 1]], dtype=dtype))
315+
red = xr.DataArray(np.array([[0, 1], [0, 2]], dtype=dtype))
316+
result = np.array([[0.25, 0.], [0.25, -0.125]], dtype=np.float32)
317+
return nir, red, result
318+
319+
320+
@pytest.fixture
321+
def data_uint_dtype_sipi(dtype):
322+
nir = xr.DataArray(np.array([[1, 1], [1, 1]], dtype=dtype))
323+
red = xr.DataArray(np.array([[0, 0], [0, 2]], dtype=dtype))
324+
blue = xr.DataArray(np.array([[0, 2], [1, 2]], dtype=dtype))
325+
result = np.array([[1, -1], [0, 1]], dtype=np.float32)
326+
return nir, red, blue, result
327+
328+
329+
@pytest.fixture
330+
def data_uint_dtype_ebbi(dtype):
331+
red = xr.DataArray(np.array([[0, 0], [0, 2]], dtype=dtype))
332+
swir = xr.DataArray(np.array([[1, 1], [1, 1]], dtype=dtype))
333+
tir = xr.DataArray(np.array([[0, 2], [1, 2]], dtype=dtype))
334+
result = np.array([[0.1, 0.05773503], [0.07071068, -0.05773503]], dtype=np.float32)
335+
return red, swir, tir, result
336+
337+
284338
# NDVI -------------
285339
def test_ndvi_data_contains_valid_values():
286340
_x = np.mgrid[1:0:21j]
@@ -310,6 +364,13 @@ def test_ndvi_cpu_against_qgis(nir_data, red_data, qgis_ndvi):
310364
general_output_checks(nir_data, result, qgis_ndvi, verify_dtype=True)
311365

312366

367+
@pytest.mark.parametrize("dtype", ["uint8", "uint16"])
368+
def test_ndvi_uint_dtype(data_uint_dtype_normalized_ratio):
369+
nir_data, red_data, result_ndvi = data_uint_dtype_normalized_ratio
370+
result = ndvi(nir_data, red_data)
371+
general_output_checks(nir_data, result, result_ndvi, verify_dtype=True)
372+
373+
313374
@cuda_and_cupy_available
314375
@pytest.mark.parametrize("backend", ["cupy", "dask+cupy"])
315376
def test_ndvi_gpu(nir_data, red_data, qgis_ndvi):
@@ -325,6 +386,7 @@ def test_savi_zero_soil_factor_cpu_against_qgis(nir_data, red_data, qgis_ndvi):
325386
general_output_checks(nir_data, qgis_savi, qgis_ndvi, verify_dtype=True)
326387

327388

389+
@cuda_and_cupy_available
328390
@pytest.mark.parametrize("backend", ["cupy", "dask+cupy"])
329391
def test_savi_zero_soil_factor_gpu(nir_data, red_data, qgis_ndvi):
330392
# savi should be same as ndvi at soil_factor=0
@@ -339,6 +401,13 @@ def test_savi_cpu_against_qgis(nir_data, red_data, qgis_savi):
339401
general_output_checks(nir_data, result, qgis_savi)
340402

341403

404+
@pytest.mark.parametrize("dtype", ["uint8", "uint16"])
405+
def test_savi_uint_dtype(data_uint_dtype_savi):
406+
nir_data, red_data, result_savi = data_uint_dtype_savi
407+
result = savi(nir_data, red_data)
408+
general_output_checks(nir_data, result, result_savi, verify_dtype=True)
409+
410+
342411
@cuda_and_cupy_available
343412
@pytest.mark.parametrize("backend", ["cupy", "dask+cupy"])
344413
def test_savi_gpu(nir_data, red_data, qgis_savi):
@@ -354,6 +423,13 @@ def test_arvi_cpu_against_qgis(nir_data, red_data, blue_data, qgis_arvi):
354423
general_output_checks(nir_data, result, qgis_arvi)
355424

356425

426+
@pytest.mark.parametrize("dtype", ["uint8", "uint16"])
427+
def test_arvi_uint_dtype(data_uint_dtype_arvi):
428+
nir_data, red_data, blue_data, result_arvi = data_uint_dtype_arvi
429+
result = arvi(nir_data, red_data, blue_data)
430+
general_output_checks(nir_data, result, result_arvi, verify_dtype=True)
431+
432+
357433
@cuda_and_cupy_available
358434
@pytest.mark.parametrize("backend", ["cupy", "dask+cupy"])
359435
def test_arvi_gpu(nir_data, red_data, blue_data, qgis_arvi):
@@ -368,6 +444,13 @@ def test_evi_cpu_against_qgis(nir_data, red_data, blue_data, qgis_evi):
368444
general_output_checks(nir_data, result, qgis_evi)
369445

370446

447+
@pytest.mark.parametrize("dtype", ["uint8", "uint16"])
448+
def test_evi_uint_dtype(data_uint_dtype_evi):
449+
nir_data, red_data, blue_data, result_evi = data_uint_dtype_evi
450+
result = evi(nir_data, red_data, blue_data)
451+
general_output_checks(nir_data, result, result_evi, verify_dtype=True)
452+
453+
371454
@cuda_and_cupy_available
372455
@pytest.mark.parametrize("backend", ["cupy", "dask+cupy"])
373456
def test_evi_gpu(nir_data, red_data, blue_data, qgis_evi):
@@ -396,6 +479,13 @@ def test_sipi_cpu_against_qgis(nir_data, red_data, blue_data, qgis_sipi):
396479
general_output_checks(nir_data, result, qgis_sipi)
397480

398481

482+
@pytest.mark.parametrize("dtype", ["uint8", "uint16"])
483+
def test_sipi_uint_dtype(data_uint_dtype_sipi):
484+
nir_data, red_data, blue_data, result_sipi = data_uint_dtype_sipi
485+
result = sipi(nir_data, red_data, blue_data)
486+
general_output_checks(nir_data, result, result_sipi, verify_dtype=True)
487+
488+
399489
@cuda_and_cupy_available
400490
@pytest.mark.parametrize("backend", ["cupy", "dask+cupy"])
401491
def test_sipi_gpu(nir_data, red_data, blue_data, qgis_sipi):
@@ -410,6 +500,13 @@ def test_nbr_cpu_against_qgis(nir_data, swir2_data, qgis_nbr):
410500
general_output_checks(nir_data, result, qgis_nbr)
411501

412502

503+
@pytest.mark.parametrize("dtype", ["uint8", "uint16"])
504+
def test_nbr_uint_dtype(data_uint_dtype_normalized_ratio):
505+
nir_data, red_data, result_nbr = data_uint_dtype_normalized_ratio
506+
result = nbr(nir_data, red_data)
507+
general_output_checks(nir_data, result, result_nbr, verify_dtype=True)
508+
509+
413510
@cuda_and_cupy_available
414511
@pytest.mark.parametrize("backend", ["cupy", "dask+cupy"])
415512
def test_nbr_gpu(nir_data, swir2_data, qgis_nbr):
@@ -424,6 +521,13 @@ def test_nbr2_cpu_against_qgis(swir1_data, swir2_data, qgis_nbr2):
424521
general_output_checks(swir1_data, result, qgis_nbr2)
425522

426523

524+
@pytest.mark.parametrize("dtype", ["uint8", "uint16"])
525+
def test_nbr2_uint_dtype(data_uint_dtype_normalized_ratio):
526+
nir_data, red_data, result_nbr2 = data_uint_dtype_normalized_ratio
527+
result = nbr2(nir_data, red_data)
528+
general_output_checks(nir_data, result, result_nbr2, verify_dtype=True)
529+
530+
427531
@cuda_and_cupy_available
428532
@pytest.mark.parametrize("backend", ["cupy", "dask+cupy"])
429533
def test_nbr2_gpu(swir1_data, swir2_data, qgis_nbr2):
@@ -438,6 +542,13 @@ def test_ndmi_cpu_against_qgis(nir_data, swir1_data, qgis_ndmi):
438542
general_output_checks(nir_data, result, qgis_ndmi)
439543

440544

545+
@pytest.mark.parametrize("dtype", ["uint8", "uint16"])
546+
def test_ndmi_uint_dtype(data_uint_dtype_normalized_ratio):
547+
nir_data, red_data, result_ndmi = data_uint_dtype_normalized_ratio
548+
result = ndmi(nir_data, red_data)
549+
general_output_checks(nir_data, result, result_ndmi, verify_dtype=True)
550+
551+
441552
@cuda_and_cupy_available
442553
@pytest.mark.parametrize("backend", ["cupy", "dask+cupy"])
443554
def test_ndmi_gpu(nir_data, swir1_data, qgis_ndmi):
@@ -452,6 +563,13 @@ def test_ebbi_cpu_against_qgis(red_data, swir1_data, tir_data, qgis_ebbi):
452563
general_output_checks(red_data, result, qgis_ebbi)
453564

454565

566+
@pytest.mark.parametrize("dtype", ["uint8", "uint16"])
567+
def test_ebbi_uint_dtype(data_uint_dtype_ebbi):
568+
red_data, swir_data, tir_data, result_ebbi = data_uint_dtype_ebbi
569+
result = ebbi(red_data, swir_data, tir_data)
570+
general_output_checks(red_data, result, result_ebbi, verify_dtype=True)
571+
572+
455573
@cuda_and_cupy_available
456574
@pytest.mark.parametrize("backend", ["cupy", "dask+cupy"])
457575
def test_ebbi_gpu(red_data, swir1_data, tir_data, qgis_ebbi):

0 commit comments

Comments
 (0)