diff --git a/src/echosms/__init__.py b/src/echosms/__init__.py index 9e0e83d..490d70f 100644 --- a/src/echosms/__init__.py +++ b/src/echosms/__init__.py @@ -1,5 +1,6 @@ """Setup the public API for echoSMs.""" -from .utils import wavenumber, eta, h1, prolate_swf, spherical_jnpp +from .utils import wavenumber, eta, h1, prolate_swf, spherical_jnpp, split_dict +from .utils import as_dataframe, as_dataarray from .scattermodelbase import ScatterModelBase from .benchmarkdata import BenchmarkData from .referencemodels import ReferenceModels @@ -14,4 +15,5 @@ __all__ = ['ScatterModelBase', 'BenchmarkData', 'ReferenceModels', 'MSSModel', 'PSMSModel', 'DCMModel', 'ESModel', 'PTDWBAModel', 'DWBAModel', 'SDWBAModel', - 'wavenumber', 'eta', 'h1', 'spherical_jnpp', 'prolate_swf'] + 'wavenumber', 'eta', 'h1', 'spherical_jnpp', 'prolate_swf', + 'as_dataframe', 'as_dataarray', 'split_dict'] diff --git a/src/echosms/scattermodelbase.py b/src/echosms/scattermodelbase.py index 8e4c89e..7fa66dd 100644 --- a/src/echosms/scattermodelbase.py +++ b/src/echosms/scattermodelbase.py @@ -1,10 +1,10 @@ """Base class for scatter model classes.""" import abc -from collections.abc import Iterable import pandas as pd import xarray as xr import numpy as np +from .utils import as_dataframe class ScatterModelBase(abc.ABC): @@ -106,7 +106,7 @@ def calculate_ts(self, data, expand=False, inplace=False, multiprocess=False): """ match data: case dict(): - data_df = self.as_dataframe(data) + data_df = as_dataframe(data, self.no_expand_parameters) case pd.DataFrame(): data_df = data case xr.DataArray(): @@ -121,7 +121,7 @@ def calculate_ts(self, data, expand=False, inplace=False, multiprocess=False): p = data_df.attrs['parameters'] if 'parameters' in data_df.attrs else {} # Note: the args argument in the apply call below requires a tuple. data_df.attrs is a - # dict and the usual behaviour is to make a tuple using the dict keys. The trailing comma + # dict and the default behaviour is to make a tuple using the dict keys. The trailing comma # and parenthesis instead causes the tuple to have one entry of the dict. if multiprocess: @@ -157,83 +157,3 @@ def __ts_helper(self, *args): @abc.abstractmethod def calculate_ts_single(self): """Calculate the TS for one parameter set.""" - - def __split_parameters(self, params): - """Split model parameters. - - Splits model parameters into a dict of expandable items and a dict of non-expandable items - - Parameters - ---------- - params : dict - Dict of model parameters. - - Returns - ------- - : tuple(dict, dict) - The input parameter dict split into those parameters that can be expanded (index 0) and - those that cannot (index 1). - """ - nexpand = {k: v for k, v in params.items() if k in self.no_expand_parameters} - expand = {k: v for k, v in params.items() if k not in self.no_expand_parameters} - return expand, nexpand - - def as_dataarray(self, params: dict) -> xr.DataArray: - """Convert model parameters from dict form to a Xarray DataArray. - - Parameters - ---------- - params : - A dictionary containing model parameters. - - Returns - ------- - : - Returns a multi-dimensional DataArray generated from the Cartesian product of all - expandable items in the input dict. Non-expandable items are added to the DataArray - attrs property. Expandable items are those that can be sensibly expandeded into - DataArray coordinates. Not all models have non-expandable items. - The array is named `ts`, the values are initialised to `nan`, the - dimension names are the dict keys, and the coordinate variables are the dict values. - - """ - expand, nexpand = self.__split_parameters(params) - - # Convert scalars to iterables so xarray is happy - for k, v in expand.items(): - if not isinstance(v, Iterable) or isinstance(v, str): - expand[k] = [v] - - sz = [len(v) for k, v in expand.items()] - return xr.DataArray(data=np.full(sz, np.nan), coords=expand, name='ts', - attrs={'units': 'dB', 'dB_reference': '1 m^2', - 'parameters': nexpand}) - - def as_dataframe(self, params: dict) -> pd.DataFrame: - """Convert model parameters from dict form to a Pandas DataFrame. - - Parameters - ---------- - params : - A dictionary containing model parameters. - - Returns - ------- - : - Returns a Pandas DataFrame generated from the Cartesian product of all expandable - items in the input dict. DataFrame column names are obtained from the dict keys. - Non-expandable items are added to the DataFrame attrs property. Expandable items are - those that can be sensibly expandeded into DataFrame columns. Not all models have - non-expandable items. - - """ - expand, nexpand = self.__split_parameters(params) - - # Use meshgrid to do the Cartesian product then create a Pandas DataFrame from that, having - # flattened the multidimensional arrays and using a dict to provide column names. - # This preserves the differing dtypes in each column compared to other ways of - # constructing the DataFrame). - df = pd.DataFrame({k: t.flatten() - for k, t in zip(expand.keys(), np.meshgrid(*tuple(expand.values())))}) - df.attrs = {'parameters': nexpand} - return df diff --git a/src/echosms/utils.py b/src/echosms/utils.py index df880f4..9a03cbd 100644 --- a/src/echosms/utils.py +++ b/src/echosms/utils.py @@ -1,6 +1,8 @@ """Miscellaneous utility functions.""" from collections.abc import Iterable import numpy as np +import xarray as xr +import pandas as pd from scipy.special import spherical_jn, spherical_yn from collections import namedtuple @@ -223,3 +225,96 @@ def prolate_swf(m: int, lnum: int, c: float, xi: float, eta: Iterable[float], S1d = p.s1dc * np.float_power(10.0, p.is1de) return R1, R2, R1d, R2d, S1, S1d, p.naccr, p.naccs + + +def split_dict(d: dict, s: list) -> tuple[dict, dict]: + """Split a dict based on a list of keys. + + Splits model parameters into a dict of expandable items and a dict of non-expandable items + + Parameters + ---------- + d : dict + Dict to be split. + + s: list + List of dict keys to use for splitting `d`. + + Returns + ------- + : tuple(dict, dict) + The `input` dict split into two dicts based on the keys in `s`. The first tuple item + contains the items that do not have keys in `s`. + """ + contains = {k: v for k, v in d.items() if k in s} + ncontains = {k: v for k, v in d.items() if k not in s} + return ncontains, contains + + +def as_dataarray(params: dict, no_expand: list = []) -> xr.DataArray: + """Convert model parameters from dict form to a Xarray DataArray. + + Parameters + ---------- + params : + The model parameters. + + no_expand : + Key values of the non-expandable model parameters in `params`. + + Returns + ------- + : + Returns a multi-dimensional DataArray generated from the Cartesian product of all + expandable items in the input dict. Non-expandable items are added to the DataArray + attrs property. Expandable items are those that can be sensibly expandeded into + DataArray coordinates. Not all models have non-expandable items. + The array is named `ts`, the values are initialised to `nan`, the + dimension names are the dict keys, and the coordinate variables are the dict values. + + """ + expand, nexpand = split_dict(params, no_expand) + + # Convert scalars to iterables so xarray is happy + for k, v in expand.items(): + if not isinstance(v, Iterable) or isinstance(v, str): + expand[k] = [v] + + sz = [len(v) for k, v in expand.items()] + return xr.DataArray(data=np.full(sz, np.nan), coords=expand, name='ts', + attrs={'units': 'dB', 'dB_reference': '1 m^2', + 'parameters': nexpand}) + + +def as_dataframe(params: dict, no_expand: list = []) -> pd.DataFrame: + """Convert model parameters from dict form to a Pandas DataFrame. + + Parameters + ---------- + params : + The model parameters. + + no_expand + --------- + Key values of the non-expandable model parameters in `params`. + + Returns + ------- + : + Returns a Pandas DataFrame generated from the Cartesian product of all expandable + items in the input dict. DataFrame column names are obtained from the dict keys. + Non-expandable items are added to the DataFrame attrs property. Expandable items are + those that can be sensibly expandeded into DataFrame columns. Not all models have + non-expandable items. + + """ + expand, nexpand = split_dict(params, no_expand) + + # Use meshgrid to do the Cartesian product then create a Pandas DataFrame from that, having + # flattened the multidimensional arrays and using a dict to provide column names. + # This preserves the differing dtypes in each column compared to other ways of + # constructing the DataFrame). + df = pd.DataFrame({k: t.flatten() + for k, t in zip(expand.keys(), np.meshgrid(*tuple(expand.values())))}) + df.attrs = {'parameters': nexpand} + return df diff --git a/src/example_code.py b/src/example_code.py index 495dda6..0c2318d 100644 --- a/src/example_code.py +++ b/src/example_code.py @@ -7,6 +7,7 @@ from echosms import MSSModel, PSMSModel, DCMModel, ESModel, PTDWBAModel from echosms import BenchmarkData from echosms import ReferenceModels +from echosms import as_dataframe, as_dataarray # Load the reference model defintiions rm = ReferenceModels() @@ -176,7 +177,7 @@ def plot_compare(f1, ts1, label1, f2, ts2, label2, title): m['target_rho'] = np.arange(1020, 1030, 1) # [kg/m^3] m['theta'] = [0, 90.0, 180.0] # can convert this to a dataframe -models_df = mss.as_dataframe(m) +models_df = as_dataframe(m, mss.no_expand_parameters) # could also make a DataFrame of parameters that are not just the combination of all input # parameters. This offers a way to specify a more tailored set of model parameters. @@ -212,7 +213,7 @@ def plot_compare(f1, ts1, label1, f2, ts2, label2, title): mss = MSSModel() # Instead of converting params to a dataframe, an xarray can be used. -params_xa = mss.as_dataarray(params) +params_xa = as_dataarray(params, mss.no_expand_parameters) # how many models runs would that be? print(f'Running {np.prod(params_xa.shape)} models!') @@ -242,6 +243,9 @@ def plot_compare(f1, ts1, label1, f2, ts2, label2, title): m['rho'] = [m['medium_rho'], m['target_rho']] m['c'] = [m['medium_c'], m['target_c']] m['f'] = bmf['Frequency_kHz']*1e3 +# remove unneeded parameters +m = {k: v for k, v in m.items() + if k not in ['boundary_type', 'a', 'medium_rho', 'medium_c', 'target_rho', 'target_c']} pt = PTDWBAModel() dwba_ts = pt.calculate_ts(m) @@ -278,6 +282,9 @@ def plot_compare(f1, ts1, label1, f2, ts2, label2, title): m['phi'] = 0 m['rho'] = [m['medium_rho'], m['target_rho']] m['c'] = [m['medium_c'], m['target_c']] +# remove unneeded parameters +m = {k: v for k, v in m.items() + if k not in ['boundary_type', 'a', 'b', 'medium_rho', 'medium_c', 'target_rho', 'target_c']} dwba_ts = [] for f in freqs: