@@ -93,7 +93,7 @@ def _is_array(arr):
93
93
94
94
# Ideally this would eventually be covered by tests. Looks like Sparse
95
95
# could be used to test this, since it has no percentile...
96
- def _percentile_fallback (array , percentiles ): # pragma: no cover
96
+ def _percentile_fallback (array , percentiles , xp = None ): # pragma: no cover
97
97
"""
98
98
Try calculating percentile using namespace, otherwise fall back to
99
99
an implmentation that uses sort. As of the 2023 version of the array API
@@ -107,12 +107,16 @@ def _percentile_fallback(array, percentiles): # pragma: no cover
107
107
percentiles : float or list-like
108
108
Percentile to calculate.
109
109
110
+ xp : array namespace, optional
111
+ Array namespace to use for calculations. If not provided, the
112
+ namespace will be determined from the array.
113
+
110
114
Returns
111
115
-------
112
116
percentile : float or list-like
113
117
Calculated percentile.
114
118
"""
115
- xp = array_api_compat .array_namespace (array )
119
+ xp = xp or array_api_compat .array_namespace (array )
116
120
try :
117
121
return xp .percentile (array , percentiles )
118
122
except AttributeError :
@@ -285,11 +289,11 @@ def ccd_process(
285
289
# apply the overscan correction
286
290
if isinstance (oscan , CCDData ):
287
291
nccd = subtract_overscan (
288
- nccd , overscan = oscan , median = oscan_median , model = oscan_model
292
+ nccd , overscan = oscan , median = oscan_median , model = oscan_model , xp = xp
289
293
)
290
294
elif isinstance (oscan , str ):
291
295
nccd = subtract_overscan (
292
- nccd , fits_section = oscan , median = oscan_median , model = oscan_model
296
+ nccd , fits_section = oscan , median = oscan_median , model = oscan_model , xp = xp
293
297
)
294
298
elif oscan is None :
295
299
pass
@@ -298,6 +302,7 @@ def ccd_process(
298
302
299
303
# apply the trim correction
300
304
if isinstance (trim , str ):
305
+ # No xp=... here because slicing can be done without knowing the array namespace
301
306
nccd = trim_image (nccd , fits_section = trim )
302
307
elif trim is None :
303
308
pass
@@ -306,7 +311,7 @@ def ccd_process(
306
311
307
312
# create the error frame
308
313
if error and gain is not None and readnoise is not None :
309
- nccd = create_deviation (nccd , gain = gain , readnoise = readnoise )
314
+ nccd = create_deviation (nccd , gain = gain , readnoise = readnoise , xp = xp )
310
315
elif error and (gain is None or readnoise is None ):
311
316
raise ValueError ("gain and readnoise must be specified to create error frame." )
312
317
@@ -324,10 +329,12 @@ def ccd_process(
324
329
raise TypeError ("gain is not None or astropy.units.Quantity." )
325
330
326
331
if gain is not None and gain_corrected :
332
+ # No need for xp here because gain_correct does not need the namespace
327
333
nccd = gain_correct (nccd , gain )
328
334
329
335
# subtracting the master bias
330
336
if isinstance (master_bias , CCDData ):
337
+ # No need for xp here because subtract_bias does not need the namespace
331
338
nccd = subtract_bias (nccd , master_bias )
332
339
elif master_bias is None :
333
340
pass
@@ -336,6 +343,7 @@ def ccd_process(
336
343
337
344
# subtract the dark frame
338
345
if isinstance (dark_frame , CCDData ):
346
+ # No need for xp here because subtract_dark does not need the namespace
339
347
nccd = subtract_dark (
340
348
nccd ,
341
349
dark_frame ,
@@ -352,21 +360,22 @@ def ccd_process(
352
360
353
361
# test dividing the master flat
354
362
if isinstance (master_flat , CCDData ):
355
- nccd = flat_correct (nccd , master_flat , min_value = min_value )
363
+ nccd = flat_correct (nccd , master_flat , min_value = min_value , xp = xp )
356
364
elif master_flat is None :
357
365
pass
358
366
else :
359
367
raise TypeError ("master_flat is not None or a CCDData object." )
360
368
361
369
# apply the gain correction only at the end if gain_corrected is False
362
370
if gain is not None and not gain_corrected :
371
+ # No need for xp here because gain_correct does not need the namespace
363
372
nccd = gain_correct (nccd , gain )
364
373
365
374
return nccd
366
375
367
376
368
377
@log_to_metadata
369
- def create_deviation (ccd_data , gain = None , readnoise = None , disregard_nan = False ):
378
+ def create_deviation (ccd_data , gain = None , readnoise = None , disregard_nan = False , xp = None ):
370
379
"""
371
380
Create a uncertainty frame. The function will update the uncertainty
372
381
plane which gives the standard deviation for the data. Gain is used in
@@ -393,6 +402,10 @@ def create_deviation(ccd_data, gain=None, readnoise=None, disregard_nan=False):
393
402
If ``True``, any value of nan in the output array will be replaced by
394
403
the readnoise.
395
404
405
+ xp : array namespace, optional
406
+ Array namespace to use for calculations. If not provided, the
407
+ namespace will be determined from the array.
408
+
396
409
{log}
397
410
398
411
Raises
@@ -409,7 +422,7 @@ def create_deviation(ccd_data, gain=None, readnoise=None, disregard_nan=False):
409
422
410
423
"""
411
424
# Get array namespace
412
- xp = array_api_compat .array_namespace (ccd_data .data )
425
+ xp = xp or array_api_compat .array_namespace (ccd_data .data )
413
426
if gain is not None and not isinstance (gain , Quantity ):
414
427
raise TypeError ("gain must be a astropy.units.Quantity." )
415
428
@@ -462,7 +475,13 @@ def create_deviation(ccd_data, gain=None, readnoise=None, disregard_nan=False):
462
475
463
476
@log_to_metadata
464
477
def subtract_overscan (
465
- ccd , overscan = None , overscan_axis = 1 , fits_section = None , median = False , model = None
478
+ ccd ,
479
+ overscan = None ,
480
+ overscan_axis = 1 ,
481
+ fits_section = None ,
482
+ median = False ,
483
+ model = None ,
484
+ xp = None ,
466
485
):
467
486
"""
468
487
Subtract the overscan region from an image.
@@ -502,6 +521,10 @@ def subtract_overscan(
502
521
by the median or the mean.
503
522
Default is ``None``.
504
523
524
+ xp : array namespace, optional
525
+ Array namespace to use for calculations. If not provided, the
526
+ namespace will be determined from the array.
527
+
505
528
{log}
506
529
507
530
Raises
@@ -555,7 +578,7 @@ def subtract_overscan(
555
578
raise TypeError ("ccddata is not a CCDData object." )
556
579
557
580
# Set array namespace
558
- xp = array_api_compat .array_namespace (ccd .data )
581
+ xp = xp or array_api_compat .array_namespace (ccd .data )
559
582
560
583
if (overscan is not None and fits_section is not None ) or (
561
584
overscan is None and fits_section is None
@@ -857,7 +880,7 @@ def gain_correct(ccd, gain, gain_unit=None):
857
880
858
881
859
882
@log_to_metadata
860
- def flat_correct (ccd , flat , min_value = None , norm_value = None ):
883
+ def flat_correct (ccd , flat , min_value = None , norm_value = None , xp = None ):
861
884
"""Correct the image for flat fielding.
862
885
863
886
The flat field image is normalized by its mean or a user-supplied value
@@ -883,6 +906,10 @@ def flat_correct(ccd, flat, min_value=None, norm_value=None):
883
906
have the same scale. If this value is negative or 0, a ``ValueError``
884
907
is raised. Default is ``None``.
885
908
909
+ xp : array namespace, optional
910
+ Array namespace to use for calculations. If not provided, the
911
+ namespace will be determined from the array.
912
+
886
913
{log}
887
914
888
915
Returns
@@ -891,7 +918,7 @@ def flat_correct(ccd, flat, min_value=None, norm_value=None):
891
918
CCDData object with flat corrected.
892
919
"""
893
920
# Get the array namespace
894
- xp = array_api_compat .array_namespace (ccd .data )
921
+ xp = xp or array_api_compat .array_namespace (ccd .data )
895
922
# Use the min_value to replace any values in the flat
896
923
use_flat = flat
897
924
if min_value is not None :
@@ -1008,7 +1035,7 @@ def transform_image(ccd, transform_func, **kwargs):
1008
1035
1009
1036
1010
1037
@log_to_metadata
1011
- def wcs_project (ccd , target_wcs , target_shape = None , order = "bilinear" ):
1038
+ def wcs_project (ccd , target_wcs , target_shape = None , order = "bilinear" , xp = None ):
1012
1039
"""
1013
1040
Given a CCDData image with WCS, project it onto a target WCS and
1014
1041
return the reprojected data as a new CCDData image.
@@ -1039,6 +1066,10 @@ def wcs_project(ccd, target_wcs, target_shape=None, order="bilinear"):
1039
1066
1040
1067
Default is ``'bilinear'``.
1041
1068
1069
+ xp : array namespace, optional
1070
+ Array namespace to use for calculations. If not provided, the
1071
+ namespace will be determined from the array.
1072
+
1042
1073
{log}
1043
1074
1044
1075
Returns
@@ -1050,7 +1081,7 @@ def wcs_project(ccd, target_wcs, target_shape=None, order="bilinear"):
1050
1081
from reproject import reproject_interp
1051
1082
1052
1083
# Set array namespace
1053
- xp = array_api_compat .array_namespace (ccd .data )
1084
+ xp = xp or array_api_compat .array_namespace (ccd .data )
1054
1085
1055
1086
if not (ccd .wcs .is_celestial and target_wcs .is_celestial ):
1056
1087
raise ValueError ("one or both WCS is not celestial." )
@@ -1354,10 +1385,10 @@ def rebin(ccd, newshape):
1354
1385
return result
1355
1386
1356
1387
1357
- def block_reduce (ccd , block_size , func = None ):
1388
+ def block_reduce (ccd , block_size , func = None , xp = None ):
1358
1389
"""Thin wrapper around `astropy.nddata.block_reduce`."""
1359
1390
if func is None :
1360
- xp = array_api_compat .array_namespace (ccd .data )
1391
+ xp = xp or array_api_compat .array_namespace (ccd .data )
1361
1392
func = xp .sum
1362
1393
data = nddata .block_reduce (ccd , block_size , func )
1363
1394
if isinstance (ccd , CCDData ):
@@ -1367,10 +1398,10 @@ def block_reduce(ccd, block_size, func=None):
1367
1398
return data
1368
1399
1369
1400
1370
- def block_average (ccd , block_size ):
1401
+ def block_average (ccd , block_size , xp = None ):
1371
1402
"""Like `block_reduce` but with predefined ``func=np.mean``."""
1372
1403
1373
- xp = array_api_compat .array_namespace (ccd .data )
1404
+ xp = xp or array_api_compat .array_namespace (ccd .data )
1374
1405
1375
1406
data = nddata .block_reduce (ccd , block_size , xp .mean )
1376
1407
# Like in block_reduce:
@@ -1822,7 +1853,7 @@ def _astroscrappy_gain_apply_helper(cleaned_data, gain, gain_apply, old_interfac
1822
1853
return cleaned_data
1823
1854
1824
1855
1825
- def cosmicray_median (ccd , error_image = None , thresh = 5 , mbox = 11 , gbox = 0 , rbox = 0 ):
1856
+ def cosmicray_median (ccd , error_image = None , thresh = 5 , mbox = 11 , gbox = 0 , rbox = 0 , xp = None ):
1826
1857
"""
1827
1858
Identify cosmic rays through median technique. The median technique
1828
1859
identifies cosmic rays by identifying pixels by subtracting a median image
@@ -1856,6 +1887,10 @@ def cosmicray_median(ccd, error_image=None, thresh=5, mbox=11, gbox=0, rbox=0):
1856
1887
be replaced.
1857
1888
Default is ``0``.
1858
1889
1890
+ xp : array namespace, optional
1891
+ The array namespace to use for the calculations. If not provided, the
1892
+ array namespace of the input data will be used.
1893
+
1859
1894
Notes
1860
1895
-----
1861
1896
Similar implementation to crmedian in iraf.imred.crutil.crmedian.
@@ -1896,7 +1931,7 @@ def cosmicray_median(ccd, error_image=None, thresh=5, mbox=11, gbox=0, rbox=0):
1896
1931
updated with the detected cosmic rays.
1897
1932
"""
1898
1933
if _is_array (ccd ):
1899
- xp = array_api_compat .array_namespace (ccd )
1934
+ xp = xp or array_api_compat .array_namespace (ccd )
1900
1935
1901
1936
# Masked data is not part of the array API so remove mask if present.
1902
1937
# Only look at the data array, guessing that if there is a .mask then
@@ -1979,6 +2014,7 @@ def ccdmask(
1979
2014
lsigma = 9 ,
1980
2015
hsigma = 9 ,
1981
2016
ngood = 5 ,
2017
+ xp = None ,
1982
2018
):
1983
2019
"""
1984
2020
Uses method based on the IRAF ccdmask task to generate a mask based on the
@@ -2035,6 +2071,10 @@ def ccdmask(
2035
2071
pixels masked in that column.
2036
2072
Default is ``5``.
2037
2073
2074
+ xp : array namespace, optional
2075
+ The array namespace to use for the calculations. If not provided, the
2076
+ array namespace of the input data will be used.
2077
+
2038
2078
Returns
2039
2079
-------
2040
2080
mask : `numpy.ndarray`
@@ -2092,7 +2132,7 @@ def ccdmask(
2092
2132
raise ValueError ('"ratio" should be a "CCDData".' ) from err
2093
2133
2094
2134
# Get array namespace
2095
- xp = array_api_compat .array_namespace (ratio .data )
2135
+ xp = xp or array_api_compat .array_namespace (ratio .data )
2096
2136
2097
2137
def _sigma_mask (baseline , one_sigma_value , lower_sigma , upper_sigma ):
2098
2138
"""Helper function to mask values outside of the specified sigma range."""
0 commit comments