Skip to content

Commit

Permalink
Return from calculate_ts() now matches the input data type
Browse files Browse the repository at this point in the history
  • Loading branch information
gavinmacaulay committed Aug 25, 2024
1 parent 0e1ecfd commit ba0194f
Show file tree
Hide file tree
Showing 2 changed files with 63 additions and 30 deletions.
78 changes: 54 additions & 24 deletions src/echosms/scattermodelbase.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,47 +42,77 @@ def __init__(self):
self.shapes = []
self.max_ka = np.nan

def calculate_ts(self, data, multiprocess=False):
def calculate_ts(self, data, multiprocess=False, result_type=None):
"""Calculate the TS for many parameter sets.
Parameters
----------
data : Pandas DataFrame or Xarray DataArray or dictionary
If a DataFrame, must contain column names as per the function parameters in the
calculate_ts_single() function in this class. Each row in the DataFrame will generate
one TS output. If a DataArray, must contain coordinate names as per the function
parameters in calculate_ts_single(). The TS will be calculated for all combinations of
the coordinate variables. If dictionary, it will be converted to a DataFrame first.
data : Pandas DataFrame, Xarray DataArray or dict
Requirements for the different input data types are:
- **DataFrame**: column names must match the function parameter names in
calculate_ts_single(). One TS value will be calculated for each row in the DataFrame.
- **DataArray**: dimension names must match the function parameter names in
calculate_ts_single(). TS will be calculated for all combinations of the
coordinate variables.
- **dict**: keys must match the function parameters in calculate_ts_single().
TS will be calculated for all combinations of the dict values.
multiprocess : boolean
Split the ts calculation across CPU cores.
result_type : str or None
Only applicable if `data` is a dict:
- `None`: return a list of TS values. This is the default.
- `expand`: return a DataFrame with a column for each key in the dict. The TS values
will be in a column named `ts`.
Returns
-------
: Numpy array
Returns the target strength calculated for all input parameters.
: list, Series, DataFrame, or DataArray
Returns the TS. Variable type is determined by the type of `data`:
- dict input returns a list (or DataFrame if `result_type` is `expand`).
- DataFrame input returns a Series
- DataArray input returns the given DataArray with the values set to the TS
"""
if isinstance(data, dict):
data = as_dataframe(data)
elif isinstance(data, pd.DataFrame):
pass
elif isinstance(data, xr.DataArray):
data = data.to_dataframe().reset_index()
else:
raise ValueError(f'Data type of {type(data)} is not supported'
' (only dictionaries, Pandas DataFrames and Xarray DataArrays are).')
match data:
case dict():
data_df = as_dataframe(data)
case pd.DataFrame():
data_df = data
case xr.DataArray():
data_df = data.to_dataframe().reset_index()
case _:
raise ValueError(f'Data type of {type(data)} is not supported'
' (only dictionaries, Pandas DataFrames and'
' Xarray DataArrays are).')

if multiprocess:
# Using mapply:
# ts = mapply(data, self.__ts_helper, axis=1)
# ts = mapply(data_df, self.__ts_helper, axis=1)
# Using swifter
# ts = df.swifter.apply(self.__ts_helper, axis=1)
ts = data.apply(self.__ts_helper, axis=1)
# ts = data_df.swifter.apply(self.__ts_helper, axis=1)
ts = data_df.apply(self.__ts_helper, axis=1)
else: # this uses just one CPU
ts = data.apply(self.__ts_helper, axis=1)

return ts.to_numpy() # TODO - return data type that matches the input data type
ts = data_df.apply(self.__ts_helper, axis=1)

match data:
case dict():
if result_type == 'expand':
data_df['ts'] = ts
return data_df
else:
return ts.to_list()
case pd.DataFrame():
return ts.rename('ts', inplace=True)
case xr.DataArray():
data.values = ts.to_numpy().reshape(data.shape)
return data
case _:
return ts

def __ts_helper(self, *args):
"""Convert function arguments and call calculate_ts_single()."""
Expand Down
15 changes: 9 additions & 6 deletions src/example_code.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,12 +174,13 @@
# parameters. This offers a way to specify a more tailored set of model parameters.

print(f'Running {len(models_df)} models')
# and run
# and run. This will return a Series
ts = mss.calculate_ts(models_df, multiprocess=True)

# And can then add the ts to the params dataframe for ease of selecting and plotting the results
models_df['ts'] = ts

# Alternatively, the ts results can be added to the dataframe that is passed in:
# ts = mss.calculate_ts(models_df, multiprocess=True, result_type='expand')

# plot some of the results
for rho in m['target_rho']:
r = models_df.query('target_rho == @rho and theta==90')
Expand Down Expand Up @@ -209,7 +210,9 @@

# and is called the same way as for the dataframe
if False: # cause it takes a long time to run (as multiprocess is not enabled internally)
ts = mss.calculate_ts(params_xa, multiprocess=True)
# When called with a dataarray, the values in that dataarray are overwritten with the ts, so
# it is not necessary to get the return value (i.e., there is an implicit inplace=True)
mss.calculate_ts(params_xa, multiprocess=True)

# and it can be inserted into params_xa
# TODO once the data is returned in an appropriate form
# Xarray selections and dimenions names can then be used
plt.plot(params_xa.f, params_xa.sel(theta=90, medium_rho=1000, medium_c=1600))

0 comments on commit ba0194f

Please sign in to comment.