diff --git a/specutils/fitting/spline.py b/specutils/fitting/spline.py new file mode 100644 index 000000000..1072f5fa3 --- /dev/null +++ b/specutils/fitting/spline.py @@ -0,0 +1,278 @@ +from __future__ import print_function, division, absolute_import + +import numpy as np +from scipy import interpolate +import warnings + +from astropy.modeling.core import FittableModel, Fittable1DModel, Model +from astropy.modeling.functional_models import Shift +from astropy.modeling.parameters import Parameter +from astropy.modeling.utils import poly_map_domain, comb +from astropy.modeling.fitting import _FitterMeta, fitter_unit_support +from astropy.utils import indent, check_broadcast +from astropy.units import Quantity + +__all__ = [] + + +class SplineModel(Fittable1DModel): + """ + Wrapper around scipy.interpolate.splrep and splev + + Analogous to scipy.interpolate.UnivariateSpline() if knots unspecified, + and scipy.interpolate.LSQUnivariateSpline if knots are specified + + There are two ways to make a spline model. + 1. you have the spline auto-determine knots from the data + 2. you specify the knots + """ + + def __init__(self, degree=3, smoothing=None, knots=None, + extrapolate_mode=0, *args, **kwargs): + """ + Set up a spline model. + + degree: degree of the spline (default 3) + In scipy fitpack, this is "k" + + smoothing (optional): smoothing value for automatically determining knots + In scipy fitpack, this is "s" + By default, uses s = len(w) (see scipy.interpolate.UnivariateSpline) + + knots (optional): spline knots (boundaries of piecewise polynomial) + If not specified, will automatically determine knots based on + degree + smoothing. The fit is identical to scipy.interpolate.UnivariateSpline. + If specified, analogous to scipy.interpolate.LSQUnivariateSpline. + + extrapolate_mode (optional): how to deal with solution outside of interval. + (see scipy.interpolate.splev) + if 0 (default): return the extrapolated value + if 1, return 0 + if 2, raise a ValueError + if 3, return the boundary value + """ + + self._param_names = () + + self._degree = degree + self._smoothing = smoothing + self._knots = self.verify_knots(knots) + self.extrapolate_mode = extrapolate_mode + + # This is used to evaluate the spline. When None, raises an error when + # trying to evaluate the spline. + self._tck = None + + super().__init__(*args, **kwargs) + + def verify_knots(self, knots): + """ + Basic knot array vetting. The goal of having this is to enable more + useful error messages than scipy (if needed). + """ + if knots is None: + return None + + knots = np.array(knots) + assert len(knots.shape) == 1, knots.shape + knots = np.sort(knots) + assert len(np.unique(knots)) == len(knots), knots + + return knots + + # Getters + @property + def degree(self): + """ Spline degree (k in FITPACK) """ + return self._degree + + @property + def smoothing(self): + """ Spline smoothing (s in FITPACK) """ + return self._smoothing + + @property + def knots(self): + """ Spline knots (t in FITPACK) """ + return self._knots + + @property + def coeffs(self): + """ Spline coefficients (c in FITPACK) """ + if self._tck is not None: + return self._tck[1] + else: + raise RuntimeError("SplineModel has not been fit yet.") + + # Setters + def reset_model(self): + """ Resets model so it needs to be refit to be valid """ + self._tck = None + self._param_names = () + + @degree.setter + def degree(self, degree): + """ Spline degree (k in FITPACK) """ + self._degree = degree + self.reset_model() + + @smoothing.setter + def smoothing(self, smoothing): + """ Spline smoothing (s in FITPACK) """ + self._smoothing = smoothing + self.reset_model() + + @knots.setter + def knots(self, knots): + """ Spline knots (t in FITPACK) """ + self._knots = self.verify_knots(knots) + self.reset_model() + + def set_model_from_tck(self, tck): + """ + Main way to update model + Use output of scipy.interpolate.splrep + """ + t, c, k = tck + self.degree = k + self.knots = t[k:-k] + self._tck = tck + self._param_names = self._generate_coeff_names() + + # Spline methods + def derivative(self, n=1): + if self._tck is None: + raise RuntimeError("SplineModel has not been fit yet") + else: + ext = 1 if self.extrapolate_mode == 3 else self.extrapolate_mode + new_tck = interpolate.fitpack.splder(self._tck, n) + newmodel = SplineModel(degree=self.degree, smoothing=self.smoothing, + knots=self.knots, extrapolate_mode=ext) + newmodel.set_model_from_tck(new_tck) + return newmodel + + def antiderivative(self, n=1): + if self._tck is None: + raise RuntimeError("SplineModel has not been fit yet.") + else: + new_tck = interpolate.fitpack.splantider(self._tck, n) + newmodel = SplineModel(degree=self.degree, smoothing=self.smoothing, + knots=self.knots, extrapolate_mode=self.extrapolate_mode) + newmodel.set_model_from_tck(new_tck) + return newmodel + + def integral(self, a, b): + if self._tck is None: + raise RuntimeError("SplineModel has not been fit yet.") + else: + t, c, k = self._tck + return interpolate.dfitpack.splint(t, c, k, a, b) + + def derivatives(self, x): + if self._tck is None: + raise RuntimeError("SplineModel has not been fit yet.") + else: + t, c, k = self._tck + d, ier = interpolate.dfitpack.spalde(t, c, k, x) + if not ier == 0: + raise ValueError("Error code returned by spalde: %s" % ier) + return d + + def roots(self): + if self._tck is None: + raise RuntimeError("SplineModel has not been fit yet.") + t, c, k = self._tck + if k == 3: + z, m, ier = interpolate.dfitpack.sproot(t, c) + if not ier == 0: + raise ValueError("Error code returned by spalde: %s" % ier) + return z[:m] + raise NotImplementedError('finding roots unsupported for ' + 'non-cubic splines') + + def __call__(self, x, der=0): + """ + Evaluate the model with the given inputs. + der is passed to scipy.interpolate.splev + """ + if self._tck is None: + raise RuntimeError("SplineModel has not been fit yet.") + + return interpolate.splev(x, self._tck, der=der, ext=self.extrapolate_mode) + + # Stuff below here is stubs + # TODO: fill out methods + @property + def param_names(self): + """ + Coefficient names generated based on the model's knots and polynomial degree. + Not Implemented + """ + #raise NotImplementedError("SplineModel does not currently expose parameters") + warnings.warn("SplineModel does not currently expose parameters\n" + "Will only work with SplineFitter") + try: + return self._param_names + except AttributeError: + return () + + def _generate_coeff_names(self): + names = [] + degree, Nknots = self._degree, len(self._knots) + for i in range(Nknots): + for j in range(degree+1): + names.append("k{}_c{}".format(i,j)) + return tuple(names) + + def evaluate(self, *args, **kwargs): + return self(*args, **kwargs) + + +class SplineFitter(metaclass=_FitterMeta): + """ + Run a spline fit. + """ + def __init__(self): + self.fit_info = {"fp": None, + "ier": None, + "msg": None} + super().__init__() + + def validate_model(self, model): + if not isinstance(model, SplineModel): + raise ValueError("model must be of type SplineModel (currently is {})".format( + type(model))) + + # TODO do something about units + # @fitter_unit_support + def __call__(self, model, x, y, w=None): + """ + Fit a spline model to data. + Internally uses scipy.interpolate.splrep. + + """ + + self.validate_model(model) + + # Case (1): fit smoothing spline + if model.knots is None: + tck, fp, ier, msg = interpolate.splrep(x, y, w=w, + t=None, + k=model.degree, + s=model.smoothing, + task=0, full_output=True + ) + # Case (2): leastsq spline + else: + knots = model.knots + ## TODO some sort of validation that the knots are internal, since + ## this procedure automatically adds knots at the two endpoints + tck, fp, ier, msg = interpolate.splrep(x, y, w=w, + t=knots, + k=model.degree, + s=model.smoothing, + task=-1, full_output=True + ) + + model.set_model_from_tck(tck) + self.fit_info.update({"fp": fp, "ier": ier, "msg": msg}) diff --git a/specutils/fitting/spline_continuum.py b/specutils/fitting/spline_continuum.py new file mode 100644 index 000000000..617946c8c --- /dev/null +++ b/specutils/fitting/spline_continuum.py @@ -0,0 +1,485 @@ +from astropy import modeling +from astropy.modeling import models, fitting +from astropy.nddata import StdDevUncertainty +from specutils.spectra import Spectrum1D + +import numpy as np +from scipy import interpolate + +import logging +import warnings + +__all__ = ['fit_continuum_generic', 'fit_continuum_linetools'] + + +def fit_continuum_generic(spectrum, + model=None, fitter=None, + sigma=3.0, sigma_lower=None, + sigma_upper=None, iters=5, + exclude_regions=None, + full_output=False): + """ + Fit a generic continuum model to a spectrum. + + The default algorithm is iterative sigma clipping + + Parameters + ---------- + spectrum : `~specutils.Spectrum1D` + The `~specutils.Spectrum1D` object to which a continuum model is fit + model : `astropy.modeling.FittableModel` + The type of model to use for the continuum. + Must be astropy.modeling.FittableModel + See astropy.modeling.models + Default: models.Chebyshev1D(3) + TODO add a spline option (since this is not currently implemented) + fitter : `astropy.modeling.fitting.Fitter` + The type of fitter to use for the continuum. + See astropy.modeling.fitting for valid choices + TODO currently does not typecheck because fitters do not subclass fitting.Fitter + Default: fitting.LevMarLSQFitter() + sigma : float, optional + The number of standard deviations to use for both lower and upper clipping limit. + Defaults to 3.0 + sigma_lower : float or None, optional + Number of standard deviations for lower bound clipping limit. + If None (default), then `sigma` is used. + + sigma_upper : float or None, optional + Number of standard deviations for upper bound clipping limit. + If None (default), then `sigma` is used. + iters : int or None, optional + Number of iterations to perform sigma clipping. + If None, clips until convergence achieved. + Defaults to 5 + exclude_regions : list of tuples, optional + A list of dispersion regions to exclude. + Each tuple must be sorted. + e.g. [(6555,6575)] + full_output : bool, optional + If True, return more information. + Currently, just the model and the pixels-used boolean array + + Returns + ------- + continuum_model : `astropy.modeling.FittableModel` + Output a model for the continuum + + Raises + ------ + ValueError + If: spectrum is not the correct type, + the exclude regions do not satisfy a list of sorted tuples, + the model and/or fitter are of the wrong type, + + Examples + -------- + TODO: add more and unit tests + + See https://github.com/spacetelescope/dat_pyinthesky/blob/master/pyinthesky_specutils_fitting.ipynb + + """ + + # Parameter checks + if not isinstance(spectrum, Spectrum1D): + raise ValueError('The spectrum parameter must be a Spectrum1D object') + + exclude_regions = [] if exclude_regions is None else exclude_regions + + for exclude_region in exclude_regions: + if len(exclude_region) != 2: + raise ValueError('All exclusion regions must be of length 2') + if exclude_region[0] >= exclude_region[1]: + raise ValueError('All exclusion regions must be (low, high)') + + # Set default model and fitter + if model is None: + logging.info("Using Chebyshev1D(3) as default continuum model") + model = models.Chebyshev1D(3) + + if fitter is None: + fitter = fitting.LevMarLSQFitter() + + if not isinstance(model, modeling.FittableModel): + raise ValueError('The model parameter must be a astropy.modeling.FittableModel object') + # TODO: this is waiting on a refactor in modeling.fitting to work + + # if not isinstance(fitter, fitting.Fitter): + # raise ValueError("The model parameter must be an " + # "astropy.modeling.fitting.Fitter object.") + + # Get input spectrum data + x = spectrum.spectral_axis.value + y = spectrum.flux.value + + # Set up valid pixels mask. Exclude non-finite values. + good = np.isfinite(y) + + # Exclude regions + for (excl1, excl2) in exclude_regions: + good[np.logical_and(x > excl1, x < excl2)] = False + + # Set up sigma clipping + if sigma_lower is None: + sigma_lower = sigma + + if sigma_upper is None: + sigma_upper = sigma + + # Set the model as the default continuum in cases where the sigma + # clipping iterations == 0 + continuum_model = model + + for i_iter in range(iters): + logging.info("Iter {}: Fitting {}/{} pixels".format( + i_iter, good.sum(), len(good))) + + # Fit model + # TODO: include data uncertainties + continuum_model = fitter(model, x[good], y[good]) + + # Sigma clip + difference = continuum_model(x) - y + finite = np.isfinite(difference) + sigma_difference = difference / np.std(difference[np.logical_and(good, finite)]) + good[sigma_difference > sigma_upper] = False + good[sigma_difference < -sigma_lower] = False + + if full_output: + return continuum_model, good + + return continuum_model + + +def fit_continuum_linetools(spec, edges=None, ax=None, debug=False, kind="QSO", **kwargs): + """ + A direct port of the linetools continuum normalization algorithm by + X Prochaska (https://github.com/linetools/linetools/blob/master/linetools/analysis/continuum.py) + + The only changes are switching to Scipy's Akima1D interpolator and + changing the relevant syntax. + """ + assert kind in ["QSO"], kind + + if not isinstance(spec, Spectrum1D): + raise ValueError('The spectrum parameter must be a Spectrum1D object') + + # To start, we define all the functions here to avoid namespace bloat, + # but this can be fixed later. The goal is to have the same algorithm but + # with flexible wavelength chunks for other object types + + def make_chunks_qso(wa, redshift, divmult=1, forest_divmult=1, + debug=False): + """ + Generate a series of wavelength chunks for use by + prepare_knots, assuming a QSO spectrum. + """ + cond = np.isnan(wa) + + if np.any(cond): + warnings.warn('Some wavelengths are NaN, ignoring these pixels.') + wa = wa[~cond] + + assert len(wa) > 0 + + zp1 = 1 + redshift + div = np.rec.fromrecords([(200. , 500. , 25), + (500. , 800. , 25), + (800. , 1190., 25), + (1190., 1213., 4), + (1213., 1230., 6), + (1230., 1263., 6), + (1263., 1290., 5), + (1290., 1340., 5), + (1340., 1370., 2), + (1370., 1410., 5), + (1410., 1515., 5), + (1515., 1600., 15), + (1600., 1800., 8), + (1800., 1900., 5), + (1900., 1940., 5), + (1940., 2240., 15), + (2240., 3000., 25), + (3000., 6000., 80), + (6000., 20000., 100), + ], names=str('left,right,num')) + + div.num[2:] = np.ceil(div.num[2:] * divmult) + div.num[:2] = np.ceil(div.num[:2] * forest_divmult) + div.left *= zp1 + div.right *= zp1 + + if debug: + logging.info(div.tolist()) + + temp = [np.linspace(left, right, n+1)[:-1] for left, right, n in div] + edges = np.concatenate(temp) + + i0, i1, i2 = edges.searchsorted([wa[0], 1210*zp1, wa[-1]]) + + if debug: + logging.info(i0, i1, i2) + + return edges[i0:i2] + + def update_knots(knots, indices, fl, masked): + """ + Calculate the y position of each knot. + + Updates `knots` inplace. + + Parameters + ---------- + knots: list of [xpos, ypos, bool] with length N + bool says whether the knot should kept unchanged. + indices: list of (i0,i1) index pairs + The start and end indices into fl and masked of each + spectrum chunk (xpos of each knot are the chunk centres). + fl, masked: arrays shape (M,) + The flux, and boolean arrays showing which pixels are + masked. + """ + iy, iflag = 1, 2 + + for iknot, (i1, i2) in enumerate(indices): + if knots[iknot][iflag]: + continue + + f0 = fl[i1:i2] + m0 = masked[i1:i2] + f1 = f0[~m0] + knots[iknot][iy] = np.median(f1) + + def linear_co(wa, knots): + """ + linear interpolation through the spline knots. + + Add extra points on either end to give + a nice slope at the end points. + """ + wavc, mfl = list(zip(*knots))[:2] + extwavc = ([wavc[0] - (wavc[1] - wavc[0])] + list(wavc) + + [wavc[-1] + (wavc[-1] - wavc[-2])]) + extmfl = ([mfl[0] - (mfl[1] - mfl[0])] + list(mfl) + + [mfl[-1] + (mfl[-1] - mfl[-2])]) + co = np.interp(wa, extwavc, extmfl) + + return co + + def Akima_co(wa, knots): + """Akima interpolation through the spline knots.""" + x, y, _ = zip(*knots) + spl = interpolate.Akima1DInterpolator(x, y) + + return spl(wa) + + def remove_bad_knots(knots, indices, masked, fl, er, debug=False): + """ + Remove knots in chunks without any good pixels. Modifies + inplace. + """ + idelknot = [] + + for iknot, (i, j) in enumerate(indices): + if np.all(masked[i:j]) or np.median(fl[i:j]) <= 2*np.median(er[i:j]): + if debug: + print('Deleting knot', iknot, 'near {:.1f} Angstroms'.format( + knots[iknot][0])) + idelknot.append(iknot) + + for i in reversed(idelknot): + del knots[i] + del indices[i] + + def chisq_chunk(model, fl, er, masked, indices, knots, chithresh=1.5): + """ + Calc chisq per chunk, update knots flags inplace if chisq is + acceptable. + """ + chisq = [] + FLAG = 2 + + for iknot, (i1, i2) in enumerate(indices): + if knots[iknot][FLAG]: + continue + + f0 = fl[i1:i2] + e0 = er[i1:i2] + m0 = masked[i1:i2] + f1 = f0[~m0] + e1 = e0[~m0] + mod0 = model[i1:i2] + mod1 = mod0[~m0] + resid = (mod1 - f1) / e1 + chisq = np.sum(resid*resid) + rchisq = chisq / len(f1) + + if rchisq < chithresh: + knots[iknot][FLAG] = True + + def prepare_knots(wa, fl, er, edges, ax=None, debug=False): + """ + Make initial knots for the continuum estimation. + + Parameters + ---------- + wa, fl, er : arrays + Wavelength, flux, error. + edges : The edges of the wavelength chunks. Splines knots are to be + places at the centre of these chunks. + ax : Matplotlib Axes + If not None, use to plot debugging info. + + Returns + ------- + knots, indices, masked + * knots: A list of [x, y, flag] lists giving the x and y position + of each knot. + * indices: A list of tuples (i,j) giving the start and end index + of each chunk. + * masked: An array the same shape as wa. + """ + indices = wa.searchsorted(edges) + indices = [(i0,i1) for i0,i1 in zip(indices[:-1],indices[1:])] + wavc = [0.5*(w1 + w2) for w1,w2 in zip(edges[:-1],edges[1:])] + + knots = [[wavc[i], 0, False] for i in range(len(wavc))] + + masked = np.zeros(len(wa), bool) + masked[~(er > 0)] = True + + # remove bad knots + remove_bad_knots(knots, indices, masked, fl, er, debug=debug) + + if ax is not None: + yedge = np.interp(edges, wa, fl) + ax.vlines(edges, 0, yedge + 100, color='c', zorder=10) + + # set the knot flux values + update_knots(knots, indices, fl, masked) + + if ax is not None: + x, y = list(zip(*knots))[:2] + ax.plot(x, y, 'o', mfc='none', mec='c', ms=10, mew=1, zorder=10) + + return knots, indices, masked + + def unmask(masked, indices, wa, fl, er, minpix=3): + """ + Forces each chunk to use at least minpix pixels. + + Sometimes all pixels can become masked in a chunk. We don't want + this! This forces there to be at least minpix pixels used in each + chunk. + """ + for iknot, (i, j) in enumerate(indices): + if np.sum(~masked[i:j]) < minpix: + # Need to unmask minpix + f0 = fl[i:j] + e0 = er[i:j] + ind = np.arange(i,j) + f1 = f0[e0 > 0] + isort = np.argsort(f1) + ind1 = ind[e0 > 0][isort[-minpix:]] + + masked[ind1] = False + + def estimate_continuum(s, knots, indices, masked, ax=None, maxiter=1000, + nsig=1.5, debug=False): + """ + Iterate to estimate the continuum. + """ + count = 0 + + while True: + if debug: + logging.info('iteration', count) + update_knots(knots, indices, s.fl, masked) + model = linear_co(s.wa, knots) + model_a = Akima_co(s.wa, knots) + chisq_chunk(model_a, s.fl, s.er, masked, + indices, knots, chithresh=1) + flags = list(zip(*knots))[-1] + if np.all(flags): + if debug: + logging.info('All regions have satisfactory fit, stopping') + break + + # remove outliers + c0 = ~masked + resid = (model - s.fl) / s.er + oldmasked = masked.copy() + masked[(resid > nsig) & ~masked] = True + unmask(masked, indices, s.wa, s.fl, s.er) + + if np.all(oldmasked == masked): + if debug: + print('No further points masked, stopping') + break + if count > maxiter: + raise RuntimeError('Exceeded maximum iterations') + + count +=1 + + co = Akima_co(s.wa, knots) + c0 = co <= 0 + co[c0] = 0 + + if ax is not None: + ax.plot(s.wa, linear_co(s.wa, knots), color='0.7', lw=2) + ax.plot(s.wa, co, 'k', lw=2, zorder=10) + x,y = list(zip(*knots))[:2] + ax.plot(x, y, 'o', mfc='none', mec='k', ms=10, mew=1, zorder=10) + + return co + + # Here starts the actual fitting. Pull uncertainty from spectrum. + # TODO: this is very hacky right now + if not hasattr(spec, "uncertainty"): + logging.info("No uncertainty, assuming all are equal (continuum will probably fail)") + error = np.ones(len(spec.wavelength.value)) + else: + if isinstance(spec.uncertainty, StdDevUncertainty): + error = spec.uncertainty.array + else: + raise ValueError("Could not understand uncertainty type: {}".format( + spec.uncertainty)) + + s = np.rec.fromarrays([spec.wavelength.value, + spec.flux.value, + error], names=["wa", "fl", "er"]) + + if edges is not None: + edges = list(edges) + elif kind.upper() == 'QSO': + if 'redshift' in kwargs: + z = kwargs['redshift'] + elif 'redshift' in spec.meta: + z = spec.meta['redshift'] + else: + raise RuntimeError( + "I need the emission redshift for kind='qso'; please\ + provide redshift using `redshift` keyword.") + + divmult = kwargs.get('divmult', 2) + forest_divmult = kwargs.get('forest_divmult', 2) + edges = make_chunks_qso( + s.wa, z, debug=debug, divmult=divmult, + forest_divmult=forest_divmult) + + if ax is not None: + ax.plot(s.wa, s.fl, '-', color='0.4', drawstyle='steps-mid') + ax.plot(s.wa, s.er, 'g') + + knots, indices, masked = prepare_knots(s.wa, s.fl, s.er, edges, + ax=ax, debug=debug) + + # Note this modifies knots and masked inplace + co = estimate_continuum(s, knots, indices, masked, ax=ax, debug=debug) + + if ax is not None: + ax.plot(s.wa[~masked], s.fl[~masked], '.y') + ymax = np.percentile(s.fl[~np.isnan(s.fl)], 95) + ax.set_ylim(-0.02*ymax, 1.1*ymax) + + return co, [k[:2] for k in knots] diff --git a/specutils/tests/test_spline.py b/specutils/tests/test_spline.py new file mode 100644 index 000000000..782cd1e9e --- /dev/null +++ b/specutils/tests/test_spline.py @@ -0,0 +1,78 @@ +import astropy.units as u +import numpy as np + +from astropy.modeling import models, fitting +from specutils.fitting.spline import SplineModel, SplineFitter + +from scipy import interpolate + + +def make_data(with_errs=True): + """ Arbitrary data """ + np.random.seed(348957) + x = np.linspace(0, 10, 200) + y = (x+1) - (x-5)**2. + 10.*np.exp(-0.5 * ((x-7.)/.5)**2.) + y = (y - np.min(y) + 10.)*10. + if with_errs: + ey = np.sqrt(y) + y = y + np.random.normal(0., ey, y.shape) + w = 1./y + return x, y, w + + +def test_spline_fit(): + x, y, w = make_data() + make_plot = False + + # Construct three sets of splines and their scipy equivalents + knots = np.arange(1, 10) + print(len(x)) + models = [SplineModel(), SplineModel(degree=5), SplineModel(knots=knots), + SplineModel(smoothing=0)] + labels = ["Deg 3", "Deg 5", "Knots", "Interpolated"] + scipyfit = [interpolate.UnivariateSpline(x, y, w), + interpolate.UnivariateSpline(x, y, w, k=5), + interpolate.LSQUnivariateSpline(x, y, knots, w=w), + interpolate.InterpolatedUnivariateSpline(x, y, w)] + + fitter = SplineFitter() + for model, label, scipymodel in zip(models, labels, scipyfit): + fitter(model, x, y, w) + my_y = model(x) + my_dy = model.derivative()(x) + my_ady = model.antiderivative()(x) + my_int = model.integral(x[0],x[-1]) + sci_y = scipymodel(x) + sci_dy = scipymodel.derivative()(x) + sci_ady = scipymodel.antiderivative()(x) + sci_int = scipymodel.integral(x[0],x[-1]) + assert np.allclose(my_y, sci_y, atol=1e-6), label + assert np.allclose(my_dy, sci_dy, atol=1e-6), label + assert np.allclose(my_ady, sci_ady, atol=1e-6), label + assert np.allclose(my_int, sci_int, atol=1e-6), label + + my_ders = model.derivatives(x) + sci_ders = scipymodel.derivatives(x) + assert np.allclose(my_ders, sci_ders, atol=1e-6), label + if model.degree == 3: + my_roots = model.roots() + sci_roots = scipymodel.roots() + assert np.allclose(my_roots, sci_roots, atol=1e-6), label + + if make_plot: + import matplotlib.pyplot as plt + fig, ax = plt.subplots() + ax.plot(x, y, 'k.') + ymin, ymax = np.min(y), np.max(y) + for i, (model, label) in enumerate(zip(models, labels)): + l, = ax.plot(x, model(x), lw=1, label=label) + knots = model.knots + # Hack for now + if knots is None: + knots = model._tck[0] + + dy = (ymax-ymin)/10. + dy /= i+1. + ax.vlines(knots, ymin, ymin + dy, color=l.get_color(), lw=1) + ax.legend() + plt.show()