diff --git a/PySDM/backends/impl_numba/methods/index_methods.py b/PySDM/backends/impl_numba/methods/index_methods.py index 3836b5c8c..9c3904885 100644 --- a/PySDM/backends/impl_numba/methods/index_methods.py +++ b/PySDM/backends/impl_numba/methods/index_methods.py @@ -8,10 +8,15 @@ from PySDM.backends.impl_numba import conf +@numba.njit(**{**conf.JIT_FLAGS, **{"parallel": False}}) +def draw_random_int(start: int, end: int, u01: float): + return min(int(start + u01 * (end - start + 1)), end) + + @numba.njit(**{**conf.JIT_FLAGS, **{"parallel": False}}) def fisher_yates_shuffle(idx, u01, start, end, random_offset=0): for i in range(end - 1, start, -1): - j = int(start + u01[random_offset + i] * (i - start) + 0.5) + j = draw_random_int(start=start, end=i, u01=u01[random_offset + i]) idx[i], idx[j] = idx[j], idx[i] diff --git a/tests/unit_tests/backends/test_index_methods.py b/tests/unit_tests/backends/test_index_methods.py new file mode 100644 index 000000000..09f0256cf --- /dev/null +++ b/tests/unit_tests/backends/test_index_methods.py @@ -0,0 +1,30 @@ +# pylint: disable=missing-module-docstring,missing-class-docstring,missing-function-docstring +import pytest + +from PySDM.backends.impl_numba.methods.index_methods import draw_random_int + + +@pytest.mark.parametrize( + "a, b, u01, expected", + ( + (0, 100, 0.0, 0), + (0, 100, 1.0, 100), + (0, 1, 0.5, 1), + (0, 1, 0.49, 0), + (0, 3, 0.49, 1), + (0, 3, 0.245, 0), + (0, 2, 0.332, 0), + (0, 2, 0.333, 0), + (0, 2, 0.334, 1), + (0, 2, 0.665, 1), + (0, 2, 0.666, 1), + (0, 2, 0.667, 2), + (0, 2, 0.999, 2), + ), +) +def test_draw_random_int(a, b, u01, expected): + # act + actual = draw_random_int(a, b, u01) + + # assert + assert actual == expected diff --git a/tests/unit_tests/impl/test_particle_attributes.py b/tests/unit_tests/impl/test_particle_attributes.py index 05ca5ffc1..a3e9d4411 100644 --- a/tests/unit_tests/impl/test_particle_attributes.py +++ b/tests/unit_tests/impl/test_particle_attributes.py @@ -1,10 +1,10 @@ # pylint: disable=missing-module-docstring,missing-class-docstring,missing-function-docstring -from collections import Counter - import numpy as np import pytest +from matplotlib import pyplot from scipy import stats +from PySDM import Formulae from PySDM.backends import CPU, GPU, ThrustRTC from PySDM.backends.impl_common.index import make_Index from PySDM.backends.impl_common.indexed_storage import make_IndexedStorage @@ -301,11 +301,11 @@ def test_permutation_local_repeatable(backend_class): ) @staticmethod + @pytest.mark.parametrize("seed", (1, 2, 3)) # pylint: disable=redefined-outer-name - def test_permutation_global_uniform_distribution(backend_class): - if backend_class is ThrustRTC: - return # TODO #328 - + def test_permutation_global_uniform_distribution( + seed, backend_class=CPU, plot=False + ): n_sd = 4 possible_permutations_num = np.math.factorial(n_sd) coverage = 1000 @@ -313,11 +313,12 @@ def test_permutation_global_uniform_distribution(backend_class): random_numbers = np.linspace( 0.0, 1.0, n_sd * possible_permutations_num * coverage ) - np.random.seed(1) + np.random.seed(seed) np.random.shuffle(random_numbers) # Arrange - particulator = DummyParticulator(CPU, n_sd=n_sd) + particulator = DummyParticulator(CPU, n_sd=n_sd, formulae=Formulae(seed=seed)) + sut = ParticleAttributesFactory.empty_particles(particulator, n_sd) idx_length = len(sut._ParticleAttributes__idx) sut._ParticleAttributes__tmp_idx = make_indexed_storage( @@ -338,6 +339,25 @@ def test_permutation_global_uniform_distribution(backend_class): sut._ParticleAttributes__idx, idx_length ) - _, uniformity = stats.chisquare(list(Counter(permutation_ids).values())) + # Plot + counts, _ = np.histogram(permutation_ids, bins=possible_permutations_num) + _, uniformity = stats.chisquare(counts) + + avg = np.mean(counts) + std = np.std(counts) + + pyplot.plot(counts, marker=".") + pyplot.xlabel("permutation id") + pyplot.ylabel("occurrence count") + pyplot.xlim(0, possible_permutations_num) + pyplot.axhline(coverage, color="black", label="coverage") + pyplot.axhline(avg, color="green", label="mean +/- std") + for offset in (-std, +std): + pyplot.axhline(avg + offset, color="green", linestyle="--") + pyplot.legend() + if plot: + pyplot.show() + # Assert + assert abs(avg - coverage) / coverage < 1e-6 assert uniformity > 0.9