From f2648ffb1530a7f201f06739eca38aebf9871b90 Mon Sep 17 00:00:00 2001 From: Tom Vo Date: Wed, 20 Nov 2024 13:06:50 -0800 Subject: [PATCH] Clean up logic in various private methods - Methods include `_subset_coords_for_custom_seasons()` and `_shift_custom_season_years()` --- xcdat/temporal.py | 72 +++++++++++++++++++++++------------------------ 1 file changed, 36 insertions(+), 36 deletions(-) diff --git a/xcdat/temporal.py b/xcdat/temporal.py index 821cef18..3367e06a 100644 --- a/xcdat/temporal.py +++ b/xcdat/temporal.py @@ -1099,11 +1099,12 @@ def _preprocess_dataset(self, ds: xr.Dataset) -> xr.Dataset: """Preprocess the dataset based on averaging settings. Operations include: - 1. Drop leap days for daily climatologies. - 2. Subset the dataset based on the reference period. - 3. Shift years for custom seasons spanning the calendar year. - 4. Shift Decembers for "DJF" mode and drop incomplete "DJF" seasons. - 5. Drop incomplete seasons if specified. + 1. Drop leap days for daily climatologies. + 2. Subset the dataset based on the reference period. + 3. Shift years for custom seasons spanning the calendar year. + 4. Shift Decembers for "DJF" mode and drop incomplete "DJF" seasons, + if specified. + 5. Drop incomplete seasons if specified. Parameters ---------- @@ -1141,6 +1142,9 @@ def _preprocess_dataset(self, ds: xr.Dataset) -> xr.Dataset: if len(months) != 12: ds = self._subset_coords_for_custom_seasons(ds, months) + # The years for time coordinates needs to be shifted by 1 for months + # that span the calendar because Xarray groups seasons by months + # in the same year, rather than the previous year. ds = self._shift_custom_season_years(ds) if self._freq == "season" and self._season_config.get("dec_mode") == "DJF": @@ -1180,16 +1184,8 @@ def _subset_coords_for_custom_seasons( The dataset with time coordinate subsetted to months used in custom seasons. """ - month_ints = sorted([MONTH_STR_TO_INT[month] for month in months]) - - coords_by_month = ds[self.dim].groupby(f"{self.dim}.month").groups - month_to_time_idx = { - k: coords_by_month[k] for k in month_ints if k in coords_by_month - } - month_to_time_idx = sorted( - list(chain.from_iterable(month_to_time_idx.values())) # type: ignore - ) - ds_new = ds.isel({f"{self.dim}": month_to_time_idx}) + month_ints = [MONTH_STR_TO_INT[month] for month in months] + ds_new = ds.sel({self.dim: ds[self.dim].dt.month.isin(month_ints)}) return ds_new @@ -1231,34 +1227,31 @@ def _shift_custom_season_years(self, ds: xr.Dataset) -> xr.Dataset: ds_new = ds.copy() custom_seasons = self._season_config["custom_seasons"] + # Identify months that span across years in custom seasons by getting + # the months before "Jan" if "Jan" is not the first month of the season. + # Note: Only one custom season can span the calendar year. span_months: List[int] = [] - - # Identify the months that span across years in custom seasons. - # This is done by checking if "Jan" is not the first month in the - # custom season and getting all months before "Jan". for months in custom_seasons.values(): # type: ignore - month_nums = [MONTH_STR_TO_INT[month] for month in months] - if 1 in month_nums: - jan_index = month_nums.index(1) + month_ints = [MONTH_STR_TO_INT[month] for month in months] - if jan_index != 0: - span_months.extend(month_nums[:jan_index]) + if 1 in month_ints and month_ints.index(1) != 0: + span_months.extend(month_ints[: month_ints.index(1)]) break if span_months: time_coords = ds_new[self.dim].copy() - idxs = np.where(time_coords.dt.month.isin(span_months))[0] + indexes = time_coords.dt.month.isin(span_months) if isinstance(time_coords.values[0], cftime.datetime): - for idx in idxs: - time_coords.values[idx] = time_coords.values[idx].replace( - year=time_coords.values[idx].year + 1 - ) + time_coords.values[indexes] = [ + time.replace(year=time.year + 1) + for time in time_coords.values[indexes] + ] else: - for idx in idxs: - time_coords.values[idx] = pd.Timestamp( - time_coords.values[idx] - ) + pd.DateOffset(years=1) + time_coords.values[indexes] = [ + pd.Timestamp(time) + pd.DateOffset(years=1) + for time in time_coords.values[indexes] + ] ds_new = ds_new.assign_coords({self.dim: time_coords}) @@ -1298,9 +1291,16 @@ def _shift_djf_decembers(self, ds: xr.Dataset) -> xr.Dataset: time_coords = ds_new[self.dim].copy() dec_indexes = time_coords.dt.month == 12 - time_coords.values[dec_indexes] = [ - time.replace(year=time.year + 1) for time in time_coords.values[dec_indexes] - ] + if isinstance(time_coords.values[0], cftime.datetime): + time_coords.values[dec_indexes] = [ + time.replace(year=time.year + 1) + for time in time_coords.values[dec_indexes] + ] + else: + time_coords.values[dec_indexes] = [ + pd.Timestamp(time) + pd.DateOffset(years=1) + for time in time_coords.values[dec_indexes] + ] ds_new = ds_new.assign_coords({self.dim: time_coords})