diff --git a/specutils/fitting/spline.py b/specutils/fitting/spline.py index da1759642..fcaacb281 100644 --- a/specutils/fitting/spline.py +++ b/specutils/fitting/spline.py @@ -11,40 +11,40 @@ from astropy.utils import indent, check_broadcast from astropy.units import Quantity - __all__ = [] + class SplineModel(FittableModel): """ Wrapper around scipy.interpolate.splrep and splev - + Analogous to scipy.interpolate.UnivariateSpline() if knots unspecified, and scipy.interpolate.LSQUnivariateSpline if knots are specified - + There are two ways to make a spline model. - (1) you have the spline auto-determine knots from the data - (2) you specify the knots - + 1. you have the spline auto-determine knots from the data + 2. you specify the knots """ - - linear = False # I think? I have no idea? - col_fit_deriv = False # Not sure what this is - - def __init__(self, degree=3, smoothing=None, knots=None, extrapolate_mode=0): + + linear = False # I think? I have no idea? + col_fit_deriv = False # Not sure what this is + + def __init__(self, degree=3, smoothing=None, knots=None, + extrapolate_mode=0, *args, **kwargs): """ Set up a spline model. - + degree: degree of the spline (default 3) In scipy fitpack, this is "k" - + smoothing (optional): smoothing value for automatically determining knots In scipy fitpack, this is "s" - By default, uses a - + By default, uses a + knots (optional): spline knots (boundaries of piecewise polynomial) If not specified, will automatically determine knots based on degree + smoothing - + extrapolate_mode (optional): how to deal with solution outside of interval. (see scipy.interpolate.splev) if 0 (default): return the extrapolated value @@ -52,117 +52,135 @@ def __init__(self, degree=3, smoothing=None, knots=None, extrapolate_mode=0): if 2, raise a ValueError if 3, return the boundary value """ + super().__init__(*args, **kwargs) + self._degree = degree self._smoothing = smoothing self._knots = self.verify_knots(knots) self.extrapolate_mode = extrapolate_mode - - ## This is used to evaluate the spline - ## When None, raises an error when trying to evaluate the spline + + # This is used to evaluate the spline. When None, raises an error when + # trying to evaluate the spline. self._tck = None - + self._param_names = () - + def verify_knots(self, knots): """ - Basic knot array vetting. - The goal of having this is to enable more useful error messages - than scipy (if needed). + Basic knot array vetting. The goal of having this is to enable more + useful error messages than scipy (if needed). """ - if knots is None: return None + if knots is None: + return None + knots = np.array(knots) assert len(knots.shape) == 1, knots.shape knots = np.sort(knots) assert len(np.unique(knots)) == len(knots), knots + return knots - - ############ - ## Getters - ############ - def get_degree(self): + + # Getters + @property + def degree(self): """ Spline degree (k in FITPACK) """ return self._degree - def get_smoothing(self): + + @property + def smoothing(self): """ Spline smoothing (s in FITPACK) """ return self._smoothing - def get_knots(self): + + @property + def knots(self): """ Spline knots (t in FITPACK) """ return self._knots - def get_coeffs(self): + + @property + def coeffs(self): """ Spline coefficients (c in FITPACK) """ if self._tck is not None: return self._tck[1] else: - raise RuntimeError("SplineModel has not been fit yet") - - ############ - ## Spline methods: not tested at all - ############ - def derivative(self, n=1): - if self._tck is None: - raise RuntimeError("SplineModel has not been fit yet") - else: - t, c, k = self._tck - return scipy.interpolate.BSpline.construct_fast( - t,c,k,extrapolate=(self.extrapolate_mode==0)).derivative(n) - def antiderivative(self, n=1): - if self._tck is None: - raise RuntimeError("SplineModel has not been fit yet") - else: - t, c, k = self._tck - return scipy.interpolate.BSpline.construct_fast( - t,c,k,extrapolate=(self.extrapolate_mode==0)).antiderivative(n) - def integral(self, a, b): - if self._tck is None: - raise RuntimeError("SplineModel has not been fit yet") - else: - t, c, k = self._tck - return scipy.interpolate.BSpline.construct_fast( - t,c,k,extrapolate=(self.extrapolate_mode==0)).integral(a,b) - def derivatives(self, x): - raise NotImplementedError - def roots(self): - raise NotImplementedError - - ############ - ## Setters: not really implemented or tested - ############ + raise RuntimeError("SplineModel has not been fit yet.") + + # Setters + # TODO: not really implemented or tested def reset_model(self): """ Resets model so it needs to be refit to be valid """ self._tck = None - def set_degree(self, degree): + + @degree.setter + def degree(self, degree): """ Spline degree (k in FITPACK) """ raise NotImplementedError self._degree = degree self.reset_model() - def set_smoothing(self, smoothing): + + @smoothing.setter + def smoothing(self, smoothing): """ Spline smoothing (s in FITPACK) """ raise NotImplementedError self._smoothing = smoothing self.reset_model() - def set_knots(self, knots): + + @knots.setter + def knots(self, knots): """ Spline knots (t in FITPACK) """ raise NotImplementedError self._knots = self.verify_knots(knots) self.reset_model() - + def set_model_from_tck(self, tck): """ Use output of scipy.interpolate.splrep """ self._tck = tck + # Spline methods + # TODO: not tested at all + def derivative(self, n=1): + if self._tck is None: + raise RuntimeError("SplineModel has not been fit yet") + else: + t, c, k = self._tck + return scipy.interpolate.BSpline.construct_fast( + t, c, k, extrapolate=(self.extrapolate_mode == 0)).derivative(n) + + def antiderivative(self, n=1): + if self._tck is None: + raise RuntimeError("SplineModel has not been fit yet.") + else: + t, c, k = self._tck + return scipy.interpolate.BSpline.construct_fast( + t, c, k, extrapolate=(self.extrapolate_mode == 0)).antiderivative(n) + + def integral(self, a, b): + if self._tck is None: + raise RuntimeError("SplineModel has not been fit yet.") + else: + t, c, k = self._tck + return scipy.interpolate.BSpline.construct_fast( + t, c, k, extrapolate=(self.extrapolate_mode == 0)).integral(a, b) + + def derivatives(self, x): + raise NotImplementedError + + def roots(self): + raise NotImplementedError + def __call__(self, x, der=0): """ Evaluate the model with the given inputs. der is passed to scipy.interpolate.splev """ if self._tck is None: - raise RuntimeError("SplineModel has not been fit yet") + raise RuntimeError("SplineModel has not been fit yet.") + return interpolate.splev(x, self._tck, der=der, ext=self.extrapolate_mode) - - #################################### - ######### Stuff below here is stubs + + # Stuff below here is stubs + # TODO: fill out methods @property def param_names(self): """ @@ -201,11 +219,10 @@ def _generate_coeff_names(self): for j in range(degree+1): names.append("k{}_c{}".format(i,j)) return tuple(names) - + def evaluate(self, *args, **kwargs): return self(*args, **kwargs) - class SplineFitter(metaclass=_FitterMeta): """ @@ -216,43 +233,42 @@ def __init__(self): "ier": None, "msg": None} super().__init__() - + def validate_model(self, model): if not isinstance(model, SplineModel): raise ValueError("model must be of type SplineModel (currently is {})".format( type(model))) - - ## TODO do something about units - #@fitter_unit_support + + # TODO do something about units + # @fitter_unit_support def __call__(self, model, x, y, w=None): """ Fit a spline model to data. Internally uses scipy.interpolate.splrep. - + """ - + self.validate_model(model) - - ## Case (1): fit smoothing spline + + # Case (1): fit smoothing spline if model.get_knots() is None: tck, fp, ier, msg = interpolate.splrep(x, y, w=w, t=None, - k=model.get_degree(), + k=model.get_degree(), s=model.get_smoothing(), task=0, full_output=True ) - ## Case (2): leastsq spline + # Case (2): leastsq spline else: knots = model.get_knots() ## TODO some sort of validation that the knots are internal, since ## this procedure automatically adds knots at the two endpoints tck, fp, ier, msg = interpolate.splrep(x, y, w=w, t=knots, - k=model.get_degree(), + k=model.get_degree(), s=model.get_smoothing(), task=-1, full_output=True ) - + model.set_model_from_tck(tck) - self.fit_info.update({"fp":fp, "ier":ier, "msg":msg}) - + self.fit_info.update({"fp": fp, "ier": ier, "msg": msg}) diff --git a/specutils/manipulation/continuum.py b/specutils/fitting/spline_continuum.py similarity index 80% rename from specutils/manipulation/continuum.py rename to specutils/fitting/spline_continuum.py index 345e4c5a4..617946c8c 100644 --- a/specutils/manipulation/continuum.py +++ b/specutils/fitting/spline_continuum.py @@ -1,9 +1,7 @@ -from __future__ import print_function, division, absolute_import - from astropy import modeling from astropy.modeling import models, fitting from astropy.nddata import StdDevUncertainty -from ..spectra import Spectrum1D +from specutils.spectra import Spectrum1D import numpy as np from scipy import interpolate @@ -13,158 +11,177 @@ __all__ = ['fit_continuum_generic', 'fit_continuum_linetools'] + def fit_continuum_generic(spectrum, model=None, fitter=None, - sigma=3.0, sigma_lower=None, sigma_upper=None, iters=5, - exclude_regions=[], + sigma=3.0, sigma_lower=None, + sigma_upper=None, iters=5, + exclude_regions=None, full_output=False): """ Fit a generic continuum model to a spectrum. - + The default algorithm is iterative sigma clipping - + Parameters ---------- spectrum : `~specutils.Spectrum1D` The `~specutils.Spectrum1D` object to which a continuum model is fit - model : `astropy.modeling.FittableModel` The type of model to use for the continuum. Must be astropy.modeling.FittableModel See astropy.modeling.models Default: models.Chebyshev1D(3) TODO add a spline option (since this is not currently implemented) - fitter : `astropy.modeling.fitting.Fitter` The type of fitter to use for the continuum. See astropy.modeling.fitting for valid choices TODO currently does not typecheck because fitters do not subclass fitting.Fitter Default: fitting.LevMarLSQFitter() - sigma : float, optional The number of standard deviations to use for both lower and upper clipping limit. Defaults to 3.0 - sigma_lower : float or None, optional Number of standard deviations for lower bound clipping limit. If None (default), then `sigma` is used. - + sigma_upper : float or None, optional Number of standard deviations for upper bound clipping limit. If None (default), then `sigma` is used. - iters : int or None, optional Number of iterations to perform sigma clipping. If None, clips until convergence achieved. Defaults to 5 - exclude_regions : list of tuples, optional A list of dispersion regions to exclude. Each tuple must be sorted. e.g. [(6555,6575)] - full_output : bool, optional If True, return more information. Currently, just the model and the pixels-used boolean array - + Returns ------- continuum_model : `astropy.modeling.FittableModel` Output a model for the continuum - + Raises ------ ValueError If: spectrum is not the correct type, the exclude regions do not satisfy a list of sorted tuples, the model and/or fitter are of the wrong type, - + Examples -------- - TODO add more and unit tests - + TODO: add more and unit tests + See https://github.com/spacetelescope/dat_pyinthesky/blob/master/pyinthesky_specutils_fitting.ipynb - + """ - - ## Parameter checks + + # Parameter checks if not isinstance(spectrum, Spectrum1D): raise ValueError('The spectrum parameter must be a Spectrum1D object') + + exclude_regions = [] if exclude_regions is None else exclude_regions + for exclude_region in exclude_regions: if len(exclude_region) != 2: raise ValueError('All exclusion regions must be of length 2') if exclude_region[0] >= exclude_region[1]: raise ValueError('All exclusion regions must be (low, high)') - - ## Set default model and fitter + + # Set default model and fitter if model is None: logging.info("Using Chebyshev1D(3) as default continuum model") model = models.Chebyshev1D(3) + if fitter is None: fitter = fitting.LevMarLSQFitter() + if not isinstance(model, modeling.FittableModel): raise ValueError('The model parameter must be a astropy.modeling.FittableModel object') - ## TODO this is waiting on a refactor in modeling.fitting to work - #if not isinstance(fitter, fitting.Fitter): - # raise ValueError('The model parameter must be a astropy.modeling.fitting.Fitter object') - - ## Get input spectrum data + # TODO: this is waiting on a refactor in modeling.fitting to work + + # if not isinstance(fitter, fitting.Fitter): + # raise ValueError("The model parameter must be an " + # "astropy.modeling.fitting.Fitter object.") + + # Get input spectrum data x = spectrum.spectral_axis.value y = spectrum.flux.value - - ## Set up valid pixels mask - ## Exclude non-finite values + + # Set up valid pixels mask. Exclude non-finite values. good = np.isfinite(y) - ## Exclude regions + + # Exclude regions for (excl1, excl2) in exclude_regions: good[np.logical_and(x > excl1, x < excl2)] = False - - ## Set up sigma clipping - if sigma_lower is None: sigma_lower = sigma - if sigma_upper is None: sigma_upper = sigma + + # Set up sigma clipping + if sigma_lower is None: + sigma_lower = sigma + + if sigma_upper is None: + sigma_upper = sigma + + # Set the model as the default continuum in cases where the sigma + # clipping iterations == 0 + continuum_model = model for i_iter in range(iters): - logging.info("Iter {}: Fitting {}/{} pixels".format(i_iter, good.sum(), len(good))) - ## Fit model - ## TODO include data uncertainties + logging.info("Iter {}: Fitting {}/{} pixels".format( + i_iter, good.sum(), len(good))) + + # Fit model + # TODO: include data uncertainties continuum_model = fitter(model, x[good], y[good]) - - ## Sigma clip + + # Sigma clip difference = continuum_model(x) - y finite = np.isfinite(difference) sigma_difference = difference / np.std(difference[np.logical_and(good, finite)]) good[sigma_difference > sigma_upper] = False good[sigma_difference < -sigma_lower] = False - + if full_output: return continuum_model, good + return continuum_model + def fit_continuum_linetools(spec, edges=None, ax=None, debug=False, kind="QSO", **kwargs): """ - A direct port of the linetools continuum normalization algorithm by X Prochaska - https://github.com/linetools/linetools/blob/master/linetools/analysis/continuum.py - - The only changes are switching to Scipy's Akima1D interpolator and changing the relevant syntax + A direct port of the linetools continuum normalization algorithm by + X Prochaska (https://github.com/linetools/linetools/blob/master/linetools/analysis/continuum.py) + + The only changes are switching to Scipy's Akima1D interpolator and + changing the relevant syntax. """ assert kind in ["QSO"], kind + if not isinstance(spec, Spectrum1D): raise ValueError('The spectrum parameter must be a Spectrum1D object') - - ### To start, we define all the functions here to avoid namespace bloat, but this can be fixed later - ### The goal is to have the same algorithm but with flexible wavelength chunks for other object types - - def make_chunks_qso(wa, redshift, divmult=1, forest_divmult=1, debug=False): - """ Generate a series of wavelength chunks for use by + + # To start, we define all the functions here to avoid namespace bloat, + # but this can be fixed later. The goal is to have the same algorithm but + # with flexible wavelength chunks for other object types + + def make_chunks_qso(wa, redshift, divmult=1, forest_divmult=1, + debug=False): + """ + Generate a series of wavelength chunks for use by prepare_knots, assuming a QSO spectrum. """ - cond = np.isnan(wa) + if np.any(cond): warnings.warn('Some wavelengths are NaN, ignoring these pixels.') wa = wa[~cond] + assert len(wa) > 0 - + zp1 = 1 + redshift div = np.rec.fromrecords([(200. , 500. , 25), (500. , 800. , 25), @@ -186,26 +203,31 @@ def make_chunks_qso(wa, redshift, divmult=1, forest_divmult=1, debug=False): (3000., 6000., 80), (6000., 20000., 100), ], names=str('left,right,num')) - + div.num[2:] = np.ceil(div.num[2:] * divmult) div.num[:2] = np.ceil(div.num[:2] * forest_divmult) div.left *= zp1 div.right *= zp1 + if debug: - print(div.tolist()) - temp = [np.linspace(left, right, n+1)[:-1] for left,right,n in div] + logging.info(div.tolist()) + + temp = [np.linspace(left, right, n+1)[:-1] for left, right, n in div] edges = np.concatenate(temp) - - i0,i1,i2 = edges.searchsorted([wa[0], 1210*zp1, wa[-1]]) + + i0, i1, i2 = edges.searchsorted([wa[0], 1210*zp1, wa[-1]]) + if debug: - print(i0,i1,i2) + logging.info(i0, i1, i2) + return edges[i0:i2] - + def update_knots(knots, indices, fl, masked): - """ Calculate the y position of each knot. - + """ + Calculate the y position of each knot. + Updates `knots` inplace. - + Parameters ---------- knots: list of [xpos, ypos, bool] with length N @@ -217,60 +239,70 @@ def update_knots(knots, indices, fl, masked): The flux, and boolean arrays showing which pixels are masked. """ - iy, iflag = 1, 2 - for iknot,(i1,i2) in enumerate(indices): + + for iknot, (i1, i2) in enumerate(indices): if knots[iknot][iflag]: continue - + f0 = fl[i1:i2] m0 = masked[i1:i2] f1 = f0[~m0] knots[iknot][iy] = np.median(f1) - + def linear_co(wa, knots): - """linear interpolation through the spline knots. - + """ + linear interpolation through the spline knots. + Add extra points on either end to give - a nice slope at the end points.""" + a nice slope at the end points. + """ wavc, mfl = list(zip(*knots))[:2] extwavc = ([wavc[0] - (wavc[1] - wavc[0])] + list(wavc) + [wavc[-1] + (wavc[-1] - wavc[-2])]) extmfl = ([mfl[0] - (mfl[1] - mfl[0])] + list(mfl) + [mfl[-1] + (mfl[-1] - mfl[-2])]) co = np.interp(wa, extwavc, extmfl) + return co - + def Akima_co(wa, knots): """Akima interpolation through the spline knots.""" - x,y,_ = zip(*knots) + x, y, _ = zip(*knots) spl = interpolate.Akima1DInterpolator(x, y) + return spl(wa) - + def remove_bad_knots(knots, indices, masked, fl, er, debug=False): - """ Remove knots in chunks without any good pixels. Modifies - inplace.""" + """ + Remove knots in chunks without any good pixels. Modifies + inplace. + """ idelknot = [] - for iknot,(i,j) in enumerate(indices): + + for iknot, (i, j) in enumerate(indices): if np.all(masked[i:j]) or np.median(fl[i:j]) <= 2*np.median(er[i:j]): if debug: print('Deleting knot', iknot, 'near {:.1f} Angstroms'.format( knots[iknot][0])) idelknot.append(iknot) - + for i in reversed(idelknot): del knots[i] del indices[i] - + def chisq_chunk(model, fl, er, masked, indices, knots, chithresh=1.5): - """ Calc chisq per chunk, update knots flags inplace if chisq is - acceptable. """ + """ + Calc chisq per chunk, update knots flags inplace if chisq is + acceptable. + """ chisq = [] FLAG = 2 - for iknot,(i1,i2) in enumerate(indices): + + for iknot, (i1, i2) in enumerate(indices): if knots[iknot][FLAG]: continue - + f0 = fl[i1:i2] e0 = er[i1:i2] m0 = masked[i1:i2] @@ -281,13 +313,14 @@ def chisq_chunk(model, fl, er, masked, indices, knots, chithresh=1.5): resid = (mod1 - f1) / e1 chisq = np.sum(resid*resid) rchisq = chisq / len(f1) + if rchisq < chithresh: - #print (good reduced chisq in knot', iknot) knots[iknot][FLAG] = True - + def prepare_knots(wa, fl, er, edges, ax=None, debug=False): - """ Make initial knots for the continuum estimation. - + """ + Make initial knots for the continuum estimation. + Parameters ---------- wa, fl, er : arrays @@ -296,7 +329,7 @@ def prepare_knots(wa, fl, er, edges, ax=None, debug=False): places at the centre of these chunks. ax : Matplotlib Axes If not None, use to plot debugging info. - + Returns ------- knots, indices, masked @@ -309,60 +342,58 @@ def prepare_knots(wa, fl, er, edges, ax=None, debug=False): indices = wa.searchsorted(edges) indices = [(i0,i1) for i0,i1 in zip(indices[:-1],indices[1:])] wavc = [0.5*(w1 + w2) for w1,w2 in zip(edges[:-1],edges[1:])] - + knots = [[wavc[i], 0, False] for i in range(len(wavc))] - + masked = np.zeros(len(wa), bool) masked[~(er > 0)] = True - + # remove bad knots remove_bad_knots(knots, indices, masked, fl, er, debug=debug) - + if ax is not None: yedge = np.interp(edges, wa, fl) ax.vlines(edges, 0, yedge + 100, color='c', zorder=10) - + # set the knot flux values update_knots(knots, indices, fl, masked) - + if ax is not None: - x,y = list(zip(*knots))[:2] + x, y = list(zip(*knots))[:2] ax.plot(x, y, 'o', mfc='none', mec='c', ms=10, mew=1, zorder=10) - + return knots, indices, masked - - + def unmask(masked, indices, wa, fl, er, minpix=3): - """ Forces each chunk to use at least minpix pixels. - + """ + Forces each chunk to use at least minpix pixels. + Sometimes all pixels can become masked in a chunk. We don't want this! This forces there to be at least minpix pixels used in each chunk. """ - for iknot,(i,j) in enumerate(indices): - #print(iknot, wa[i], wa[j], (~masked[i:j]).sum()) + for iknot, (i, j) in enumerate(indices): if np.sum(~masked[i:j]) < minpix: - #print('unmasking pixels') - # need to unmask minpix + # Need to unmask minpix f0 = fl[i:j] e0 = er[i:j] ind = np.arange(i,j) f1 = f0[e0 > 0] isort = np.argsort(f1) ind1 = ind[e0 > 0][isort[-minpix:]] - # print(wa[i], wa[j]) - # print(wa[ind1]) + masked[ind1] = False - def estimate_continuum(s, knots, indices, masked, ax=None, maxiter=1000, nsig=1.5, debug=False): - """ Iterate to estimate the continuum. + """ + Iterate to estimate the continuum. """ count = 0 + while True: if debug: - print('iteration', count) + logging.info('iteration', count) update_knots(knots, indices, s.fl, masked) model = linear_co(s.wa, knots) model_a = Akima_co(s.wa, knots) @@ -371,38 +402,39 @@ def estimate_continuum(s, knots, indices, masked, ax=None, maxiter=1000, flags = list(zip(*knots))[-1] if np.all(flags): if debug: - print('All regions have satisfactory fit, stopping') + logging.info('All regions have satisfactory fit, stopping') break + # remove outliers c0 = ~masked resid = (model - s.fl) / s.er oldmasked = masked.copy() masked[(resid > nsig) & ~masked] = True unmask(masked, indices, s.wa, s.fl, s.er) + if np.all(oldmasked == masked): if debug: print('No further points masked, stopping') break if count > maxiter: raise RuntimeError('Exceeded maximum iterations') - + count +=1 co = Akima_co(s.wa, knots) c0 = co <= 0 co[c0] = 0 - + if ax is not None: ax.plot(s.wa, linear_co(s.wa, knots), color='0.7', lw=2) ax.plot(s.wa, co, 'k', lw=2, zorder=10) x,y = list(zip(*knots))[:2] ax.plot(x, y, 'o', mfc='none', mec='k', ms=10, mew=1, zorder=10) - + return co - ### Here starts the actual fitting - ## Pull uncertainty from spectrum - ## TODO this is very hacky right now + # Here starts the actual fitting. Pull uncertainty from spectrum. + # TODO: this is very hacky right now if not hasattr(spec, "uncertainty"): logging.info("No uncertainty, assuming all are equal (continuum will probably fail)") error = np.ones(len(spec.wavelength.value)) @@ -412,10 +444,10 @@ def estimate_continuum(s, knots, indices, masked, ax=None, maxiter=1000, else: raise ValueError("Could not understand uncertainty type: {}".format( spec.uncertainty)) - + s = np.rec.fromarrays([spec.wavelength.value, spec.flux.value, - error], names=["wa","fl","er"]) + error], names=["wa", "fl", "er"]) if edges is not None: edges = list(edges) @@ -434,8 +466,7 @@ def estimate_continuum(s, knots, indices, masked, ax=None, maxiter=1000, edges = make_chunks_qso( s.wa, z, debug=debug, divmult=divmult, forest_divmult=forest_divmult) - - + if ax is not None: ax.plot(s.wa, s.fl, '-', color='0.4', drawstyle='steps-mid') ax.plot(s.wa, s.er, 'g') diff --git a/specutils/tests/test_spline.py b/specutils/tests/test_spline.py index 84579d75a..68f1b14e9 100644 --- a/specutils/tests/test_spline.py +++ b/specutils/tests/test_spline.py @@ -6,6 +6,7 @@ from scipy import interpolate + def make_data(with_errs=True): """ Arbitrary data """ np.random.seed(348957) @@ -17,44 +18,44 @@ def make_data(with_errs=True): y = y + np.random.normal(0., ey, y.shape) w = 1./y return x, y, w - + + def test_spline_fit(): x, y, w = make_data() - make_plot=False - + make_plot = False + # Construct three sets of splines and their scipy equivalents - knots = np.arange(1,10) - models = [SplineModel(), SplineModel(degree=5), SplineModel(knots=knots), SplineModel(smoothing=0)] + knots = np.arange(1, 10) + models = [SplineModel(), SplineModel(degree=5), SplineModel(knots=knots), + SplineModel(smoothing=0)] labels = ["Deg 3", "Deg 5", "Knots", "Interpolated"] - scipyfit = [interpolate.UnivariateSpline(x,y,w), - interpolate.UnivariateSpline(x,y,w,k=5), - interpolate.LSQUnivariateSpline(x,y,knots,w=w), - interpolate.InterpolatedUnivariateSpline(x,y,w)] - + scipyfit = [interpolate.UnivariateSpline(x, y, w), + interpolate.UnivariateSpline(x, y, w, k=5), + interpolate.LSQUnivariateSpline(x, y, knots, w=w), + interpolate.InterpolatedUnivariateSpline(x, y, w)] + fitter = SplineFitter() for model, label, scipymodel in zip(models, labels, scipyfit): fitter(model, x, y, w) my_y = model(x) sci_y = scipymodel(x) assert np.allclose(my_y, sci_y, atol=1e-6) - + if make_plot: import matplotlib.pyplot as plt fig, ax = plt.subplots() - ax.plot(x,y,'k.') + ax.plot(x, y, 'k.') ymin, ymax = np.min(y), np.max(y) - for i,(model, label) in enumerate(zip(models, labels)): + for i, (model, label) in enumerate(zip(models, labels)): l, = ax.plot(x, model(x), lw=1, label=label) - knots = model.get_knots() + knots = model.knots # Hack for now - if knots is None: knots = model._tck[0] + if knots is None: + knots = model._tck[0] + print(knots) dy = (ymax-ymin)/10. dy /= i+1. ax.vlines(knots, ymin, ymin + dy, color=l.get_color(), lw=1) ax.legend() plt.show() - -if __name__=="__main__": - test_spline_fit() -