Skip to content

Commit

Permalink
Merge branch 'basemodel' into modes
Browse files Browse the repository at this point in the history
  • Loading branch information
TomDonoghue committed Apr 10, 2024
2 parents ac1d71b + 6951711 commit 9609cf1
Show file tree
Hide file tree
Showing 9 changed files with 179 additions and 154 deletions.
26 changes: 13 additions & 13 deletions specparam/objs/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
load_json, load_jsonlines, get_files)
from specparam.core.modutils import copy_doc_func_to_method
from specparam.plts.event import plot_event_model
from specparam.objs.fit import BaseFit, BaseFit2D, BaseFit2DT, BaseFit3D
from specparam.objs.results import BaseResults, BaseResults2D, BaseResults2DT, BaseResults3D
from specparam.objs.data import BaseData, BaseData2D, BaseData2DT, BaseData3D

###################################################################################################
Expand Down Expand Up @@ -117,15 +117,15 @@ def _add_from_dict(self, data):
setattr(self, key, data[key])


class BaseObject(CommonBase, BaseFit, BaseData):
class BaseObject(CommonBase, BaseResults, BaseData):
"""Define Base object for fitting models to 1D data."""

def __init__(self, aperiodic_mode=None, periodic_mode=None, debug_mode=False, verbose=True):

CommonBase.__init__(self)
BaseData.__init__(self)
BaseFit.__init__(self, aperiodic_mode=aperiodic_mode, periodic_mode=periodic_mode,
debug_mode=debug_mode, verbose=verbose)
BaseResults.__init__(self, aperiodic_mode=aperiodic_mode, periodic_mode=periodic_mode,
debug_mode=debug_mode, verbose=verbose)


def add_data(self, freqs, power_spectrum, freq_range=None, clear_results=True):
Expand Down Expand Up @@ -203,15 +203,15 @@ def _reset_data_results(self, clear_freqs=False, clear_spectrum=False, clear_res
self._reset_results(clear_results)


class BaseObject2D(CommonBase, BaseFit2D, BaseData2D):
class BaseObject2D(CommonBase, BaseResults2D, BaseData2D):
"""Define Base object for fitting models to 2D data."""

def __init__(self, aperiodic_mode=None, periodic_mode=None, debug_mode=False, verbose=True):

CommonBase.__init__(self)
BaseData2D.__init__(self)
BaseFit2D.__init__(self, aperiodic_mode=aperiodic_mode, periodic_mode=periodic_mode,
debug_mode=debug_mode, verbose=verbose)
BaseResults2D.__init__(self, aperiodic_mode=aperiodic_mode, periodic_mode=periodic_mode,
debug_mode=debug_mode, verbose=verbose)


def add_data(self, freqs, power_spectra, freq_range=None, clear_results=True):
Expand Down Expand Up @@ -315,15 +315,15 @@ def _reset_data_results(self, clear_freqs=False, clear_spectrum=False,
self._reset_results(clear_results)


class BaseObject2DT(BaseObject2D, BaseFit2DT, BaseData2DT):
class BaseObject2DT(BaseObject2D, BaseResults2DT, BaseData2DT):
"""Define Base object for fitting models to 2D data - tranpose version."""

def __init__(self, aperiodic_mode=None, periodic_mode=None, debug_mode=False, verbose=True):

BaseObject2D.__init__(self)
BaseData2DT.__init__(self)
BaseFit2D.__init__(self, aperiodic_mode=aperiodic_mode, periodic_mode=periodic_mode,
debug_mode=debug_mode, verbose=verbose)
BaseResults2D.__init__(self, aperiodic_mode=aperiodic_mode, periodic_mode=periodic_mode,
debug_mode=debug_mode, verbose=verbose)


def load(self, file_name, file_path=None, peak_org=None):
Expand All @@ -348,15 +348,15 @@ def load(self, file_name, file_path=None, peak_org=None):
self.convert_results(peak_org)


class BaseObject3D(BaseObject2DT, BaseFit3D, BaseData3D):
class BaseObject3D(BaseObject2DT, BaseResults3D, BaseData3D):
"""Define Base object for fitting models to 3D data."""

def __init__(self, aperiodic_mode=None, periodic_mode=None, debug_mode=False, verbose=True):

BaseObject2DT.__init__(self)
BaseData3D.__init__(self)
BaseFit3D.__init__(self, aperiodic_mode=aperiodic_mode, periodic_mode=periodic_mode,
debug_mode=debug_mode, verbose=verbose)
BaseResults3D.__init__(self, aperiodic_mode=aperiodic_mode, periodic_mode=periodic_mode,
debug_mode=debug_mode, verbose=verbose)


def add_data(self, freqs, spectrograms, freq_range=None, clear_results=True):
Expand Down
14 changes: 13 additions & 1 deletion specparam/objs/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from specparam.core.errors import DataError, InconsistentDataError
from specparam.data import SpectrumMetaData
from specparam.plts.settings import PLT_COLORS
from specparam.plts.spectra import plot_spectra
from specparam.plts.spectra import plot_spectra, plot_spectrogram
from specparam.plts.utils import check_plot_kwargs

###################################################################################################
Expand Down Expand Up @@ -381,6 +381,12 @@ def add_data(self, freqs, spectrogram, freq_range=None):
super().add_data(freqs, spectrogram, freq_range)


def plot(self, **plt_kwargs):
"""Plot the spectrogram."""

plot_spectrogram(self.freqs, self.spectrogram, **plot_kwargs)


class BaseData3D(BaseData2DT):
"""Base object for managing data for spectral parameterization - for 3D data."""

Expand Down Expand Up @@ -440,3 +446,9 @@ def add_data(self, freqs, spectrograms, freq_range=None):
# Otherwise, pass through 2d array to underlying object method
else:
super().add_data(freqs, spectrograms, freq_range)


def plot(self, event_ind):
"""Plot a selected spectrogram."""

plot_spectrogram(self.freqs, self.spectrograms[event_ind, :, :], **plot_kwargs)
11 changes: 5 additions & 6 deletions specparam/objs/event.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from specparam.objs import SpectralModel
from specparam.objs.base import BaseObject3D
from specparam.objs.algorithm import SpectralFitAlgorithm
from specparam.objs.fit import _progress
from specparam.objs.results import _progress
from specparam.plts.event import plot_event_model
from specparam.data.conversions import event_group_to_dict, event_group_to_dataframe, dict_to_df
from specparam.data.utils import get_group_params, get_results_by_row, flatten_results_dict
Expand Down Expand Up @@ -68,8 +68,8 @@ def __init__(self, *args, **kwargs):
BaseObject3D.__init__(self,
aperiodic_mode=kwargs.pop('aperiodic_mode', 'fixed'),
periodic_mode=kwargs.pop('periodic_mode', 'gaussian'),
debug_mode=kwargs.pop('debug_mode', 'False'),
verbose=kwargs.pop('verbose', 'True'))
debug_mode=kwargs.pop('debug_mode', False),
verbose=kwargs.pop('verbose', True))

SpectralFitAlgorithm.__init__(self, *args, **kwargs)

Expand Down Expand Up @@ -222,8 +222,7 @@ def to_df(self, peak_org=None):
def _check_width_limits(self):
"""Check and warn about bandwidth limits / frequency resolution interaction."""

# Only check & warn on first spectrogram
# Only check & warn on first spectrum
# This is to avoid spamming standard output for every spectrogram in the set
if np.all(self.spectrograms[0] == self.spectrogram):
#if self.power_spectra[0, 0] == self.power_spectrum[0]:
if np.all(self.power_spectrum == self.spectrograms[0, :, 0]):
super()._check_width_limits()
4 changes: 2 additions & 2 deletions specparam/objs/group.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,8 +74,8 @@ def __init__(self, *args, **kwargs):
BaseObject2D.__init__(self,
aperiodic_mode=kwargs.pop('aperiodic_mode', 'fixed'),
periodic_mode=kwargs.pop('periodic_mode', 'gaussian'),
debug_mode=kwargs.pop('debug_mode', 'False'),
verbose=kwargs.pop('verbose', 'True'))
debug_mode=kwargs.pop('debug_mode', False),
verbose=kwargs.pop('verbose', True))

SpectralFitAlgorithm.__init__(self, *args, **kwargs)

Expand Down
25 changes: 14 additions & 11 deletions specparam/objs/fit.py → specparam/objs/results.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,8 @@
###################################################################################################
###################################################################################################

class BaseFit():
"""Base object for managing fit procedures."""
class BaseResults():
"""Base object for managing results."""
# pylint: disable=attribute-defined-outside-init, arguments-differ

def __init__(self, aperiodic_mode, periodic_mode, debug_mode=False,
Expand Down Expand Up @@ -341,12 +341,13 @@ def _calc_error(self, metric=None):
raise ValueError(error_msg)


class BaseFit2D(BaseFit):
"""Base object for managing fit procedures - 2D version."""
class BaseResults2D(BaseResults):
"""Base object for managing results - 2D version."""

def __init__(self, aperiodic_mode, periodic_mode, debug_mode=False, verbose=True):

BaseFit.__init__(self, aperiodic_mode, periodic_mode, debug_mode=False, verbose=True)
BaseResults.__init__(self, aperiodic_mode, periodic_mode,
debug_mode=debug_mode, verbose=verbose)

self._reset_group_results()

Expand Down Expand Up @@ -617,12 +618,13 @@ def get_group(self, inds):
return group


class BaseFit2DT(BaseFit2D):
"""Base object for managing fit procedures - 2D transpose version."""
class BaseResults2DT(BaseResults2D):
"""Base object for managing results - 2D transpose version."""

def __init__(self, aperiodic_mode, periodic_mode, debug_mode=False, verbose=True):

BaseFit2D.__init__(self, aperiodic_mode, periodic_mode, debug_mode=False, verbose=True)
BaseResults2D.__init__(self, aperiodic_mode, periodic_mode,
debug_mode=debug_mode, verbose=verbose)

self._reset_time_results()

Expand Down Expand Up @@ -758,12 +760,13 @@ def convert_results(self, peak_org):
self.time_results = group_to_dict(self.group_results, peak_org)


class BaseFit3D(BaseFit2DT):
"""Base object for managing fit procedures - 3D version."""
class BaseResults3D(BaseResults2DT):
"""Base object for managing results - 3D version."""

def __init__(self, aperiodic_mode, periodic_mode, debug_mode=False, verbose=True):

BaseFit2DT.__init__(self, aperiodic_mode, periodic_mode, debug_mode=False, verbose=True)
BaseResults2DT.__init__(self, aperiodic_mode, periodic_mode,
debug_mode=debug_mode, verbose=verbose)

self._reset_event_results()

Expand Down
15 changes: 13 additions & 2 deletions specparam/objs/time.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
"""Time model object and associated code for fitting the model to spectrograms."""

import numpy as np

from specparam.objs import SpectralModel
from specparam.objs.base import BaseObject2DT
from specparam.objs.algorithm import SpectralFitAlgorithm
Expand Down Expand Up @@ -60,8 +62,8 @@ def __init__(self, *args, **kwargs):
BaseObject2DT.__init__(self,
aperiodic_mode=kwargs.pop('aperiodic_mode', 'fixed'),
periodic_mode=kwargs.pop('periodic_mode', 'gaussian'),
debug_mode=kwargs.pop('debug_mode', 'False'),
verbose=kwargs.pop('verbose', 'True'))
debug_mode=kwargs.pop('debug_mode', False),
verbose=kwargs.pop('verbose', True))

SpectralFitAlgorithm.__init__(self, *args, **kwargs)

Expand Down Expand Up @@ -156,3 +158,12 @@ def to_df(self, peak_org=None):
df = dict_to_df(self.get_results())

return df


def _check_width_limits(self):
"""Check and warn about bandwidth limits / frequency resolution interaction."""

# Only check & warn on first power spectrum
# This is to avoid spamming standard output for every spectrum in the group
if np.all(self.power_spectrum == self.spectrogram[:, 0]):
super()._check_width_limits()
6 changes: 3 additions & 3 deletions specparam/tests/objs/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ def test_base2d():
tobj2d = BaseObject2D()
assert isinstance(tobj2d, CommonBase)
assert isinstance(tobj2d, BaseObject2D)
assert isinstance(tobj2d, BaseFit2D)
assert isinstance(tobj2d, BaseResults2D)
assert isinstance(tobj2d, BaseObject2D)

## 2DT Base Object
Expand All @@ -58,7 +58,7 @@ def test_base2dt():
tobj2dt = BaseObject2DT()
assert isinstance(tobj2dt, CommonBase)
assert isinstance(tobj2dt, BaseObject2DT)
assert isinstance(tobj2dt, BaseFit2DT)
assert isinstance(tobj2dt, BaseResults2DT)
assert isinstance(tobj2dt, BaseObject2DT)

## 3D Base Object
Expand All @@ -68,6 +68,6 @@ def test_base3d():
tobj3d = BaseObject3D()
assert isinstance(tobj3d, CommonBase)
assert isinstance(tobj3d, BaseObject2DT)
assert isinstance(tobj3d, BaseFit2DT)
assert isinstance(tobj3d, BaseResults2DT)
assert isinstance(tobj3d, BaseObject2DT)
assert isinstance(tobj3d, BaseObject3D)
116 changes: 0 additions & 116 deletions specparam/tests/objs/test_fit.py

This file was deleted.

Loading

0 comments on commit 9609cf1

Please sign in to comment.