Skip to content

Commit

Permalink
Tidy up parameters and returned data from simulate_ts()
Browse files Browse the repository at this point in the history
  • Loading branch information
gavinmacaulay committed Aug 25, 2024
1 parent b669d8d commit 86b3a6e
Showing 1 changed file with 32 additions and 22 deletions.
54 changes: 32 additions & 22 deletions src/echosms/scattermodelbase.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,10 @@ class ScatterModelBase(abc.ABC):
A short version of the model's long name, typically an ancronym.
analytical_type : str
Whether the model implements an ``exact`` or an ``approximate`` model.
boundary_types : list of str
boundary_types : list[str]
The types of boundary conditions that the model provides, e.g., 'fixed rigid',
'pressure release', 'fluid filled'
shapes : list of str
shapes : list[str]
The target shapes that the model can represent.
max_ka : float
An approximate maximum ka value that will result in accurate target strength results. Note
Expand All @@ -42,14 +42,14 @@ def __init__(self):
self.shapes = []
self.max_ka = np.nan

def calculate_ts(self, data, multiprocess=False, result_type=None):
"""Calculate the TS for many parameter sets.
def calculate_ts(self, data, expand=None, inplace=False, multiprocess=False):
"""Calculate the TS for many parameters.
Parameters
----------
data : Pandas DataFrame, Xarray DataArray or dict
Requirements for the different input data types are:
- **DataFrame**: column names must match the function parameter names in
calculate_ts_single(). One TS value will be calculated for each row in the DataFrame.
- **DataArray**: dimension names must match the function parameter names in
Expand All @@ -61,21 +61,28 @@ def calculate_ts(self, data, multiprocess=False, result_type=None):
multiprocess : boolean
Split the ts calculation across CPU cores.
result_type : str or None
Only applicable if `data` is a dict:
expand : bool
Only applicable if `data` is a dict. The default is `False`. If `True`, will expand
the dict into a DataFrame containing the Cartesian product of all values in the dict
(with one column per key in the dict) and return that DataFrame, adding a column
named `ts` for the TS results.
- `None`: return a list of TS values. This is the default.
- `expand`: return a DataFrame with a column for each key in the dict. The TS values
will be in a column named `ts`.
inplace : bool
Only applicable if `data` is a DataFrame. Default is `False`. If `True`, the TS results
will be added to the input DataFrame in a column named `ts`. It a `ts` column exists,
it is overwritten.
Returns
-------
: list, Series, DataFrame, or DataArray
Returns the TS. Variable type is determined by the type of `data`:
: None, list[float], Series, DataFrame
The return type and value is determined by the type of the input variable `data` and
the `expand` and `inplace` input parameters:
- dict input returns a list or DataFrame (if `result_type` is `expand`).
- DataFrame input returns a Series
- DataArray input returns the given DataArray with the values set to the TS
- dict input and `expand=False` returns a list of floats.
- dict input and `expand=True` returns a DataFrame.
- DataFrame input and `inplace=False` returns a Series.
- DataFrame input and `inplace=True` modifies `data` and returns `None`.
- DataArray input always modifies `data` and returns `None`.
"""
match data:
Expand All @@ -100,19 +107,22 @@ def calculate_ts(self, data, multiprocess=False, result_type=None):
ts = data_df.apply(self.__ts_helper, axis=1)

match data:
case dict() if expand:
data_df['ts'] = ts
return data_df
case dict():
if result_type == 'expand':
data_df['ts'] = ts
return data_df
else:
return ts.to_list()
return ts.to_list()
case pd.DataFrame() if inplace:
data_df['ts'] = ts
return
case pd.DataFrame():
return ts.rename('ts', inplace=True)
case xr.DataArray():
data.values = ts.to_numpy().reshape(data.shape)
return data
return
case _:
return ts
raise AssertionError('This code should never be reached - unsupported input data '
f'type of {type(data)}.')

def __ts_helper(self, *args):
"""Convert function arguments and call calculate_ts_single()."""
Expand Down

0 comments on commit 86b3a6e

Please sign in to comment.