@@ -1378,7 +1378,18 @@ def _interp_axis(
13781378 elif component == "real" :
13791379 values = values .real
13801380
1381- return values .sum (axis = - 1 ).reshape (eps_shape )
1381+ vjp_array = values .sum (axis = - 1 ).reshape (eps_shape )
1382+
1383+ # match derivative dtype to the underlying dataset
1384+ target_array = getattr (spatial_data , "values" , None )
1385+ if target_array is None and hasattr (spatial_data , "data" ):
1386+ target_array = spatial_data .data
1387+ if target_array is not None :
1388+ target_dtype = np .asarray (target_array ).dtype
1389+ if not np .issubdtype (target_dtype , np .complexfloating ):
1390+ vjp_array = np .real (vjp_array ).astype (target_dtype , copy = False )
1391+
1392+ return vjp_array
13821393
13831394
13841395""" Dispersionless Medium """
@@ -3591,7 +3602,9 @@ def is_spatially_uniform(self) -> bool:
35913602 return True
35923603
35933604 @staticmethod
3594- def _sorted_spatial_data (data : CustomSpatialDataTypeAnnotated ):
3605+ def _sorted_spatial_data (
3606+ data : CustomSpatialDataTypeAnnotated ,
3607+ ) -> CustomSpatialDataTypeAnnotated :
35953608 """Return spatial data sorted along its coordinates if applicable."""
35963609 if isinstance (data , SpatialDataArray ):
35973610 return data ._spatially_sorted
0 commit comments