Skip to content

Commit

Permalink
Merge pull request #1 from abulenok/fix/merge-shuffle-draw-random-int
Browse files Browse the repository at this point in the history
  • Loading branch information
zengraf authored Feb 23, 2023
2 parents 5c22557 + 833e068 commit 91b6c7a
Show file tree
Hide file tree
Showing 3 changed files with 65 additions and 10 deletions.
7 changes: 6 additions & 1 deletion PySDM/backends/impl_numba/methods/index_methods.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,15 @@
from PySDM.backends.impl_numba import conf


@numba.njit(**{**conf.JIT_FLAGS, **{"parallel": False}})
def draw_random_int(start: int, end: int, u01: float):
return min(int(start + u01 * (end - start + 1)), end)


@numba.njit(**{**conf.JIT_FLAGS, **{"parallel": False}})
def fisher_yates_shuffle(idx, u01, start, end, random_offset=0):
for i in range(end - 1, start, -1):
j = int(start + u01[random_offset + i] * (i - start) + 0.5)
j = draw_random_int(start=start, end=i, u01=u01[random_offset + i])
idx[i], idx[j] = idx[j], idx[i]


Expand Down
30 changes: 30 additions & 0 deletions tests/unit_tests/backends/test_index_methods.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
# pylint: disable=missing-module-docstring,missing-class-docstring,missing-function-docstring
import pytest

from PySDM.backends.impl_numba.methods.index_methods import draw_random_int


@pytest.mark.parametrize(
"a, b, u01, expected",
(
(0, 100, 0.0, 0),
(0, 100, 1.0, 100),
(0, 1, 0.5, 1),
(0, 1, 0.49, 0),
(0, 3, 0.49, 1),
(0, 3, 0.245, 0),
(0, 2, 0.332, 0),
(0, 2, 0.333, 0),
(0, 2, 0.334, 1),
(0, 2, 0.665, 1),
(0, 2, 0.666, 1),
(0, 2, 0.667, 2),
(0, 2, 0.999, 2),
),
)
def test_draw_random_int(a, b, u01, expected):
# act
actual = draw_random_int(a, b, u01)

# assert
assert actual == expected
38 changes: 29 additions & 9 deletions tests/unit_tests/impl/test_particle_attributes.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
# pylint: disable=missing-module-docstring,missing-class-docstring,missing-function-docstring
from collections import Counter

import numpy as np
import pytest
from matplotlib import pyplot
from scipy import stats

from PySDM import Formulae
from PySDM.backends import CPU, GPU, ThrustRTC
from PySDM.backends.impl_common.index import make_Index
from PySDM.backends.impl_common.indexed_storage import make_IndexedStorage
Expand Down Expand Up @@ -301,23 +301,24 @@ def test_permutation_local_repeatable(backend_class):
)

@staticmethod
@pytest.mark.parametrize("seed", (1, 2, 3))
# pylint: disable=redefined-outer-name
def test_permutation_global_uniform_distribution(backend_class):
if backend_class is ThrustRTC:
return # TODO #328

def test_permutation_global_uniform_distribution(
seed, backend_class=CPU, plot=False
):
n_sd = 4
possible_permutations_num = np.math.factorial(n_sd)
coverage = 1000

random_numbers = np.linspace(
0.0, 1.0, n_sd * possible_permutations_num * coverage
)
np.random.seed(1)
np.random.seed(seed)
np.random.shuffle(random_numbers)

# Arrange
particulator = DummyParticulator(CPU, n_sd=n_sd)
particulator = DummyParticulator(CPU, n_sd=n_sd, formulae=Formulae(seed=seed))

sut = ParticleAttributesFactory.empty_particles(particulator, n_sd)
idx_length = len(sut._ParticleAttributes__idx)
sut._ParticleAttributes__tmp_idx = make_indexed_storage(
Expand All @@ -338,6 +339,25 @@ def test_permutation_global_uniform_distribution(backend_class):
sut._ParticleAttributes__idx, idx_length
)

_, uniformity = stats.chisquare(list(Counter(permutation_ids).values()))
# Plot
counts, _ = np.histogram(permutation_ids, bins=possible_permutations_num)
_, uniformity = stats.chisquare(counts)

avg = np.mean(counts)
std = np.std(counts)

pyplot.plot(counts, marker=".")
pyplot.xlabel("permutation id")
pyplot.ylabel("occurrence count")
pyplot.xlim(0, possible_permutations_num)
pyplot.axhline(coverage, color="black", label="coverage")
pyplot.axhline(avg, color="green", label="mean +/- std")
for offset in (-std, +std):
pyplot.axhline(avg + offset, color="green", linestyle="--")
pyplot.legend()
if plot:
pyplot.show()

# Assert
assert abs(avg - coverage) / coverage < 1e-6
assert uniformity > 0.9

0 comments on commit 91b6c7a

Please sign in to comment.