diff --git a/xcdat/temporal.py b/xcdat/temporal.py index a5611e42..a0daf9a4 100644 --- a/xcdat/temporal.py +++ b/xcdat/temporal.py @@ -757,21 +757,17 @@ def _averager( # Preprocess the dataset based on method argument values. ds = self._preprocess_dataset(ds) - # Get the data variable and the required time axis metadata. - dv = _get_data_var(ds, data_var) - time_bounds = ds.bounds.get_bounds("T", var_key=dv.name) - if self._mode == "average": - dv = self._average(dv, time_bounds) + dv_avg = self._average(ds, data_var) elif self._mode in ["group_average", "climatology", "departures"]: - dv = self._group_average(dv, time_bounds) + dv_avg = self._group_average(ds, data_var) # The original time dimension is dropped from the dataset because # it becomes obsolete after the data variable is averaged. When the # averaged data variable is added to the dataset, the new time dimension # and its associated coordinates are also added. ds = ds.drop_dims(self.dim) # type: ignore - ds[dv.name] = dv + ds[dv_avg.name] = dv_avg if keep_weights: ds = self._keep_weights(ds) @@ -1075,28 +1071,28 @@ def _drop_leap_days(self, ds: xr.Dataset): ) return ds - def _average( - self, data_var: xr.DataArray, time_bounds: xr.DataArray - ) -> xr.DataArray: + def _average(self, ds: xr.Dataset, data_var: str) -> xr.DataArray: """Averages a data variable with the time dimension removed. Parameters ---------- - data_var : xr.DataArray - The data variable. - time_bounds : xr.DataArray - The time bounds. + ds : xr.Dataset + The dataset. + data_var : str + The key of the data variable. Returns ------- xr.DataArray - The averages for a data variable with the time dimension removed. + The data variable averaged with the time dimension removed. """ - dv = data_var.copy() + dv = _get_data_var(ds, data_var) with xr.set_options(keep_attrs=True): if self._weighted: + time_bounds = ds.bounds.get_bounds("T", var_key=data_var) self._weights = self._get_weights(time_bounds) + dv = dv.weighted(self._weights).mean(dim=self.dim) # type: ignore else: dv = dv.mean(dim=self.dim) # type: ignore @@ -1105,31 +1101,31 @@ def _average( return dv - def _group_average( - self, data_var: xr.DataArray, time_bounds: xr.DataArray - ) -> xr.DataArray: + def _group_average(self, ds: xr.Dataset, data_var: str) -> xr.DataArray: """Averages a data variable by time group. Parameters ---------- - data_var : xr.DataArray - The data variable. - time_bounds : xr.DataArray - The time bounds. + ds : xr.Dataset + The dataset. + data_var : str + The key of the data variable. Returns ------- xr.DataArray The data variable averaged by time group. """ - dv = data_var.copy() + dv = _get_data_var(ds, data_var) # Label the time coordinates for grouping weights and the data variable # values. self._labeled_time = self._label_time_coords(dv[self.dim]) if self._weighted: + time_bounds = ds.bounds.get_bounds("T", var_key=data_var) self._weights = self._get_weights(time_bounds) + # Weight the data variable. dv *= self._weights @@ -1145,8 +1141,9 @@ def _group_average( # included to take into account zero weight for missing data. with xr.set_options(keep_attrs=True): dv = self._group_data(dv).sum() / self._group_data(weights).sum() + # Restore the data variable's name. - dv.name = data_var.name + dv.name = data_var else: dv = self._group_data(dv).mean()