From ba0194f4be242240d3e91f8d795037b7dc859f0a Mon Sep 17 00:00:00 2001 From: Gavin Macaulay Date: Sun, 25 Aug 2024 16:03:25 +1200 Subject: [PATCH] Return from calculate_ts() now matches the input data type --- src/echosms/scattermodelbase.py | 78 +++++++++++++++++++++++---------- src/example_code.py | 15 ++++--- 2 files changed, 63 insertions(+), 30 deletions(-) diff --git a/src/echosms/scattermodelbase.py b/src/echosms/scattermodelbase.py index 3f78503..cf54faa 100644 --- a/src/echosms/scattermodelbase.py +++ b/src/echosms/scattermodelbase.py @@ -42,47 +42,77 @@ def __init__(self): self.shapes = [] self.max_ka = np.nan - def calculate_ts(self, data, multiprocess=False): + def calculate_ts(self, data, multiprocess=False, result_type=None): """Calculate the TS for many parameter sets. Parameters ---------- - data : Pandas DataFrame or Xarray DataArray or dictionary - If a DataFrame, must contain column names as per the function parameters in the - calculate_ts_single() function in this class. Each row in the DataFrame will generate - one TS output. If a DataArray, must contain coordinate names as per the function - parameters in calculate_ts_single(). The TS will be calculated for all combinations of - the coordinate variables. If dictionary, it will be converted to a DataFrame first. + data : Pandas DataFrame, Xarray DataArray or dict + Requirements for the different input data types are: + + - **DataFrame**: column names must match the function parameter names in + calculate_ts_single(). One TS value will be calculated for each row in the DataFrame. + - **DataArray**: dimension names must match the function parameter names in + calculate_ts_single(). TS will be calculated for all combinations of the + coordinate variables. + - **dict**: keys must match the function parameters in calculate_ts_single(). + TS will be calculated for all combinations of the dict values. multiprocess : boolean Split the ts calculation across CPU cores. + result_type : str or None + Only applicable if `data` is a dict: + + - `None`: return a list of TS values. This is the default. + - `expand`: return a DataFrame with a column for each key in the dict. The TS values + will be in a column named `ts`. + Returns ------- - : Numpy array - Returns the target strength calculated for all input parameters. + : list, Series, DataFrame, or DataArray + Returns the TS. Variable type is determined by the type of `data`: + + - dict input returns a list (or DataFrame if `result_type` is `expand`). + - DataFrame input returns a Series + - DataArray input returns the given DataArray with the values set to the TS """ - if isinstance(data, dict): - data = as_dataframe(data) - elif isinstance(data, pd.DataFrame): - pass - elif isinstance(data, xr.DataArray): - data = data.to_dataframe().reset_index() - else: - raise ValueError(f'Data type of {type(data)} is not supported' - ' (only dictionaries, Pandas DataFrames and Xarray DataArrays are).') + match data: + case dict(): + data_df = as_dataframe(data) + case pd.DataFrame(): + data_df = data + case xr.DataArray(): + data_df = data.to_dataframe().reset_index() + case _: + raise ValueError(f'Data type of {type(data)} is not supported' + ' (only dictionaries, Pandas DataFrames and' + ' Xarray DataArrays are).') if multiprocess: # Using mapply: - # ts = mapply(data, self.__ts_helper, axis=1) + # ts = mapply(data_df, self.__ts_helper, axis=1) # Using swifter - # ts = df.swifter.apply(self.__ts_helper, axis=1) - ts = data.apply(self.__ts_helper, axis=1) + # ts = data_df.swifter.apply(self.__ts_helper, axis=1) + ts = data_df.apply(self.__ts_helper, axis=1) else: # this uses just one CPU - ts = data.apply(self.__ts_helper, axis=1) - - return ts.to_numpy() # TODO - return data type that matches the input data type + ts = data_df.apply(self.__ts_helper, axis=1) + + match data: + case dict(): + if result_type == 'expand': + data_df['ts'] = ts + return data_df + else: + return ts.to_list() + case pd.DataFrame(): + return ts.rename('ts', inplace=True) + case xr.DataArray(): + data.values = ts.to_numpy().reshape(data.shape) + return data + case _: + return ts def __ts_helper(self, *args): """Convert function arguments and call calculate_ts_single().""" diff --git a/src/example_code.py b/src/example_code.py index 6201b25..bb6e651 100644 --- a/src/example_code.py +++ b/src/example_code.py @@ -174,12 +174,13 @@ # parameters. This offers a way to specify a more tailored set of model parameters. print(f'Running {len(models_df)} models') -# and run +# and run. This will return a Series ts = mss.calculate_ts(models_df, multiprocess=True) - -# And can then add the ts to the params dataframe for ease of selecting and plotting the results models_df['ts'] = ts +# Alternatively, the ts results can be added to the dataframe that is passed in: +# ts = mss.calculate_ts(models_df, multiprocess=True, result_type='expand') + # plot some of the results for rho in m['target_rho']: r = models_df.query('target_rho == @rho and theta==90') @@ -209,7 +210,9 @@ # and is called the same way as for the dataframe if False: # cause it takes a long time to run (as multiprocess is not enabled internally) - ts = mss.calculate_ts(params_xa, multiprocess=True) + # When called with a dataarray, the values in that dataarray are overwritten with the ts, so + # it is not necessary to get the return value (i.e., there is an implicit inplace=True) + mss.calculate_ts(params_xa, multiprocess=True) -# and it can be inserted into params_xa -# TODO once the data is returned in an appropriate form +# Xarray selections and dimenions names can then be used +plt.plot(params_xa.f, params_xa.sel(theta=90, medium_rho=1000, medium_c=1600))