From facd87e1dbdafee2daab185c0cec38708a21a6b5 Mon Sep 17 00:00:00 2001 From: Tom Donoghue Date: Tue, 9 Apr 2024 17:54:51 -0400 Subject: [PATCH 1/3] add plot methods to data objects --- specparam/objs/data.py | 14 +++++++++++++- 1 file changed, 13 insertions(+), 1 deletion(-) diff --git a/specparam/objs/data.py b/specparam/objs/data.py index 7823542f..4b40ab32 100644 --- a/specparam/objs/data.py +++ b/specparam/objs/data.py @@ -10,7 +10,7 @@ from specparam.core.errors import DataError, InconsistentDataError from specparam.data import SpectrumMetaData from specparam.plts.settings import PLT_COLORS -from specparam.plts.spectra import plot_spectra +from specparam.plts.spectra import plot_spectra, plot_spectrogram from specparam.plts.utils import check_plot_kwargs ################################################################################################### @@ -381,6 +381,12 @@ def add_data(self, freqs, spectrogram, freq_range=None): super().add_data(freqs, spectrogram, freq_range) + def plot(self, **plt_kwargs): + """Plot the spectrogram.""" + + plot_spectrogram(self.freqs, self.spectrogram, **plot_kwargs) + + class BaseData3D(BaseData2DT): """Base object for managing data for spectral parameterization - for 3D data.""" @@ -440,3 +446,9 @@ def add_data(self, freqs, spectrograms, freq_range=None): # Otherwise, pass through 2d array to underlying object method else: super().add_data(freqs, spectrograms, freq_range) + + + def plot(self, event_ind): + """Plot a selected spectrogram.""" + + plot_spectrogram(self.freqs, self.spectrograms[event_ind, :, :], **plot_kwargs) From 4383b1a6c9beb9abf37b84d0b9a36521891ccb11 Mon Sep 17 00:00:00 2001 From: Tom Donoghue Date: Wed, 10 Apr 2024 00:37:10 -0400 Subject: [PATCH 2/3] FitObjects -> ResultsObjects --- specparam/objs/base.py | 26 +++--- specparam/objs/event.py | 2 +- specparam/objs/{fit.py => results.py} | 22 ++--- specparam/tests/objs/test_base.py | 6 +- specparam/tests/objs/test_fit.py | 116 -------------------------- specparam/tests/objs/test_results.py | 116 ++++++++++++++++++++++++++ 6 files changed, 144 insertions(+), 144 deletions(-) rename specparam/objs/{fit.py => results.py} (98%) delete mode 100644 specparam/tests/objs/test_fit.py create mode 100644 specparam/tests/objs/test_results.py diff --git a/specparam/objs/base.py b/specparam/objs/base.py index 49d932bf..49bdb384 100644 --- a/specparam/objs/base.py +++ b/specparam/objs/base.py @@ -12,7 +12,7 @@ load_json, load_jsonlines, get_files) from specparam.core.modutils import copy_doc_func_to_method from specparam.plts.event import plot_event_model -from specparam.objs.fit import BaseFit, BaseFit2D, BaseFit2DT, BaseFit3D +from specparam.objs.results import BaseResults, BaseResults2D, BaseResults2DT, BaseResults3D from specparam.objs.data import BaseData, BaseData2D, BaseData2DT, BaseData3D ################################################################################################### @@ -117,15 +117,15 @@ def _add_from_dict(self, data): setattr(self, key, data[key]) -class BaseObject(CommonBase, BaseFit, BaseData): +class BaseObject(CommonBase, BaseResults, BaseData): """Define Base object for fitting models to 1D data.""" def __init__(self, aperiodic_mode=None, periodic_mode=None, debug_mode=False, verbose=True): CommonBase.__init__(self) BaseData.__init__(self) - BaseFit.__init__(self, aperiodic_mode=aperiodic_mode, periodic_mode=periodic_mode, - debug_mode=debug_mode, verbose=verbose) + BaseResults.__init__(self, aperiodic_mode=aperiodic_mode, periodic_mode=periodic_mode, + debug_mode=debug_mode, verbose=verbose) def add_data(self, freqs, power_spectrum, freq_range=None, clear_results=True): @@ -203,15 +203,15 @@ def _reset_data_results(self, clear_freqs=False, clear_spectrum=False, clear_res self._reset_results(clear_results) -class BaseObject2D(CommonBase, BaseFit2D, BaseData2D): +class BaseObject2D(CommonBase, BaseResults2D, BaseData2D): """Define Base object for fitting models to 2D data.""" def __init__(self, aperiodic_mode=None, periodic_mode=None, debug_mode=False, verbose=True): CommonBase.__init__(self) BaseData2D.__init__(self) - BaseFit2D.__init__(self, aperiodic_mode=aperiodic_mode, periodic_mode=periodic_mode, - debug_mode=debug_mode, verbose=verbose) + BaseResults2D.__init__(self, aperiodic_mode=aperiodic_mode, periodic_mode=periodic_mode, + debug_mode=debug_mode, verbose=verbose) def add_data(self, freqs, power_spectra, freq_range=None, clear_results=True): @@ -315,15 +315,15 @@ def _reset_data_results(self, clear_freqs=False, clear_spectrum=False, self._reset_results(clear_results) -class BaseObject2DT(BaseObject2D, BaseFit2DT, BaseData2DT): +class BaseObject2DT(BaseObject2D, BaseResults2DT, BaseData2DT): """Define Base object for fitting models to 2D data - tranpose version.""" def __init__(self, aperiodic_mode=None, periodic_mode=None, debug_mode=False, verbose=True): BaseObject2D.__init__(self) BaseData2DT.__init__(self) - BaseFit2D.__init__(self, aperiodic_mode=aperiodic_mode, periodic_mode=periodic_mode, - debug_mode=debug_mode, verbose=verbose) + BaseResults2D.__init__(self, aperiodic_mode=aperiodic_mode, periodic_mode=periodic_mode, + debug_mode=debug_mode, verbose=verbose) def load(self, file_name, file_path=None, peak_org=None): @@ -348,15 +348,15 @@ def load(self, file_name, file_path=None, peak_org=None): self.convert_results(peak_org) -class BaseObject3D(BaseObject2DT, BaseFit3D, BaseData3D): +class BaseObject3D(BaseObject2DT, BaseResults3D, BaseData3D): """Define Base object for fitting models to 3D data.""" def __init__(self, aperiodic_mode=None, periodic_mode=None, debug_mode=False, verbose=True): BaseObject2DT.__init__(self) BaseData3D.__init__(self) - BaseFit3D.__init__(self, aperiodic_mode=aperiodic_mode, periodic_mode=periodic_mode, - debug_mode=debug_mode, verbose=verbose) + BaseResults3D.__init__(self, aperiodic_mode=aperiodic_mode, periodic_mode=periodic_mode, + debug_mode=debug_mode, verbose=verbose) def add_data(self, freqs, spectrograms, freq_range=None, clear_results=True): diff --git a/specparam/objs/event.py b/specparam/objs/event.py index 884524a1..7710c5c1 100644 --- a/specparam/objs/event.py +++ b/specparam/objs/event.py @@ -9,7 +9,7 @@ from specparam.objs import SpectralModel from specparam.objs.base import BaseObject3D from specparam.objs.algorithm import SpectralFitAlgorithm -from specparam.objs.fit import _progress +from specparam.objs.results import _progress from specparam.plts.event import plot_event_model from specparam.data.conversions import event_group_to_dict, event_group_to_dataframe, dict_to_df from specparam.data.utils import get_group_params, get_results_by_row, flatten_results_dict diff --git a/specparam/objs/fit.py b/specparam/objs/results.py similarity index 98% rename from specparam/objs/fit.py rename to specparam/objs/results.py index c08d8062..4189acf0 100644 --- a/specparam/objs/fit.py +++ b/specparam/objs/results.py @@ -19,8 +19,8 @@ ################################################################################################### ################################################################################################### -class BaseFit(): - """Base object for managing fit procedures.""" +class BaseResults(): + """Base object for managing results.""" # pylint: disable=attribute-defined-outside-init, arguments-differ def __init__(self, aperiodic_mode, periodic_mode, debug_mode=False, @@ -334,12 +334,12 @@ def _calc_error(self, metric=None): raise ValueError(error_msg) -class BaseFit2D(BaseFit): - """Base object for managing fit procedures - 2D version.""" +class BaseResults2D(BaseResults): + """Base object for managing results - 2D version.""" def __init__(self, aperiodic_mode, periodic_mode, debug_mode=False, verbose=True): - BaseFit.__init__(self, aperiodic_mode, periodic_mode, debug_mode=False, verbose=True) + BaseResults.__init__(self, aperiodic_mode, periodic_mode, debug_mode=False, verbose=True) self._reset_group_results() @@ -610,12 +610,12 @@ def get_group(self, inds): return group -class BaseFit2DT(BaseFit2D): - """Base object for managing fit procedures - 2D transpose version.""" +class BaseResults2DT(BaseResults2D): + """Base object for managing results - 2D transpose version.""" def __init__(self, aperiodic_mode, periodic_mode, debug_mode=False, verbose=True): - BaseFit2D.__init__(self, aperiodic_mode, periodic_mode, debug_mode=False, verbose=True) + BaseResults2D.__init__(self, aperiodic_mode, periodic_mode, debug_mode=False, verbose=True) self._reset_time_results() @@ -751,12 +751,12 @@ def convert_results(self, peak_org): self.time_results = group_to_dict(self.group_results, peak_org) -class BaseFit3D(BaseFit2DT): - """Base object for managing fit procedures - 3D version.""" +class BaseResults3D(BaseResults2DT): + """Base object for managing results - 3D version.""" def __init__(self, aperiodic_mode, periodic_mode, debug_mode=False, verbose=True): - BaseFit2DT.__init__(self, aperiodic_mode, periodic_mode, debug_mode=False, verbose=True) + BaseResults2DT.__init__(self, aperiodic_mode, periodic_mode, debug_mode=False, verbose=True) self._reset_event_results() diff --git a/specparam/tests/objs/test_base.py b/specparam/tests/objs/test_base.py index a6ae7ccb..108064f2 100644 --- a/specparam/tests/objs/test_base.py +++ b/specparam/tests/objs/test_base.py @@ -48,7 +48,7 @@ def test_base2d(): tobj2d = BaseObject2D() assert isinstance(tobj2d, CommonBase) assert isinstance(tobj2d, BaseObject2D) - assert isinstance(tobj2d, BaseFit2D) + assert isinstance(tobj2d, BaseResults2D) assert isinstance(tobj2d, BaseObject2D) ## 2DT Base Object @@ -58,7 +58,7 @@ def test_base2dt(): tobj2dt = BaseObject2DT() assert isinstance(tobj2dt, CommonBase) assert isinstance(tobj2dt, BaseObject2DT) - assert isinstance(tobj2dt, BaseFit2DT) + assert isinstance(tobj2dt, BaseResults2DT) assert isinstance(tobj2dt, BaseObject2DT) ## 3D Base Object @@ -68,6 +68,6 @@ def test_base3d(): tobj3d = BaseObject3D() assert isinstance(tobj3d, CommonBase) assert isinstance(tobj3d, BaseObject2DT) - assert isinstance(tobj3d, BaseFit2DT) + assert isinstance(tobj3d, BaseResults2DT) assert isinstance(tobj3d, BaseObject2DT) assert isinstance(tobj3d, BaseObject3D) diff --git a/specparam/tests/objs/test_fit.py b/specparam/tests/objs/test_fit.py deleted file mode 100644 index 2890c090..00000000 --- a/specparam/tests/objs/test_fit.py +++ /dev/null @@ -1,116 +0,0 @@ -"""Tests for specparam.objs.fit, including the data object and it's methods.""" - -from specparam.core.items import OBJ_DESC -from specparam.data import ModelSettings - -from specparam.objs.fit import * - -################################################################################################### -################################################################################################### - -## 1D fit object - -def test_base_fit(): - - tfit1 = BaseFit(None, None) - assert isinstance(tfit1, BaseFit) - - tfit2 = BaseFit(aperiodic_mode='fixed', periodic_mode='gaussian') - assert isinstance(tfit2, BaseFit) - -def test_base_fit_settings(): - - tfit = BaseFit(None, None) - - settings = ModelSettings([1, 4], 6, 0, 2, 'fixed') - tfit.add_settings(settings) - for setting in OBJ_DESC['settings']: - assert getattr(tfit, setting) == getattr(settings, setting) - - settings_out = tfit.get_settings() - assert isinstance(settings, ModelSettings) - assert settings_out == settings - -def test_base_fit_results(tresults): - - tfit = BaseFit(None, None) - - tfit.add_results(tresults) - assert tfit.has_model - for result in OBJ_DESC['results']: - assert np.array_equal(getattr(tfit, result), getattr(tresults, result.strip('_'))) - - results_out = tfit.get_results() - assert isinstance(tresults, FitResults) - assert results_out == tresults - -## 2D fit object - -def test_base_fit2d(): - - tfit2d1 = BaseFit2D(None, None) - assert isinstance(tfit2d1, BaseFit) - assert isinstance(tfit2d1, BaseFit2D) - - tfit2d2 = BaseFit2D(aperiodic_mode='fixed', periodic_mode='gaussian') - assert isinstance(tfit2d2, BaseFit2D) - -def test_base_fit2d_results(tresults): - - tfit2d = BaseFit2D(None, None) - - results = [tresults, tresults] - tfit2d.add_results(results) - assert tfit2d.has_model - results_out = tfit2d.get_results() - assert isinstance(results_out, list) - assert results_out == results - -## 2DT fit object - -def test_base_fit2dt(): - - tfit2dt1 = BaseFit2DT(None, None) - assert isinstance(tfit2dt1, BaseFit) - assert isinstance(tfit2dt1, BaseFit2D) - assert isinstance(tfit2dt1, BaseFit2DT) - - tfit2dt2 = BaseFit2DT(aperiodic_mode='fixed', periodic_mode='gaussian') - assert isinstance(tfit2dt2, BaseFit2DT) - -def test_base_fit2d_results(tresults): - - tfit2dt = BaseFit2DT(None, None) - - results = [tresults, tresults] - tfit2dt.add_results(results) - tfit2dt.convert_results(None) - - assert tfit2dt.has_model - results_out = tfit2dt.get_results() - assert isinstance(results_out, dict) - -## 3D fit object - -def test_base_fit3d(): - - tfit3d1 = BaseFit3D(None, None) - assert isinstance(tfit3d1, BaseFit) - assert isinstance(tfit3d1, BaseFit2D) - assert isinstance(tfit3d1, BaseFit2DT) - assert isinstance(tfit3d1, BaseFit3D) - - tfit3d2 = BaseFit3D(aperiodic_mode='fixed', periodic_mode='gaussian') - assert isinstance(tfit3d2, BaseFit3D) - -def test_base_fit3d_results(tresults): - - tfit3d = BaseFit3D(None, None) - - eresults = [[tresults, tresults], [tresults, tresults]] - tfit3d.add_results(eresults) - tfit3d.convert_results(None) - - assert tfit3d.has_model - results_out = tfit3d.get_results() - assert isinstance(results_out, dict) diff --git a/specparam/tests/objs/test_results.py b/specparam/tests/objs/test_results.py new file mode 100644 index 00000000..37dff9a6 --- /dev/null +++ b/specparam/tests/objs/test_results.py @@ -0,0 +1,116 @@ +"""Tests for specparam.objs.results, including the data object and it's methods.""" + +from specparam.core.items import OBJ_DESC +from specparam.data import ModelSettings + +from specparam.objs.results import * + +################################################################################################### +################################################################################################### + +## 1D results object + +def test_base_results(): + + tres1 = BaseResults(None, None) + assert isinstance(tres1, BaseResults) + + tres2 = BaseResults(aperiodic_mode='fixed', periodic_mode='gaussian') + assert isinstance(tres2, BaseResults) + +def test_base_results_settings(): + + tres = BaseResults(None, None) + + settings = ModelSettings([1, 4], 6, 0, 2, 'fixed') + tres.add_settings(settings) + for setting in OBJ_DESC['settings']: + assert getattr(tres, setting) == getattr(settings, setting) + + settings_out = tres.get_settings() + assert isinstance(settings, ModelSettings) + assert settings_out == settings + +def test_base_results_results(tresults): + + tres = BaseResults(None, None) + + tres.add_results(tresults) + assert tres.has_model + for result in OBJ_DESC['results']: + assert np.array_equal(getattr(tres, result), getattr(tresults, result.strip('_'))) + + results_out = tres.get_results() + assert isinstance(tresults, FitResults) + assert results_out == tresults + +## 2D results object + +def test_base_results2d(): + + tres2d1 = BaseResults2D(None, None) + assert isinstance(tres2d1, BaseResults) + assert isinstance(tres2d1, BaseResults2D) + + tres2d2 = BaseResults2D(aperiodic_mode='fixed', periodic_mode='gaussian') + assert isinstance(tres2d2, BaseResults2D) + +def test_base_results2d_results(tresults): + + tres2d = BaseResults2D(None, None) + + results = [tresults, tresults] + tres2d.add_results(results) + assert tres2d.has_model + results_out = tres2d.get_results() + assert isinstance(results_out, list) + assert results_out == results + +## 2DT results object + +def test_base_results2dt(): + + tres2dt1 = BaseResults2DT(None, None) + assert isinstance(tres2dt1, BaseResults) + assert isinstance(tres2dt1, BaseResults2D) + assert isinstance(tres2dt1, BaseResults2DT) + + tres2dt2 = BaseResults2DT(aperiodic_mode='fixed', periodic_mode='gaussian') + assert isinstance(tres2dt2, BaseResults2DT) + +def test_base_results2d_results(tresults): + + tres2dt = BaseResults2DT(None, None) + + results = [tresults, tresults] + tres2dt.add_results(results) + tres2dt.convert_results(None) + + assert tres2dt.has_model + results_out = tres2dt.get_results() + assert isinstance(results_out, dict) + +## 3D results object + +def test_base_results3d(): + + tres3d1 = BaseResults3D(None, None) + assert isinstance(tres3d1, BaseResults) + assert isinstance(tres3d1, BaseResults2D) + assert isinstance(tres3d1, BaseResults2DT) + assert isinstance(tres3d1, BaseResults3D) + + tres3d2 = BaseResults3D(aperiodic_mode='fixed', periodic_mode='gaussian') + assert isinstance(tres3d2, BaseResults3D) + +def test_base_results3d_results(tresults): + + tres3d = BaseResults3D(None, None) + + eresults = [[tresults, tresults], [tresults, tresults]] + tres3d.add_results(eresults) + tres3d.convert_results(None) + + assert tres3d.has_model + results_out = tres3d.get_results() + assert isinstance(results_out, dict) From 69517112595e78ed9e5e742a29c15e5881607135 Mon Sep 17 00:00:00 2001 From: Tom Donoghue Date: Wed, 10 Apr 2024 00:57:18 -0400 Subject: [PATCH 3/3] fix up verboseness & warnings --- specparam/objs/event.py | 9 ++++----- specparam/objs/group.py | 4 ++-- specparam/objs/results.py | 9 ++++++--- specparam/objs/time.py | 15 +++++++++++++-- 4 files changed, 25 insertions(+), 12 deletions(-) diff --git a/specparam/objs/event.py b/specparam/objs/event.py index 7710c5c1..19c85168 100644 --- a/specparam/objs/event.py +++ b/specparam/objs/event.py @@ -68,8 +68,8 @@ def __init__(self, *args, **kwargs): BaseObject3D.__init__(self, aperiodic_mode=kwargs.pop('aperiodic_mode', 'fixed'), periodic_mode=kwargs.pop('periodic_mode', 'gaussian'), - debug_mode=kwargs.pop('debug_mode', 'False'), - verbose=kwargs.pop('verbose', 'True')) + debug_mode=kwargs.pop('debug_mode', False), + verbose=kwargs.pop('verbose', True)) SpectralFitAlgorithm.__init__(self, *args, **kwargs) @@ -222,8 +222,7 @@ def to_df(self, peak_org=None): def _check_width_limits(self): """Check and warn about bandwidth limits / frequency resolution interaction.""" - # Only check & warn on first spectrogram + # Only check & warn on first spectrum # This is to avoid spamming standard output for every spectrogram in the set - if np.all(self.spectrograms[0] == self.spectrogram): - #if self.power_spectra[0, 0] == self.power_spectrum[0]: + if np.all(self.power_spectrum == self.spectrograms[0, :, 0]): super()._check_width_limits() diff --git a/specparam/objs/group.py b/specparam/objs/group.py index 834024ad..4b633dd1 100644 --- a/specparam/objs/group.py +++ b/specparam/objs/group.py @@ -74,8 +74,8 @@ def __init__(self, *args, **kwargs): BaseObject2D.__init__(self, aperiodic_mode=kwargs.pop('aperiodic_mode', 'fixed'), periodic_mode=kwargs.pop('periodic_mode', 'gaussian'), - debug_mode=kwargs.pop('debug_mode', 'False'), - verbose=kwargs.pop('verbose', 'True')) + debug_mode=kwargs.pop('debug_mode', False), + verbose=kwargs.pop('verbose', True)) SpectralFitAlgorithm.__init__(self, *args, **kwargs) diff --git a/specparam/objs/results.py b/specparam/objs/results.py index 4189acf0..f94507c4 100644 --- a/specparam/objs/results.py +++ b/specparam/objs/results.py @@ -339,7 +339,8 @@ class BaseResults2D(BaseResults): def __init__(self, aperiodic_mode, periodic_mode, debug_mode=False, verbose=True): - BaseResults.__init__(self, aperiodic_mode, periodic_mode, debug_mode=False, verbose=True) + BaseResults.__init__(self, aperiodic_mode, periodic_mode, + debug_mode=debug_mode, verbose=verbose) self._reset_group_results() @@ -615,7 +616,8 @@ class BaseResults2DT(BaseResults2D): def __init__(self, aperiodic_mode, periodic_mode, debug_mode=False, verbose=True): - BaseResults2D.__init__(self, aperiodic_mode, periodic_mode, debug_mode=False, verbose=True) + BaseResults2D.__init__(self, aperiodic_mode, periodic_mode, + debug_mode=debug_mode, verbose=verbose) self._reset_time_results() @@ -756,7 +758,8 @@ class BaseResults3D(BaseResults2DT): def __init__(self, aperiodic_mode, periodic_mode, debug_mode=False, verbose=True): - BaseResults2DT.__init__(self, aperiodic_mode, periodic_mode, debug_mode=False, verbose=True) + BaseResults2DT.__init__(self, aperiodic_mode, periodic_mode, + debug_mode=debug_mode, verbose=verbose) self._reset_event_results() diff --git a/specparam/objs/time.py b/specparam/objs/time.py index 125ac578..4ad99fae 100644 --- a/specparam/objs/time.py +++ b/specparam/objs/time.py @@ -1,5 +1,7 @@ """Time model object and associated code for fitting the model to spectrograms.""" +import numpy as np + from specparam.objs import SpectralModel from specparam.objs.base import BaseObject2DT from specparam.objs.algorithm import SpectralFitAlgorithm @@ -60,8 +62,8 @@ def __init__(self, *args, **kwargs): BaseObject2DT.__init__(self, aperiodic_mode=kwargs.pop('aperiodic_mode', 'fixed'), periodic_mode=kwargs.pop('periodic_mode', 'gaussian'), - debug_mode=kwargs.pop('debug_mode', 'False'), - verbose=kwargs.pop('verbose', 'True')) + debug_mode=kwargs.pop('debug_mode', False), + verbose=kwargs.pop('verbose', True)) SpectralFitAlgorithm.__init__(self, *args, **kwargs) @@ -156,3 +158,12 @@ def to_df(self, peak_org=None): df = dict_to_df(self.get_results()) return df + + + def _check_width_limits(self): + """Check and warn about bandwidth limits / frequency resolution interaction.""" + + # Only check & warn on first power spectrum + # This is to avoid spamming standard output for every spectrum in the group + if np.all(self.power_spectrum == self.spectrogram[:, 0]): + super()._check_width_limits()