Skip to content

Commit

Permalink
Merge pull request #43 from meom-group/jax_vmap
Browse files Browse the repository at this point in the history
remove np.* artifacts. change signatures to handle masks without the …
  • Loading branch information
vadmbertr authored Feb 9, 2024
2 parents fa4649e + 75aad62 commit a17887e
Show file tree
Hide file tree
Showing 51 changed files with 1,459 additions and 1,260 deletions.
27 changes: 20 additions & 7 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -27,24 +27,37 @@ By default, **jaxparrow** will install a CPU-only version of JAX if no other ver

Two functions are directly available from `jaxparrow`:

- `geostrophy`: computes the geostrophic velocity field (returns two `numpy 2darray`) from a SSH `2darray`, two `2darray` of spatial steps, and two `2darray` of Coriolis factors.
- `cyclogeostrophy`: computes the cyclogeostrophic velocity field (returns two `2darray`) from two `2darray` of geostrophic velocities, four `2darray` of spatial steps, and two `2darray` of Coriolis factors.
- `geostrophy` computes the geostrophic velocity field (returns two `2darray`) from:
- a SSH field (a `2darray`),
- its latitude and longitude grids (two `2darray`),
- the latitude grids at the U and V points (two `2darray`),
- and the optional mask grids at the T, U and V points (three `2darray`).
- `cyclogeostrophy` computes the cyclogeostrophic velocity field (returns two `2darray`) from:
- a geostrophic velocity fields (two `2darray`),
- its latitude and longitude grids at U and V points (four `2darray`),
- and the optional mask grids at the U and V points (two `2darray`).

*Because **jaxparrow** uses [C-grids](https://xgcm.readthedocs.io/en/latest/grids.html) the velocity fields are represented on two grids, and the SSH on one grid.*

In a Python script, assuming that the input grids have already been initialised / imported, it would simply resort to:
In a Python script, assuming that the input grids have already been initialised / imported, it would resort to:

```python
from jaxparrow import cyclogeostrophy, geostrophy

u_geos, v_geos = geostrophy(ssh=ssh,
dx_ssh=dx_ssh, dy_ssh=dy_ssh,
coriolis_factor_u=coriolis_factor_u, coriolis_factor_v=coriolis_factor_v)
lat=lat, lon=lon,
lat_u=lat_u, lat_v=lat_v,
mask_t=mask_t, mask_u=mask_u, mask_v=mask_v)
u_cyclo, v_cyclo = cyclogeostrophy(u_geos=u_geos, v_geos=v_geos,
dx_u=dx_u, dx_v=dx_v, dy_u=dy_u, dy_v=dy_v,
coriolis_factor_u=coriolis_factor_u, coriolis_factor_v=coriolis_factor_v)
lat_u=lat_u, lon_u=lon_u,
lat_v=lat_v, lon_v=lon_v,
mask_u=mask_u, mask_v=mask_v)
```

To vectorise the application of the `geostrophy` and `cyclogeostrophy` functions across an added time dimension, one aims to utilize `vmap`.
However, this necessitates avoiding the use of `np.ma.masked_array`.
Hence, our functions accommodate mask `array` as parameters to effectively consider masked regions.

By default, the `cyclogeostrophy` function relies on our variational method.
Its `method` argument provides the ability to use an iterative method instead, either the one described by [Penven *et al.*](https://doi.org/10.1016/j.dsr2.2013.10.015), or the one by [Ioannou *et al.*](https://doi.org/10.1029/2019JC015031).
Additional arguments also give a finer control over the three approaches hyperparameters. \
Expand Down
133 changes: 73 additions & 60 deletions jaxparrow/__main__.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,19 @@
import argparse
from datetime import datetime
from typing import Union
import yaml

import numpy as np
import numpy.ma as ma
import xarray as xr

from .version import __version__
from .tools import compute_coriolis_factor, compute_spatial_step
from .cyclogeostrophy import cyclogeostrophy
from .geostrophy import geostrophy


def _read_data(conf_path: str) -> list:
def _read_data(
conf_path: str
) -> list:
with open(conf_path) as f:
conf = yaml.safe_load(f) # parse conf file

Expand Down Expand Up @@ -54,46 +54,39 @@ def _read_data(conf_path: str) -> list:
return values


def _apply_mask(mask_ssh: Union[np.ndarray, None], mask_u: Union[np.ndarray, None], mask_v: Union[np.ndarray, None],
ssh: np.ndarray, lon_ssh: np.ndarray, lat_ssh: np.ndarray,
lon_u: np.ndarray, lat_u: np.ndarray, lon_v: np.ndarray, lat_v: np.ndarray) -> tuple:
def __do_apply(arr: np.ndarray, mask: Union[np.ndarray, None]) -> np.ndarray:
if mask is None:
mask = np.ones_like(arr)
mask = 1 - mask # don't forget to invert the masks (for ma.MaskedArray, True means invalid)
return ma.masked_array(arr, mask)

ssh = __do_apply(ssh, mask_ssh)
lon_ssh = __do_apply(lon_ssh, mask_ssh)
lat_ssh = __do_apply(lat_ssh, mask_ssh)

lon_u = __do_apply(lon_u, mask_u)
lat_u = __do_apply(lat_u, mask_u)

lon_v = __do_apply(lon_v, mask_v)
lat_v = __do_apply(lat_v, mask_v)

return ssh, lon_ssh, lat_ssh, lon_u, lat_u, lon_v, lat_v


def _compute_spatial_step(lon_ssh: ma.MaskedArray, lat_ssh: ma.MaskedArray,
lon_u: ma.MaskedArray, lat_u: ma.MaskedArray,
lon_v: ma.MaskedArray, lat_v: ma.MaskedArray) -> tuple:
dx_ssh, dy_ssh = compute_spatial_step(lat_ssh, lon_ssh)
dx_u, dy_u = compute_spatial_step(lat_u, lon_u)
dx_v, dy_v = compute_spatial_step(lat_v, lon_v)

return dx_ssh, dy_ssh, dx_u, dy_u, dx_v, dy_v


def _compute_coriolis_factor(lat_u: ma.MaskedArray, lat_v: ma.MaskedArray) -> tuple:
coriolis_factor_u = compute_coriolis_factor(lat_u)
coriolis_factor_v = compute_coriolis_factor(lat_v)

return coriolis_factor_u, coriolis_factor_v


def _create_attrs(conf_path: str, out_attrs: dict, run_datetime: str) -> dict:
def _reverse_masks(
mask_ssh: np.ndarray,
mask_u: np.ndarray,
mask_v: np.ndarray
) -> [np.ndarray, np.ndarray, np.ndarray]:
def do_reverse(mask):
if mask is not None:
return 1 - mask
return do_reverse(mask_ssh), do_reverse(mask_u), do_reverse(mask_v)


def _apply_masks(
u_geos: np.ndarray,
v_geos: np.ndarray,
u_cyclo: np.ndarray,
v_cyclo: np.ndarray,
mask_u: np.ndarray,
mask_v: np.ndarray
) -> [np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
def do_apply_mask(arr, mask):
if mask is not None:
return ma.masked_array(arr, mask)
else:
return arr
return (do_apply_mask(u_geos, mask_u), do_apply_mask(v_geos, mask_v),
do_apply_mask(u_cyclo, mask_u), do_apply_mask(v_cyclo, mask_v))


def _create_attrs(
conf_path: str,
out_attrs: dict,
run_datetime: str
) -> dict:
with open(conf_path) as f:
raw_conf = f.read()

Expand All @@ -113,9 +106,19 @@ def _create_attrs(conf_path: str, out_attrs: dict, run_datetime: str) -> dict:
return attrs


def _to_dataset(u_geos: np.ndarray, v_geos: np.ndarray, u_cyclo: np.ndarray, v_cyclo: np.ndarray,
lon_u: ma.MaskedArray, lat_u: ma.MaskedArray, lon_v: ma.MaskedArray, lat_v: ma.MaskedArray,
conf_path: str, out_attrs: dict, run_datetime: str) -> xr.Dataset:
def _to_dataset(
u_geos: ma.MaskedArray,
v_geos: ma.MaskedArray,
u_cyclo: ma.MaskedArray,
v_cyclo: ma.MaskedArray,
lat_u: ma.MaskedArray,
lon_u: ma.MaskedArray,
lat_v: ma.MaskedArray,
lon_v: ma.MaskedArray,
conf_path: str,
out_attrs: dict,
run_datetime: str
) -> xr.Dataset:
ds = xr.Dataset({
"u_geos": (["y", "x"], u_geos),
"v_geos": (["y", "x"], v_geos),
Expand All @@ -130,31 +133,41 @@ def _to_dataset(u_geos: np.ndarray, v_geos: np.ndarray, u_cyclo: np.ndarray, v_c
return ds


def _write_data(u_geos: np.ndarray, v_geos: np.ndarray, u_cyclo: np.ndarray, v_cyclo: np.ndarray,
lon_u: ma.MaskedArray, lat_u: ma.MaskedArray, lon_v: ma.MaskedArray, lat_v: ma.MaskedArray,
conf_path: str, out_attrs: dict, run_datetime: str, out_path: str):
ds = _to_dataset(u_geos, v_geos, u_cyclo, v_cyclo, lon_u, lat_u, lon_v, lat_v, conf_path, out_attrs, run_datetime)
def _write_data(
u_geos: ma.MaskedArray,
v_geos: ma.MaskedArray,
u_cyclo: ma.MaskedArray,
v_cyclo: ma.MaskedArray,
lat_u: ma.MaskedArray,
lon_u: ma.MaskedArray,
lat_v: ma.MaskedArray,
lon_v: ma.MaskedArray,
conf_path: str,
out_attrs: dict,
run_datetime: str,
out_path: str
):
ds = _to_dataset(u_geos, v_geos, u_cyclo, v_cyclo, lat_u, lon_u, lat_v, lon_v, conf_path, out_attrs, run_datetime)
ds.to_netcdf(out_path)


def _main(conf_path: str):
def _main(
conf_path: str
):
run_datetime = datetime.now().strftime("%Y-%m-%d %H:%M:%S")

mask_ssh, mask_u, mask_v, ssh, lon_ssh, lat_ssh, lon_u, lat_u, lon_v, lat_v, cyclo_kwargs, out_attrs, out_path = (
_read_data(conf_path))

ssh, lon_ssh, lat_ssh, lon_u, lat_u, lon_v, lat_v = _apply_mask(mask_ssh, mask_u, mask_v, ssh, lon_ssh, lat_ssh,
lon_u, lat_u, lon_v, lat_v)
mask_ssh, mask_u, mask_v = _reverse_masks(mask_ssh, mask_u, mask_v)

dx_ssh, dy_ssh, dx_u, dy_u, dx_v, dy_v = _compute_spatial_step(lon_ssh, lat_ssh, lon_u, lat_u, lon_v, lat_v)

coriolis_factor_u, coriolis_factor_v = _compute_coriolis_factor(lat_u, lat_v)

u_geos, v_geos = geostrophy(ssh, dx_ssh, dy_ssh, coriolis_factor_u, coriolis_factor_v)
u_cyclo, v_cyclo = cyclogeostrophy(u_geos, v_geos, dx_u, dx_v, dy_u, dy_v, coriolis_factor_u, coriolis_factor_v,
u_geos, v_geos = geostrophy(ssh, lat_ssh, lon_ssh, lat_u, lat_v, mask_ssh, mask_u, mask_v)
u_cyclo, v_cyclo = cyclogeostrophy(u_geos, v_geos, lat_u, lon_u, lat_v, lon_v, mask_u, mask_v,
**cyclo_kwargs)

_write_data(u_geos, v_geos, u_cyclo, v_cyclo, lon_u, lat_u, lon_v, lat_v, conf_path, out_attrs, run_datetime,
u_geos, v_geos, u_cyclo, v_cyclo = _apply_masks(u_geos, v_geos, u_cyclo, v_cyclo, mask_u, mask_v)

_write_data(u_geos, v_geos, u_cyclo, v_cyclo, lat_u, lon_u, lat_v, lon_v, conf_path, out_attrs, run_datetime,
out_path)


Expand Down
Loading

0 comments on commit a17887e

Please sign in to comment.