11
11
from astropy .utils import indent , check_broadcast
12
12
from astropy .units import Quantity
13
13
14
-
15
14
__all__ = []
16
15
16
+
17
17
class SplineModel (FittableModel ):
18
18
"""
19
19
Wrapper around scipy.interpolate.splrep and splev
20
-
20
+
21
21
Analogous to scipy.interpolate.UnivariateSpline() if knots unspecified,
22
22
and scipy.interpolate.LSQUnivariateSpline if knots are specified
23
-
23
+
24
24
There are two ways to make a spline model.
25
- (1) you have the spline auto-determine knots from the data
26
- (2) you specify the knots
27
-
25
+ 1. you have the spline auto-determine knots from the data
26
+ 2. you specify the knots
28
27
"""
29
-
30
- linear = False # I think? I have no idea?
31
- col_fit_deriv = False # Not sure what this is
32
-
33
- def __init__ (self , degree = 3 , smoothing = None , knots = None , extrapolate_mode = 0 ):
28
+
29
+ linear = False # I think? I have no idea?
30
+ col_fit_deriv = False # Not sure what this is
31
+
32
+ def __init__ (self , degree = 3 , smoothing = None , knots = None ,
33
+ extrapolate_mode = 0 , * args , ** kwargs ):
34
34
"""
35
35
Set up a spline model.
36
-
36
+
37
37
degree: degree of the spline (default 3)
38
38
In scipy fitpack, this is "k"
39
-
39
+
40
40
smoothing (optional): smoothing value for automatically determining knots
41
41
In scipy fitpack, this is "s"
42
- By default, uses a
43
-
42
+ By default, uses a
43
+
44
44
knots (optional): spline knots (boundaries of piecewise polynomial)
45
45
If not specified, will automatically determine knots based on
46
46
degree + smoothing
47
-
47
+
48
48
extrapolate_mode (optional): how to deal with solution outside of interval.
49
49
(see scipy.interpolate.splev)
50
50
if 0 (default): return the extrapolated value
51
51
if 1, return 0
52
52
if 2, raise a ValueError
53
53
if 3, return the boundary value
54
54
"""
55
+ super ().__init__ (* args , ** kwargs )
56
+
55
57
self ._degree = degree
56
58
self ._smoothing = smoothing
57
59
self ._knots = self .verify_knots (knots )
58
60
self .extrapolate_mode = extrapolate_mode
59
-
60
- ## This is used to evaluate the spline
61
- ## When None, raises an error when trying to evaluate the spline
61
+
62
+ # This is used to evaluate the spline. When None, raises an error when
63
+ # trying to evaluate the spline.
62
64
self ._tck = None
63
-
65
+
64
66
self ._param_names = ()
65
-
67
+
66
68
def verify_knots (self , knots ):
67
69
"""
68
- Basic knot array vetting.
69
- The goal of having this is to enable more useful error messages
70
- than scipy (if needed).
70
+ Basic knot array vetting. The goal of having this is to enable more
71
+ useful error messages than scipy (if needed).
71
72
"""
72
- if knots is None : return None
73
+ if knots is None :
74
+ return None
75
+
73
76
knots = np .array (knots )
74
77
assert len (knots .shape ) == 1 , knots .shape
75
78
knots = np .sort (knots )
76
79
assert len (np .unique (knots )) == len (knots ), knots
80
+
77
81
return knots
78
-
79
- ############
80
- ## Getters
81
- ############
82
- def get_degree (self ):
82
+
83
+ # Getters
84
+ @property
85
+ def degree (self ):
83
86
""" Spline degree (k in FITPACK) """
84
87
return self ._degree
85
- def get_smoothing (self ):
88
+
89
+ @property
90
+ def smoothing (self ):
86
91
""" Spline smoothing (s in FITPACK) """
87
92
return self ._smoothing
88
- def get_knots (self ):
93
+
94
+ @property
95
+ def knots (self ):
89
96
""" Spline knots (t in FITPACK) """
90
97
return self ._knots
91
- def get_coeffs (self ):
98
+
99
+ @property
100
+ def coeffs (self ):
92
101
""" Spline coefficients (c in FITPACK) """
93
102
if self ._tck is not None :
94
103
return self ._tck [1 ]
95
104
else :
96
- raise RuntimeError ("SplineModel has not been fit yet" )
97
-
98
- ############
99
- ## Spline methods: not tested at all
100
- ############
101
- def derivative (self , n = 1 ):
102
- if self ._tck is None :
103
- raise RuntimeError ("SplineModel has not been fit yet" )
104
- else :
105
- t , c , k = self ._tck
106
- return scipy .interpolate .BSpline .construct_fast (
107
- t ,c ,k ,extrapolate = (self .extrapolate_mode == 0 )).derivative (n )
108
- def antiderivative (self , n = 1 ):
109
- if self ._tck is None :
110
- raise RuntimeError ("SplineModel has not been fit yet" )
111
- else :
112
- t , c , k = self ._tck
113
- return scipy .interpolate .BSpline .construct_fast (
114
- t ,c ,k ,extrapolate = (self .extrapolate_mode == 0 )).antiderivative (n )
115
- def integral (self , a , b ):
116
- if self ._tck is None :
117
- raise RuntimeError ("SplineModel has not been fit yet" )
118
- else :
119
- t , c , k = self ._tck
120
- return scipy .interpolate .BSpline .construct_fast (
121
- t ,c ,k ,extrapolate = (self .extrapolate_mode == 0 )).integral (a ,b )
122
- def derivatives (self , x ):
123
- raise NotImplementedError
124
- def roots (self ):
125
- raise NotImplementedError
126
-
127
- ############
128
- ## Setters: not really implemented or tested
129
- ############
105
+ raise RuntimeError ("SplineModel has not been fit yet." )
106
+
107
+ # Setters
108
+ # TODO: not really implemented or tested
130
109
def reset_model (self ):
131
110
""" Resets model so it needs to be refit to be valid """
132
111
self ._tck = None
133
- def set_degree (self , degree ):
112
+
113
+ @degree .setter
114
+ def degree (self , degree ):
134
115
""" Spline degree (k in FITPACK) """
135
116
raise NotImplementedError
136
117
self ._degree = degree
137
118
self .reset_model ()
138
- def set_smoothing (self , smoothing ):
119
+
120
+ @smoothing .setter
121
+ def smoothing (self , smoothing ):
139
122
""" Spline smoothing (s in FITPACK) """
140
123
raise NotImplementedError
141
124
self ._smoothing = smoothing
142
125
self .reset_model ()
143
- def set_knots (self , knots ):
126
+
127
+ @knots .setter
128
+ def knots (self , knots ):
144
129
""" Spline knots (t in FITPACK) """
145
130
raise NotImplementedError
146
131
self ._knots = self .verify_knots (knots )
147
132
self .reset_model ()
148
-
133
+
149
134
def set_model_from_tck (self , tck ):
150
135
"""
151
136
Use output of scipy.interpolate.splrep
152
137
"""
153
138
self ._tck = tck
154
139
140
+ # Spline methods
141
+ # TODO: not tested at all
142
+ def derivative (self , n = 1 ):
143
+ if self ._tck is None :
144
+ raise RuntimeError ("SplineModel has not been fit yet" )
145
+ else :
146
+ t , c , k = self ._tck
147
+ return scipy .interpolate .BSpline .construct_fast (
148
+ t , c , k , extrapolate = (self .extrapolate_mode == 0 )).derivative (n )
149
+
150
+ def antiderivative (self , n = 1 ):
151
+ if self ._tck is None :
152
+ raise RuntimeError ("SplineModel has not been fit yet." )
153
+ else :
154
+ t , c , k = self ._tck
155
+ return scipy .interpolate .BSpline .construct_fast (
156
+ t , c , k , extrapolate = (self .extrapolate_mode == 0 )).antiderivative (n )
157
+
158
+ def integral (self , a , b ):
159
+ if self ._tck is None :
160
+ raise RuntimeError ("SplineModel has not been fit yet." )
161
+ else :
162
+ t , c , k = self ._tck
163
+ return scipy .interpolate .BSpline .construct_fast (
164
+ t , c , k , extrapolate = (self .extrapolate_mode == 0 )).integral (a , b )
165
+
166
+ def derivatives (self , x ):
167
+ raise NotImplementedError
168
+
169
+ def roots (self ):
170
+ raise NotImplementedError
171
+
155
172
def __call__ (self , x , der = 0 ):
156
173
"""
157
174
Evaluate the model with the given inputs.
158
175
der is passed to scipy.interpolate.splev
159
176
"""
160
177
if self ._tck is None :
161
- raise RuntimeError ("SplineModel has not been fit yet" )
178
+ raise RuntimeError ("SplineModel has not been fit yet." )
179
+
162
180
return interpolate .splev (x , self ._tck , der = der , ext = self .extrapolate_mode )
163
-
164
- ####################################
165
- ######### Stuff below here is stubs
181
+
182
+ # Stuff below here is stubs
183
+ # TODO: fill out methods
166
184
@property
167
185
def param_names (self ):
168
186
"""
@@ -201,11 +219,10 @@ def _generate_coeff_names(self):
201
219
for j in range (degree + 1 ):
202
220
names .append ("k{}_c{}" .format (i ,j ))
203
221
return tuple (names )
204
-
222
+
205
223
def evaluate (self , * args , ** kwargs ):
206
224
return self (* args , ** kwargs )
207
225
208
-
209
226
210
227
class SplineFitter (metaclass = _FitterMeta ):
211
228
"""
@@ -216,43 +233,42 @@ def __init__(self):
216
233
"ier" : None ,
217
234
"msg" : None }
218
235
super ().__init__ ()
219
-
236
+
220
237
def validate_model (self , model ):
221
238
if not isinstance (model , SplineModel ):
222
239
raise ValueError ("model must be of type SplineModel (currently is {})" .format (
223
240
type (model )))
224
-
225
- ## TODO do something about units
226
- #@fitter_unit_support
241
+
242
+ # TODO do something about units
243
+ # @fitter_unit_support
227
244
def __call__ (self , model , x , y , w = None ):
228
245
"""
229
246
Fit a spline model to data.
230
247
Internally uses scipy.interpolate.splrep.
231
-
248
+
232
249
"""
233
-
250
+
234
251
self .validate_model (model )
235
-
236
- ## Case (1): fit smoothing spline
252
+
253
+ # Case (1): fit smoothing spline
237
254
if model .get_knots () is None :
238
255
tck , fp , ier , msg = interpolate .splrep (x , y , w = w ,
239
256
t = None ,
240
- k = model .get_degree (),
257
+ k = model .get_degree (),
241
258
s = model .get_smoothing (),
242
259
task = 0 , full_output = True
243
260
)
244
- ## Case (2): leastsq spline
261
+ # Case (2): leastsq spline
245
262
else :
246
263
knots = model .get_knots ()
247
264
## TODO some sort of validation that the knots are internal, since
248
265
## this procedure automatically adds knots at the two endpoints
249
266
tck , fp , ier , msg = interpolate .splrep (x , y , w = w ,
250
267
t = knots ,
251
- k = model .get_degree (),
268
+ k = model .get_degree (),
252
269
s = model .get_smoothing (),
253
270
task = - 1 , full_output = True
254
271
)
255
-
272
+
256
273
model .set_model_from_tck (tck )
257
- self .fit_info .update ({"fp" :fp , "ier" :ier , "msg" :msg })
258
-
274
+ self .fit_info .update ({"fp" : fp , "ier" : ier , "msg" : msg })
0 commit comments