Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Shift lons and lats in pre-generated weights #185

Merged
merged 2 commits into from
May 12, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 16 additions & 12 deletions cmip6_downscaling/data/cmip.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,10 @@
import xarray as xr

from . import cat
from .utils import lon_to_180, to_standard_calendar
from .utils import lon_to_180, to_standard_calendar as convert_to_standard_calendar


def postprocess(ds: xr.Dataset) -> xr.Dataset:
def postprocess(ds: xr.Dataset, to_standard_calendar: bool = True) -> xr.Dataset:
"""Post process input experiment

- Drops band variables (if present)
Expand All @@ -22,6 +22,8 @@ def postprocess(ds: xr.Dataset) -> xr.Dataset:
----------
ds : xr.Dataset
Input dataset
to_standard_calendar : bool, optional
Whether to convert time to standard calendar, by default True

Returns
-------
Expand Down Expand Up @@ -52,18 +54,20 @@ def postprocess(ds: xr.Dataset) -> xr.Dataset:
if ds.lat[0] > ds.lat[-1]:
ds = ds.reindex({"lat": ds.lat[::-1]})

# checks calendar
ds = to_standard_calendar(ds)
if to_standard_calendar:

# Shifts time from Noon (12:00) start to Midnight (00:00) start to match with Obs
# ds.coords['time'] = ds['time'].resample(time='1D').first()
# checks calendar
ds = convert_to_standard_calendar(ds)

ds['time'] = pd.date_range(
start=ds['time'].data[0],
end=ds['time'].data[-1],
normalize=True,
freq="1D",
)
# Shifts time from Noon (12:00) start to Midnight (00:00) start to match with Obs
# ds.coords['time'] = ds['time'].resample(time='1D').first()

ds['time'] = pd.date_range(
start=ds['time'].data[0],
end=ds['time'].data[-1],
normalize=True,
freq="1D",
)
return ds


Expand Down
28 changes: 22 additions & 6 deletions flows/gcm_pyramid_weights.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,30 @@
from __future__ import annotations

import traceback
from functools import partial

import dask
from prefect import Flow, task, unmapped
from prefect import Flow, Parameter, task, unmapped
from prefect.engine.signals import SKIP
from prefect.tasks.control_flow import merge
from prefect.tasks.control_flow.filter import FilterTask
from upath import UPath

from cmip6_downscaling import config
from cmip6_downscaling.runtimes import PangeoRuntime
from cmip6_downscaling.data.cmip import postprocess
from cmip6_downscaling.runtimes import CloudRuntime

config.set(
{
'runtime.cloud.extra_pip_packages': 'git+https://github.com/carbonplan/cmip6-downscaling.git@main git+https://github.com/intake/intake-esm.git'
}
)

folder = 'xesmf_weights/cmip6_pyramids'

scratch_dir = UPath(config.get('storage.static.uri')) / folder

runtime = PangeoRuntime()
runtime = CloudRuntime()

filter_results = FilterTask(
filter_func=lambda x: not isinstance(x, (BaseException, SKIP, type(None)))
Expand Down Expand Up @@ -56,7 +66,11 @@ def generate_weights(store: dict, levels: int, method: str = 'bilinear') -> dict

try:
with dask.config.set({'scheduler': 'sync'}):
ds_in = xr.open_dataset(store['zstore'], engine='zarr', chunks={}).isel(time=0)
ds_in = (
xr.open_dataset(store['zstore'], engine='zarr', chunks={})
.pipe(partial(postprocess, to_standard_calendar=False))
.isel(time=0)
)
weights_pyramid = generate_weights_pyramid(ds_in, levels, method=method)
print(weights_pyramid)
weights_pyramid.to_zarr(target, mode='w')
Expand All @@ -70,7 +84,7 @@ def generate_weights(store: dict, levels: int, method: str = 'bilinear') -> dict
}

except Exception as e:
raise SKIP(f"Failed to load {store['zstore']}") from e
raise SKIP(f"Failed to process {store['zstore']}\nError: {traceback.format_exc()}") from e


@task(log_stdout=True)
Expand All @@ -89,9 +103,11 @@ def catalog(vals):
run_config=runtime.run_config,
executor=runtime.executor,
) as flow:
levels = Parameter('levels', default=2)
method = Parameter('method', default='bilinear')
stores = get_stores()
attrs = filter_results(
generate_weights.map(stores, levels=unmapped(4), method=unmapped('bilinear'))
generate_weights.map(stores, levels=unmapped(levels), method=unmapped(method))
)
vals = merge(attrs)
_ = catalog(vals)