Skip to content

Commit

Permalink
Clean up logic in various private methods
Browse files Browse the repository at this point in the history
- Methods include `_subset_coords_for_custom_seasons()` and `_shift_custom_season_years()`
  • Loading branch information
tomvothecoder committed Nov 20, 2024
1 parent 8d156c2 commit f2648ff
Showing 1 changed file with 36 additions and 36 deletions.
72 changes: 36 additions & 36 deletions xcdat/temporal.py
Original file line number Diff line number Diff line change
Expand Up @@ -1099,11 +1099,12 @@ def _preprocess_dataset(self, ds: xr.Dataset) -> xr.Dataset:
"""Preprocess the dataset based on averaging settings.
Operations include:
1. Drop leap days for daily climatologies.
2. Subset the dataset based on the reference period.
3. Shift years for custom seasons spanning the calendar year.
4. Shift Decembers for "DJF" mode and drop incomplete "DJF" seasons.
5. Drop incomplete seasons if specified.
1. Drop leap days for daily climatologies.
2. Subset the dataset based on the reference period.
3. Shift years for custom seasons spanning the calendar year.
4. Shift Decembers for "DJF" mode and drop incomplete "DJF" seasons,
if specified.
5. Drop incomplete seasons if specified.
Parameters
----------
Expand Down Expand Up @@ -1141,6 +1142,9 @@ def _preprocess_dataset(self, ds: xr.Dataset) -> xr.Dataset:
if len(months) != 12:
ds = self._subset_coords_for_custom_seasons(ds, months)

# The years for time coordinates needs to be shifted by 1 for months
# that span the calendar because Xarray groups seasons by months
# in the same year, rather than the previous year.
ds = self._shift_custom_season_years(ds)

if self._freq == "season" and self._season_config.get("dec_mode") == "DJF":
Expand Down Expand Up @@ -1180,16 +1184,8 @@ def _subset_coords_for_custom_seasons(
The dataset with time coordinate subsetted to months used in
custom seasons.
"""
month_ints = sorted([MONTH_STR_TO_INT[month] for month in months])

coords_by_month = ds[self.dim].groupby(f"{self.dim}.month").groups
month_to_time_idx = {
k: coords_by_month[k] for k in month_ints if k in coords_by_month
}
month_to_time_idx = sorted(
list(chain.from_iterable(month_to_time_idx.values())) # type: ignore
)
ds_new = ds.isel({f"{self.dim}": month_to_time_idx})
month_ints = [MONTH_STR_TO_INT[month] for month in months]
ds_new = ds.sel({self.dim: ds[self.dim].dt.month.isin(month_ints)})

return ds_new

Expand Down Expand Up @@ -1231,34 +1227,31 @@ def _shift_custom_season_years(self, ds: xr.Dataset) -> xr.Dataset:
ds_new = ds.copy()
custom_seasons = self._season_config["custom_seasons"]

# Identify months that span across years in custom seasons by getting
# the months before "Jan" if "Jan" is not the first month of the season.
# Note: Only one custom season can span the calendar year.
span_months: List[int] = []

# Identify the months that span across years in custom seasons.
# This is done by checking if "Jan" is not the first month in the
# custom season and getting all months before "Jan".
for months in custom_seasons.values(): # type: ignore
month_nums = [MONTH_STR_TO_INT[month] for month in months]
if 1 in month_nums:
jan_index = month_nums.index(1)
month_ints = [MONTH_STR_TO_INT[month] for month in months]

if jan_index != 0:
span_months.extend(month_nums[:jan_index])
if 1 in month_ints and month_ints.index(1) != 0:
span_months.extend(month_ints[: month_ints.index(1)])
break

if span_months:
time_coords = ds_new[self.dim].copy()
idxs = np.where(time_coords.dt.month.isin(span_months))[0]
indexes = time_coords.dt.month.isin(span_months)

if isinstance(time_coords.values[0], cftime.datetime):
for idx in idxs:
time_coords.values[idx] = time_coords.values[idx].replace(
year=time_coords.values[idx].year + 1
)
time_coords.values[indexes] = [
time.replace(year=time.year + 1)
for time in time_coords.values[indexes]
]
else:
for idx in idxs:
time_coords.values[idx] = pd.Timestamp(
time_coords.values[idx]
) + pd.DateOffset(years=1)
time_coords.values[indexes] = [
pd.Timestamp(time) + pd.DateOffset(years=1)
for time in time_coords.values[indexes]
]

ds_new = ds_new.assign_coords({self.dim: time_coords})

Expand Down Expand Up @@ -1298,9 +1291,16 @@ def _shift_djf_decembers(self, ds: xr.Dataset) -> xr.Dataset:
time_coords = ds_new[self.dim].copy()
dec_indexes = time_coords.dt.month == 12

time_coords.values[dec_indexes] = [
time.replace(year=time.year + 1) for time in time_coords.values[dec_indexes]
]
if isinstance(time_coords.values[0], cftime.datetime):
time_coords.values[dec_indexes] = [
time.replace(year=time.year + 1)
for time in time_coords.values[dec_indexes]
]
else:
time_coords.values[dec_indexes] = [
pd.Timestamp(time) + pd.DateOffset(years=1)
for time in time_coords.values[dec_indexes]
]

ds_new = ds_new.assign_coords({self.dim: time_coords})

Expand Down

0 comments on commit f2648ff

Please sign in to comment.