Skip to content

Commit

Permalink
Shift lons and lats in pre-generated weights (#185)
Browse files Browse the repository at this point in the history
  • Loading branch information
andersy005 authored May 12, 2022
1 parent 84b81ec commit 6deea20
Show file tree
Hide file tree
Showing 2 changed files with 38 additions and 18 deletions.
28 changes: 16 additions & 12 deletions cmip6_downscaling/data/cmip.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,10 @@
import xarray as xr

from . import cat
from .utils import lon_to_180, to_standard_calendar
from .utils import lon_to_180, to_standard_calendar as convert_to_standard_calendar


def postprocess(ds: xr.Dataset) -> xr.Dataset:
def postprocess(ds: xr.Dataset, to_standard_calendar: bool = True) -> xr.Dataset:
"""Post process input experiment
- Drops band variables (if present)
Expand All @@ -22,6 +22,8 @@ def postprocess(ds: xr.Dataset) -> xr.Dataset:
----------
ds : xr.Dataset
Input dataset
to_standard_calendar : bool, optional
Whether to convert time to standard calendar, by default True
Returns
-------
Expand Down Expand Up @@ -52,18 +54,20 @@ def postprocess(ds: xr.Dataset) -> xr.Dataset:
if ds.lat[0] > ds.lat[-1]:
ds = ds.reindex({"lat": ds.lat[::-1]})

# checks calendar
ds = to_standard_calendar(ds)
if to_standard_calendar:

# Shifts time from Noon (12:00) start to Midnight (00:00) start to match with Obs
# ds.coords['time'] = ds['time'].resample(time='1D').first()
# checks calendar
ds = convert_to_standard_calendar(ds)

ds['time'] = pd.date_range(
start=ds['time'].data[0],
end=ds['time'].data[-1],
normalize=True,
freq="1D",
)
# Shifts time from Noon (12:00) start to Midnight (00:00) start to match with Obs
# ds.coords['time'] = ds['time'].resample(time='1D').first()

ds['time'] = pd.date_range(
start=ds['time'].data[0],
end=ds['time'].data[-1],
normalize=True,
freq="1D",
)
return ds


Expand Down
28 changes: 22 additions & 6 deletions flows/gcm_pyramid_weights.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,30 @@
from __future__ import annotations

import traceback
from functools import partial

import dask
from prefect import Flow, task, unmapped
from prefect import Flow, Parameter, task, unmapped
from prefect.engine.signals import SKIP
from prefect.tasks.control_flow import merge
from prefect.tasks.control_flow.filter import FilterTask
from upath import UPath

from cmip6_downscaling import config
from cmip6_downscaling.runtimes import PangeoRuntime
from cmip6_downscaling.data.cmip import postprocess
from cmip6_downscaling.runtimes import CloudRuntime

config.set(
{
'runtime.cloud.extra_pip_packages': 'git+https://github.com/carbonplan/cmip6-downscaling.git@main git+https://github.com/intake/intake-esm.git'
}
)

folder = 'xesmf_weights/cmip6_pyramids'

scratch_dir = UPath(config.get('storage.static.uri')) / folder

runtime = PangeoRuntime()
runtime = CloudRuntime()

filter_results = FilterTask(
filter_func=lambda x: not isinstance(x, (BaseException, SKIP, type(None)))
Expand Down Expand Up @@ -56,7 +66,11 @@ def generate_weights(store: dict, levels: int, method: str = 'bilinear') -> dict

try:
with dask.config.set({'scheduler': 'sync'}):
ds_in = xr.open_dataset(store['zstore'], engine='zarr', chunks={}).isel(time=0)
ds_in = (
xr.open_dataset(store['zstore'], engine='zarr', chunks={})
.pipe(partial(postprocess, to_standard_calendar=False))
.isel(time=0)
)
weights_pyramid = generate_weights_pyramid(ds_in, levels, method=method)
print(weights_pyramid)
weights_pyramid.to_zarr(target, mode='w')
Expand All @@ -70,7 +84,7 @@ def generate_weights(store: dict, levels: int, method: str = 'bilinear') -> dict
}

except Exception as e:
raise SKIP(f"Failed to load {store['zstore']}") from e
raise SKIP(f"Failed to process {store['zstore']}\nError: {traceback.format_exc()}") from e


@task(log_stdout=True)
Expand All @@ -89,9 +103,11 @@ def catalog(vals):
run_config=runtime.run_config,
executor=runtime.executor,
) as flow:
levels = Parameter('levels', default=2)
method = Parameter('method', default='bilinear')
stores = get_stores()
attrs = filter_results(
generate_weights.map(stores, levels=unmapped(4), method=unmapped('bilinear'))
generate_weights.map(stores, levels=unmapped(levels), method=unmapped(method))
)
vals = merge(attrs)
_ = catalog(vals)

1 comment on commit 6deea20

@vercel
Copy link

@vercel vercel bot commented on 6deea20 May 12, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please sign in to comment.