|
| 1 | +from __future__ import print_function, division, absolute_import |
| 2 | + |
| 3 | +import numpy as np |
| 4 | +from scipy import interpolate |
| 5 | + |
| 6 | +from astropy.modeling.core import FittableModel, Model |
| 7 | +from astropy.modeling.functional_models import Shift |
| 8 | +from astropy.modeling.parameters import Parameter |
| 9 | +from astropy.modeling.utils import poly_map_domain, comb |
| 10 | +from astropy.modeling.fitting import _FitterMeta, fitter_unit_support |
| 11 | +from astropy.utils import indent, check_broadcast |
| 12 | +from astropy.units import Quantity |
| 13 | + |
| 14 | + |
| 15 | +__all__ = [] |
| 16 | + |
| 17 | +class SplineModel(FittableModel): |
| 18 | + """ |
| 19 | + Wrapper around scipy.interpolate.splrep and splev |
| 20 | + |
| 21 | + Analogous to scipy.interpolate.UnivariateSpline() if knots unspecified, |
| 22 | + and scipy.interpolate.LSQUnivariateSpline if knots are specified |
| 23 | + |
| 24 | + There are two ways to make a spline model. |
| 25 | + (1) you have the spline auto-determine knots from the data |
| 26 | + (2) you specify the knots |
| 27 | + |
| 28 | + """ |
| 29 | + |
| 30 | + linear = False # I think? I have no idea? |
| 31 | + col_fit_deriv = False # Not sure what this is |
| 32 | + |
| 33 | + def __init__(self, degree=3, smoothing=None, knots=None, extrapolate_mode=0): |
| 34 | + """ |
| 35 | + Set up a spline model. |
| 36 | + |
| 37 | + degree: degree of the spline (default 3) |
| 38 | + In scipy fitpack, this is "k" |
| 39 | + |
| 40 | + smoothing (optional): smoothing value for automatically determining knots |
| 41 | + In scipy fitpack, this is "s" |
| 42 | + By default, uses a |
| 43 | + |
| 44 | + knots (optional): spline knots (boundaries of piecewise polynomial) |
| 45 | + If not specified, will automatically determine knots based on |
| 46 | + degree + smoothing |
| 47 | + |
| 48 | + extrapolate_mode (optional): how to deal with solution outside of interval. |
| 49 | + (see scipy.interpolate.splev) |
| 50 | + if 0 (default): return the extrapolated value |
| 51 | + if 1, return 0 |
| 52 | + if 2, raise a ValueError |
| 53 | + if 3, return the boundary value |
| 54 | + """ |
| 55 | + self._degree = degree |
| 56 | + self._smoothing = smoothing |
| 57 | + self._knots = self.verify_knots(knots) |
| 58 | + self.extrapolate_mode = extrapolate_mode |
| 59 | + |
| 60 | + ## This is used to evaluate the spline |
| 61 | + ## When None, raises an error when trying to evaluate the spline |
| 62 | + self._tck = None |
| 63 | + |
| 64 | + self._param_names = () |
| 65 | + |
| 66 | + def verify_knots(self, knots): |
| 67 | + """ |
| 68 | + Basic knot array vetting. |
| 69 | + The goal of having this is to enable more useful error messages |
| 70 | + than scipy (if needed). |
| 71 | + """ |
| 72 | + if knots is None: return None |
| 73 | + knots = np.array(knots) |
| 74 | + assert len(knots.shape) == 1, knots.shape |
| 75 | + knots = np.sort(knots) |
| 76 | + assert len(np.unique(knots)) == len(knots), knots |
| 77 | + return knots |
| 78 | + |
| 79 | + ############ |
| 80 | + ## Getters |
| 81 | + ############ |
| 82 | + def get_degree(self): |
| 83 | + """ Spline degree (k in FITPACK) """ |
| 84 | + return self._degree |
| 85 | + def get_smoothing(self): |
| 86 | + """ Spline smoothing (s in FITPACK) """ |
| 87 | + return self._smoothing |
| 88 | + def get_knots(self): |
| 89 | + """ Spline knots (t in FITPACK) """ |
| 90 | + return self._knots |
| 91 | + def get_coeffs(self): |
| 92 | + """ Spline coefficients (c in FITPACK) """ |
| 93 | + if self._tck is not None: |
| 94 | + return self._tck[1] |
| 95 | + else: |
| 96 | + raise RuntimeError("SplineModel has not been fit yet") |
| 97 | + |
| 98 | + ############ |
| 99 | + ## Spline methods: not tested at all |
| 100 | + ############ |
| 101 | + def derivative(self, n=1): |
| 102 | + if self._tck is None: |
| 103 | + raise RuntimeError("SplineModel has not been fit yet") |
| 104 | + else: |
| 105 | + t, c, k = self._tck |
| 106 | + return scipy.interpolate.BSpline.construct_fast( |
| 107 | + t,c,k,extrapolate=(self.extrapolate_mode==0)).derivative(n) |
| 108 | + def antiderivative(self, n=1): |
| 109 | + if self._tck is None: |
| 110 | + raise RuntimeError("SplineModel has not been fit yet") |
| 111 | + else: |
| 112 | + t, c, k = self._tck |
| 113 | + return scipy.interpolate.BSpline.construct_fast( |
| 114 | + t,c,k,extrapolate=(self.extrapolate_mode==0)).antiderivative(n) |
| 115 | + def integral(self, a, b): |
| 116 | + if self._tck is None: |
| 117 | + raise RuntimeError("SplineModel has not been fit yet") |
| 118 | + else: |
| 119 | + t, c, k = self._tck |
| 120 | + return scipy.interpolate.BSpline.construct_fast( |
| 121 | + t,c,k,extrapolate=(self.extrapolate_mode==0)).integral(a,b) |
| 122 | + def derivatives(self, x): |
| 123 | + raise NotImplementedError |
| 124 | + def roots(self): |
| 125 | + raise NotImplementedError |
| 126 | + |
| 127 | + ############ |
| 128 | + ## Setters: not really implemented or tested |
| 129 | + ############ |
| 130 | + def reset_model(self): |
| 131 | + """ Resets model so it needs to be refit to be valid """ |
| 132 | + self._tck = None |
| 133 | + def set_degree(self, degree): |
| 134 | + """ Spline degree (k in FITPACK) """ |
| 135 | + raise NotImplementedError |
| 136 | + self._degree = degree |
| 137 | + self.reset_model() |
| 138 | + def set_smoothing(self, smoothing): |
| 139 | + """ Spline smoothing (s in FITPACK) """ |
| 140 | + raise NotImplementedError |
| 141 | + self._smoothing = smoothing |
| 142 | + self.reset_model() |
| 143 | + def set_knots(self, knots): |
| 144 | + """ Spline knots (t in FITPACK) """ |
| 145 | + raise NotImplementedError |
| 146 | + self._knots = self.verify_knots(knots) |
| 147 | + self.reset_model() |
| 148 | + |
| 149 | + def set_model_from_tck(self, tck): |
| 150 | + """ |
| 151 | + Use output of scipy.interpolate.splrep |
| 152 | + """ |
| 153 | + self._tck = tck |
| 154 | + |
| 155 | + def __call__(self, x, der=0): |
| 156 | + """ |
| 157 | + Evaluate the model with the given inputs. |
| 158 | + der is passed to scipy.interpolate.splev |
| 159 | + """ |
| 160 | + if self._tck is None: |
| 161 | + raise RuntimeError("SplineModel has not been fit yet") |
| 162 | + return interpolate.splev(x, self._tck, der=der, ext=self.extrapolate_mode) |
| 163 | + |
| 164 | + #################################### |
| 165 | + ######### Stuff below here is stubs |
| 166 | + @property |
| 167 | + def param_names(self): |
| 168 | + """ |
| 169 | + Coefficient names generated based on the model's knots and polynomial degree. |
| 170 | + Not Implemented |
| 171 | + """ |
| 172 | + raise NotImplementedError("SplineModel does not currently expose parameters") |
| 173 | + return self._param_names |
| 174 | + |
| 175 | + #def __getattr__(self, attr): |
| 176 | + # """ |
| 177 | + # Fails right now. Future code: |
| 178 | + # # From astropy.modeling.polynomial.PolynomialBase |
| 179 | + # if self._param_names and attr in self._param_names: |
| 180 | + # return Parameter(attr, default=0.0, model=self) |
| 181 | + # raise AttributeError(attr) |
| 182 | + # """ |
| 183 | + # raise NotImplementedError("SplineModel does not currently expose parameters") |
| 184 | + |
| 185 | + #def __setattr__(self, attr, value): |
| 186 | + # """ |
| 187 | + # Fails right now. Future code: |
| 188 | + # # From astropy.modeling.polynomial.PolynomialBase |
| 189 | + # if attr[0] != '_' and self._param_names and attr in self._param_names: |
| 190 | + # param = Parameter(attr, default=0.0, model=self) |
| 191 | + # param.__set__(self, value) |
| 192 | + # else: |
| 193 | + # super().__setattr__(attr, value) |
| 194 | + # """ |
| 195 | + # raise NotImplementedError("SplineModel does not currently expose parameters") |
| 196 | + |
| 197 | + def _generate_coeff_names(self): |
| 198 | + names = [] |
| 199 | + degree, Nknots = self._degree, len(self._knots) |
| 200 | + for i in range(Nknots): |
| 201 | + for j in range(degree+1): |
| 202 | + names.append("k{}_c{}".format(i,j)) |
| 203 | + return tuple(names) |
| 204 | + |
| 205 | + def evaluate(self, *args, **kwargs): |
| 206 | + return self(*args, **kwargs) |
| 207 | + |
| 208 | + |
| 209 | + |
| 210 | +class SplineFitter(metaclass=_FitterMeta): |
| 211 | + """ |
| 212 | + Run a spline fit. |
| 213 | + """ |
| 214 | + def __init__(self): |
| 215 | + self.fit_info = {"fp": None, |
| 216 | + "ier": None, |
| 217 | + "msg": None} |
| 218 | + super().__init__() |
| 219 | + |
| 220 | + def validate_model(self, model): |
| 221 | + if not isinstance(model, SplineModel): |
| 222 | + raise ValueError("model must be of type SplineModel (currently is {})".format( |
| 223 | + type(model))) |
| 224 | + |
| 225 | + ## TODO do something about units |
| 226 | + #@fitter_unit_support |
| 227 | + def __call__(self, model, x, y, w=None): |
| 228 | + """ |
| 229 | + Fit a spline model to data. |
| 230 | + Internally uses scipy.interpolate.splrep. |
| 231 | + |
| 232 | + """ |
| 233 | + |
| 234 | + self.validate_model(model) |
| 235 | + |
| 236 | + ## Case (1): fit smoothing spline |
| 237 | + if model.get_knots() is None: |
| 238 | + tck, fp, ier, msg = interpolate.splrep(x, y, w=w, |
| 239 | + t=None, |
| 240 | + k=model.get_degree(), |
| 241 | + s=model.get_smoothing(), |
| 242 | + task=0, full_output=True |
| 243 | + ) |
| 244 | + ## Case (2): leastsq spline |
| 245 | + else: |
| 246 | + knots = model.get_knots() |
| 247 | + ## TODO some sort of validation that the knots are internal, since |
| 248 | + ## this procedure automatically adds knots at the two endpoints |
| 249 | + tck, fp, ier, msg = interpolate.splrep(x, y, w=w, |
| 250 | + t=knots, |
| 251 | + k=model.get_degree(), |
| 252 | + s=model.get_smoothing(), |
| 253 | + task=-1, full_output=True |
| 254 | + ) |
| 255 | + |
| 256 | + model.set_model_from_tck(tck) |
| 257 | + self.fit_info.update({"fp":fp, "ier":ier, "msg":msg}) |
| 258 | + |
0 commit comments