Skip to content

Commit

Permalink
Use seed for all rng in blending to make a test run completely determ…
Browse files Browse the repository at this point in the history
…inistic (#450)

* Use seed for all rng in blending to make a test run completely deterministic

* fix coverage

* Actually add a test that runs the previously uncovered lines

* Add randgen to docstring and add default value
  • Loading branch information
mats-knmi authored Jan 15, 2025
1 parent e332585 commit a7dae54
Show file tree
Hide file tree
Showing 5 changed files with 82 additions and 53 deletions.
41 changes: 28 additions & 13 deletions pysteps/blending/steps.py
Original file line number Diff line number Diff line change
Expand Up @@ -734,16 +734,19 @@ def forecast(
)

# 6. Initialize all the random generators and prepare for the forecast loop
randgen_prec, vps, generate_vel_noise = _init_random_generators(
velocity,
noise_method,
vel_pert_method,
vp_par,
vp_perp,
seed,
n_ens_members,
kmperpixel,
timestep,
randgen_prec, vps, generate_vel_noise, randgen_probmatching = (
_init_random_generators(
velocity,
noise_method,
probmatching_method,
vel_pert_method,
vp_par,
vp_perp,
seed,
n_ens_members,
kmperpixel,
timestep,
)
)
D, D_Yn, D_pb, R_f, R_m, mask_rim, struct, fft_objs = _prepare_forecast_loop(
precip_cascade,
Expand Down Expand Up @@ -1621,6 +1624,7 @@ def worker(j):
first_array=arr1,
second_array=arr2,
probability_first_array=weights_pm_normalized[0],
randgen=randgen_probmatching[j],
)
else:
R_pm_resampled = R_pm_blended.copy()
Expand Down Expand Up @@ -2290,6 +2294,7 @@ def _find_nwp_combination(
def _init_random_generators(
velocity,
noise_method,
probmatching_method,
vel_pert_method,
vp_par,
vp_perp,
Expand All @@ -2299,18 +2304,28 @@ def _init_random_generators(
timestep,
):
"""Initialize all the random generators."""
randgen_prec = None
if noise_method is not None:
randgen_prec = []
randgen_motion = []
for j in range(n_ens_members):
rs = np.random.RandomState(seed)
randgen_prec.append(rs)
seed = rs.randint(0, high=1e9)

randgen_probmatching = None
if probmatching_method is not None:
randgen_probmatching = []
for j in range(n_ens_members):
rs = np.random.RandomState(seed)
randgen_motion.append(rs)
randgen_probmatching.append(rs)
seed = rs.randint(0, high=1e9)

if vel_pert_method is not None:
randgen_motion = []
for j in range(n_ens_members):
rs = np.random.RandomState(seed)
randgen_motion.append(rs)
seed = rs.randint(0, high=1e9)
init_vel_noise, generate_vel_noise = noise.get_method(vel_pert_method)

# initialize the perturbation generators for the motion field
Expand All @@ -2326,7 +2341,7 @@ def _init_random_generators(
else:
vps, generate_vel_noise = None, None

return randgen_prec, vps, generate_vel_noise
return randgen_prec, vps, generate_vel_noise, randgen_probmatching


def _prepare_forecast_loop(
Expand Down
5 changes: 3 additions & 2 deletions pysteps/noise/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,8 +101,9 @@ def compute_noise_stddev_adjs(
randstates = []

for k in range(num_iter):
randstates.append(np.random.RandomState(seed=seed))
seed = np.random.randint(0, high=1e9)
rs = np.random.RandomState(seed=seed)
randstates.append(rs)
seed = rs.randint(0, high=1e9)

def worker(k):
# generate Gaussian white noise field, filter it using the chosen
Expand Down
9 changes: 7 additions & 2 deletions pysteps/postprocessing/probmatching.py
Original file line number Diff line number Diff line change
Expand Up @@ -274,7 +274,9 @@ def _get_error(scale):
return shift, scale, R.reshape(shape)


def resample_distributions(first_array, second_array, probability_first_array):
def resample_distributions(
first_array, second_array, probability_first_array, randgen=np.random
):
"""
Merges two distributions (e.g., from the extrapolation nowcast and NWP in the blending module)
to effectively combine two distributions for probability matching without losing extremes.
Expand All @@ -291,6 +293,9 @@ def resample_distributions(first_array, second_array, probability_first_array):
probability_first_array: float
The weight that `first_array` should get (a value between 0 and 1). This determines the
likelihood of selecting elements from `first_array` over `second_array`.
randgen: numpy.random or numpy.RandomState
The random number generator to be used for the binomial distribution. You can pass a seeded
random state here for reproducibility. Default is numpy.random.
Returns
-------
Expand Down Expand Up @@ -324,7 +329,7 @@ def resample_distributions(first_array, second_array, probability_first_array):
n = asort.shape[0]

# Resample the distributions
idxsamples = np.random.binomial(1, probability_first_array, n).astype(bool)
idxsamples = randgen.binomial(1, probability_first_array, n).astype(bool)
csort = np.where(idxsamples, asort, bsort)
csort = np.sort(csort)[::-1]

Expand Down
73 changes: 39 additions & 34 deletions pysteps/tests/test_blending_steps.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,45 +8,48 @@
import pysteps
from pysteps import blending, cascade

# fmt:off
steps_arg_values = [
(1, 3, 4, 8, None, None, False, "spn", True, 4, False, False, 0, False),
(1, 3, 4, 8, "obs", None, False, "spn", True, 4, False, False, 0, False),
(1, 3, 4, 8, "incremental", None, False, "spn", True, 4, False, False, 0, False),
(1, 3, 4, 8, None, "mean", False, "spn", True, 4, False, False, 0, False),
(1, 3, 4, 8, None, "mean", False, "spn", True, 4, False, False, 0, True),
(1, 3, 4, 8, None, "cdf", False, "spn", True, 4, False, False, 0, False),
(1, [1, 2, 3], 4, 8, None, "cdf", False, "spn", True, 4, False, False, 0, False),
(1, 3, 4, 8, "incremental", "cdf", False, "spn", True, 4, False, False, 0, False),
(1, 3, 4, 6, "incremental", "cdf", False, "bps", True, 4, False, False, 0, False),
(1, 3, 4, 6, "incremental", "cdf", False, "bps", False, 4, False, False, 0, False),
(1, 3, 4, 6, "incremental", "cdf", False, "bps", False, 4, False, False, 0, True),
(1, 3, 4, 9, "incremental", "cdf", False, "spn", True, 4, False, False, 0, False),
(2, 3, 10, 8, "incremental", "cdf", False, "spn", True, 10, False, False, 0, False),
(5, 3, 5, 8, "incremental", "cdf", False, "spn", True, 5, False, False, 0, False),
(1, 10, 1, 8, "incremental", "cdf", False, "spn", True, 1, False, False, 0, False),
(2, 3, 2, 8, "incremental", "cdf", True, "spn", True, 2, False, False, 0, False),
(1, 3, 6, 8, None, None, False, "spn", True, 6, False, False, 0, False),
(1, 3, 4, 8, None, None, False, "spn", True, 4, False, False, 0, False, None),
(1, 3, 4, 8, "obs", None, False, "spn", True, 4, False, False, 0, False, None),
(1, 3, 4, 8, "incremental", None, False, "spn", True, 4, False, False, 0, False, None),
(1, 3, 4, 8, None, "mean", False, "spn", True, 4, False, False, 0, False, None),
(1, 3, 4, 8, None, "mean", False, "spn", True, 4, False, False, 0, True, None),
(1, 3, 4, 8, None, "cdf", False, "spn", True, 4, False, False, 0, False, None),
(1, [1, 2, 3], 4, 8, None, "cdf", False, "spn", True, 4, False, False, 0, False, None),
(1, 3, 4, 8, "incremental", "cdf", False, "spn", True, 4, False, False, 0, False, None),
(1, 3, 4, 6, "incremental", "cdf", False, "bps", True, 4, False, False, 0, False, None),
(1, 3, 4, 6, "incremental", "cdf", False, "bps", False, 4, False, False, 0, False, None),
(1, 3, 4, 6, "incremental", "cdf", False, "bps", False, 4, False, False, 0, True, None),
(1, 3, 4, 9, "incremental", "cdf", False, "spn", True, 4, False, False, 0, False, None),
(2, 3, 10, 8, "incremental", "cdf", False, "spn", True, 10, False, False, 0, False, None),
(5, 3, 5, 8, "incremental", "cdf", False, "spn", True, 5, False, False, 0, False, None),
(1, 10, 1, 8, "incremental", "cdf", False, "spn", True, 1, False, False, 0, False, None),
(2, 3, 2, 8, "incremental", "cdf", True, "spn", True, 2, False, False, 0, False, None),
(1, 3, 6, 8, None, None, False, "spn", True, 6, False, False, 0, False, None),
(1, 3, 6, 8, None, None, False, "spn", True, 6, False, False, 0, False, "bps"),
# Test the case where the radar image contains no rain.
(1, 3, 6, 8, None, None, False, "spn", True, 6, True, False, 0, False),
(5, 3, 5, 6, "incremental", "cdf", False, "spn", False, 5, True, False, 0, False),
(5, 3, 5, 6, "incremental", "cdf", False, "spn", False, 5, True, False, 0, True),
(1, 3, 6, 8, None, None, False, "spn", True, 6, True, False, 0, False, None),
(5, 3, 5, 6, "incremental", "cdf", False, "spn", False, 5, True, False, 0, False, None),
(5, 3, 5, 6, "incremental", "cdf", False, "spn", False, 5, True, False, 0, True, None),
# Test the case where the NWP fields contain no rain.
(1, 3, 6, 8, None, None, False, "spn", True, 6, False, True, 0, False),
(5, 3, 5, 6, "incremental", "cdf", False, "spn", False, 5, False, True, 0, True),
(1, 3, 6, 8, None, None, False, "spn", True, 6, False, True, 0, False, None),
(5, 3, 5, 6, "incremental", "cdf", False, "spn", False, 5, False, True, 0, True, None),
# Test the case where both the radar image and the NWP fields contain no rain.
(1, 3, 6, 8, None, None, False, "spn", True, 6, True, True, 0, False),
(5, 3, 5, 6, "incremental", "cdf", False, "spn", False, 5, True, True, 0, False),
(5, 3, 5, 6, "obs", "mean", True, "spn", True, 5, True, True, 0, False),
(1, 3, 6, 8, None, None, False, "spn", True, 6, True, True, 0, False, None),
(5, 3, 5, 6, "incremental", "cdf", False, "spn", False, 5, True, True, 0, False, None),
(5, 3, 5, 6, "obs", "mean", True, "spn", True, 5, True, True, 0, False, None),
# Test for smooth radar mask
(1, 3, 6, 8, None, None, False, "spn", True, 6, False, False, 80, False),
(5, 3, 5, 6, "incremental", "cdf", False, "spn", False, 5, False, False, 80, False),
(5, 3, 5, 6, "obs", "mean", False, "spn", False, 5, False, False, 80, False),
(1, 3, 6, 8, None, None, False, "spn", True, 6, False, True, 80, False),
(5, 3, 5, 6, "incremental", "cdf", False, "spn", False, 5, True, False, 80, True),
(5, 3, 5, 6, "obs", "mean", False, "spn", False, 5, True, True, 80, False),
(5, [1, 2, 3], 5, 6, "obs", "mean", False, "spn", False, 5, True, True, 80, False),
(5, [1, 3], 5, 6, "obs", "mean", False, "spn", False, 5, True, True, 80, False),
(1, 3, 6, 8, None, None, False, "spn", True, 6, False, False, 80, False, None),
(5, 3, 5, 6, "incremental", "cdf", False, "spn", False, 5, False, False, 80, False, None),
(5, 3, 5, 6, "obs", "mean", False, "spn", False, 5, False, False, 80, False, None),
(1, 3, 6, 8, None, None, False, "spn", True, 6, False, True, 80, False, None),
(5, 3, 5, 6, "incremental", "cdf", False, "spn", False, 5, True, False, 80, True, None),
(5, 3, 5, 6, "obs", "mean", False, "spn", False, 5, True, True, 80, False, None),
(5, [1, 2, 3], 5, 6, "obs", "mean", False, "spn", False, 5, True, True, 80, False, None),
(5, [1, 3], 5, 6, "obs", "mean", False, "spn", False, 5, True, True, 80, False, None),
]
# fmt:on

steps_arg_names = (
"n_models",
Expand All @@ -63,6 +66,7 @@
"zero_nwp",
"smooth_radar_mask_range",
"resample_distribution",
"vel_pert_method",
)


Expand All @@ -82,6 +86,7 @@ def test_steps_blending(
zero_nwp,
smooth_radar_mask_range,
resample_distribution,
vel_pert_method,
):
pytest.importorskip("cv2")

Expand Down Expand Up @@ -275,7 +280,7 @@ def test_steps_blending(
noise_method="nonparametric",
noise_stddev_adj="auto",
ar_order=2,
vel_pert_method=None,
vel_pert_method=vel_pert_method,
weights_method=weights_method,
conditional=False,
probmatching_method=probmatching_method,
Expand Down
7 changes: 5 additions & 2 deletions pysteps/tests/test_postprocessing_probmatching.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
import numpy as np
import pytest
from pysteps.postprocessing.probmatching import resample_distributions
from pysteps.postprocessing.probmatching import nonparam_match_empirical_cdf

from pysteps.postprocessing.probmatching import (
nonparam_match_empirical_cdf,
resample_distributions,
)


class TestResampleDistributions:
Expand Down

0 comments on commit a7dae54

Please sign in to comment.