Skip to content

Commit 96fd450

Browse files
committed
Rebase and clean up to better match current specutils
1 parent db53940 commit 96fd450

File tree

3 files changed

+280
-232
lines changed

3 files changed

+280
-232
lines changed

specutils/fitting/spline.py

Lines changed: 107 additions & 91 deletions
Original file line numberDiff line numberDiff line change
@@ -11,158 +11,176 @@
1111
from astropy.utils import indent, check_broadcast
1212
from astropy.units import Quantity
1313

14-
1514
__all__ = []
1615

16+
1717
class SplineModel(FittableModel):
1818
"""
1919
Wrapper around scipy.interpolate.splrep and splev
20-
20+
2121
Analogous to scipy.interpolate.UnivariateSpline() if knots unspecified,
2222
and scipy.interpolate.LSQUnivariateSpline if knots are specified
23-
23+
2424
There are two ways to make a spline model.
25-
(1) you have the spline auto-determine knots from the data
26-
(2) you specify the knots
27-
25+
1. you have the spline auto-determine knots from the data
26+
2. you specify the knots
2827
"""
29-
30-
linear = False # I think? I have no idea?
31-
col_fit_deriv = False # Not sure what this is
32-
33-
def __init__(self, degree=3, smoothing=None, knots=None, extrapolate_mode=0):
28+
29+
linear = False # I think? I have no idea?
30+
col_fit_deriv = False # Not sure what this is
31+
32+
def __init__(self, degree=3, smoothing=None, knots=None,
33+
extrapolate_mode=0, *args, **kwargs):
3434
"""
3535
Set up a spline model.
36-
36+
3737
degree: degree of the spline (default 3)
3838
In scipy fitpack, this is "k"
39-
39+
4040
smoothing (optional): smoothing value for automatically determining knots
4141
In scipy fitpack, this is "s"
42-
By default, uses a
43-
42+
By default, uses a
43+
4444
knots (optional): spline knots (boundaries of piecewise polynomial)
4545
If not specified, will automatically determine knots based on
4646
degree + smoothing
47-
47+
4848
extrapolate_mode (optional): how to deal with solution outside of interval.
4949
(see scipy.interpolate.splev)
5050
if 0 (default): return the extrapolated value
5151
if 1, return 0
5252
if 2, raise a ValueError
5353
if 3, return the boundary value
5454
"""
55+
super().__init__(*args, **kwargs)
56+
5557
self._degree = degree
5658
self._smoothing = smoothing
5759
self._knots = self.verify_knots(knots)
5860
self.extrapolate_mode = extrapolate_mode
59-
60-
## This is used to evaluate the spline
61-
## When None, raises an error when trying to evaluate the spline
61+
62+
# This is used to evaluate the spline. When None, raises an error when
63+
# trying to evaluate the spline.
6264
self._tck = None
63-
65+
6466
self._param_names = ()
65-
67+
6668
def verify_knots(self, knots):
6769
"""
68-
Basic knot array vetting.
69-
The goal of having this is to enable more useful error messages
70-
than scipy (if needed).
70+
Basic knot array vetting. The goal of having this is to enable more
71+
useful error messages than scipy (if needed).
7172
"""
72-
if knots is None: return None
73+
if knots is None:
74+
return None
75+
7376
knots = np.array(knots)
7477
assert len(knots.shape) == 1, knots.shape
7578
knots = np.sort(knots)
7679
assert len(np.unique(knots)) == len(knots), knots
80+
7781
return knots
78-
79-
############
80-
## Getters
81-
############
82-
def get_degree(self):
82+
83+
# Getters
84+
@property
85+
def degree(self):
8386
""" Spline degree (k in FITPACK) """
8487
return self._degree
85-
def get_smoothing(self):
88+
89+
@property
90+
def smoothing(self):
8691
""" Spline smoothing (s in FITPACK) """
8792
return self._smoothing
88-
def get_knots(self):
93+
94+
@property
95+
def knots(self):
8996
""" Spline knots (t in FITPACK) """
9097
return self._knots
91-
def get_coeffs(self):
98+
99+
@property
100+
def coeffs(self):
92101
""" Spline coefficients (c in FITPACK) """
93102
if self._tck is not None:
94103
return self._tck[1]
95104
else:
96-
raise RuntimeError("SplineModel has not been fit yet")
97-
98-
############
99-
## Spline methods: not tested at all
100-
############
101-
def derivative(self, n=1):
102-
if self._tck is None:
103-
raise RuntimeError("SplineModel has not been fit yet")
104-
else:
105-
t, c, k = self._tck
106-
return scipy.interpolate.BSpline.construct_fast(
107-
t,c,k,extrapolate=(self.extrapolate_mode==0)).derivative(n)
108-
def antiderivative(self, n=1):
109-
if self._tck is None:
110-
raise RuntimeError("SplineModel has not been fit yet")
111-
else:
112-
t, c, k = self._tck
113-
return scipy.interpolate.BSpline.construct_fast(
114-
t,c,k,extrapolate=(self.extrapolate_mode==0)).antiderivative(n)
115-
def integral(self, a, b):
116-
if self._tck is None:
117-
raise RuntimeError("SplineModel has not been fit yet")
118-
else:
119-
t, c, k = self._tck
120-
return scipy.interpolate.BSpline.construct_fast(
121-
t,c,k,extrapolate=(self.extrapolate_mode==0)).integral(a,b)
122-
def derivatives(self, x):
123-
raise NotImplementedError
124-
def roots(self):
125-
raise NotImplementedError
126-
127-
############
128-
## Setters: not really implemented or tested
129-
############
105+
raise RuntimeError("SplineModel has not been fit yet.")
106+
107+
# Setters
108+
# TODO: not really implemented or tested
130109
def reset_model(self):
131110
""" Resets model so it needs to be refit to be valid """
132111
self._tck = None
133-
def set_degree(self, degree):
112+
113+
@degree.setter
114+
def degree(self, degree):
134115
""" Spline degree (k in FITPACK) """
135116
raise NotImplementedError
136117
self._degree = degree
137118
self.reset_model()
138-
def set_smoothing(self, smoothing):
119+
120+
@smoothing.setter
121+
def smoothing(self, smoothing):
139122
""" Spline smoothing (s in FITPACK) """
140123
raise NotImplementedError
141124
self._smoothing = smoothing
142125
self.reset_model()
143-
def set_knots(self, knots):
126+
127+
@knots.setter
128+
def knots(self, knots):
144129
""" Spline knots (t in FITPACK) """
145130
raise NotImplementedError
146131
self._knots = self.verify_knots(knots)
147132
self.reset_model()
148-
133+
149134
def set_model_from_tck(self, tck):
150135
"""
151136
Use output of scipy.interpolate.splrep
152137
"""
153138
self._tck = tck
154139

140+
# Spline methods
141+
# TODO: not tested at all
142+
def derivative(self, n=1):
143+
if self._tck is None:
144+
raise RuntimeError("SplineModel has not been fit yet")
145+
else:
146+
t, c, k = self._tck
147+
return scipy.interpolate.BSpline.construct_fast(
148+
t, c, k, extrapolate=(self.extrapolate_mode == 0)).derivative(n)
149+
150+
def antiderivative(self, n=1):
151+
if self._tck is None:
152+
raise RuntimeError("SplineModel has not been fit yet.")
153+
else:
154+
t, c, k = self._tck
155+
return scipy.interpolate.BSpline.construct_fast(
156+
t, c, k, extrapolate=(self.extrapolate_mode == 0)).antiderivative(n)
157+
158+
def integral(self, a, b):
159+
if self._tck is None:
160+
raise RuntimeError("SplineModel has not been fit yet.")
161+
else:
162+
t, c, k = self._tck
163+
return scipy.interpolate.BSpline.construct_fast(
164+
t, c, k, extrapolate=(self.extrapolate_mode == 0)).integral(a, b)
165+
166+
def derivatives(self, x):
167+
raise NotImplementedError
168+
169+
def roots(self):
170+
raise NotImplementedError
171+
155172
def __call__(self, x, der=0):
156173
"""
157174
Evaluate the model with the given inputs.
158175
der is passed to scipy.interpolate.splev
159176
"""
160177
if self._tck is None:
161-
raise RuntimeError("SplineModel has not been fit yet")
178+
raise RuntimeError("SplineModel has not been fit yet.")
179+
162180
return interpolate.splev(x, self._tck, der=der, ext=self.extrapolate_mode)
163-
164-
####################################
165-
######### Stuff below here is stubs
181+
182+
# Stuff below here is stubs
183+
# TODO: fill out methods
166184
@property
167185
def param_names(self):
168186
"""
@@ -201,11 +219,10 @@ def _generate_coeff_names(self):
201219
for j in range(degree+1):
202220
names.append("k{}_c{}".format(i,j))
203221
return tuple(names)
204-
222+
205223
def evaluate(self, *args, **kwargs):
206224
return self(*args, **kwargs)
207225

208-
209226

210227
class SplineFitter(metaclass=_FitterMeta):
211228
"""
@@ -216,43 +233,42 @@ def __init__(self):
216233
"ier": None,
217234
"msg": None}
218235
super().__init__()
219-
236+
220237
def validate_model(self, model):
221238
if not isinstance(model, SplineModel):
222239
raise ValueError("model must be of type SplineModel (currently is {})".format(
223240
type(model)))
224-
225-
## TODO do something about units
226-
#@fitter_unit_support
241+
242+
# TODO do something about units
243+
# @fitter_unit_support
227244
def __call__(self, model, x, y, w=None):
228245
"""
229246
Fit a spline model to data.
230247
Internally uses scipy.interpolate.splrep.
231-
248+
232249
"""
233-
250+
234251
self.validate_model(model)
235-
236-
## Case (1): fit smoothing spline
252+
253+
# Case (1): fit smoothing spline
237254
if model.get_knots() is None:
238255
tck, fp, ier, msg = interpolate.splrep(x, y, w=w,
239256
t=None,
240-
k=model.get_degree(),
257+
k=model.get_degree(),
241258
s=model.get_smoothing(),
242259
task=0, full_output=True
243260
)
244-
## Case (2): leastsq spline
261+
# Case (2): leastsq spline
245262
else:
246263
knots = model.get_knots()
247264
## TODO some sort of validation that the knots are internal, since
248265
## this procedure automatically adds knots at the two endpoints
249266
tck, fp, ier, msg = interpolate.splrep(x, y, w=w,
250267
t=knots,
251-
k=model.get_degree(),
268+
k=model.get_degree(),
252269
s=model.get_smoothing(),
253270
task=-1, full_output=True
254271
)
255-
272+
256273
model.set_model_from_tck(tck)
257-
self.fit_info.update({"fp":fp, "ier":ier, "msg":msg})
258-
274+
self.fit_info.update({"fp": fp, "ier": ier, "msg": msg})

0 commit comments

Comments
 (0)