Skip to content
111 changes: 89 additions & 22 deletions specutils/manipulation/extract_spectral_region.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,19 @@
__all__ = ['extract_region', 'extract_bounding_spectral_region', 'spectral_slab']


def _edge_value_to_pixel(edge_value, spectrum, order, side):
spectral_axis = spectrum.spectral_axis

def _edge_value_to_pixel(edge_value, spectrum, order, side, axis=None):
spectral_axis = spectrum.spectral_axis if axis is None else axis
try:
edge_value = edge_value.to(spectral_axis.unit, u.spectral())
except u.UnitConversionError:
pass
if order == 'ascending':
index = np.searchsorted(spectral_axis, edge_value, side=side)
if side == 'right':
index = np.searchsorted(spectral_axis, edge_value, side='right')
if np.isclose(spectral_axis[index-1].value, edge_value.value):
index += 1
else:
index = np.searchsorted(spectral_axis, edge_value, side='left')
return index

elif order == 'descending':
Expand Down Expand Up @@ -45,14 +53,6 @@ def _subregion_to_edge_pixels(subregion, spectrum):

"""
spectral_axis = spectrum.spectral_axis
if spectral_axis[-1] > spectral_axis[0]:
order = "ascending"
left_func = min
right_func = max
else:
order = "descending"
left_func = max
right_func = min

# Left/lower side of sub region
if subregion[0].unit.is_equivalent(u.pix):
Expand All @@ -72,10 +72,16 @@ def _subregion_to_edge_pixels(subregion, spectrum):
if (spectral_axis[left_index] > subregion[0]) and (left_index >= 1):
left_index -= 1
else:
# Convert lower value to spectrum spectral_axis units
left_reg_in_spec_unit = left_func(subregion).to(spectral_axis.unit,
u.spectral())
left_index = _edge_value_to_pixel(left_reg_in_spec_unit, spectrum, order, "left")
# Convert lower value to the appropriate axis and compute order on that axis
try:
axis_to_use = spectral_axis
left_reg_in_spec_unit = subregion[0].to(axis_to_use.unit, u.spectral())
except u.UnitConversionError:
axis_to_use = _get_axis_in_matching_unit(subregion[0].unit, spectrum)
left_reg_in_spec_unit = subregion[0].to(axis_to_use.unit, u.spectral())

order_left = "ascending" if axis_to_use[-1] > axis_to_use[0] else "descending"
left_index = _edge_value_to_pixel(left_reg_in_spec_unit, spectrum, order_left, "left", axis=axis_to_use)

# Right/upper side of sub region
if subregion[1].unit.is_equivalent(u.pix):
Expand All @@ -95,11 +101,16 @@ def _subregion_to_edge_pixels(subregion, spectrum):
if (spectral_axis[right_index] < subregion[1]) and (right_index < len(spectral_axis)):
right_index += 1
else:
# Convert upper value to spectrum spectral_axis units
right_reg_in_spec_unit = right_func(subregion).to(spectral_axis.unit,
u.spectral())
# Convert upper value to the appropriate axis and compute order on that axis
try:
axis_to_use_r = spectral_axis
right_reg_in_spec_unit = subregion[1].to(axis_to_use_r.unit, u.spectral())
except u.UnitConversionError:
axis_to_use_r = _get_axis_in_matching_unit(subregion[1].unit, spectrum)
right_reg_in_spec_unit = subregion[1].to(axis_to_use_r.unit, u.spectral())

right_index = _edge_value_to_pixel(right_reg_in_spec_unit, spectrum, order, "right")
order_right = "ascending" if axis_to_use_r[-1] > axis_to_use_r[0] else "descending"
right_index = _edge_value_to_pixel(right_reg_in_spec_unit, spectrum, order_right, "right", axis=axis_to_use_r)

# If the spectrum is in wavelength and region is in Hz (for example), these still might be reversed
if left_index < right_index:
Expand All @@ -108,7 +119,7 @@ def _subregion_to_edge_pixels(subregion, spectrum):
return right_index, left_index


def extract_region(spectrum, region, return_single_spectrum=False):
def extract_region(spectrum, region, return_single_spectrum=False, preserve_wcs=False):
"""
Extract a region from the input `~specutils.Spectrum`
defined by the lower and upper bounds defined by the ``region``
Expand All @@ -128,6 +139,10 @@ def extract_region(spectrum, region, return_single_spectrum=False):
instead of multiple `~specutils.Spectrum` objects. The returned spectrum
will be a unique, concatenated, spectrum of all sub-regions.

preserve_wcs: `bool`
If True, the WCS will be adjusted and retained in the output spectrum(s).
If False (default), WCS will be dropped.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The current False behavior does keep a WCS, it just turns it into a lookuptable WCS - so it doesn't exactly drop the WCS, it changes it.


Returns
-------
spectrum: `~specutils.Spectrum` or list of `~specutils.Spectrum`
Expand Down Expand Up @@ -170,7 +185,23 @@ def extract_region(spectrum, region, return_single_spectrum=False):
slices = slices[0]
else:
slices = tuple(slices)
extracted_spectrum.append(spectrum[slices])
sliced = spectrum[slices]

# Adjust WCS properly
if preserve_wcs and spectrum.wcs is not None:
new_wcs = spectrum.wcs.deepcopy()

# Set CRPIX = 1.0 (FITS convention: reference pixel is 1-indexed)
new_wcs.wcs.crpix[0] = 1.0

# Set CRVAL to match the first spectral axis value in the sliced spectrum
new_wcs.wcs.crval[0] = sliced.spectral_axis[0].to_value(new_wcs.wcs.cunit[0])

sliced._wcs = new_wcs
else:
sliced._wcs = None

extracted_spectrum.append(sliced)

# If there is only one subregion in the region then we will
# just return a spectrum.
Expand Down Expand Up @@ -285,3 +316,39 @@ def extract_bounding_spectral_region(spectrum, region):
single_region = SpectralRegion(min(min_list), max(max_list))

return extract_region(spectrum, single_region)


def _get_axis_in_matching_unit(unit, spectrum):
"""
Return the appropriate spectral axis (wavelength, frequency, or velocity)
from the input Spectrum object that matches the given unit.

Parameters
----------
unit : astropy.units.Unit
The unit to match (e.g., km/s, Hz, micron).

spectrum : specutils.Spectrum
The spectrum from which to select the appropriate axis.

Returns
-------
Quantity
The corresponding axis: one of spectrum.spectral_axis, spectrum.velocity,
or spectrum.frequency.

Raises
------
UnitConversionError
If the unit is not compatible with any of the known spectral axes.
"""
if unit.is_equivalent(spectrum.spectral_axis.unit):
return spectrum.spectral_axis
elif unit.is_equivalent(u.km / u.s):
return spectrum.velocity
elif unit.is_equivalent(u.Hz):
return spectrum.frequency
else:
raise u.UnitConversionError(
f"Cannot convert subregion unit {unit} to any known spectral axis"
)
77 changes: 77 additions & 0 deletions specutils/tests/test_spectral_region.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
import numpy as np
import astropy.units as u
from numpy.testing import assert_allclose
import pytest
from astropy.wcs import WCS

from specutils.spectra.spectrum import Spectrum
from specutils.spectra.spectral_region import SpectralRegion
from specutils.manipulation import extract_region

@pytest.fixture
def frequency_spectrum():
# Create basic frequency WCS
w = WCS(naxis=1)
w.wcs.crval = [1.410e9] # starting frequency (Hz)
w.wcs.cdelt = [1.0e6] # 1 MHz per channel
w.wcs.crpix = [1] # reference pixel
w.wcs.cunit = ['Hz']
w.wcs.restfrq = 1.420e9 # rest frequency in Hz

# Build spectral axis and flux
freqs = np.arange(1.410e9, 1.431e9, 1.0e6) * u.Hz
flux = np.arange(1, len(freqs) + 1, dtype=float) * u.Jy

return Spectrum(spectral_axis=freqs, flux=flux, wcs=w, velocity_convention='radio')


def test_extract_region_velocity_on_frequency_axis(frequency_spectrum):
spec = frequency_spectrum

# Define velocity range
region = SpectralRegion(-500 * u.km / u.s, 500 * u.km / u.s)

# Extract region with WCS preservation
sub = extract_region(spec, region, preserve_wcs=True)

# Determine expected frequency channels based on velocity condition
velocities = spec.velocity.to(u.km / u.s)
mask = (velocities >= -500 * u.km / u.s) & (velocities <= 500 * u.km / u.s)
expected_freqs = spec.spectral_axis[mask]
expected_flux = spec.flux[mask]

# Assertions
assert len(sub.spectral_axis) == len(expected_freqs)
assert_allclose(sub.spectral_axis.to_value(u.Hz),
expected_freqs.to_value(u.Hz),
rtol=0, atol=1e-12)

assert_allclose(sub.flux.to_value(u.Jy),
expected_flux.to_value(u.Jy),
rtol=0, atol=0)

assert np.isclose(sub.wcs.wcs.crval[0], expected_freqs[0].value)
assert np.isclose(sub.wcs.wcs.crpix[0], 1)
assert np.isclose(sub.wcs.wcs.cdelt[0], spec.wcs.wcs.cdelt[0])
assert sub.wcs.wcs.restfrq == spec.wcs.wcs.restfrq

def test_extract_region_drops_wcs_when_disabled(frequency_spectrum):
spec = frequency_spectrum

# Define velocity range
region = SpectralRegion(-500 * u.km / u.s, 500 * u.km / u.s)

# Extract region without WCS preservation
sub = extract_region(spec, region, preserve_wcs=False)

# Basic content check
velocities = spec.velocity.to(u.km / u.s)
mask = (velocities >= -500 * u.km / u.s) & (velocities <= 500 * u.km / u.s)
expected_flux = spec.flux[mask]

assert_allclose(sub.flux.to_value(u.Jy),
expected_flux.to_value(u.Jy),
rtol=0, atol=0)

# Ensure WCS is removed
assert not hasattr(sub, "wcs") or sub.wcs is None
Loading