Skip to content

Commit

Permalink
Add initial prototype for group average bounds code
Browse files Browse the repository at this point in the history
  • Loading branch information
tomvothecoder committed Dec 6, 2024
1 parent 6d257fc commit ff2030d
Showing 1 changed file with 30 additions and 7 deletions.
37 changes: 30 additions & 7 deletions xcdat/temporal.py
Original file line number Diff line number Diff line change
Expand Up @@ -876,7 +876,7 @@ def _averager(
if self._mode == "average":
dv_avg = self._average(ds, data_var)
elif self._mode in ["group_average", "climatology", "departures"]:
dv_avg = self._group_average(ds, data_var)
dv_avg, time_bnds = self._group_average(ds, data_var)

# The original time dimension is dropped from the dataset because
# it becomes obsolete after the data variable is averaged. When the
Expand All @@ -885,8 +885,10 @@ def _averager(
ds = ds.drop_dims(self.dim)
ds[dv_avg.name] = dv_avg

if self._mode == "group_average":
ds = ds.bounds.add_missing_bounds(axes="T")
if self._mode in ["group_average", "climatology", "departures"]:
ds[time_bnds.name] = time_bnds
# FIXME: This is not working when time bounds are datetime and
# time is cftime.
ds = center_times(ds)

if keep_weights:
Expand Down Expand Up @@ -1479,7 +1481,9 @@ def _average(self, ds: xr.Dataset, data_var: str) -> xr.DataArray:

return dv

def _group_average(self, ds: xr.Dataset, data_var: str) -> xr.DataArray:
def _group_average(
self, ds: xr.Dataset, data_var: str
) -> Tuple[xr.DataArray, xr.DataArray]:
"""Averages a data variable by time group.
Parameters
Expand All @@ -1491,7 +1495,7 @@ def _group_average(self, ds: xr.Dataset, data_var: str) -> xr.DataArray:
Returns
-------
xr.DataArray
Tuple[xr.DataArray, xr.DataArray]
The data variable averaged by time group.
"""
dv = _get_data_var(ds, data_var)
Expand All @@ -1500,9 +1504,9 @@ def _group_average(self, ds: xr.Dataset, data_var: str) -> xr.DataArray:
# values.
self._labeled_time = self._label_time_coords(dv[self.dim])
dv = dv.assign_coords({self.dim: self._labeled_time})
time_bounds = ds.bounds.get_bounds("T", var_key=data_var)

if self._weighted:
time_bounds = ds.bounds.get_bounds("T", var_key=data_var)
self._weights = self._get_weights(time_bounds)

# Weight the data variable.
Expand All @@ -1526,6 +1530,25 @@ def _group_average(self, ds: xr.Dataset, data_var: str) -> xr.DataArray:
else:
dv = self._group_data(dv).mean()

"""I think we'll need to collect the bounds for each group, (e.g., group_bounds_array = [("2000-01-01 00:00", "2000-01-02 00:00"), ("2000-01-02 00:00", "2000-01-03 00:00"), ..., ("2000-01-31 00:00", "2000-02-01 00:00")] and then take the min of the lower bound and the max of the upper bound (i.e., group_bnd = [np.min(groups_bound_array[:, 0]), np.max(group_bounds_array[:, 1])].
"""
# Create time bounds for each group
time_bounds_grouped = self._group_data(time_bounds)
group_bounds = []

for _, group_data in time_bounds_grouped:
group_times = group_data.values
group_bnds = (np.min(group_times[:, 0]), np.max(group_times[:, 1]))
group_bounds.append(group_bnds)

# Convert group bounds to DataArray
da_bnds = xr.DataArray(
data=np.array(group_bounds),
dims=[self.dim, "bnds"],
coords={self.dim: dv[self.dim].values},
name=f"{self.dim}_bnds",
)

# After grouping and aggregating, the grouped time dimension's
# attributes are removed. Xarray's `keep_attrs=True` option only keeps
# attributes for data variables and not their coordinates, so the
Expand All @@ -1535,7 +1558,7 @@ def _group_average(self, ds: xr.Dataset, data_var: str) -> xr.DataArray:

dv = self._add_operation_attrs(dv)

return dv
return dv, da_bnds

def _get_weights(self, time_bounds: xr.DataArray) -> xr.DataArray:
"""Calculates weights for a data variable using time bounds.
Expand Down

0 comments on commit ff2030d

Please sign in to comment.