Skip to content

Commit e03dbcb

Browse files
committed
Add optional namespace argument to several functions
1 parent 84b7f1f commit e03dbcb

File tree

1 file changed

+61
-21
lines changed

1 file changed

+61
-21
lines changed

ccdproc/core.py

Lines changed: 61 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,7 @@ def _is_array(arr):
9393

9494
# Ideally this would eventually be covered by tests. Looks like Sparse
9595
# could be used to test this, since it has no percentile...
96-
def _percentile_fallback(array, percentiles): # pragma: no cover
96+
def _percentile_fallback(array, percentiles, xp=None): # pragma: no cover
9797
"""
9898
Try calculating percentile using namespace, otherwise fall back to
9999
an implmentation that uses sort. As of the 2023 version of the array API
@@ -107,12 +107,16 @@ def _percentile_fallback(array, percentiles): # pragma: no cover
107107
percentiles : float or list-like
108108
Percentile to calculate.
109109
110+
xp : array namespace, optional
111+
Array namespace to use for calculations. If not provided, the
112+
namespace will be determined from the array.
113+
110114
Returns
111115
-------
112116
percentile : float or list-like
113117
Calculated percentile.
114118
"""
115-
xp = array_api_compat.array_namespace(array)
119+
xp = xp or array_api_compat.array_namespace(array)
116120
try:
117121
return xp.percentile(array, percentiles)
118122
except AttributeError:
@@ -285,11 +289,11 @@ def ccd_process(
285289
# apply the overscan correction
286290
if isinstance(oscan, CCDData):
287291
nccd = subtract_overscan(
288-
nccd, overscan=oscan, median=oscan_median, model=oscan_model
292+
nccd, overscan=oscan, median=oscan_median, model=oscan_model, xp=xp
289293
)
290294
elif isinstance(oscan, str):
291295
nccd = subtract_overscan(
292-
nccd, fits_section=oscan, median=oscan_median, model=oscan_model
296+
nccd, fits_section=oscan, median=oscan_median, model=oscan_model, xp=xp
293297
)
294298
elif oscan is None:
295299
pass
@@ -298,6 +302,7 @@ def ccd_process(
298302

299303
# apply the trim correction
300304
if isinstance(trim, str):
305+
# No xp=... here because slicing can be done without knowing the array namespace
301306
nccd = trim_image(nccd, fits_section=trim)
302307
elif trim is None:
303308
pass
@@ -306,7 +311,7 @@ def ccd_process(
306311

307312
# create the error frame
308313
if error and gain is not None and readnoise is not None:
309-
nccd = create_deviation(nccd, gain=gain, readnoise=readnoise)
314+
nccd = create_deviation(nccd, gain=gain, readnoise=readnoise, xp=xp)
310315
elif error and (gain is None or readnoise is None):
311316
raise ValueError("gain and readnoise must be specified to create error frame.")
312317

@@ -324,10 +329,12 @@ def ccd_process(
324329
raise TypeError("gain is not None or astropy.units.Quantity.")
325330

326331
if gain is not None and gain_corrected:
332+
# No need for xp here because gain_correct does not need the namespace
327333
nccd = gain_correct(nccd, gain)
328334

329335
# subtracting the master bias
330336
if isinstance(master_bias, CCDData):
337+
# No need for xp here because subtract_bias does not need the namespace
331338
nccd = subtract_bias(nccd, master_bias)
332339
elif master_bias is None:
333340
pass
@@ -336,6 +343,7 @@ def ccd_process(
336343

337344
# subtract the dark frame
338345
if isinstance(dark_frame, CCDData):
346+
# No need for xp here because subtract_dark does not need the namespace
339347
nccd = subtract_dark(
340348
nccd,
341349
dark_frame,
@@ -352,21 +360,22 @@ def ccd_process(
352360

353361
# test dividing the master flat
354362
if isinstance(master_flat, CCDData):
355-
nccd = flat_correct(nccd, master_flat, min_value=min_value)
363+
nccd = flat_correct(nccd, master_flat, min_value=min_value, xp=xp)
356364
elif master_flat is None:
357365
pass
358366
else:
359367
raise TypeError("master_flat is not None or a CCDData object.")
360368

361369
# apply the gain correction only at the end if gain_corrected is False
362370
if gain is not None and not gain_corrected:
371+
# No need for xp here because gain_correct does not need the namespace
363372
nccd = gain_correct(nccd, gain)
364373

365374
return nccd
366375

367376

368377
@log_to_metadata
369-
def create_deviation(ccd_data, gain=None, readnoise=None, disregard_nan=False):
378+
def create_deviation(ccd_data, gain=None, readnoise=None, disregard_nan=False, xp=None):
370379
"""
371380
Create a uncertainty frame. The function will update the uncertainty
372381
plane which gives the standard deviation for the data. Gain is used in
@@ -393,6 +402,10 @@ def create_deviation(ccd_data, gain=None, readnoise=None, disregard_nan=False):
393402
If ``True``, any value of nan in the output array will be replaced by
394403
the readnoise.
395404
405+
xp : array namespace, optional
406+
Array namespace to use for calculations. If not provided, the
407+
namespace will be determined from the array.
408+
396409
{log}
397410
398411
Raises
@@ -409,7 +422,7 @@ def create_deviation(ccd_data, gain=None, readnoise=None, disregard_nan=False):
409422
410423
"""
411424
# Get array namespace
412-
xp = array_api_compat.array_namespace(ccd_data.data)
425+
xp = xp or array_api_compat.array_namespace(ccd_data.data)
413426
if gain is not None and not isinstance(gain, Quantity):
414427
raise TypeError("gain must be a astropy.units.Quantity.")
415428

@@ -462,7 +475,13 @@ def create_deviation(ccd_data, gain=None, readnoise=None, disregard_nan=False):
462475

463476
@log_to_metadata
464477
def subtract_overscan(
465-
ccd, overscan=None, overscan_axis=1, fits_section=None, median=False, model=None
478+
ccd,
479+
overscan=None,
480+
overscan_axis=1,
481+
fits_section=None,
482+
median=False,
483+
model=None,
484+
xp=None,
466485
):
467486
"""
468487
Subtract the overscan region from an image.
@@ -502,6 +521,10 @@ def subtract_overscan(
502521
by the median or the mean.
503522
Default is ``None``.
504523
524+
xp : array namespace, optional
525+
Array namespace to use for calculations. If not provided, the
526+
namespace will be determined from the array.
527+
505528
{log}
506529
507530
Raises
@@ -555,7 +578,7 @@ def subtract_overscan(
555578
raise TypeError("ccddata is not a CCDData object.")
556579

557580
# Set array namespace
558-
xp = array_api_compat.array_namespace(ccd.data)
581+
xp = xp or array_api_compat.array_namespace(ccd.data)
559582

560583
if (overscan is not None and fits_section is not None) or (
561584
overscan is None and fits_section is None
@@ -857,7 +880,7 @@ def gain_correct(ccd, gain, gain_unit=None):
857880

858881

859882
@log_to_metadata
860-
def flat_correct(ccd, flat, min_value=None, norm_value=None):
883+
def flat_correct(ccd, flat, min_value=None, norm_value=None, xp=None):
861884
"""Correct the image for flat fielding.
862885
863886
The flat field image is normalized by its mean or a user-supplied value
@@ -883,6 +906,10 @@ def flat_correct(ccd, flat, min_value=None, norm_value=None):
883906
have the same scale. If this value is negative or 0, a ``ValueError``
884907
is raised. Default is ``None``.
885908
909+
xp : array namespace, optional
910+
Array namespace to use for calculations. If not provided, the
911+
namespace will be determined from the array.
912+
886913
{log}
887914
888915
Returns
@@ -891,7 +918,7 @@ def flat_correct(ccd, flat, min_value=None, norm_value=None):
891918
CCDData object with flat corrected.
892919
"""
893920
# Get the array namespace
894-
xp = array_api_compat.array_namespace(ccd.data)
921+
xp = xp or array_api_compat.array_namespace(ccd.data)
895922
# Use the min_value to replace any values in the flat
896923
use_flat = flat
897924
if min_value is not None:
@@ -1008,7 +1035,7 @@ def transform_image(ccd, transform_func, **kwargs):
10081035

10091036

10101037
@log_to_metadata
1011-
def wcs_project(ccd, target_wcs, target_shape=None, order="bilinear"):
1038+
def wcs_project(ccd, target_wcs, target_shape=None, order="bilinear", xp=None):
10121039
"""
10131040
Given a CCDData image with WCS, project it onto a target WCS and
10141041
return the reprojected data as a new CCDData image.
@@ -1039,6 +1066,10 @@ def wcs_project(ccd, target_wcs, target_shape=None, order="bilinear"):
10391066
10401067
Default is ``'bilinear'``.
10411068
1069+
xp : array namespace, optional
1070+
Array namespace to use for calculations. If not provided, the
1071+
namespace will be determined from the array.
1072+
10421073
{log}
10431074
10441075
Returns
@@ -1050,7 +1081,7 @@ def wcs_project(ccd, target_wcs, target_shape=None, order="bilinear"):
10501081
from reproject import reproject_interp
10511082

10521083
# Set array namespace
1053-
xp = array_api_compat.array_namespace(ccd.data)
1084+
xp = xp or array_api_compat.array_namespace(ccd.data)
10541085

10551086
if not (ccd.wcs.is_celestial and target_wcs.is_celestial):
10561087
raise ValueError("one or both WCS is not celestial.")
@@ -1354,10 +1385,10 @@ def rebin(ccd, newshape):
13541385
return result
13551386

13561387

1357-
def block_reduce(ccd, block_size, func=None):
1388+
def block_reduce(ccd, block_size, func=None, xp=None):
13581389
"""Thin wrapper around `astropy.nddata.block_reduce`."""
13591390
if func is None:
1360-
xp = array_api_compat.array_namespace(ccd.data)
1391+
xp = xp or array_api_compat.array_namespace(ccd.data)
13611392
func = xp.sum
13621393
data = nddata.block_reduce(ccd, block_size, func)
13631394
if isinstance(ccd, CCDData):
@@ -1367,10 +1398,10 @@ def block_reduce(ccd, block_size, func=None):
13671398
return data
13681399

13691400

1370-
def block_average(ccd, block_size):
1401+
def block_average(ccd, block_size, xp=None):
13711402
"""Like `block_reduce` but with predefined ``func=np.mean``."""
13721403

1373-
xp = array_api_compat.array_namespace(ccd.data)
1404+
xp = xp or array_api_compat.array_namespace(ccd.data)
13741405

13751406
data = nddata.block_reduce(ccd, block_size, xp.mean)
13761407
# Like in block_reduce:
@@ -1822,7 +1853,7 @@ def _astroscrappy_gain_apply_helper(cleaned_data, gain, gain_apply, old_interfac
18221853
return cleaned_data
18231854

18241855

1825-
def cosmicray_median(ccd, error_image=None, thresh=5, mbox=11, gbox=0, rbox=0):
1856+
def cosmicray_median(ccd, error_image=None, thresh=5, mbox=11, gbox=0, rbox=0, xp=None):
18261857
"""
18271858
Identify cosmic rays through median technique. The median technique
18281859
identifies cosmic rays by identifying pixels by subtracting a median image
@@ -1856,6 +1887,10 @@ def cosmicray_median(ccd, error_image=None, thresh=5, mbox=11, gbox=0, rbox=0):
18561887
be replaced.
18571888
Default is ``0``.
18581889
1890+
xp : array namespace, optional
1891+
The array namespace to use for the calculations. If not provided, the
1892+
array namespace of the input data will be used.
1893+
18591894
Notes
18601895
-----
18611896
Similar implementation to crmedian in iraf.imred.crutil.crmedian.
@@ -1896,7 +1931,7 @@ def cosmicray_median(ccd, error_image=None, thresh=5, mbox=11, gbox=0, rbox=0):
18961931
updated with the detected cosmic rays.
18971932
"""
18981933
if _is_array(ccd):
1899-
xp = array_api_compat.array_namespace(ccd)
1934+
xp = xp or array_api_compat.array_namespace(ccd)
19001935

19011936
# Masked data is not part of the array API so remove mask if present.
19021937
# Only look at the data array, guessing that if there is a .mask then
@@ -1979,6 +2014,7 @@ def ccdmask(
19792014
lsigma=9,
19802015
hsigma=9,
19812016
ngood=5,
2017+
xp=None,
19822018
):
19832019
"""
19842020
Uses method based on the IRAF ccdmask task to generate a mask based on the
@@ -2035,6 +2071,10 @@ def ccdmask(
20352071
pixels masked in that column.
20362072
Default is ``5``.
20372073
2074+
xp : array namespace, optional
2075+
The array namespace to use for the calculations. If not provided, the
2076+
array namespace of the input data will be used.
2077+
20382078
Returns
20392079
-------
20402080
mask : `numpy.ndarray`
@@ -2092,7 +2132,7 @@ def ccdmask(
20922132
raise ValueError('"ratio" should be a "CCDData".') from err
20932133

20942134
# Get array namespace
2095-
xp = array_api_compat.array_namespace(ratio.data)
2135+
xp = xp or array_api_compat.array_namespace(ratio.data)
20962136

20972137
def _sigma_mask(baseline, one_sigma_value, lower_sigma, upper_sigma):
20982138
"""Helper function to mask values outside of the specified sigma range."""

0 commit comments

Comments
 (0)