Skip to content

Commit 1c88ccf

Browse files
committed
Draft fitter for spline models
1 parent 5640b7b commit 1c88ccf

File tree

2 files changed

+318
-0
lines changed

2 files changed

+318
-0
lines changed

specutils/fitting/spline.py

Lines changed: 258 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,258 @@
1+
from __future__ import print_function, division, absolute_import
2+
3+
import numpy as np
4+
from scipy import interpolate
5+
6+
from astropy.modeling.core import FittableModel, Model
7+
from astropy.modeling.functional_models import Shift
8+
from astropy.modeling.parameters import Parameter
9+
from astropy.modeling.utils import poly_map_domain, comb
10+
from astropy.modeling.fitting import _FitterMeta, fitter_unit_support
11+
from astropy.utils import indent, check_broadcast
12+
from astropy.units import Quantity
13+
14+
15+
__all__ = []
16+
17+
class SplineModel(FittableModel):
18+
"""
19+
Wrapper around scipy.interpolate.splrep and splev
20+
21+
Analogous to scipy.interpolate.UnivariateSpline() if knots unspecified,
22+
and scipy.interpolate.LSQUnivariateSpline if knots are specified
23+
24+
There are two ways to make a spline model.
25+
(1) you have the spline auto-determine knots from the data
26+
(2) you specify the knots
27+
28+
"""
29+
30+
linear = False # I think? I have no idea?
31+
col_fit_deriv = False # Not sure what this is
32+
33+
def __init__(self, degree=3, smoothing=None, knots=None, extrapolate_mode=0):
34+
"""
35+
Set up a spline model.
36+
37+
degree: degree of the spline (default 3)
38+
In scipy fitpack, this is "k"
39+
40+
smoothing (optional): smoothing value for automatically determining knots
41+
In scipy fitpack, this is "s"
42+
By default, uses a
43+
44+
knots (optional): spline knots (boundaries of piecewise polynomial)
45+
If not specified, will automatically determine knots based on
46+
degree + smoothing
47+
48+
extrapolate_mode (optional): how to deal with solution outside of interval.
49+
(see scipy.interpolate.splev)
50+
if 0 (default): return the extrapolated value
51+
if 1, return 0
52+
if 2, raise a ValueError
53+
if 3, return the boundary value
54+
"""
55+
self._degree = degree
56+
self._smoothing = smoothing
57+
self._knots = self.verify_knots(knots)
58+
self.extrapolate_mode = extrapolate_mode
59+
60+
## This is used to evaluate the spline
61+
## When None, raises an error when trying to evaluate the spline
62+
self._tck = None
63+
64+
self._param_names = ()
65+
66+
def verify_knots(self, knots):
67+
"""
68+
Basic knot array vetting.
69+
The goal of having this is to enable more useful error messages
70+
than scipy (if needed).
71+
"""
72+
if knots is None: return None
73+
knots = np.array(knots)
74+
assert len(knots.shape) == 1, knots.shape
75+
knots = np.sort(knots)
76+
assert len(np.unique(knots)) == len(knots), knots
77+
return knots
78+
79+
############
80+
## Getters
81+
############
82+
def get_degree(self):
83+
""" Spline degree (k in FITPACK) """
84+
return self._degree
85+
def get_smoothing(self):
86+
""" Spline smoothing (s in FITPACK) """
87+
return self._smoothing
88+
def get_knots(self):
89+
""" Spline knots (t in FITPACK) """
90+
return self._knots
91+
def get_coeffs(self):
92+
""" Spline coefficients (c in FITPACK) """
93+
if self._tck is not None:
94+
return self._tck[1]
95+
else:
96+
raise RuntimeError("SplineModel has not been fit yet")
97+
98+
############
99+
## Spline methods: not tested at all
100+
############
101+
def derivative(self, n=1):
102+
if self._tck is None:
103+
raise RuntimeError("SplineModel has not been fit yet")
104+
else:
105+
t, c, k = self._tck
106+
return scipy.interpolate.BSpline.construct_fast(
107+
t,c,k,extrapolate=(self.extrapolate_mode==0)).derivative(n)
108+
def antiderivative(self, n=1):
109+
if self._tck is None:
110+
raise RuntimeError("SplineModel has not been fit yet")
111+
else:
112+
t, c, k = self._tck
113+
return scipy.interpolate.BSpline.construct_fast(
114+
t,c,k,extrapolate=(self.extrapolate_mode==0)).antiderivative(n)
115+
def integral(self, a, b):
116+
if self._tck is None:
117+
raise RuntimeError("SplineModel has not been fit yet")
118+
else:
119+
t, c, k = self._tck
120+
return scipy.interpolate.BSpline.construct_fast(
121+
t,c,k,extrapolate=(self.extrapolate_mode==0)).integral(a,b)
122+
def derivatives(self, x):
123+
raise NotImplementedError
124+
def roots(self):
125+
raise NotImplementedError
126+
127+
############
128+
## Setters: not really implemented or tested
129+
############
130+
def reset_model(self):
131+
""" Resets model so it needs to be refit to be valid """
132+
self._tck = None
133+
def set_degree(self, degree):
134+
""" Spline degree (k in FITPACK) """
135+
raise NotImplementedError
136+
self._degree = degree
137+
self.reset_model()
138+
def set_smoothing(self, smoothing):
139+
""" Spline smoothing (s in FITPACK) """
140+
raise NotImplementedError
141+
self._smoothing = smoothing
142+
self.reset_model()
143+
def set_knots(self, knots):
144+
""" Spline knots (t in FITPACK) """
145+
raise NotImplementedError
146+
self._knots = self.verify_knots(knots)
147+
self.reset_model()
148+
149+
def set_model_from_tck(self, tck):
150+
"""
151+
Use output of scipy.interpolate.splrep
152+
"""
153+
self._tck = tck
154+
155+
def __call__(self, x, der=0):
156+
"""
157+
Evaluate the model with the given inputs.
158+
der is passed to scipy.interpolate.splev
159+
"""
160+
if self._tck is None:
161+
raise RuntimeError("SplineModel has not been fit yet")
162+
return interpolate.splev(x, self._tck, der=der, ext=self.extrapolate_mode)
163+
164+
####################################
165+
######### Stuff below here is stubs
166+
@property
167+
def param_names(self):
168+
"""
169+
Coefficient names generated based on the model's knots and polynomial degree.
170+
Not Implemented
171+
"""
172+
raise NotImplementedError("SplineModel does not currently expose parameters")
173+
return self._param_names
174+
175+
#def __getattr__(self, attr):
176+
# """
177+
# Fails right now. Future code:
178+
# # From astropy.modeling.polynomial.PolynomialBase
179+
# if self._param_names and attr in self._param_names:
180+
# return Parameter(attr, default=0.0, model=self)
181+
# raise AttributeError(attr)
182+
# """
183+
# raise NotImplementedError("SplineModel does not currently expose parameters")
184+
185+
#def __setattr__(self, attr, value):
186+
# """
187+
# Fails right now. Future code:
188+
# # From astropy.modeling.polynomial.PolynomialBase
189+
# if attr[0] != '_' and self._param_names and attr in self._param_names:
190+
# param = Parameter(attr, default=0.0, model=self)
191+
# param.__set__(self, value)
192+
# else:
193+
# super().__setattr__(attr, value)
194+
# """
195+
# raise NotImplementedError("SplineModel does not currently expose parameters")
196+
197+
def _generate_coeff_names(self):
198+
names = []
199+
degree, Nknots = self._degree, len(self._knots)
200+
for i in range(Nknots):
201+
for j in range(degree+1):
202+
names.append("k{}_c{}".format(i,j))
203+
return tuple(names)
204+
205+
def evaluate(self, *args, **kwargs):
206+
return self(*args, **kwargs)
207+
208+
209+
210+
class SplineFitter(metaclass=_FitterMeta):
211+
"""
212+
Run a spline fit.
213+
"""
214+
def __init__(self):
215+
self.fit_info = {"fp": None,
216+
"ier": None,
217+
"msg": None}
218+
super().__init__()
219+
220+
def validate_model(self, model):
221+
if not isinstance(model, SplineModel):
222+
raise ValueError("model must be of type SplineModel (currently is {})".format(
223+
type(model)))
224+
225+
## TODO do something about units
226+
#@fitter_unit_support
227+
def __call__(self, model, x, y, w=None):
228+
"""
229+
Fit a spline model to data.
230+
Internally uses scipy.interpolate.splrep.
231+
232+
"""
233+
234+
self.validate_model(model)
235+
236+
## Case (1): fit smoothing spline
237+
if model.get_knots() is None:
238+
tck, fp, ier, msg = interpolate.splrep(x, y, w=w,
239+
t=None,
240+
k=model.get_degree(),
241+
s=model.get_smoothing(),
242+
task=0, full_output=True
243+
)
244+
## Case (2): leastsq spline
245+
else:
246+
knots = model.get_knots()
247+
## TODO some sort of validation that the knots are internal, since
248+
## this procedure automatically adds knots at the two endpoints
249+
tck, fp, ier, msg = interpolate.splrep(x, y, w=w,
250+
t=knots,
251+
k=model.get_degree(),
252+
s=model.get_smoothing(),
253+
task=-1, full_output=True
254+
)
255+
256+
model.set_model_from_tck(tck)
257+
self.fit_info.update({"fp":fp, "ier":ier, "msg":msg})
258+

specutils/tests/test_spline.py

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
import astropy.units as u
2+
import numpy as np
3+
4+
from astropy.modeling import models, fitting
5+
from specutils.fitting.spline import SplineModel, SplineFitter
6+
7+
from scipy import interpolate
8+
9+
def make_data(with_errs=True):
10+
""" Arbitrary data """
11+
np.random.seed(348957)
12+
x = np.linspace(0, 10, 200)
13+
y = (x+1) - (x-5)**2. + 10.*np.exp(-0.5 * ((x-7.)/.5)**2.)
14+
y = (y - np.min(y) + 10.)*10.
15+
if with_errs:
16+
ey = np.sqrt(y)
17+
y = y + np.random.normal(0., ey, y.shape)
18+
w = 1./y
19+
return x, y, w
20+
21+
def test_spline_fit():
22+
x, y, w = make_data()
23+
make_plot=False
24+
25+
# Construct three sets of splines and their scipy equivalents
26+
knots = np.arange(1,10)
27+
models = [SplineModel(), SplineModel(degree=5), SplineModel(knots=knots), SplineModel(smoothing=0)]
28+
labels = ["Deg 3", "Deg 5", "Knots", "Interpolated"]
29+
scipyfit = [interpolate.UnivariateSpline(x,y,w),
30+
interpolate.UnivariateSpline(x,y,w,k=5),
31+
interpolate.LSQUnivariateSpline(x,y,knots,w=w),
32+
interpolate.InterpolatedUnivariateSpline(x,y,w)]
33+
34+
fitter = SplineFitter()
35+
for model, label, scipymodel in zip(models, labels, scipyfit):
36+
fitter(model, x, y, w)
37+
my_y = model(x)
38+
sci_y = scipymodel(x)
39+
assert np.allclose(my_y, sci_y, atol=1e-6)
40+
41+
if make_plot:
42+
import matplotlib.pyplot as plt
43+
fig, ax = plt.subplots()
44+
ax.plot(x,y,'k.')
45+
ymin, ymax = np.min(y), np.max(y)
46+
for i,(model, label) in enumerate(zip(models, labels)):
47+
l, = ax.plot(x, model(x), lw=1, label=label)
48+
knots = model.get_knots()
49+
# Hack for now
50+
if knots is None: knots = model._tck[0]
51+
print(knots)
52+
dy = (ymax-ymin)/10.
53+
dy /= i+1.
54+
ax.vlines(knots, ymin, ymin + dy, color=l.get_color(), lw=1)
55+
ax.legend()
56+
plt.show()
57+
58+
if __name__=="__main__":
59+
test_spline_fit()
60+

0 commit comments

Comments
 (0)