Skip to content

Commit

Permalink
rework time obj to use new obj org - move methods
Browse files Browse the repository at this point in the history
  • Loading branch information
TomDonoghue committed Apr 8, 2024
1 parent 65ec2b8 commit fa6298a
Showing 1 changed file with 11 additions and 214 deletions.
225 changes: 11 additions & 214 deletions specparam/objs/time.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
"""Time model object and associated code for fitting the model to spectrograms."""

from functools import wraps

import numpy as np

from specparam.objs import SpectralModel, SpectralGroupModel
Expand All @@ -14,29 +12,15 @@
replace_docstring_sections)
from specparam.core.strings import gen_time_results_str

from specparam.objs.base import BaseObject2DT
from specparam.objs.algorithm import SpectralFitAlgorithm

###################################################################################################
###################################################################################################

def transpose_arg1(func):
"""Decorator function to transpose the 1th argument input to a function."""

@wraps(func)
def decorated(*args, **kwargs):

if len(args) >= 2:
args = list(args)
args[2] = args[2].T if isinstance(args[2], np.ndarray) else args[2]
if 'spectrogram' in kwargs:
kwargs['spectrogram'] = kwargs['spectrogram'].T

return func(*args, **kwargs)

return decorated


@replace_docstring_sections([docs_get_section(SpectralModel.__doc__, 'Parameters'),
docs_get_section(SpectralModel.__doc__, 'Notes')])
class SpectralTimeModel(SpectralGroupModel):
class SpectralTimeModel(SpectralFitAlgorithm, BaseObject2DT):
"""Model a spectrogram as a combination of aperiodic and periodic components.
WARNING: frequency and power values inputs must be in linear space.
Expand Down Expand Up @@ -78,67 +62,15 @@ class SpectralTimeModel(SpectralGroupModel):
def __init__(self, *args, **kwargs):
"""Initialize object with desired settings."""

SpectralGroupModel.__init__(self, *args, **kwargs)

self._reset_time_results()


def __getitem__(self, ind):
"""Allow for indexing into the object to select fit results for a specific time window."""

return get_results_by_ind(self.time_results, ind)


@property
def n_peaks_(self):
"""How many peaks were fit for each model."""

return [res.peak_params.shape[0] for res in self.group_results] \
if self.has_model else None


@property
def n_time_windows(self):
"""How many time windows are included in the model object."""

return self.spectrogram.shape[1] if self.has_data else 0


def _reset_time_results(self):
"""Set, or reset, time results to be empty."""

self.time_results = {}


@property
def spectrogram(self):
"""Data attribute view on the power spectra, transposed to spectrogram orientation."""

return self.power_spectra.T


@transpose_arg1
def add_data(self, freqs, spectrogram, freq_range=None):
"""Add data (frequencies and spectrogram values) to the current object.
Parameters
----------
freqs : 1d array
Frequency values for the spectrogram, in linear space.
spectrogram : 2d array, shape=[n_freqs, n_time_windows]
Matrix of power values, in linear space.
freq_range : list of [float, float], optional
Frequency range to restrict spectrogram to. If not provided, keeps the entire range.
BaseObject2DT.__init__(self,
aperiodic_mode=kwargs.pop('aperiodic_mode', 'fixed'),
periodic_mode=kwargs.pop('periodic_mode', 'gaussian'),
debug_mode=kwargs.pop('debug_mode', 'False'),
verbose=kwargs.pop('verbose', 'True'))

Notes
-----
If called on an object with existing data and/or results
these will be cleared by this method call.
"""
SpectralFitAlgorithm.__init__(self, *args, **kwargs)

if np.any(self.freqs):
self._reset_time_results()
super().add_data(freqs, spectrogram, freq_range)
self._reset_time_results()


def report(self, freqs=None, spectrogram=None, freq_range=None,
Expand Down Expand Up @@ -173,105 +105,6 @@ def report(self, freqs=None, spectrogram=None, freq_range=None,
self.print_results(report_type)


def fit(self, freqs=None, spectrogram=None, freq_range=None, peak_org=None,
n_jobs=1, progress=None):
"""Fit a spectrogram.
Parameters
----------
freqs : 1d array, optional
Frequency values for the spectrogram, in linear space.
spectrogram : 2d array, shape: [n_freqs, n_time_windows], optional
Spectrogram of power spectrum values, in linear space.
freq_range : list of [float, float], optional
Frequency range to fit the model to. If not provided, fits the entire given range.
peak_org : int or Bands
How to organize peaks.
If int, extracts the first n peaks.
If Bands, extracts peaks based on band definitions.
n_jobs : int, optional, default: 1
Number of jobs to run in parallel.
1 is no parallelization. -1 uses all available cores.
progress : {None, 'tqdm', 'tqdm.notebook'}, optional
Which kind of progress bar to use. If None, no progress bar is used.
Notes
-----
Data is optional, if data has already been added to the object.
"""

super().fit(freqs, spectrogram, freq_range, n_jobs, progress)
if peak_org is not False:
self.convert_results(peak_org)


def drop(self, inds):
"""Drop one or more model fit results from the object.
Parameters
----------
inds : int or array_like of int or array_like of bool
Indices to drop model fit results for.
Notes
-----
This method sets the model fits as null, and preserves the shape of the model fits.
"""

super().drop(inds)
for key in self.time_results.keys():
self.time_results[key][inds] = np.nan


def get_results(self):
"""Return the results run across a spectrogram."""

return self.time_results


def get_group(self, inds, output_type='time'):
"""Get a new model object with the specified sub-selection of model fits.
Parameters
----------
inds : array_like of int or array_like of bool
Indices to extract from the object.
output_type : {'time', 'group'}, optional
Type of model object to extract:
'time' : SpectralTimeObject
'group' : SpectralGroupObject
Returns
-------
output : SpectralTimeModel or SpectralGroupModel
The requested selection of results data loaded into a new model object.
"""

if output_type == 'time':

# Initialize a new model object, with same settings as current object
output = SpectralTimeModel(*self.get_settings(), verbose=self.verbose)
output.add_meta_data(self.get_meta_data())

if inds is not None:

# Check and convert indices encoding to list of int
inds = check_inds(inds)

# Add data for specified power spectra, if available
if self.has_data:
output.power_spectra = self.power_spectra[inds, :]

# Add results for specified power spectra
output.group_results = [self.group_results[ind] for ind in inds]
output.time_results = get_results_by_ind(self.time_results, inds)

if output_type == 'group':
output = super().get_group(inds)

return output


def print_results(self, print_type='time', concise=False):
"""Print out SpectralTimeModel results.
Expand Down Expand Up @@ -305,28 +138,6 @@ def save_report(self, file_name, file_path=None, add_settings=True):
save_time_report(self, file_name, file_path, add_settings)


def load(self, file_name, file_path=None, peak_org=None):
"""Load time data from file.
Parameters
----------
file_name : str
File to load data from.
file_path : str, optional
Path to directory to load from. If None, loads from current directory.
peak_org : int or Bands
How to organize peaks.
If int, extracts the first n peaks.
If Bands, extracts peaks based on band definitions.
"""

# Clear results so as not to have possible prior results interfere
self._reset_time_results()
super().load(file_name, file_path=file_path)
if peak_org is not False and self.group_results:
self.convert_results(peak_org)


def to_df(self, peak_org=None):
"""Convert and extract the model results as a pandas object.
Expand All @@ -350,17 +161,3 @@ def to_df(self, peak_org=None):
df = dict_to_df(self.get_results())

return df


def convert_results(self, peak_org):
"""Convert the model results to be organized across time windows.
Parameters
----------
peak_org : int or Bands
How to organize peaks.
If int, extracts the first n peaks.
If Bands, extracts peaks based on band definitions.
"""

self.time_results = group_to_dict(self.group_results, peak_org)

0 comments on commit fa6298a

Please sign in to comment.