diff --git a/cmip6_downscaling/data/cmip.py b/cmip6_downscaling/data/cmip.py index 56be055b..5e9aa190 100644 --- a/cmip6_downscaling/data/cmip.py +++ b/cmip6_downscaling/data/cmip.py @@ -5,10 +5,10 @@ import xarray as xr from . import cat -from .utils import lon_to_180, to_standard_calendar +from .utils import lon_to_180, to_standard_calendar as convert_to_standard_calendar -def postprocess(ds: xr.Dataset) -> xr.Dataset: +def postprocess(ds: xr.Dataset, to_standard_calendar: bool = True) -> xr.Dataset: """Post process input experiment - Drops band variables (if present) @@ -22,6 +22,8 @@ def postprocess(ds: xr.Dataset) -> xr.Dataset: ---------- ds : xr.Dataset Input dataset + to_standard_calendar : bool, optional + Whether to convert time to standard calendar, by default True Returns ------- @@ -52,18 +54,20 @@ def postprocess(ds: xr.Dataset) -> xr.Dataset: if ds.lat[0] > ds.lat[-1]: ds = ds.reindex({"lat": ds.lat[::-1]}) - # checks calendar - ds = to_standard_calendar(ds) + if to_standard_calendar: - # Shifts time from Noon (12:00) start to Midnight (00:00) start to match with Obs - # ds.coords['time'] = ds['time'].resample(time='1D').first() + # checks calendar + ds = convert_to_standard_calendar(ds) - ds['time'] = pd.date_range( - start=ds['time'].data[0], - end=ds['time'].data[-1], - normalize=True, - freq="1D", - ) + # Shifts time from Noon (12:00) start to Midnight (00:00) start to match with Obs + # ds.coords['time'] = ds['time'].resample(time='1D').first() + + ds['time'] = pd.date_range( + start=ds['time'].data[0], + end=ds['time'].data[-1], + normalize=True, + freq="1D", + ) return ds diff --git a/flows/gcm_pyramid_weights.py b/flows/gcm_pyramid_weights.py index fc32b188..2aa7c5ba 100644 --- a/flows/gcm_pyramid_weights.py +++ b/flows/gcm_pyramid_weights.py @@ -1,20 +1,30 @@ from __future__ import annotations +import traceback +from functools import partial + import dask -from prefect import Flow, task, unmapped +from prefect import Flow, Parameter, task, unmapped from prefect.engine.signals import SKIP from prefect.tasks.control_flow import merge from prefect.tasks.control_flow.filter import FilterTask from upath import UPath from cmip6_downscaling import config -from cmip6_downscaling.runtimes import PangeoRuntime +from cmip6_downscaling.data.cmip import postprocess +from cmip6_downscaling.runtimes import CloudRuntime + +config.set( + { + 'runtime.cloud.extra_pip_packages': 'git+https://github.com/carbonplan/cmip6-downscaling.git@main git+https://github.com/intake/intake-esm.git' + } +) folder = 'xesmf_weights/cmip6_pyramids' scratch_dir = UPath(config.get('storage.static.uri')) / folder -runtime = PangeoRuntime() +runtime = CloudRuntime() filter_results = FilterTask( filter_func=lambda x: not isinstance(x, (BaseException, SKIP, type(None))) @@ -56,7 +66,11 @@ def generate_weights(store: dict, levels: int, method: str = 'bilinear') -> dict try: with dask.config.set({'scheduler': 'sync'}): - ds_in = xr.open_dataset(store['zstore'], engine='zarr', chunks={}).isel(time=0) + ds_in = ( + xr.open_dataset(store['zstore'], engine='zarr', chunks={}) + .pipe(partial(postprocess, to_standard_calendar=False)) + .isel(time=0) + ) weights_pyramid = generate_weights_pyramid(ds_in, levels, method=method) print(weights_pyramid) weights_pyramid.to_zarr(target, mode='w') @@ -70,7 +84,7 @@ def generate_weights(store: dict, levels: int, method: str = 'bilinear') -> dict } except Exception as e: - raise SKIP(f"Failed to load {store['zstore']}") from e + raise SKIP(f"Failed to process {store['zstore']}\nError: {traceback.format_exc()}") from e @task(log_stdout=True) @@ -89,9 +103,11 @@ def catalog(vals): run_config=runtime.run_config, executor=runtime.executor, ) as flow: + levels = Parameter('levels', default=2) + method = Parameter('method', default='bilinear') stores = get_stores() attrs = filter_results( - generate_weights.map(stores, levels=unmapped(4), method=unmapped('bilinear')) + generate_weights.map(stores, levels=unmapped(levels), method=unmapped(method)) ) vals = merge(attrs) _ = catalog(vals)