Skip to content

Commit

Permalink
Merge pull request #50 from WFP-VAM/patch
Browse files Browse the repository at this point in the history
fix & improve grouped spi
  • Loading branch information
valpesendorfer authored May 10, 2024
2 parents a80f263 + c75efc8 commit d91e6fd
Show file tree
Hide file tree
Showing 7 changed files with 195 additions and 71 deletions.
2 changes: 1 addition & 1 deletion hdc/algo/_version.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
"""Version only in this file."""

__version__ = "0.4.0"
__version__ = "0.5.0"
94 changes: 53 additions & 41 deletions hdc/algo/accessors.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,11 @@
import dask.array as da
from dask.base import tokenize
import numpy as np
import pandas as pd
import xarray

from . import ops
from .dekad import Dekad
from .utils import get_calibration_indices, to_linspace

__all__ = [
"Anomalies",
Expand Down Expand Up @@ -463,20 +463,19 @@ class PixelAlgorithms(AccessorBase):

def spi(
self,
calibration_start: Optional[str] = None,
calibration_stop: Optional[str] = None,
calibration_begin: Optional[str] = None,
calibration_end: Optional[str] = None,
nodata: Optional[Union[float, int]] = None,
groups: Optional[Iterable[int]] = None,
groups: Optional[Iterable[Union[int, float, str]]] = None,
dtype="int16",
):
"""Calculate the SPI along the time dimension.
Calculates the Standardized Precipitation Index along the time dimension.
Optionally, a calibration start and / or stop date can be provided which
Optionally, a calibration begin and / or end date can be provided which
determine the part of the timeseries used to fit the gamma distribution.
`groups` can be supplied as list of group labels (attention, they are required
to be in format {0..n-1} where n is the number of unique groups.
`groups` can be supplied as list of group labels.
If `groups` is supplied, the SPI will be computed for each individual group.
This is intended to be used when SPI should be calculated for specific timesteps.
"""
Expand All @@ -497,33 +496,31 @@ def spi(

tix = self._obj.get_index("time")

calstart_ix = 0
if calibration_start is not None:
calstart = pd.Timestamp(calibration_start)
if calstart > tix[-1]:
raise ValueError(
"Calibration start cannot be greater than last timestamp!"
)
(calstart_ix,) = tix.get_indexer([calstart], method="bfill")
if calibration_begin is None:
calibration_begin = tix[0]

calstop_ix = tix.size
if calibration_stop is not None:
calstop = pd.Timestamp(calibration_stop)
if calstop < tix[0]:
raise ValueError(
"Calibration stop cannot be smaller than first timestamp!"
)
(calstop_ix,) = tix.get_indexer([calstop], method="ffill") + 1
if calibration_end is None:
calibration_end = tix[-1]

if calstart_ix >= calstop_ix:
raise ValueError("calibration_start < calibration_stop!")
if calibration_begin > tix[-1:]:
raise ValueError("Calibration begin cannot be greater than last timestamp!")

if abs(calstop_ix - calstart_ix) <= 1:
raise ValueError(
"Timeseries too short for calculating SPI. Please adjust calibration period!"
)
if calibration_end < tix[:1]:
raise ValueError("Calibration end cannot be smaller than first timestamp!")

if groups is None:
calstart_ix, calstop_ix = get_calibration_indices(
tix, (calibration_begin, calibration_end)
)

if calstart_ix >= calstop_ix:
raise ValueError("calibration_begin < calibration_end!")

if abs(calstop_ix - calstart_ix) <= 1:
raise ValueError(
"Timeseries too short for calculating SPI. Please adjust calibration period!"
)

res = xarray.apply_ufunc(
gammastd_yxt,
self._obj,
Expand All @@ -540,22 +537,37 @@ def spi(
)

else:
groups = np.array(groups) if not isinstance(groups, np.ndarray) else groups
num_groups = np.unique(groups).size

if not groups.dtype.name == "int16":
warn("Casting groups to int16!")
groups = groups.astype("int16")
groups, keys = to_linspace(np.array(groups, dtype="str"))

if len(groups) != len(self._obj.time):
raise ValueError("Need array of groups same length as time dimension!")

groups = groups.astype("int16")
num_groups = len(keys)

cal_indices = get_calibration_indices(
tix, (calibration_begin, calibration_end), groups, num_groups
)
# assert for mypy
assert isinstance(cal_indices, np.ndarray)

if np.any(cal_indices[:, 0] >= cal_indices[:, 1]):
raise ValueError("calibration_begin < calibration_end!")

if np.any(np.diff(cal_indices, axis=1) <= 1):
raise ValueError(
"Timeseries too short for calculating SPI. Please adjust calibration period!"
)

res = xarray.apply_ufunc(
gammastd_grp,
self._obj,
groups,
num_groups,
nodata,
calstart_ix,
calstop_ix,
input_core_dims=[["time"], ["grps"], [], [], [], []],
cal_indices,
input_core_dims=[["time"], ["grps"], [], [], ["start", "stop"]],
output_core_dims=[["time"]],
keep_attrs=True,
dask="parallelized",
Expand All @@ -564,8 +576,8 @@ def spi(

res.attrs.update(
{
"spi_calibration_start": str(tix[calstart_ix].date()),
"spi_calibration_stop": str(tix[calstop_ix - 1].date()),
"spi_calibration_begin": str(tix[tix >= calibration_begin][0]),
"spi_calibration_end": str(tix[tix <= calibration_end][-1]),
}
)

Expand Down Expand Up @@ -806,7 +818,7 @@ def mean(

# set null values to nodata value
xx = xx.where(xx.notnull(), xx.nodata)

attrs = xx.attrs
num_zones = len(zone_ids)
dims = (xx.dims[0], dim_name, "stat")
coords = {
Expand Down Expand Up @@ -849,7 +861,7 @@ def mean(
)

return xarray.DataArray(
data=data, dims=dims, coords=coords, attrs={}, name=name
data=data, dims=dims, coords=coords, attrs=attrs, name=name
)


Expand Down
15 changes: 11 additions & 4 deletions hdc/algo/ops/stats.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,21 +241,28 @@ def gammastd_yxt(
@lazycompile(
guvectorize(
[
"(int16[:], int16[:], float64, float64, float64, float64, int16[:])",
"(float32[:], int16[:], float64, float64, float64, float64, int16[:])",
"(int16[:], int16[:], float64, float64, int16[:, :], int16[:])",
"(float32[:], int16[:], float64, float64, int16[:, :], int16[:])",
],
"(n),(m),(),(),(),() -> (n)",
"(n),(m),(),(),(o, p) -> (n)",
)
)
def gammastd_grp(xx, groups, num_groups, nodata, cal_start, cal_stop, yy):
def gammastd_grp(xx, groups, num_groups, nodata, cal_indices, yy):
"""Calculate the gammastd for specific groups.
This calculates gammastd across xx for indivual groups
defined in `groups`. These need to be in ascending order from
0 to num_groups - 1.
`cal_indices` is an array of shape (num_groups, 2) where each row
contains the start and end index for the calibration period for each group.
"""
for grp in range(num_groups):
grp_ix = groups == grp

cal_start = cal_indices[grp, 0]
cal_stop = cal_indices[grp, 1]

pix = xx[grp_ix]
if (pix != nodata).sum() == 0:
yy[grp_ix] = nodata
Expand Down
52 changes: 50 additions & 2 deletions hdc/algo/utils.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,15 @@
"""hcd-algo utility functions."""

from typing import List, Tuple
from typing import Iterable, List, Optional, Tuple, Union

import numpy as np
from numpy.typing import NDArray
import pandas as pd

DateType = Union[str, pd.Timestamp, np.datetime64]

def to_linspace(x) -> Tuple[np.ndarray, List[int]]:

def to_linspace(x) -> Tuple[NDArray[(np.int16,)], List[int]]:
"""Map input array to linear space.
Returns array with linear index (0 - n-1) and list of
Expand All @@ -21,3 +25,47 @@ def to_linspace(x) -> Tuple[np.ndarray, List[int]]:
new_pix = np.where(mask, values[idx], 0)

return new_pix, list(keys)


def get_calibration_indices(
time: pd.DatetimeIndex,
calibration_range: Tuple[DateType, DateType],
groups: Optional[Iterable[Union[int, float, str]]] = None,
num_groups: Optional[int] = None,
) -> Union[Tuple[int, int], np.ndarray]:
"""
Get the calibration indices for a given time range.
This function returns indices for a calibration period (e.g. used for SPI)
given an index of timestamps and a start & stop date.
If groups are provided, the indices are returned per group, as an
array of shape (num_groups, 2) where the first column is the start index and
the second column is the stop index.
Parameters:
time: The time index.
start: The start time of the calibration range.
stop: The stop time of the calibration range.
groups: Optional groups to consider for calibration.
num_groups: Optional number of groups to consider for calibration.
"""
begin, end = calibration_range

def _get_ix(x: NDArray[(np.datetime64,)], v: DateType, side: str):
return x.searchsorted(np.datetime64(v), side) # type: ignore

if groups is not None:
if num_groups is None:
num_groups = len(np.unique(np.array(groups)))
return np.array(
[
[
_get_ix(time[groups == ix].values, begin, "left"),
_get_ix(time[groups == ix].values, end, "right"),
]
for ix in range(num_groups)
],
dtype="int16",
)

return _get_ix(time.values, begin, "left"), _get_ix(time.values, end, "right")
Loading

0 comments on commit d91e6fd

Please sign in to comment.