Skip to content

Commit

Permalink
move as_dataframe and as_dataarray back to utils.py
Browse files Browse the repository at this point in the history
  • Loading branch information
gavinmacaulay committed Sep 10, 2024
1 parent 863b041 commit b6ee027
Show file tree
Hide file tree
Showing 4 changed files with 111 additions and 87 deletions.
6 changes: 4 additions & 2 deletions src/echosms/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""Setup the public API for echoSMs."""
from .utils import wavenumber, eta, h1, prolate_swf, spherical_jnpp
from .utils import wavenumber, eta, h1, prolate_swf, spherical_jnpp, split_dict
from .utils import as_dataframe, as_dataarray
from .scattermodelbase import ScatterModelBase
from .benchmarkdata import BenchmarkData
from .referencemodels import ReferenceModels
Expand All @@ -14,4 +15,5 @@
__all__ = ['ScatterModelBase', 'BenchmarkData', 'ReferenceModels',
'MSSModel', 'PSMSModel', 'DCMModel', 'ESModel', 'PTDWBAModel',
'DWBAModel', 'SDWBAModel',
'wavenumber', 'eta', 'h1', 'spherical_jnpp', 'prolate_swf']
'wavenumber', 'eta', 'h1', 'spherical_jnpp', 'prolate_swf',
'as_dataframe', 'as_dataarray', 'split_dict']
86 changes: 3 additions & 83 deletions src/echosms/scattermodelbase.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
"""Base class for scatter model classes."""

import abc
from collections.abc import Iterable
import pandas as pd
import xarray as xr
import numpy as np
from .utils import as_dataframe


class ScatterModelBase(abc.ABC):
Expand Down Expand Up @@ -106,7 +106,7 @@ def calculate_ts(self, data, expand=False, inplace=False, multiprocess=False):
"""
match data:
case dict():
data_df = self.as_dataframe(data)
data_df = as_dataframe(data, self.no_expand_parameters)
case pd.DataFrame():
data_df = data
case xr.DataArray():
Expand All @@ -121,7 +121,7 @@ def calculate_ts(self, data, expand=False, inplace=False, multiprocess=False):
p = data_df.attrs['parameters'] if 'parameters' in data_df.attrs else {}

# Note: the args argument in the apply call below requires a tuple. data_df.attrs is a
# dict and the usual behaviour is to make a tuple using the dict keys. The trailing comma
# dict and the default behaviour is to make a tuple using the dict keys. The trailing comma
# and parenthesis instead causes the tuple to have one entry of the dict.

if multiprocess:
Expand Down Expand Up @@ -157,83 +157,3 @@ def __ts_helper(self, *args):
@abc.abstractmethod
def calculate_ts_single(self):
"""Calculate the TS for one parameter set."""

def __split_parameters(self, params):
"""Split model parameters.
Splits model parameters into a dict of expandable items and a dict of non-expandable items
Parameters
----------
params : dict
Dict of model parameters.
Returns
-------
: tuple(dict, dict)
The input parameter dict split into those parameters that can be expanded (index 0) and
those that cannot (index 1).
"""
nexpand = {k: v for k, v in params.items() if k in self.no_expand_parameters}
expand = {k: v for k, v in params.items() if k not in self.no_expand_parameters}
return expand, nexpand

def as_dataarray(self, params: dict) -> xr.DataArray:
"""Convert model parameters from dict form to a Xarray DataArray.
Parameters
----------
params :
A dictionary containing model parameters.
Returns
-------
:
Returns a multi-dimensional DataArray generated from the Cartesian product of all
expandable items in the input dict. Non-expandable items are added to the DataArray
attrs property. Expandable items are those that can be sensibly expandeded into
DataArray coordinates. Not all models have non-expandable items.
The array is named `ts`, the values are initialised to `nan`, the
dimension names are the dict keys, and the coordinate variables are the dict values.
"""
expand, nexpand = self.__split_parameters(params)

# Convert scalars to iterables so xarray is happy
for k, v in expand.items():
if not isinstance(v, Iterable) or isinstance(v, str):
expand[k] = [v]

sz = [len(v) for k, v in expand.items()]
return xr.DataArray(data=np.full(sz, np.nan), coords=expand, name='ts',
attrs={'units': 'dB', 'dB_reference': '1 m^2',
'parameters': nexpand})

def as_dataframe(self, params: dict) -> pd.DataFrame:
"""Convert model parameters from dict form to a Pandas DataFrame.
Parameters
----------
params :
A dictionary containing model parameters.
Returns
-------
:
Returns a Pandas DataFrame generated from the Cartesian product of all expandable
items in the input dict. DataFrame column names are obtained from the dict keys.
Non-expandable items are added to the DataFrame attrs property. Expandable items are
those that can be sensibly expandeded into DataFrame columns. Not all models have
non-expandable items.
"""
expand, nexpand = self.__split_parameters(params)

# Use meshgrid to do the Cartesian product then create a Pandas DataFrame from that, having
# flattened the multidimensional arrays and using a dict to provide column names.
# This preserves the differing dtypes in each column compared to other ways of
# constructing the DataFrame).
df = pd.DataFrame({k: t.flatten()
for k, t in zip(expand.keys(), np.meshgrid(*tuple(expand.values())))})
df.attrs = {'parameters': nexpand}
return df
95 changes: 95 additions & 0 deletions src/echosms/utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
"""Miscellaneous utility functions."""
from collections.abc import Iterable
import numpy as np
import xarray as xr
import pandas as pd
from scipy.special import spherical_jn, spherical_yn
from collections import namedtuple

Expand Down Expand Up @@ -223,3 +225,96 @@ def prolate_swf(m: int, lnum: int, c: float, xi: float, eta: Iterable[float],
S1d = p.s1dc * np.float_power(10.0, p.is1de)

return R1, R2, R1d, R2d, S1, S1d, p.naccr, p.naccs


def split_dict(d: dict, s: list) -> tuple[dict, dict]:
"""Split a dict based on a list of keys.
Splits model parameters into a dict of expandable items and a dict of non-expandable items
Parameters
----------
d : dict
Dict to be split.
s: list
List of dict keys to use for splitting `d`.
Returns
-------
: tuple(dict, dict)
The `input` dict split into two dicts based on the keys in `s`. The first tuple item
contains the items that do not have keys in `s`.
"""
contains = {k: v for k, v in d.items() if k in s}
ncontains = {k: v for k, v in d.items() if k not in s}
return ncontains, contains


def as_dataarray(params: dict, no_expand: list = []) -> xr.DataArray:
"""Convert model parameters from dict form to a Xarray DataArray.
Parameters
----------
params :
The model parameters.
no_expand :
Key values of the non-expandable model parameters in `params`.
Returns
-------
:
Returns a multi-dimensional DataArray generated from the Cartesian product of all
expandable items in the input dict. Non-expandable items are added to the DataArray
attrs property. Expandable items are those that can be sensibly expandeded into
DataArray coordinates. Not all models have non-expandable items.
The array is named `ts`, the values are initialised to `nan`, the
dimension names are the dict keys, and the coordinate variables are the dict values.
"""
expand, nexpand = split_dict(params, no_expand)

# Convert scalars to iterables so xarray is happy
for k, v in expand.items():
if not isinstance(v, Iterable) or isinstance(v, str):
expand[k] = [v]

sz = [len(v) for k, v in expand.items()]
return xr.DataArray(data=np.full(sz, np.nan), coords=expand, name='ts',
attrs={'units': 'dB', 'dB_reference': '1 m^2',
'parameters': nexpand})


def as_dataframe(params: dict, no_expand: list = []) -> pd.DataFrame:
"""Convert model parameters from dict form to a Pandas DataFrame.
Parameters
----------
params :
The model parameters.
no_expand
---------
Key values of the non-expandable model parameters in `params`.
Returns
-------
:
Returns a Pandas DataFrame generated from the Cartesian product of all expandable
items in the input dict. DataFrame column names are obtained from the dict keys.
Non-expandable items are added to the DataFrame attrs property. Expandable items are
those that can be sensibly expandeded into DataFrame columns. Not all models have
non-expandable items.
"""
expand, nexpand = split_dict(params, no_expand)

# Use meshgrid to do the Cartesian product then create a Pandas DataFrame from that, having
# flattened the multidimensional arrays and using a dict to provide column names.
# This preserves the differing dtypes in each column compared to other ways of
# constructing the DataFrame).
df = pd.DataFrame({k: t.flatten()
for k, t in zip(expand.keys(), np.meshgrid(*tuple(expand.values())))})
df.attrs = {'parameters': nexpand}
return df
11 changes: 9 additions & 2 deletions src/example_code.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from echosms import MSSModel, PSMSModel, DCMModel, ESModel, PTDWBAModel
from echosms import BenchmarkData
from echosms import ReferenceModels
from echosms import as_dataframe, as_dataarray

# Load the reference model defintiions
rm = ReferenceModels()
Expand Down Expand Up @@ -176,7 +177,7 @@ def plot_compare(f1, ts1, label1, f2, ts2, label2, title):
m['target_rho'] = np.arange(1020, 1030, 1) # [kg/m^3]
m['theta'] = [0, 90.0, 180.0]
# can convert this to a dataframe
models_df = mss.as_dataframe(m)
models_df = as_dataframe(m, mss.no_expand_parameters)
# could also make a DataFrame of parameters that are not just the combination of all input
# parameters. This offers a way to specify a more tailored set of model parameters.

Expand Down Expand Up @@ -212,7 +213,7 @@ def plot_compare(f1, ts1, label1, f2, ts2, label2, title):
mss = MSSModel()

# Instead of converting params to a dataframe, an xarray can be used.
params_xa = mss.as_dataarray(params)
params_xa = as_dataarray(params, mss.no_expand_parameters)

# how many models runs would that be?
print(f'Running {np.prod(params_xa.shape)} models!')
Expand Down Expand Up @@ -242,6 +243,9 @@ def plot_compare(f1, ts1, label1, f2, ts2, label2, title):
m['rho'] = [m['medium_rho'], m['target_rho']]
m['c'] = [m['medium_c'], m['target_c']]
m['f'] = bmf['Frequency_kHz']*1e3
# remove unneeded parameters
m = {k: v for k, v in m.items()
if k not in ['boundary_type', 'a', 'medium_rho', 'medium_c', 'target_rho', 'target_c']}

pt = PTDWBAModel()
dwba_ts = pt.calculate_ts(m)
Expand Down Expand Up @@ -278,6 +282,9 @@ def plot_compare(f1, ts1, label1, f2, ts2, label2, title):
m['phi'] = 0
m['rho'] = [m['medium_rho'], m['target_rho']]
m['c'] = [m['medium_c'], m['target_c']]
# remove unneeded parameters
m = {k: v for k, v in m.items()
if k not in ['boundary_type', 'a', 'b', 'medium_rho', 'medium_c', 'target_rho', 'target_c']}

dwba_ts = []
for f in freqs:
Expand Down

0 comments on commit b6ee027

Please sign in to comment.