diff --git a/.github/workflows/tests+artifacts+pypi.yml b/.github/workflows/tests+artifacts+pypi.yml index 0977f3b21..2a9a98002 100644 --- a/.github/workflows/tests+artifacts+pypi.yml +++ b/.github/workflows/tests+artifacts+pypi.yml @@ -131,7 +131,7 @@ jobs: python-version: "3.8" fail-fast: false runs-on: ${{ matrix.platform }} - timeout-minutes: 45 + timeout-minutes: ${{ startsWith(matrix.platform, 'windows-') && 45 || 30 }} steps: - uses: actions/checkout@v4.1.6 with: @@ -216,7 +216,7 @@ jobs: test-suite: [ "chemistry_freezing_isotopes", "condensation_a", "condensation_b", "coagulation", "breakup", "multi-process_a", "multi-process_b"] fail-fast: false runs-on: ${{ matrix.platform }} - timeout-minutes: 50 + timeout-minutes: ${{ startsWith(matrix.platform, 'windows-') && 65 || 50 }} steps: - uses: actions/checkout@v4.1.6 with: diff --git a/PySDM/backends/impl_common/backend_methods.py b/PySDM/backends/impl_common/backend_methods.py index 6ab0824f3..dc8917b1d 100644 --- a/PySDM/backends/impl_common/backend_methods.py +++ b/PySDM/backends/impl_common/backend_methods.py @@ -9,5 +9,9 @@ class BackendMethods: def __init__(self): if not hasattr(self, "formulae"): self.formulae = None + if not hasattr(self, "formulae_flattened"): + self.formulae_flattened = None if not hasattr(self, "Storage"): self.Storage = None + if not hasattr(self, "default_jit_flags"): + self.default_jit_flags = {} diff --git a/PySDM/backends/impl_numba/methods/__init__.py b/PySDM/backends/impl_numba/methods/__init__.py index 813231724..ca35c0e65 100644 --- a/PySDM/backends/impl_numba/methods/__init__.py +++ b/PySDM/backends/impl_numba/methods/__init__.py @@ -1 +1,14 @@ """ method classes of the CPU backend """ + +from .chemistry_methods import ChemistryMethods +from .collisions_methods import CollisionsMethods +from .condensation_methods import CondensationMethods +from .displacement_methods import DisplacementMethods +from .fragmentation_methods import FragmentationMethods +from .freezing_methods import FreezingMethods +from .index_methods import IndexMethods +from .isotope_methods import IsotopeMethods +from .moments_methods import MomentsMethods +from .pair_methods import PairMethods +from .physics_methods import PhysicsMethods +from .terminal_velocity_methods import TerminalVelocityMethods diff --git a/PySDM/backends/impl_numba/methods/collisions_methods.py b/PySDM/backends/impl_numba/methods/collisions_methods.py index e6bf5c41d..06573bfc6 100644 --- a/PySDM/backends/impl_numba/methods/collisions_methods.py +++ b/PySDM/backends/impl_numba/methods/collisions_methods.py @@ -2,7 +2,7 @@ CPU implementation of backend methods for particle collisions """ -# pylint: disable=too-many-lines +from functools import cached_property import numba import numpy as np @@ -243,78 +243,13 @@ def break_up_while( warn("overflow", __file__) -@numba.njit(**{**conf.JIT_FLAGS, **{"parallel": False}}) -def straub_Nr( # pylint: disable=too-many-arguments,unused-argument - i, - Nr1, - Nr2, - Nr3, - Nr4, - Nrt, - CW, - gam, -): # pylint: disable=too-many-branches` - if gam[i] * CW[i] >= 7.0: - Nr1[i] = 0.088 * (gam[i] * CW[i] - 7.0) - if CW[i] >= 21.0: - Nr2[i] = 0.22 * (CW[i] - 21.0) - if CW[i] <= 46.0: - Nr3[i] = 0.04 * (46.0 - CW[i]) - else: - Nr3[i] = 1.0 - Nr4[i] = 1.0 - Nrt[i] = Nr1[i] + Nr2[i] + Nr3[i] + Nr4[i] - - -@numba.njit(**{**conf.JIT_FLAGS, **{"parallel": False}}) -def straub_mass_remainder( # pylint: disable=too-many-arguments,unused-argument - i, vl, ds, mu1, sigma1, mu2, sigma2, mu3, sigma3, d34, Nr1, Nr2, Nr3, Nr4 -): - # pylint: disable=too-many-arguments, too-many-locals - Nr1[i] = Nr1[i] * np.exp(3 * mu1 + 9 * np.power(sigma1, 2) / 2) - Nr2[i] = Nr2[i] * (mu2**3 + 3 * mu2 * sigma2**2) - Nr3[i] = Nr3[i] * (mu3**3 + 3 * mu3 * sigma3**2) - Nr4[i] = vl[i] * 6 / np.pi + ds[i] ** 3 - Nr1[i] - Nr2[i] - Nr3[i] - if Nr4[i] <= 0.0: - d34[i] = 0 - Nr4[i] = 0 - else: - d34[i] = np.exp(np.log(Nr4[i]) / 3) - - -@numba.njit(**{**conf.JIT_FLAGS, **{"parallel": False}}) -def ll82_Nr( # pylint: disable=too-many-arguments,unused-argument - i, - Rf, - Rs, - Rd, - CKE, - W, - W2, -): # pylint: disable=too-many-branches` - if CKE[i] >= 0.893e-6: - Rf[i] = 1.11e-4 * CKE[i] ** (-0.654) - else: - Rf[i] = 1.0 - if W[i] >= 0.86: - Rs[i] = 0.685 * (1 - np.exp(-1.63 * (W2[i] - 0.86))) - else: - Rs[i] = 0.0 - if (Rs[i] + Rf[i]) > 1.0: - Rd[i] = 0.0 - else: - Rd[i] = 1.0 - Rs[i] - Rf[i] - - class CollisionsMethods(BackendMethods): - def __init__(self): # pylint: disable=too-many-statements,too-many-locals - BackendMethods.__init__(self) - + @cached_property + def _collision_coalescence_breakup_body(self): _break_up = break_up_while if self.formulae.handle_all_breakups else break_up - const = self.formulae.constants - @numba.njit(**{**conf.JIT_FLAGS, "fastmath": self.formulae.fastmath}) - def __collision_coalescence_breakup_body( + @numba.njit(**self.default_jit_flags) + def body( *, multiplicity, idx, @@ -373,252 +308,66 @@ def __collision_coalescence_breakup_body( ) flag_zero_multiplicity(j, k, multiplicity, healthy) - self.__collision_coalescence_breakup_body = __collision_coalescence_breakup_body - - @numba.njit(**{**conf.JIT_FLAGS, "fastmath": self.formulae.fastmath}) - def __ll82_coalescence_check_body(*, Ec, dl): - for i in numba.prange(len(Ec)): # pylint: disable=not-an-iterable - if dl[i] < 0.4e-3: - Ec[i] = 1.0 - - self.__ll82_coalescence_check_body = __ll82_coalescence_check_body - - if self.formulae.fragmentation_function.__name__ == "Straub2010Nf": - straub_sigma1 = self.formulae.fragmentation_function.params_sigma1 - straub_mu1 = self.formulae.fragmentation_function.params_mu1 - straub_sigma2 = self.formulae.fragmentation_function.params_sigma2 - straub_mu2 = self.formulae.fragmentation_function.params_mu2 - straub_sigma3 = self.formulae.fragmentation_function.params_sigma3 - straub_mu3 = self.formulae.fragmentation_function.params_mu3 - straub_erfinv = self.formulae.trivia.erfinv_approx - - @numba.njit(**{**conf.JIT_FLAGS, "fastmath": self.formulae.fastmath}) - def __straub_fragmentation_body( - *, CW, gam, ds, v_max, frag_volume, rand, Nr1, Nr2, Nr3, Nr4, Nrt, d34 - ): # pylint: disable=too-many-arguments,too-many-locals - for i in numba.prange( # pylint: disable=not-an-iterable - len(frag_volume) - ): - straub_Nr(i, Nr1, Nr2, Nr3, Nr4, Nrt, CW, gam) - sigma1 = straub_sigma1(CW[i]) - mu1 = straub_mu1(sigma1) - sigma2 = straub_sigma2(CW[i]) - mu2 = straub_mu2(ds[i]) - sigma3 = straub_sigma3(CW[i]) - mu3 = straub_mu3(ds[i]) - straub_mass_remainder( - i, - v_max, - ds, - mu1, - sigma1, - mu2, - sigma2, - mu3, - sigma3, - d34, - Nr1, - Nr2, - Nr3, - Nr4, - ) - Nrt[i] = Nr1[i] + Nr2[i] + Nr3[i] + Nr4[i] + return body - if Nrt[i] == 0.0: - diameter = 0.0 - else: - if rand[i] < Nr1[i] / Nrt[i]: - X = rand[i] * Nrt[i] / Nr1[i] - lnarg = mu1 + np.sqrt(2) * sigma1 * straub_erfinv(X) - diameter = np.exp(lnarg) - elif rand[i] < (Nr2[i] + Nr1[i]) / Nrt[i]: - X = (rand[i] * Nrt[i] - Nr1[i]) / Nr2[i] - diameter = mu2 + np.sqrt(2) * sigma2 * straub_erfinv(X) - elif rand[i] < (Nr3[i] + Nr2[i] + Nr1[i]) / Nrt[i]: - X = (rand[i] * Nrt[i] - Nr1[i] - Nr2[i]) / Nr3[i] - diameter = mu3 + np.sqrt(2) * sigma3 * straub_erfinv(X) - else: - diameter = d34[i] - - frag_volume[i] = diameter**3 * const.PI / 6 - - self.__straub_fragmentation_body = __straub_fragmentation_body - elif self.formulae.fragmentation_function.__name__ == "LowList1982Nf": - ll82_params_f1 = self.formulae.fragmentation_function.params_f1 - ll82_params_f2 = self.formulae.fragmentation_function.params_f2 - ll82_params_f3 = self.formulae.fragmentation_function.params_f3 - ll82_params_s1 = self.formulae.fragmentation_function.params_s1 - ll82_params_s2 = self.formulae.fragmentation_function.params_s2 - ll82_params_d1 = self.formulae.fragmentation_function.params_d1 - ll82_params_d2 = self.formulae.fragmentation_function.params_d2 - ll82_erfinv = self.formulae.fragmentation_function.erfinv - - @numba.njit(**{**conf.JIT_FLAGS, "fastmath": self.formulae.fastmath}) - def __ll82_fragmentation_body( - *, CKE, W, W2, St, ds, dl, dcoal, frag_volume, rand, Rf, Rs, Rd, tol - ): # pylint: disable=too-many-branches,too-many-locals,too-many-statements - for i in numba.prange( # pylint: disable=not-an-iterable - len(frag_volume) - ): - if dl[i] <= 0.4e-3: - frag_volume[i] = dcoal[i] ** 3 * const.PI / 6 - elif ds[i] == 0.0 or dl[i] == 0.0: - frag_volume[i] = 1e-18 - else: - ll82_Nr(i, Rf, Rs, Rd, CKE, W, W2) - if rand[i] <= Rf[i]: # filament breakup - (H1, mu1, sigma1) = ll82_params_f1(dl[i], dcoal[i]) - (H2, mu2, sigma2) = ll82_params_f2(ds[i]) - (H3, mu3, sigma3) = ll82_params_f3(ds[i], dl[i]) - H1 = H1 * mu1 - H2 = H2 * mu2 - H3 = H3 * np.exp(mu3) - Hsum = H1 + H2 + H3 - rand[i] = rand[i] / Rf[i] - if rand[i] <= H1 / Hsum: - X = max(rand[i] * Hsum / H1, tol) - frag_volume[i] = mu1 + np.sqrt( - 2 - ) * sigma1 * ll82_erfinv(2 * X - 1) - elif rand[i] <= (H1 + H2) / Hsum: - X = (rand[i] * Hsum - H1) / H2 - frag_volume[i] = mu2 + np.sqrt( - 2 - ) * sigma2 * ll82_erfinv(2 * X - 1) - else: - X = min((rand[i] * Hsum - H1 - H2) / H3, 1.0 - tol) - lnarg = mu3 + np.sqrt(2) * sigma3 * ll82_erfinv( - 2 * X - 1 - ) - frag_volume[i] = np.exp(lnarg) - - elif rand[i] <= Rf[i] + Rs[i]: # sheet breakup - (H1, mu1, sigma1) = ll82_params_s1(dl[i], ds[i], dcoal[i]) - (H2, mu2, sigma2) = ll82_params_s2(dl[i], ds[i], St[i]) - H1 = H1 * mu1 - H2 = H2 * np.exp(mu2) - Hsum = H1 + H2 - rand[i] = (rand[i] - Rf[i]) / (Rs[i]) - if rand[i] <= H1 / Hsum: - X = max(rand[i] * Hsum / H1, tol) - frag_volume[i] = mu1 + np.sqrt( - 2 - ) * sigma1 * ll82_erfinv(2 * X - 1) - else: - X = min((rand[i] * Hsum - H1) / H2, 1.0 - tol) - lnarg = mu2 + np.sqrt(2) * sigma2 * ll82_erfinv( - 2 * X - 1 - ) - frag_volume[i] = np.exp(lnarg) - - else: # disk breakup - (H1, mu1, sigma1) = ll82_params_d1( - W[i], dl[i], dcoal[i], CKE[i] - ) - (H2, mu2, sigma2) = ll82_params_d2(ds[i], dl[i], CKE[i]) - H1 = H1 * mu1 - Hsum = H1 + H2 - rand[i] = (rand[i] - Rf[i] - Rs[i]) / Rd[i] - if rand[i] <= H1 / Hsum: - X = max(rand[i] * Hsum / H1, tol) - frag_volume[i] = mu1 + np.sqrt( - 2 - ) * sigma1 * ll82_erfinv(2 * X - 1) - else: - X = min((rand[i] * Hsum - H1) / H2, 1 - tol) - lnarg = mu2 + np.sqrt(2) * sigma2 * ll82_erfinv( - 2 * X - 1 - ) - frag_volume[i] = np.exp(lnarg) - - frag_volume[i] = ( - frag_volume[i] * 0.01 - ) # diameter in cm; convert to m - frag_volume[i] = frag_volume[i] ** 3 * const.PI / 6 - - self.__ll82_fragmentation_body = __ll82_fragmentation_body - elif self.formulae.fragmentation_function.__name__ == "Gaussian": - erfinv_approx = self.formulae.trivia.erfinv_approx - - @numba.njit(**{**conf.JIT_FLAGS, "fastmath": self.formulae.fastmath}) - def __gauss_fragmentation_body( - *, mu, sigma, frag_volume, rand - ): # pylint: disable=too-many-arguments - for i in numba.prange( # pylint: disable=not-an-iterable - len(frag_volume) - ): - frag_volume[i] = mu + sigma * erfinv_approx(rand[i]) - - self.__gauss_fragmentation_body = __gauss_fragmentation_body - elif self.formulae.fragmentation_function.__name__ == "Feingold1988": - feingold1988_frag_volume = self.formulae.fragmentation_function.frag_volume - - @numba.njit(**{**conf.JIT_FLAGS, "fastmath": self.formulae.fastmath}) - # pylint: disable=too-many-arguments - def __feingold1988_fragmentation_body( - *, scale, frag_volume, x_plus_y, rand, fragtol - ): - for i in numba.prange( # pylint: disable=not-an-iterable - len(frag_volume) - ): - frag_volume[i] = feingold1988_frag_volume( - scale, rand[i], x_plus_y[i], fragtol - ) - - self.__feingold1988_fragmentation_body = __feingold1988_fragmentation_body + @cached_property + def _adaptive_sdm_end_body(self): + @numba.njit(**{**self.default_jit_flags, "parallel": False}) + def body(dt_left, n_cell, cell_start): + end = 0 + for i in range(n_cell - 1, -1, -1): + if dt_left[i] == 0: + continue + end = cell_start[i + 1] + break + return end - @staticmethod - @numba.njit(**{**conf.JIT_FLAGS, **{"parallel": False}}) - def __adaptive_sdm_end_body(dt_left, n_cell, cell_start): - end = 0 - for i in range(n_cell - 1, -1, -1): - if dt_left[i] == 0: - continue - end = cell_start[i + 1] - break - return end + return body def adaptive_sdm_end(self, dt_left, cell_start): - return self.__adaptive_sdm_end_body(dt_left.data, len(dt_left), cell_start.data) + return self._adaptive_sdm_end_body(dt_left.data, len(dt_left), cell_start.data) - @staticmethod - @numba.njit(**conf.JIT_FLAGS) - # pylint: disable=too-many-arguments,too-many-locals - def __scale_prob_for_adaptive_sdm_gamma_body( - prob, - idx, - length, - multiplicity, - cell_id, - dt_left, - dt, - dt_range, - is_first_in_pair, - stats_n_substep, - stats_dt_min, - ): - dt_todo = np.empty_like(dt_left) - for cid in numba.prange(len(dt_todo)): # pylint: disable=not-an-iterable - dt_todo[cid] = min(dt_left[cid], dt_range[1]) - for i in range(length // 2): # TODO #571 - j, k, skip_pair = pair_indices(i, idx, is_first_in_pair, prob) - if skip_pair: - continue - prop = multiplicity[j] // multiplicity[k] - dt_optimal = dt * prop / prob[i] - cid = cell_id[j] - dt_optimal = max(dt_optimal, dt_range[0]) - dt_todo[cid] = min(dt_todo[cid], dt_optimal) - stats_dt_min[cid] = min(stats_dt_min[cid], dt_optimal) - for i in numba.prange(length // 2): # pylint: disable=not-an-iterable - j, _, skip_pair = pair_indices(i, idx, is_first_in_pair, prob) - if skip_pair: - continue - prob[i] *= dt_todo[cell_id[j]] / dt - for cid in numba.prange(len(dt_todo)): # pylint: disable=not-an-iterable - dt_left[cid] -= dt_todo[cid] - if dt_todo[cid] > 0: - stats_n_substep[cid] += 1 + @cached_property + def _scale_prob_for_adaptive_sdm_gamma_body(self): + @numba.njit(**self.default_jit_flags) + # pylint: disable=too-many-arguments,too-many-locals + def body( + prob, + idx, + length, + multiplicity, + cell_id, + dt_left, + dt, + dt_range, + is_first_in_pair, + stats_n_substep, + stats_dt_min, + ): + dt_todo = np.empty_like(dt_left) + for cid in numba.prange(len(dt_todo)): # pylint: disable=not-an-iterable + dt_todo[cid] = min(dt_left[cid], dt_range[1]) + for i in range(length // 2): # TODO #571 + j, k, skip_pair = pair_indices(i, idx, is_first_in_pair, prob) + if skip_pair: + continue + prop = multiplicity[j] // multiplicity[k] + dt_optimal = dt * prop / prob[i] + cid = cell_id[j] + dt_optimal = max(dt_optimal, dt_range[0]) + dt_todo[cid] = min(dt_todo[cid], dt_optimal) + stats_dt_min[cid] = min(stats_dt_min[cid], dt_optimal) + for i in numba.prange(length // 2): # pylint: disable=not-an-iterable + j, _, skip_pair = pair_indices(i, idx, is_first_in_pair, prob) + if skip_pair: + continue + prob[i] *= dt_todo[cell_id[j]] / dt + for cid in numba.prange(len(dt_todo)): # pylint: disable=not-an-iterable + dt_left[cid] -= dt_todo[cid] + if dt_todo[cid] > 0: + stats_n_substep[cid] += 1 + + return body def scale_prob_for_adaptive_sdm_gamma( self, @@ -633,7 +382,7 @@ def scale_prob_for_adaptive_sdm_gamma( stats_n_substep, stats_dt_min, ): - return self.__scale_prob_for_adaptive_sdm_gamma_body( + return self._scale_prob_for_adaptive_sdm_gamma_body( prob.data, multiplicity.idx.data, len(multiplicity), @@ -647,38 +396,53 @@ def scale_prob_for_adaptive_sdm_gamma( stats_dt_min.data, ) - @staticmethod - # @numba.njit(**conf.JIT_FLAGS) # note: as of Numba 0.51, np.dot() does not support ints - def __cell_id_body(cell_id, cell_origin, strides): - cell_id[:] = np.dot(strides, cell_origin) + @cached_property + def _cell_id_body(self): + # @numba.njit(**conf.JIT_FLAGS) # note: as of Numba 0.51, np.dot() does not support ints + def body(cell_id, cell_origin, strides): + cell_id[:] = np.dot(strides, cell_origin) + + return body def cell_id(self, cell_id, cell_origin, strides): - return self.__cell_id_body(cell_id.data, cell_origin.data, strides.data) + return self._cell_id_body(cell_id.data, cell_origin.data, strides.data) - @staticmethod - @numba.njit(**conf.JIT_FLAGS) - def __collision_coalescence_body( - *, - multiplicity, - idx, - length, - attributes, - gamma, - healthy, - cell_id, - coalescence_rate, - is_first_in_pair, - ): - for i in numba.prange( # pylint: disable=not-an-iterable,too-many-nested-blocks - length // 2 + @cached_property + def _collision_coalescence_body(self): + @numba.njit(**self.default_jit_flags) + def body( + *, + multiplicity, + idx, + length, + attributes, + gamma, + healthy, + cell_id, + coalescence_rate, + is_first_in_pair, ): - j, k, skip_pair = pair_indices(i, idx, is_first_in_pair, gamma) - if skip_pair: - continue - coalesce( - i, j, k, cell_id[j], multiplicity, gamma, attributes, coalescence_rate - ) - flag_zero_multiplicity(j, k, multiplicity, healthy) + for ( + i + ) in numba.prange( # pylint: disable=not-an-iterable,too-many-nested-blocks + length // 2 + ): + j, k, skip_pair = pair_indices(i, idx, is_first_in_pair, gamma) + if skip_pair: + continue + coalesce( + i, + j, + k, + cell_id[j], + multiplicity, + gamma, + attributes, + coalescence_rate, + ) + flag_zero_multiplicity(j, k, multiplicity, healthy) + + return body def collision_coalescence( self, @@ -692,7 +456,7 @@ def collision_coalescence( coalescence_rate, is_first_in_pair, ): - self.__collision_coalescence_body( + self._collision_coalescence_body( multiplicity=multiplicity.data, idx=idx.data, length=len(idx), @@ -726,7 +490,7 @@ def collision_coalescence_breakup( max_multiplicity, ): # pylint: disable=too-many-locals - self.__collision_coalescence_breakup_body( + self._collision_coalescence_breakup_body( multiplicity=multiplicity.data, idx=idx.data, length=len(idx), @@ -747,269 +511,45 @@ def collision_coalescence_breakup( particle_mass=particle_mass.data, ) - @staticmethod - @numba.njit(**{**conf.JIT_FLAGS}) - # pylint: disable=too-many-arguments - def __fragmentation_limiters(n_fragment, frag_volume, vmin, nfmax, x_plus_y): - for i in numba.prange(len(frag_volume)): # pylint: disable=not-an-iterable - if x_plus_y[i] == 0.0: - frag_volume[i] = 0.0 - n_fragment[i] = 1.0 - else: - if np.isnan(frag_volume[i]) or frag_volume[i] == 0.0: - frag_volume[i] = x_plus_y[i] - frag_volume[i] = min(frag_volume[i], x_plus_y[i]) - if nfmax is not None and x_plus_y[i] / frag_volume[i] > nfmax: - frag_volume[i] = x_plus_y[i] / nfmax - elif frag_volume[i] < vmin: - frag_volume[i] = x_plus_y[i] - n_fragment[i] = x_plus_y[i] / frag_volume[i] - - def fragmentation_limiters(self, *, n_fragment, frag_volume, vmin, nfmax, x_plus_y): - self.__fragmentation_limiters( - n_fragment=n_fragment.data, - frag_volume=frag_volume.data, - vmin=vmin, - nfmax=nfmax, - x_plus_y=x_plus_y.data, - ) - - @staticmethod - @numba.njit(**{**conf.JIT_FLAGS}) - def __slams_fragmentation_body(n_fragment, frag_volume, x_plus_y, probs, rand): - for i in numba.prange(len(n_fragment)): # pylint: disable=not-an-iterable - probs[i] = 0.0 - n_fragment[i] = 1 - for n in range(22): - probs[i] += 0.91 * (n + 2) ** (-1.56) - if rand[i] < probs[i]: - n_fragment[i] = n + 2 - break - frag_volume[i] = x_plus_y[i] / n_fragment[i] - - def slams_fragmentation( - self, n_fragment, frag_volume, x_plus_y, probs, rand, vmin, nfmax - ): # pylint: disable=too-many-arguments - self.__slams_fragmentation_body( - n_fragment.data, frag_volume.data, x_plus_y.data, probs.data, rand.data - ) - self.__fragmentation_limiters( - n_fragment=n_fragment.data, - frag_volume=frag_volume.data, - vmin=vmin, - nfmax=nfmax, - x_plus_y=x_plus_y.data, - ) - - @staticmethod - @numba.njit(**{**conf.JIT_FLAGS}) - # pylint: disable=too-many-arguments - def __exp_fragmentation_body(*, scale, frag_volume, rand, tol=1e-5): - """ - Exponential PDF - """ - for i in numba.prange(len(frag_volume)): # pylint: disable=not-an-iterable - frag_volume[i] = -scale * np.log(max(1 - rand[i], tol)) - - def exp_fragmentation( - self, - *, - n_fragment, - scale, - frag_volume, - x_plus_y, - rand, - vmin, - nfmax, - tol=1e-5, - ): - self.__exp_fragmentation_body( - scale=scale, - frag_volume=frag_volume.data, - rand=rand.data, - tol=tol, - ) - self.__fragmentation_limiters( - n_fragment=n_fragment.data, - frag_volume=frag_volume.data, - x_plus_y=x_plus_y.data, - vmin=vmin, - nfmax=nfmax, - ) - - def feingold1988_fragmentation( - self, - *, - n_fragment, - scale, - frag_volume, - x_plus_y, - rand, - fragtol, - vmin, - nfmax, - ): - self.__feingold1988_fragmentation_body( - scale=scale, - frag_volume=frag_volume.data, - x_plus_y=x_plus_y.data, - rand=rand.data, - fragtol=fragtol, - ) - - self.__fragmentation_limiters( - n_fragment=n_fragment.data, - frag_volume=frag_volume.data, - x_plus_y=x_plus_y.data, - vmin=vmin, - nfmax=nfmax, - ) - - def gauss_fragmentation( - self, *, n_fragment, mu, sigma, frag_volume, x_plus_y, rand, vmin, nfmax - ): - self.__gauss_fragmentation_body( - mu=mu, - sigma=sigma, - frag_volume=frag_volume.data, - rand=rand.data, - ) - self.__fragmentation_limiters( - n_fragment=n_fragment.data, - frag_volume=frag_volume.data, - x_plus_y=x_plus_y.data, - vmin=vmin, - nfmax=nfmax, - ) - - def straub_fragmentation( - # pylint: disable=too-many-arguments,too-many-locals - self, - *, - n_fragment, - CW, - gam, - ds, - frag_volume, - v_max, - x_plus_y, - rand, - vmin, - nfmax, - Nr1, - Nr2, - Nr3, - Nr4, - Nrt, - d34, - ): - self.__straub_fragmentation_body( - CW=CW.data, - gam=gam.data, - ds=ds.data, - frag_volume=frag_volume.data, - v_max=v_max.data, - rand=rand.data, - Nr1=Nr1.data, - Nr2=Nr2.data, - Nr3=Nr3.data, - Nr4=Nr4.data, - Nrt=Nrt.data, - d34=d34.data, - ) - self.__fragmentation_limiters( - n_fragment=n_fragment.data, - frag_volume=frag_volume.data, - x_plus_y=x_plus_y.data, - vmin=vmin, - nfmax=nfmax, - ) - - def ll82_fragmentation( + @cached_property + def _compute_gamma_body(self): + @numba.njit(**self.default_jit_flags) # pylint: disable=too-many-arguments,too-many-locals - self, - *, - n_fragment, - CKE, - W, - W2, - St, - ds, - dl, - dcoal, - frag_volume, - x_plus_y, - rand, - vmin, - nfmax, - Rf, - Rs, - Rd, - tol=1e-8, - ): - self.__ll82_fragmentation_body( - CKE=CKE.data, - W=W.data, - W2=W2.data, - St=St.data, - ds=ds.data, - dl=dl.data, - dcoal=dcoal.data, - frag_volume=frag_volume.data, - rand=rand.data, - Rf=Rf.data, - Rs=Rs.data, - Rd=Rd.data, - tol=tol, - ) - self.__fragmentation_limiters( - n_fragment=n_fragment.data, - frag_volume=frag_volume.data, - x_plus_y=x_plus_y.data, - vmin=vmin, - nfmax=nfmax, - ) - - def ll82_coalescence_check(self, *, Ec, dl): - self.__ll82_coalescence_check_body( - Ec=Ec.data, - dl=dl.data, - ) + def body( + prob, + rand, + idx, + length, + multiplicity, + cell_id, + collision_rate_deficit, + collision_rate, + is_first_in_pair, + out, + ): + """ + return in "out" array gamma (see: http://doi.org/10.1002/qj.441, section 5) + formula: + gamma = floor(prob) + 1 if rand < prob - floor(prob) + = floor(prob) if rand >= prob - floor(prob) + + out may point to the same array as prob + """ + for i in numba.prange(length // 2): # pylint: disable=not-an-iterable + out[i] = np.ceil(prob[i] - rand[i]) + j, k, skip_pair = pair_indices(i, idx, is_first_in_pair, out) + if skip_pair: + continue + prop = multiplicity[j] // multiplicity[k] + g = min(int(out[i]), prop) + cid = cell_id[j] + atomic_add(collision_rate, cid, g * multiplicity[k]) + atomic_add( + collision_rate_deficit, cid, (int(out[i]) - g) * multiplicity[k] + ) + out[i] = g - @staticmethod - @numba.njit(**conf.JIT_FLAGS) - # pylint: disable=too-many-arguments,too-many-locals - def __compute_gamma_body( - prob, - rand, - idx, - length, - multiplicity, - cell_id, - collision_rate_deficit, - collision_rate, - is_first_in_pair, - out, - ): - """ - return in "out" array gamma (see: http://doi.org/10.1002/qj.441, section 5) - formula: - gamma = floor(prob) + 1 if rand < prob - floor(prob) - = floor(prob) if rand >= prob - floor(prob) - - out may point to the same array as prob - """ - for i in numba.prange(length // 2): # pylint: disable=not-an-iterable - out[i] = np.ceil(prob[i] - rand[i]) - j, k, skip_pair = pair_indices(i, idx, is_first_in_pair, out) - if skip_pair: - continue - prop = multiplicity[j] // multiplicity[k] - g = min(int(out[i]), prop) - cid = cell_id[j] - atomic_add(collision_rate, cid, g * multiplicity[k]) - atomic_add(collision_rate_deficit, cid, (int(out[i]) - g) * multiplicity[k]) - out[i] = g + return body def compute_gamma( self, @@ -1023,7 +563,7 @@ def compute_gamma( is_first_in_pair, out, ): - return self.__compute_gamma_body( + return self._compute_gamma_body( prob.data, rand.data, multiplicity.idx.data, @@ -1082,27 +622,28 @@ def __call__(self, cell_id, cell_idx, cell_start, idx): return CellCaretaker(idx_shape, idx_dtype, cell_start_len, scheme) - @staticmethod - @numba.njit(**{**conf.JIT_FLAGS, **{"parallel": False}}) - # pylint: disable=too-many-arguments - def __normalize_body( - prob, cell_id, cell_idx, cell_start, norm_factor, timestep, dv - ): - n_cell = cell_start.shape[0] - 1 - for i in range(n_cell): - sd_num = cell_start[i + 1] - cell_start[i] - if sd_num < 2: - norm_factor[i] = 0 - else: - norm_factor[i] = ( - timestep / dv * sd_num * (sd_num - 1) / 2 / (sd_num // 2) - ) - for d in numba.prange(prob.shape[0]): # pylint: disable=not-an-iterable - prob[d] *= norm_factor[cell_idx[cell_id[d]]] + @cached_property + def _normalize_body(self): + @numba.njit(**{**self.default_jit_flags, **{"parallel": False}}) + # pylint: disable=too-many-arguments + def body(prob, cell_id, cell_idx, cell_start, norm_factor, timestep, dv): + n_cell = cell_start.shape[0] - 1 + for i in range(n_cell): + sd_num = cell_start[i + 1] - cell_start[i] + if sd_num < 2: + norm_factor[i] = 0 + else: + norm_factor[i] = ( + timestep / dv * sd_num * (sd_num - 1) / 2 / (sd_num // 2) + ) + for d in numba.prange(prob.shape[0]): # pylint: disable=not-an-iterable + prob[d] *= norm_factor[cell_idx[cell_id[d]]] + + return body # pylint: disable=too-many-arguments def normalize(self, prob, cell_id, cell_idx, cell_start, norm_factor, timestep, dv): - return self.__normalize_body( + return self._normalize_body( prob.data, cell_id.data, cell_idx.data, @@ -1112,20 +653,23 @@ def normalize(self, prob, cell_id, cell_idx, cell_start, norm_factor, timestep, dv, ) - @staticmethod - @numba.njit(**{**conf.JIT_FLAGS, **{"parallel": False}}) - def remove_zero_n_or_flagged(multiplicity, idx, length) -> int: - flag = len(idx) - new_length = length - i = 0 - while i < new_length: - if idx[i] == flag or multiplicity[idx[i]] == 0: - new_length -= 1 - idx[i] = idx[new_length] - idx[new_length] = flag - else: - i += 1 - return new_length + @cached_property + def remove_zero_n_or_flagged(self): + @numba.njit(**{**self.default_jit_flags, **{"parallel": False}}) + def body(multiplicity, idx, length) -> int: + flag = len(idx) + new_length = length + i = 0 + while i < new_length: + if idx[i] == flag or multiplicity[idx[i]] == 0: + new_length -= 1 + idx[i] = idx[new_length] + idx[new_length] = flag + else: + i += 1 + return new_length + + return body @staticmethod @numba.njit(**conf.JIT_FLAGS) @@ -1188,37 +732,38 @@ def _parallel_counting_sort_by_cell_id_and_update_cell_start( cell_start[:] = cell_end_thread[0, :] - @staticmethod - @numba.njit(**conf.JIT_FLAGS) - # pylint: disable=too-many-arguments,too-many-locals - def linear_collection_efficiency_body( - params, output, radii, is_first_in_pair, idx, length, unit - ): - A, B, D1, D2, E1, E2, F1, F2, G1, G2, G3, Mf, Mg = params - output[:] = 0 - for i in numba.prange(length - 1): # pylint: disable=not-an-iterable - if is_first_in_pair[i]: - if radii[idx[i]] > radii[idx[i + 1]]: - r = radii[idx[i]] / unit - r_s = radii[idx[i + 1]] / unit - else: - r = radii[idx[i + 1]] / unit - r_s = radii[idx[i]] / unit - p = r_s / r - if p not in (0, 1): - G = (G1 / r) ** Mg + G2 + G3 * r - Gp = (1 - p) ** G - if Gp != 0: - D = D1 / r**D2 - E = E1 / r**E2 - F = (F1 / r) ** Mf + F2 - output[i // 2] = A + B * p + D / p**F + E / Gp - output[i // 2] = max(0, output[i // 2]) + @cached_property + def _linear_collection_efficiency_body(self): + @numba.njit(**self.default_jit_flags) + # pylint: disable=too-many-arguments,too-many-locals + def body(params, output, radii, is_first_in_pair, idx, length, unit): + A, B, D1, D2, E1, E2, F1, F2, G1, G2, G3, Mf, Mg = params + output[:] = 0 + for i in numba.prange(length - 1): # pylint: disable=not-an-iterable + if is_first_in_pair[i]: + if radii[idx[i]] > radii[idx[i + 1]]: + r = radii[idx[i]] / unit + r_s = radii[idx[i + 1]] / unit + else: + r = radii[idx[i + 1]] / unit + r_s = radii[idx[i]] / unit + p = r_s / r + if p not in (0, 1): + G = (G1 / r) ** Mg + G2 + G3 * r + Gp = (1 - p) ** G + if Gp != 0: + D = D1 / r**D2 + E = E1 / r**E2 + F = (F1 / r) ** Mf + F2 + output[i // 2] = A + B * p + D / p**F + E / Gp + output[i // 2] = max(0, output[i // 2]) + + return body def linear_collection_efficiency( self, *, params, output, radii, is_first_in_pair, unit ): - return self.linear_collection_efficiency_body( + return self._linear_collection_efficiency_body( params, output.data, radii.data, diff --git a/PySDM/backends/impl_numba/methods/condensation_methods.py b/PySDM/backends/impl_numba/methods/condensation_methods.py index a8ff22b8a..59d3440bb 100644 --- a/PySDM/backends/impl_numba/methods/condensation_methods.py +++ b/PySDM/backends/impl_numba/methods/condensation_methods.py @@ -597,7 +597,7 @@ def make_condensation_solver( max_iters, ): return CondensationMethods.make_condensation_solver_impl( - formulae=self.formulae.flatten, + formulae=self.formulae_flattened, timestep=timestep, dt_range=dt_range, adaptive=adaptive, diff --git a/PySDM/backends/impl_numba/methods/displacement_methods.py b/PySDM/backends/impl_numba/methods/displacement_methods.py index 454e61785..681c45c2b 100644 --- a/PySDM/backends/impl_numba/methods/displacement_methods.py +++ b/PySDM/backends/impl_numba/methods/displacement_methods.py @@ -2,6 +2,8 @@ CPU implementation of backend methods for particle displacement (advection and sedimentation) """ +from functools import cached_property + import numba from PySDM.backends.impl_numba import conf @@ -143,59 +145,65 @@ def calculate_displacement( else: raise NotImplementedError() - @staticmethod - @numba.njit(**{**conf.JIT_FLAGS, **{"parallel": False}}) - # pylint: disable=too-many-arguments - def flag_precipitated_body( - cell_origin, - position_in_cell, - volume, - multiplicity, - idx, - length, - healthy, - precipitation_counting_level_index, - displacement, - ): - rainfall = 0.0 - flag = len(idx) - for i in range(length): - position_within_column = ( - cell_origin[-1, idx[i]] + position_in_cell[-1, idx[i]] - ) - if ( - # falling - displacement[-1, idx[i]] < 0 - and - # and crossed precip-counting level - position_within_column < precipitation_counting_level_index - ): - rainfall += volume[idx[i]] * multiplicity[idx[i]] # TODO #599 - idx[i] = flag - healthy[0] = 0 - return rainfall + @cached_property + def _flag_precipitated_body(self): + @numba.njit(**{**self.default_jit_flags, "parallel": False}) + # pylint: disable=too-many-arguments + def body( + cell_origin, + position_in_cell, + volume, + multiplicity, + idx, + length, + healthy, + precipitation_counting_level_index, + displacement, + ): + rainfall = 0.0 + flag = len(idx) + for i in range(length): + position_within_column = ( + cell_origin[-1, idx[i]] + position_in_cell[-1, idx[i]] + ) + if ( + # falling + displacement[-1, idx[i]] < 0 + and + # and crossed precip-counting level + position_within_column < precipitation_counting_level_index + ): + rainfall += volume[idx[i]] * multiplicity[idx[i]] # TODO #599 + idx[i] = flag + healthy[0] = 0 + return rainfall - @staticmethod - @numba.njit(**{**conf.JIT_FLAGS, **{"parallel": False}}) - # pylint: disable=too-many-arguments - def flag_out_of_column_body( - cell_origin, position_in_cell, idx, length, healthy, domain_top_level_index - ): - flag = len(idx) - for i in range(length): - position_within_column = ( - cell_origin[-1, idx[i]] + position_in_cell[-1, idx[i]] - ) - if ( - position_within_column < 0 - or position_within_column > domain_top_level_index - ): - idx[i] = flag - healthy[0] = 0 + return body + + @cached_property + def _flag_out_of_column_body(self): + @numba.njit(**{**self.default_jit_flags, "parallel": False}) + # pylint: disable=too-many-arguments + def body( + cell_origin, position_in_cell, idx, length, healthy, domain_top_level_index + ): + flag = len(idx) + for i in range(length): + position_within_column = ( + cell_origin[-1, idx[i]] + position_in_cell[-1, idx[i]] + ) + if ( + position_within_column < 0 + or position_within_column > domain_top_level_index + ): + idx[i] = flag + healthy[0] = 0 + + return body - @staticmethod # pylint: disable=too-many-arguments def flag_precipitated( + self, cell_origin, position_in_cell, volume, @@ -206,7 +214,7 @@ def flag_precipitated( precipitation_counting_level_index, displacement, ) -> float: - return DisplacementMethods.flag_precipitated_body( + return self._flag_precipitated_body( cell_origin.data, position_in_cell.data, volume.data, @@ -218,12 +226,17 @@ def flag_precipitated( displacement.data, ) - @staticmethod # pylint: disable=too-many-arguments def flag_out_of_column( - cell_origin, position_in_cell, idx, length, healthy, domain_top_level_index - ) -> float: - return DisplacementMethods.flag_out_of_column_body( + self, + cell_origin, + position_in_cell, + idx, + length, + healthy, + domain_top_level_index, + ): + self._flag_out_of_column_body( cell_origin.data, position_in_cell.data, idx.data, diff --git a/PySDM/backends/impl_numba/methods/fragmentation_methods.py b/PySDM/backends/impl_numba/methods/fragmentation_methods.py new file mode 100644 index 000000000..ef9789df6 --- /dev/null +++ b/PySDM/backends/impl_numba/methods/fragmentation_methods.py @@ -0,0 +1,499 @@ +""" +CPU implementation of backend methods supporting fragmentation functions +""" + +from functools import cached_property +import numba +import numpy as np +from PySDM.backends.impl_numba import conf +from PySDM.backends.impl_common.backend_methods import BackendMethods + + +@numba.njit(**{**conf.JIT_FLAGS, **{"parallel": False}}) +def straub_Nr( # pylint: disable=too-many-arguments,unused-argument + i, + Nr1, + Nr2, + Nr3, + Nr4, + Nrt, + CW, + gam, +): # pylint: disable=too-many-branches` + if gam[i] * CW[i] >= 7.0: + Nr1[i] = 0.088 * (gam[i] * CW[i] - 7.0) + if CW[i] >= 21.0: + Nr2[i] = 0.22 * (CW[i] - 21.0) + if CW[i] <= 46.0: + Nr3[i] = 0.04 * (46.0 - CW[i]) + else: + Nr3[i] = 1.0 + Nr4[i] = 1.0 + Nrt[i] = Nr1[i] + Nr2[i] + Nr3[i] + Nr4[i] + + +@numba.njit(**{**conf.JIT_FLAGS, **{"parallel": False}}) +def straub_mass_remainder( # pylint: disable=too-many-arguments,unused-argument + i, vl, ds, mu1, sigma1, mu2, sigma2, mu3, sigma3, d34, Nr1, Nr2, Nr3, Nr4 +): + # pylint: disable=too-many-arguments, too-many-locals + Nr1[i] = Nr1[i] * np.exp(3 * mu1 + 9 * np.power(sigma1, 2) / 2) + Nr2[i] = Nr2[i] * (mu2**3 + 3 * mu2 * sigma2**2) + Nr3[i] = Nr3[i] * (mu3**3 + 3 * mu3 * sigma3**2) + Nr4[i] = vl[i] * 6 / np.pi + ds[i] ** 3 - Nr1[i] - Nr2[i] - Nr3[i] + if Nr4[i] <= 0.0: + d34[i] = 0 + Nr4[i] = 0 + else: + d34[i] = np.exp(np.log(Nr4[i]) / 3) + + +@numba.njit(**{**conf.JIT_FLAGS, **{"parallel": False}}) +def ll82_Nr( # pylint: disable=too-many-arguments,unused-argument + i, + Rf, + Rs, + Rd, + CKE, + W, + W2, +): # pylint: disable=too-many-branches` + if CKE[i] >= 0.893e-6: + Rf[i] = 1.11e-4 * CKE[i] ** (-0.654) + else: + Rf[i] = 1.0 + if W[i] >= 0.86: + Rs[i] = 0.685 * (1 - np.exp(-1.63 * (W2[i] - 0.86))) + else: + Rs[i] = 0.0 + if (Rs[i] + Rf[i]) > 1.0: + Rd[i] = 0.0 + else: + Rd[i] = 1.0 - Rs[i] - Rf[i] + + +class FragmentationMethods(BackendMethods): + @cached_property + def _fragmentation_limiters_body(self): + @numba.njit(**self.default_jit_flags) + # pylint: disable=too-many-arguments + def body(n_fragment, frag_volume, vmin, nfmax, x_plus_y): + for i in numba.prange(len(frag_volume)): # pylint: disable=not-an-iterable + if x_plus_y[i] == 0.0: + frag_volume[i] = 0.0 + n_fragment[i] = 1.0 + else: + if np.isnan(frag_volume[i]) or frag_volume[i] == 0.0: + frag_volume[i] = x_plus_y[i] + frag_volume[i] = min(frag_volume[i], x_plus_y[i]) + if nfmax is not None and x_plus_y[i] / frag_volume[i] > nfmax: + frag_volume[i] = x_plus_y[i] / nfmax + elif frag_volume[i] < vmin: + frag_volume[i] = x_plus_y[i] + n_fragment[i] = x_plus_y[i] / frag_volume[i] + + return body + + def fragmentation_limiters(self, *, n_fragment, frag_volume, vmin, nfmax, x_plus_y): + self._fragmentation_limiters_body( + n_fragment=n_fragment.data, + frag_volume=frag_volume.data, + vmin=vmin, + nfmax=nfmax, + x_plus_y=x_plus_y.data, + ) + + @cached_property + def _slams_fragmentation_body(self): + @numba.njit(**self.default_jit_flags) + def body(n_fragment, frag_volume, x_plus_y, probs, rand): + for i in numba.prange(len(n_fragment)): # pylint: disable=not-an-iterable + probs[i] = 0.0 + n_fragment[i] = 1 + for n in range(22): + probs[i] += 0.91 * (n + 2) ** (-1.56) + if rand[i] < probs[i]: + n_fragment[i] = n + 2 + break + frag_volume[i] = x_plus_y[i] / n_fragment[i] + + return body + + def slams_fragmentation( + self, n_fragment, frag_volume, x_plus_y, probs, rand, vmin, nfmax + ): # pylint: disable=too-many-arguments + self._slams_fragmentation_body( + n_fragment.data, frag_volume.data, x_plus_y.data, probs.data, rand.data + ) + self._fragmentation_limiters_body( + n_fragment=n_fragment.data, + frag_volume=frag_volume.data, + vmin=vmin, + nfmax=nfmax, + x_plus_y=x_plus_y.data, + ) + + @cached_property + def _exp_fragmentation_body(self): + @numba.njit(**self.default_jit_flags) + # pylint: disable=too-many-arguments + def body(*, scale, frag_volume, rand, tol=1e-5): + for i in numba.prange(len(frag_volume)): # pylint: disable=not-an-iterable + frag_volume[i] = -scale * np.log(max(1 - rand[i], tol)) + + return body + + def exp_fragmentation( + self, + *, + n_fragment, + scale, + frag_volume, + x_plus_y, + rand, + vmin, + nfmax, + tol=1e-5, + ): + self._exp_fragmentation_body( + scale=scale, + frag_volume=frag_volume.data, + rand=rand.data, + tol=tol, + ) + self._fragmentation_limiters_body( + n_fragment=n_fragment.data, + frag_volume=frag_volume.data, + x_plus_y=x_plus_y.data, + vmin=vmin, + nfmax=nfmax, + ) + + def feingold1988_fragmentation( + self, + *, + n_fragment, + scale, + frag_volume, + x_plus_y, + rand, + fragtol, + vmin, + nfmax, + ): + self._feingold1988_fragmentation_body( + scale=scale, + frag_volume=frag_volume.data, + x_plus_y=x_plus_y.data, + rand=rand.data, + fragtol=fragtol, + ) + + self._fragmentation_limiters_body( + n_fragment=n_fragment.data, + frag_volume=frag_volume.data, + x_plus_y=x_plus_y.data, + vmin=vmin, + nfmax=nfmax, + ) + + def gauss_fragmentation( + self, *, n_fragment, mu, sigma, frag_volume, x_plus_y, rand, vmin, nfmax + ): + self._gauss_fragmentation_body( + mu=mu, + sigma=sigma, + frag_volume=frag_volume.data, + rand=rand.data, + ) + self._fragmentation_limiters_body( + n_fragment=n_fragment.data, + frag_volume=frag_volume.data, + x_plus_y=x_plus_y.data, + vmin=vmin, + nfmax=nfmax, + ) + + def straub_fragmentation( + # pylint: disable=too-many-arguments,too-many-locals + self, + *, + n_fragment, + CW, + gam, + ds, + frag_volume, + v_max, + x_plus_y, + rand, + vmin, + nfmax, + Nr1, + Nr2, + Nr3, + Nr4, + Nrt, + d34, + ): + self._straub_fragmentation_body( + CW=CW.data, + gam=gam.data, + ds=ds.data, + frag_volume=frag_volume.data, + v_max=v_max.data, + rand=rand.data, + Nr1=Nr1.data, + Nr2=Nr2.data, + Nr3=Nr3.data, + Nr4=Nr4.data, + Nrt=Nrt.data, + d34=d34.data, + ) + self._fragmentation_limiters_body( + n_fragment=n_fragment.data, + frag_volume=frag_volume.data, + x_plus_y=x_plus_y.data, + vmin=vmin, + nfmax=nfmax, + ) + + def ll82_fragmentation( + # pylint: disable=too-many-arguments,too-many-locals + self, + *, + n_fragment, + CKE, + W, + W2, + St, + ds, + dl, + dcoal, + frag_volume, + x_plus_y, + rand, + vmin, + nfmax, + Rf, + Rs, + Rd, + tol=1e-8, + ): + self._ll82_fragmentation_body( + CKE=CKE.data, + W=W.data, + W2=W2.data, + St=St.data, + ds=ds.data, + dl=dl.data, + dcoal=dcoal.data, + frag_volume=frag_volume.data, + rand=rand.data, + Rf=Rf.data, + Rs=Rs.data, + Rd=Rd.data, + tol=tol, + ) + self._fragmentation_limiters_body( + n_fragment=n_fragment.data, + frag_volume=frag_volume.data, + x_plus_y=x_plus_y.data, + vmin=vmin, + nfmax=nfmax, + ) + + @cached_property + def _ll82_coalescence_check_body(self): + @numba.njit(**self.default_jit_flags) + def body(*, Ec, dl): + for i in numba.prange(len(Ec)): # pylint: disable=not-an-iterable + if dl[i] < 0.4e-3: + Ec[i] = 1.0 + + return body + + def ll82_coalescence_check(self, *, Ec, dl): + self._ll82_coalescence_check_body( + Ec=Ec.data, + dl=dl.data, + ) + + @cached_property + def _straub_fragmentation_body(self): + ff = self.formulae_flattened + + @numba.njit(**self.default_jit_flags) + def body( + *, CW, gam, ds, v_max, frag_volume, rand, Nr1, Nr2, Nr3, Nr4, Nrt, d34 + ): # pylint: disable=too-many-arguments,too-many-locals + for i in numba.prange(len(frag_volume)): # pylint: disable=not-an-iterable + straub_Nr(i, Nr1, Nr2, Nr3, Nr4, Nrt, CW, gam) + sigma1 = ff.fragmentation_function__params_sigma1(CW[i]) + mu1 = ff.fragmentation_function__params_mu1(sigma1) + sigma2 = ff.fragmentation_function__params_sigma2(CW[i]) + mu2 = ff.fragmentation_function__params_mu2(ds[i]) + sigma3 = ff.fragmentation_function__params_sigma3(CW[i]) + mu3 = ff.fragmentation_function__params_mu3(ds[i]) + straub_mass_remainder( + i, + v_max, + ds, + mu1, + sigma1, + mu2, + sigma2, + mu3, + sigma3, + d34, + Nr1, + Nr2, + Nr3, + Nr4, + ) + Nrt[i] = Nr1[i] + Nr2[i] + Nr3[i] + Nr4[i] + + if Nrt[i] == 0.0: + diameter = 0.0 + else: + if rand[i] < Nr1[i] / Nrt[i]: + X = rand[i] * Nrt[i] / Nr1[i] + lnarg = mu1 + np.sqrt(2) * sigma1 * ff.trivia__erfinv_approx(X) + diameter = np.exp(lnarg) + elif rand[i] < (Nr2[i] + Nr1[i]) / Nrt[i]: + X = (rand[i] * Nrt[i] - Nr1[i]) / Nr2[i] + diameter = mu2 + np.sqrt(2) * sigma2 * ff.trivia__erfinv_approx( + X + ) + elif rand[i] < (Nr3[i] + Nr2[i] + Nr1[i]) / Nrt[i]: + X = (rand[i] * Nrt[i] - Nr1[i] - Nr2[i]) / Nr3[i] + diameter = mu3 + np.sqrt(2) * sigma3 * ff.trivia__erfinv_approx( + X + ) + else: + diameter = d34[i] + + frag_volume[i] = diameter**3 * ff.constants.PI / 6 + + return body + + @cached_property + def _ll82_fragmentation_body(self): # pylint: disable=too-many-statements + ff = self.formulae_flattened + + @numba.njit(**self.default_jit_flags) + def body( + *, CKE, W, W2, St, ds, dl, dcoal, frag_volume, rand, Rf, Rs, Rd, tol + ): # pylint: disable=too-many-branches,too-many-locals,too-many-statements + for i in numba.prange(len(frag_volume)): # pylint: disable=not-an-iterable + if dl[i] <= 0.4e-3: + frag_volume[i] = dcoal[i] ** 3 * ff.constants.PI / 6 + elif ds[i] == 0.0 or dl[i] == 0.0: + frag_volume[i] = 1e-18 + else: + ll82_Nr(i, Rf, Rs, Rd, CKE, W, W2) + if rand[i] <= Rf[i]: # filament breakup + (H1, mu1, sigma1) = ff.fragmentation_function__params_f1( + dl[i], dcoal[i] + ) + (H2, mu2, sigma2) = ff.fragmentation_function__params_f2(ds[i]) + (H3, mu3, sigma3) = ff.fragmentation_function__params_f3( + ds[i], dl[i] + ) + H1 = H1 * mu1 + H2 = H2 * mu2 + H3 = H3 * np.exp(mu3) + Hsum = H1 + H2 + H3 + rand[i] = rand[i] / Rf[i] + if rand[i] <= H1 / Hsum: + X = max(rand[i] * Hsum / H1, tol) + frag_volume[i] = mu1 + np.sqrt( + 2 + ) * sigma1 * ff.trivia__erfinv_approx(2 * X - 1) + elif rand[i] <= (H1 + H2) / Hsum: + X = (rand[i] * Hsum - H1) / H2 + frag_volume[i] = mu2 + np.sqrt( + 2 + ) * sigma2 * ff.trivia__erfinv_approx(2 * X - 1) + else: + X = min((rand[i] * Hsum - H1 - H2) / H3, 1.0 - tol) + lnarg = mu3 + np.sqrt( + 2 + ) * sigma3 * ff.trivia__erfinv_approx(2 * X - 1) + frag_volume[i] = np.exp(lnarg) + + elif rand[i] <= Rf[i] + Rs[i]: # sheet breakup + (H1, mu1, sigma1) = ff.fragmentation_function__params_s1( + dl[i], ds[i], dcoal[i] + ) + (H2, mu2, sigma2) = ff.fragmentation_function__params_s2( + dl[i], ds[i], St[i] + ) + H1 = H1 * mu1 + H2 = H2 * np.exp(mu2) + Hsum = H1 + H2 + rand[i] = (rand[i] - Rf[i]) / (Rs[i]) + if rand[i] <= H1 / Hsum: + X = max(rand[i] * Hsum / H1, tol) + frag_volume[i] = mu1 + np.sqrt( + 2 + ) * sigma1 * ff.trivia__erfinv_approx(2 * X - 1) + else: + X = min((rand[i] * Hsum - H1) / H2, 1.0 - tol) + lnarg = mu2 + np.sqrt( + 2 + ) * sigma2 * ff.trivia__erfinv_approx(2 * X - 1) + frag_volume[i] = np.exp(lnarg) + + else: # disk breakup + (H1, mu1, sigma1) = ff.fragmentation_function__params_d1( + W[i], dl[i], dcoal[i], CKE[i] + ) + (H2, mu2, sigma2) = ff.fragmentation_function__params_d2( + ds[i], dl[i], CKE[i] + ) + H1 = H1 * mu1 + Hsum = H1 + H2 + rand[i] = (rand[i] - Rf[i] - Rs[i]) / Rd[i] + if rand[i] <= H1 / Hsum: + X = max(rand[i] * Hsum / H1, tol) + frag_volume[i] = mu1 + np.sqrt( + 2 + ) * sigma1 * ff.trivia__erfinv_approx(2 * X - 1) + else: + X = min((rand[i] * Hsum - H1) / H2, 1 - tol) + lnarg = mu2 + np.sqrt( + 2 + ) * sigma2 * ff.trivia__erfinv_approx(2 * X - 1) + frag_volume[i] = np.exp(lnarg) + + frag_volume[i] = ( + frag_volume[i] * 0.01 + ) # diameter in cm; convert to m + frag_volume[i] = frag_volume[i] ** 3 * ff.constants.PI / 6 + + return body + + @cached_property + def _gauss_fragmentation_body(self): + ff = self.formulae_flattened + + @numba.njit(**self.default_jit_flags) + def body(*, mu, sigma, frag_volume, rand): # pylint: disable=too-many-arguments + for i in numba.prange(len(frag_volume)): # pylint: disable=not-an-iterable + frag_volume[i] = mu + sigma * ff.trivia__erfinv_approx(rand[i]) + + return body + + @cached_property + def _feingold1988_fragmentation_body(self): + ff = self.formulae_flattened + + @numba.njit(**self.default_jit_flags) + # pylint: disable=too-many-arguments + def body(*, scale, frag_volume, x_plus_y, rand, fragtol): + for i in numba.prange(len(frag_volume)): # pylint: disable=not-an-iterable + frag_volume[i] = ff.fragmentation_function__frag_volume( + scale, rand[i], x_plus_y[i], fragtol + ) + + return body diff --git a/PySDM/backends/impl_numba/methods/freezing_methods.py b/PySDM/backends/impl_numba/methods/freezing_methods.py index d3037ff7b..3a973d7bb 100644 --- a/PySDM/backends/impl_numba/methods/freezing_methods.py +++ b/PySDM/backends/impl_numba/methods/freezing_methods.py @@ -11,7 +11,6 @@ SingularAttributes, TimeDependentAttributes, ) -from ...impl_numba import conf class FreezingMethods(BackendMethods): @@ -22,21 +21,17 @@ def __init__(self): self.formulae.trivia.frozen_and_above_freezing_point ) - @numba.njit( - **{**conf.JIT_FLAGS, "fastmath": self.formulae.fastmath, "parallel": False} - ) + @numba.njit(**{**self.default_jit_flags, "parallel": False}) def _freeze(water_mass, i): water_mass[i] = -1 * water_mass[i] # TODO #599: change thd (latent heat)! - @numba.njit( - **{**conf.JIT_FLAGS, "fastmath": self.formulae.fastmath, "parallel": False} - ) + @numba.njit(**{**self.default_jit_flags, "parallel": False}) def _thaw(water_mass, i): water_mass[i] = -1 * water_mass[i] # TODO #599: change thd (latent heat)! - @numba.njit(**{**conf.JIT_FLAGS, "fastmath": self.formulae.fastmath}) + @numba.njit(**self.default_jit_flags) def freeze_singular_body( attributes, temperature, relative_humidity, cell, thaw ): @@ -60,7 +55,7 @@ def freeze_singular_body( j_het = self.formulae.heterogeneous_ice_nucleation_rate.j_het - @numba.njit(**{**conf.JIT_FLAGS, "fastmath": self.formulae.fastmath}) + @numba.njit(**self.default_jit_flags) def freeze_time_dependent_body( # pylint: disable=unused-argument,too-many-arguments rand, attributes, diff --git a/PySDM/backends/impl_numba/methods/index_methods.py b/PySDM/backends/impl_numba/methods/index_methods.py index 119be41a4..caa92108a 100644 --- a/PySDM/backends/impl_numba/methods/index_methods.py +++ b/PySDM/backends/impl_numba/methods/index_methods.py @@ -2,34 +2,47 @@ CPU implementation of shuffling and sorting backend methods """ +from functools import cached_property + import numba from PySDM.backends.impl_common.backend_methods import BackendMethods -from PySDM.backends.impl_numba import conf class IndexMethods(BackendMethods): - @staticmethod - @numba.njit(**conf.JIT_FLAGS) - def identity_index(idx): - for i in numba.prange(len(idx)): # pylint: disable=not-an-iterable - idx[i] = i + @cached_property + def identity_index(self): + @numba.njit(**self.default_jit_flags) + def body(idx): + for i in numba.prange(len(idx)): # pylint: disable=not-an-iterable + idx[i] = i - @staticmethod - @numba.njit(**{**conf.JIT_FLAGS, **{"parallel": False}}) - def shuffle_global(idx, length, u01): - for i in range(length - 1, 0, -1): - j = int(u01[i] * (i + 1)) - idx[i], idx[j] = idx[j], idx[i] + return body - @staticmethod - @numba.njit(**conf.JIT_FLAGS) - def shuffle_local(idx, u01, cell_start): - for c in numba.prange(len(cell_start) - 1): # pylint: disable=not-an-iterable - for i in range(cell_start[c + 1] - 1, cell_start[c], -1): - j = int(cell_start[c] + u01[i] * (cell_start[c + 1] - cell_start[c])) + @cached_property + def shuffle_global(self): + @numba.njit(**{**self.default_jit_flags, "parallel": False}) + def body(idx, length, u01): + for i in range(length - 1, 0, -1): + j = int(u01[i] * (i + 1)) idx[i], idx[j] = idx[j], idx[i] + return body + + @cached_property + def shuffle_local(self): + @numba.njit(**self.default_jit_flags) + def body(idx, u01, cell_start): + # pylint: disable=not-an-iterable + for c in numba.prange(len(cell_start) - 1): + for i in range(cell_start[c + 1] - 1, cell_start[c], -1): + j = int( + cell_start[c] + u01[i] * (cell_start[c + 1] - cell_start[c]) + ) + idx[i], idx[j] = idx[j], idx[i] + + return body + @staticmethod def sort_by_key(idx, attr): idx.data[:] = attr.data.argsort(kind="stable")[::-1] diff --git a/PySDM/backends/impl_numba/methods/isotope_methods.py b/PySDM/backends/impl_numba/methods/isotope_methods.py index eed4039da..f33e6761d 100644 --- a/PySDM/backends/impl_numba/methods/isotope_methods.py +++ b/PySDM/backends/impl_numba/methods/isotope_methods.py @@ -7,23 +7,22 @@ import numba from PySDM.backends.impl_common.backend_methods import BackendMethods -from PySDM.backends.impl_numba import conf class IsotopeMethods(BackendMethods): @cached_property - def __isotopic_delta_body(self): - phys_isotopic_delta = self.formulae.trivia.isotopic_ratio_2_delta + def _isotopic_delta_body(self): + ff = self.formulae_flattened - @numba.njit(**{**conf.JIT_FLAGS, "fastmath": self.formulae.fastmath}) - def isotopic_delta(output, ratio, reference_ratio): + @numba.njit(**self.default_jit_flags) + def body(output, ratio, reference_ratio): for i in numba.prange(output.shape[0]): # pylint: disable=not-an-iterable - output[i] = phys_isotopic_delta(ratio[i], reference_ratio) + output[i] = ff.trivia__isotopic_ratio_2_delta(ratio[i], reference_ratio) - return isotopic_delta + return body def isotopic_delta(self, output, ratio, reference_ratio): - self.__isotopic_delta_body(output.data, ratio.data, reference_ratio) + self._isotopic_delta_body(output.data, ratio.data, reference_ratio) def isotopic_fractionation(self): pass diff --git a/PySDM/backends/impl_numba/methods/moments_methods.py b/PySDM/backends/impl_numba/methods/moments_methods.py index cb3f00072..29cfba19a 100644 --- a/PySDM/backends/impl_numba/methods/moments_methods.py +++ b/PySDM/backends/impl_numba/methods/moments_methods.py @@ -2,63 +2,69 @@ CPU implementation of moment calculation backend methods """ +from functools import cached_property + import numba from PySDM.backends.impl_common.backend_methods import BackendMethods -from PySDM.backends.impl_numba import conf from PySDM.backends.impl_numba.atomic_operations import atomic_add class MomentsMethods(BackendMethods): - @staticmethod - @numba.njit(**conf.JIT_FLAGS) - def moments_body( - *, - moment_0, - moments, - multiplicity, - attr_data, - cell_id, - idx, - length, - ranks, - min_x, - max_x, - x_attr, - weighting_attribute, - weighting_rank, - skip_division_by_m0, - ): - # pylint: disable=too-many-locals - moment_0[:] = 0 - moments[:, :] = 0 - for idx_i in numba.prange(length): # pylint: disable=not-an-iterable - i = idx[idx_i] - if min_x <= x_attr[i] < max_x: - atomic_add( - moment_0, - cell_id[i], - multiplicity[i] * weighting_attribute[i] ** weighting_rank, - ) - for k in range(ranks.shape[0]): + @cached_property + def _moments_body(self): + @numba.njit(**self.default_jit_flags) + def body( + *, + moment_0, + moments, + multiplicity, + attr_data, + cell_id, + idx, + length, + ranks, + min_x, + max_x, + x_attr, + weighting_attribute, + weighting_rank, + skip_division_by_m0, + ): + # pylint: disable=too-many-locals + moment_0[:] = 0 + moments[:, :] = 0 + for idx_i in numba.prange(length): # pylint: disable=not-an-iterable + i = idx[idx_i] + if min_x <= x_attr[i] < max_x: atomic_add( - moments, - (k, cell_id[i]), - ( - multiplicity[i] - * weighting_attribute[i] ** weighting_rank - * attr_data[i] ** ranks[k] - ), - ) - if not skip_division_by_m0: - for c_id in range(moment_0.shape[0]): - for k in range(ranks.shape[0]): - moments[k, c_id] = ( - moments[k, c_id] / moment_0[c_id] if moment_0[c_id] != 0 else 0 + moment_0, + cell_id[i], + multiplicity[i] * weighting_attribute[i] ** weighting_rank, ) + for k in range(ranks.shape[0]): + atomic_add( + moments, + (k, cell_id[i]), + ( + multiplicity[i] + * weighting_attribute[i] ** weighting_rank + * attr_data[i] ** ranks[k] + ), + ) + if not skip_division_by_m0: + for c_id in range(moment_0.shape[0]): + for k in range(ranks.shape[0]): + moments[k, c_id] = ( + moments[k, c_id] / moment_0[c_id] + if moment_0[c_id] != 0 + else 0 + ) + + return body - @staticmethod def moments( + self, *, moment_0, moments, @@ -75,7 +81,7 @@ def moments( weighting_rank, skip_division_by_m0, ): - return MomentsMethods.moments_body( + return self._moments_body( moment_0=moment_0.data, moments=moments.data, multiplicity=multiplicity.data, @@ -92,55 +98,58 @@ def moments( skip_division_by_m0=skip_division_by_m0, ) - @staticmethod - @numba.njit(**conf.JIT_FLAGS) - def spectrum_moments_body( - *, - moment_0, - moments, - multiplicity, - attr_data, - cell_id, - idx, - length, - rank, - x_bins, - x_attr, - weighting_attribute, - weighting_rank, - ): - # pylint: disable=too-many-locals - moment_0[:, :] = 0 - moments[:, :] = 0 - for idx_i in numba.prange(length): # pylint: disable=not-an-iterable - i = idx[idx_i] - for k in range(x_bins.shape[0] - 1): - if x_bins[k] <= x_attr[i] < x_bins[k + 1]: - atomic_add( - moment_0, - (k, cell_id[i]), - multiplicity[i] * weighting_attribute[i] ** weighting_rank, - ) - atomic_add( - moments, - (k, cell_id[i]), - ( - multiplicity[i] - * weighting_attribute[i] ** weighting_rank - * attr_data[i] ** rank - ), + @cached_property + def _spectrum_moments_body(self): + @numba.njit(**self.default_jit_flags) + def body( + *, + moment_0, + moments, + multiplicity, + attr_data, + cell_id, + idx, + length, + rank, + x_bins, + x_attr, + weighting_attribute, + weighting_rank, + ): + # pylint: disable=too-many-locals + moment_0[:, :] = 0 + moments[:, :] = 0 + for idx_i in numba.prange(length): # pylint: disable=not-an-iterable + i = idx[idx_i] + for k in range(x_bins.shape[0] - 1): + if x_bins[k] <= x_attr[i] < x_bins[k + 1]: + atomic_add( + moment_0, + (k, cell_id[i]), + multiplicity[i] * weighting_attribute[i] ** weighting_rank, + ) + atomic_add( + moments, + (k, cell_id[i]), + ( + multiplicity[i] + * weighting_attribute[i] ** weighting_rank + * attr_data[i] ** rank + ), + ) + break + for c_id in range(moment_0.shape[1]): + for k in range(x_bins.shape[0] - 1): + moments[k, c_id] = ( + moments[k, c_id] / moment_0[k, c_id] + if moment_0[k, c_id] != 0 + else 0 ) - break - for c_id in range(moment_0.shape[1]): - for k in range(x_bins.shape[0] - 1): - moments[k, c_id] = ( - moments[k, c_id] / moment_0[k, c_id] - if moment_0[k, c_id] != 0 - else 0 - ) - @staticmethod + return body + def spectrum_moments( + self, *, moment_0, moments, @@ -157,7 +166,7 @@ def spectrum_moments( ): assert moments.shape[0] == x_bins.shape[0] - 1 assert moment_0.shape == moments.shape - return MomentsMethods.spectrum_moments_body( + return self._spectrum_moments_body( moment_0=moment_0.data, moments=moments.data, multiplicity=multiplicity.data, diff --git a/PySDM/backends/impl_numba/methods/pair_methods.py b/PySDM/backends/impl_numba/methods/pair_methods.py index d726ed53d..48449873d 100644 --- a/PySDM/backends/impl_numba/methods/pair_methods.py +++ b/PySDM/backends/impl_numba/methods/pair_methods.py @@ -2,25 +2,28 @@ CPU implementation of pairwise operations backend methods """ +from functools import cached_property + import numba import numpy as np from PySDM.backends.impl_common.backend_methods import BackendMethods -from PySDM.backends.impl_numba import conf class PairMethods(BackendMethods): - @staticmethod - @numba.njit(**conf.JIT_FLAGS) - def distance_pair_body(data_out, data_in, is_first_in_pair, idx, length): - data_out[:] = 0 - for i in numba.prange(length - 1): # pylint: disable=not-an-iterable - if is_first_in_pair[i]: - data_out[i // 2] = np.abs(data_in[idx[i]] - data_in[idx[i + 1]]) - - @staticmethod - def distance_pair(data_out, data_in, is_first_in_pair, idx): - return PairMethods.distance_pair_body( + @cached_property + def _distance_pair_body(self): + @numba.njit(**self.default_jit_flags) + def body(data_out, data_in, is_first_in_pair, idx, length): + data_out[:] = 0 + for i in numba.prange(length - 1): # pylint: disable=not-an-iterable + if is_first_in_pair[i]: + data_out[i // 2] = np.abs(data_in[idx[i]] - data_in[idx[i + 1]]) + + return body + + def distance_pair(self, data_out, data_in, is_first_in_pair, idx): + return self._distance_pair_body( data_out.data, data_in.data, is_first_in_pair.indicator.data, @@ -28,20 +31,21 @@ def distance_pair(data_out, data_in, is_first_in_pair, idx): len(idx), ) - @staticmethod - @numba.njit(**conf.JIT_FLAGS) - def find_pairs_body( - *, cell_start, is_first_in_pair, cell_id, cell_idx, idx, length - ): - for i in numba.prange(length - 1): # pylint: disable=not-an-iterable - is_in_same_cell = cell_id[idx[i]] == cell_id[idx[i + 1]] - is_even_index = (i - cell_start[cell_idx[cell_id[idx[i]]]]) % 2 == 0 - is_first_in_pair[i] = is_in_same_cell and is_even_index - is_first_in_pair[length - 1] = False - - @staticmethod - def find_pairs(cell_start, is_first_in_pair, cell_id, cell_idx, idx): - return PairMethods.find_pairs_body( + @cached_property + def _find_pairs_body(self): + @numba.njit(**self.default_jit_flags) + def body(*, cell_start, is_first_in_pair, cell_id, cell_idx, idx, length): + for i in numba.prange(length - 1): # pylint: disable=not-an-iterable + is_in_same_cell = cell_id[idx[i]] == cell_id[idx[i + 1]] + is_even_index = (i - cell_start[cell_idx[cell_id[idx[i]]]]) % 2 == 0 + is_first_in_pair[i] = is_in_same_cell and is_even_index + is_first_in_pair[length - 1] = False + + return body + + # pylint: disable=too-many-arguments + def find_pairs(self, cell_start, is_first_in_pair, cell_id, cell_idx, idx): + return self._find_pairs_body( cell_start=cell_start.data, is_first_in_pair=is_first_in_pair.indicator.data, cell_id=cell_id.data, @@ -50,17 +54,19 @@ def find_pairs(cell_start, is_first_in_pair, cell_id, cell_idx, idx): length=len(idx), ) - @staticmethod - @numba.njit(**conf.JIT_FLAGS) - def max_pair_body(data_out, data_in, is_first_in_pair, idx, length): - data_out[:] = 0 - for i in numba.prange(length - 1): # pylint: disable=not-an-iterable - if is_first_in_pair[i]: - data_out[i // 2] = max(data_in[idx[i]], data_in[idx[i + 1]]) - - @staticmethod - def max_pair(data_out, data_in, is_first_in_pair, idx): - return PairMethods.max_pair_body( + @cached_property + def _max_pair_body(self): + @numba.njit(**self.default_jit_flags) + def body(data_out, data_in, is_first_in_pair, idx, length): + data_out[:] = 0 + for i in numba.prange(length - 1): # pylint: disable=not-an-iterable + if is_first_in_pair[i]: + data_out[i // 2] = max(data_in[idx[i]], data_in[idx[i + 1]]) + + return body + + def max_pair(self, data_out, data_in, is_first_in_pair, idx): + return self._max_pair_body( data_out.data, data_in.data, is_first_in_pair.indicator.data, @@ -68,17 +74,19 @@ def max_pair(data_out, data_in, is_first_in_pair, idx): len(idx), ) - @staticmethod - @numba.njit(**conf.JIT_FLAGS) - def min_pair_body(data_out, data_in, is_first_in_pair, idx, length): - data_out[:] = 0 - for i in numba.prange(length): # pylint: disable=not-an-iterable - if is_first_in_pair[i]: - data_out[i // 2] = min(data_in[idx[i]], data_in[idx[i + 1]]) - - @staticmethod - def min_pair(data_out, data_in, is_first_in_pair, idx): - return PairMethods.min_pair_body( + @cached_property + def _min_pair_body(self): + @numba.njit(**self.default_jit_flags) + def body(data_out, data_in, is_first_in_pair, idx, length): + data_out[:] = 0 + for i in numba.prange(length): # pylint: disable=not-an-iterable + if is_first_in_pair[i]: + data_out[i // 2] = min(data_in[idx[i]], data_in[idx[i + 1]]) + + return body + + def min_pair(self, data_out, data_in, is_first_in_pair, idx): + return self._min_pair_body( data_out.data, data_in.data, is_first_in_pair.indicator.data, @@ -86,20 +94,28 @@ def min_pair(data_out, data_in, is_first_in_pair, idx): len(idx), ) - @staticmethod - @numba.njit(**conf.JIT_FLAGS) - def sort_pair_body(data_out, data_in, is_first_in_pair, idx, length): - data_out[:] = 0 - for i in numba.prange(length - 1): # pylint: disable=not-an-iterable - if is_first_in_pair[i]: - if data_in[idx[i]] < data_in[idx[i + 1]]: - data_out[i], data_out[i + 1] = data_in[idx[i + 1]], data_in[idx[i]] - else: - data_out[i], data_out[i + 1] = data_in[idx[i]], data_in[idx[i + 1]] - - @staticmethod - def sort_pair(data_out, data_in, is_first_in_pair, idx): - return PairMethods.sort_pair_body( + @cached_property + def _sort_pair_body(self): + @numba.njit(**self.default_jit_flags) + def body(data_out, data_in, is_first_in_pair, idx, length): + data_out[:] = 0 + for i in numba.prange(length - 1): # pylint: disable=not-an-iterable + if is_first_in_pair[i]: + if data_in[idx[i]] < data_in[idx[i + 1]]: + data_out[i], data_out[i + 1] = ( + data_in[idx[i + 1]], + data_in[idx[i]], + ) + else: + data_out[i], data_out[i + 1] = ( + data_in[idx[i]], + data_in[idx[i + 1]], + ) + + return body + + def sort_pair(self, data_out, data_in, is_first_in_pair, idx): + return self._sort_pair_body( data_out.data, data_in.data, is_first_in_pair.indicator.data, @@ -107,31 +123,35 @@ def sort_pair(data_out, data_in, is_first_in_pair, idx): len(idx), ) - @staticmethod - @numba.njit(**conf.JIT_FLAGS) - def sort_within_pair_by_attr_body(idx, length, is_first_in_pair, attr): - for i in numba.prange(length - 1): # pylint: disable=not-an-iterable - if is_first_in_pair[i]: - if attr[idx[i]] < attr[idx[i + 1]]: - idx[i], idx[i + 1] = idx[i + 1], idx[i] - - @staticmethod - def sort_within_pair_by_attr(idx, is_first_in_pair, attr): - PairMethods.sort_within_pair_by_attr_body( + @cached_property + def _sort_within_pair_by_attr_body(self): + @numba.njit(**self.default_jit_flags) + def body(idx, length, is_first_in_pair, attr): + for i in numba.prange(length - 1): # pylint: disable=not-an-iterable + if is_first_in_pair[i]: + if attr[idx[i]] < attr[idx[i + 1]]: + idx[i], idx[i + 1] = idx[i + 1], idx[i] + + return body + + def sort_within_pair_by_attr(self, idx, is_first_in_pair, attr): + self._sort_within_pair_by_attr_body( idx.data, len(idx), is_first_in_pair.indicator.data, attr.data ) - @staticmethod - @numba.njit(**conf.JIT_FLAGS) - def sum_pair_body(data_out, data_in, is_first_in_pair, idx, length): - data_out[:] = 0 - for i in numba.prange(length): # pylint: disable=not-an-iterable - if is_first_in_pair[i]: - data_out[i // 2] = data_in[idx[i]] + data_in[idx[i + 1]] - - @staticmethod - def sum_pair(data_out, data_in, is_first_in_pair, idx): - return PairMethods.sum_pair_body( + @cached_property + def _sum_pair_body(self): + @numba.njit(**self.default_jit_flags) + def body(data_out, data_in, is_first_in_pair, idx, length): + data_out[:] = 0 + for i in numba.prange(length): # pylint: disable=not-an-iterable + if is_first_in_pair[i]: + data_out[i // 2] = data_in[idx[i]] + data_in[idx[i + 1]] + + return body + + def sum_pair(self, data_out, data_in, is_first_in_pair, idx): + return self._sum_pair_body( data_out.data, data_in.data, is_first_in_pair.indicator.data, @@ -139,17 +159,19 @@ def sum_pair(data_out, data_in, is_first_in_pair, idx): len(idx), ) - @staticmethod - @numba.njit(**conf.JIT_FLAGS) - def multiply_pair_body(data_out, data_in, is_first_in_pair, idx, length): - data_out[:] = 0 - for i in numba.prange(length - 1): # pylint: disable=not-an-iterable - if is_first_in_pair[i]: - data_out[i // 2] = data_in[idx[i]] * data_in[idx[i + 1]] - - @staticmethod - def multiply_pair(data_out, data_in, is_first_in_pair, idx): - return PairMethods.multiply_pair_body( + @cached_property + def _multiply_pair_body(self): + @numba.njit(**self.default_jit_flags) + def body(data_out, data_in, is_first_in_pair, idx, length): + data_out[:] = 0 + for i in numba.prange(length - 1): # pylint: disable=not-an-iterable + if is_first_in_pair[i]: + data_out[i // 2] = data_in[idx[i]] * data_in[idx[i + 1]] + + return body + + def multiply_pair(self, data_out, data_in, is_first_in_pair, idx): + return self._multiply_pair_body( data_out.data, data_in.data, is_first_in_pair.indicator.data, diff --git a/PySDM/backends/impl_numba/methods/physics_methods.py b/PySDM/backends/impl_numba/methods/physics_methods.py index 0d035c9bd..f83efa137 100644 --- a/PySDM/backends/impl_numba/methods/physics_methods.py +++ b/PySDM/backends/impl_numba/methods/physics_methods.py @@ -8,89 +8,65 @@ from numba import prange from PySDM.backends.impl_common.backend_methods import BackendMethods -from PySDM.backends.impl_numba import conf class PhysicsMethods(BackendMethods): - def __init__(self): # pylint: disable=too-many-locals + def __init__(self): BackendMethods.__init__(self) - pvs_C = self.formulae.saturation_vapour_pressure.pvs_Celsius - pvi_C = self.formulae.saturation_vapour_pressure.ice_Celsius - phys_T = self.formulae.state_variable_triplet.T - phys_p = self.formulae.state_variable_triplet.p - phys_pv = self.formulae.state_variable_triplet.pv - explicit_euler = self.formulae.trivia.explicit_euler - phys_sigma = self.formulae.surface_tension.sigma - phys_volume = self.formulae.trivia.volume - phys_r_cr = self.formulae.hygroscopicity.r_cr - phys_mass_to_volume = self.formulae.particle_shape_and_density.mass_to_volume - phys_volume_to_mass = self.formulae.particle_shape_and_density.volume_to_mass - const = self.formulae.constants - - @numba.njit(**{**conf.JIT_FLAGS, "fastmath": self.formulae.fastmath}) - def explicit_euler_body(y, dt, dy_dt): - y[:] = explicit_euler(y, dt, dy_dt) - - self.explicit_euler_body = explicit_euler_body - - @numba.njit(**{**conf.JIT_FLAGS, "fastmath": self.formulae.fastmath}) - def critical_volume(*, v_cr, kappa, f_org, v_dry, v_wet, T, cell): + + @cached_property + def _critical_volume_body(self): + ff = self.formulae_flattened + + @numba.njit(**self.default_jit_flags) + def body(*, v_cr, kappa, f_org, v_dry, v_wet, T, cell): for i in prange(len(v_cr)): # pylint: disable=not-an-iterable - sigma = phys_sigma(T[cell[i]], v_wet[i], v_dry[i], f_org[i]) - v_cr[i] = phys_volume( - phys_r_cr( + sigma = ff.surface_tension__sigma( + T[cell[i]], v_wet[i], v_dry[i], f_org[i] + ) + v_cr[i] = ff.trivia__volume( + ff.hygroscopicity__r_cr( kp=kappa[i], - rd3=v_dry[i] / const.PI_4_3, + rd3=v_dry[i] / ff.constants.PI_4_3, T=T[cell[i]], sgm=sigma, ) ) - self.critical_volume_body = critical_volume - - @numba.njit(**{**conf.JIT_FLAGS, "fastmath": self.formulae.fastmath}) - def temperature_pressure_RH_body( - *, rhod, thd, water_vapour_mixing_ratio, T, p, RH - ): - for i in prange(T.shape[0]): # pylint: disable=not-an-iterable - T[i] = phys_T(rhod[i], thd[i]) - p[i] = phys_p(rhod[i], T[i], water_vapour_mixing_ratio[i]) - RH[i] = phys_pv(p[i], water_vapour_mixing_ratio[i]) / pvs_C( - T[i] - const.T0 - ) - - self.temperature_pressure_RH_body = temperature_pressure_RH_body - - @numba.njit(**{**conf.JIT_FLAGS, "fastmath": self.formulae.fastmath}) - def a_w_ice_body( - *, T_in, p_in, RH_in, water_vapour_mixing_ratio_in, a_w_ice_out - ): - for i in prange(T_in.shape[0]): # pylint: disable=not-an-iterable - pvi = pvi_C(T_in[i] - const.T0) - pv = phys_pv(p_in[i], water_vapour_mixing_ratio_in[i]) - pvs = pv / RH_in[i] - a_w_ice_out[i] = pvi / pvs - - self.a_w_ice_body = a_w_ice_body + return body - @numba.njit(**{**conf.JIT_FLAGS, "fastmath": self.formulae.fastmath}) - def volume_of_mass(volume, mass): - for i in prange(volume.shape[0]): # pylint: disable=not-an-iterable - volume[i] = phys_mass_to_volume(mass[i]) + def critical_volume(self, *, v_cr, kappa, f_org, v_dry, v_wet, T, cell): + self._critical_volume_body( + v_cr=v_cr.data, + kappa=kappa.data, + f_org=f_org.data, + v_dry=v_dry.data, + v_wet=v_wet.data, + T=T.data, + cell=cell.data, + ) - self.volume_of_mass_body = volume_of_mass + @cached_property + def _temperature_pressure_rh_body(self): + ff = self.formulae_flattened - @numba.njit(**{**conf.JIT_FLAGS, "fastmath": self.formulae.fastmath}) - def mass_of_volume(mass, volume): - for i in prange(volume.shape[0]): # pylint: disable=not-an-iterable - mass[i] = phys_volume_to_mass(volume[i]) + @numba.njit(**self.default_jit_flags) + def body(*, rhod, thd, water_vapour_mixing_ratio, T, p, RH): + for i in prange(T.shape[0]): # pylint: disable=not-an-iterable + T[i] = ff.state_variable_triplet__T(rhod[i], thd[i]) + p[i] = ff.state_variable_triplet__p( + rhod[i], T[i], water_vapour_mixing_ratio[i] + ) + RH[i] = ff.state_variable_triplet__pv( + p[i], water_vapour_mixing_ratio[i] + ) / ff.saturation_vapour_pressure__pvs_Celsius(T[i] - ff.constants.T0) - self.mass_of_volume_body = mass_of_volume + return body - def temperature_pressure_RH( + def temperature_pressure_rh( self, *, rhod, thd, water_vapour_mixing_ratio, T, p, RH ): - self.temperature_pressure_RH_body( + self._temperature_pressure_rh_body( rhod=rhod.data, thd=thd.data, water_vapour_mixing_ratio=water_vapour_mixing_ratio.data, @@ -99,22 +75,26 @@ def temperature_pressure_RH( RH=RH.data, ) - def explicit_euler(self, y, dt, dy_dt): - self.explicit_euler_body(y.data, dt, dy_dt) + @cached_property + def _a_w_ice_body(self): + ff = self.formulae_flattened - def critical_volume(self, *, v_cr, kappa, f_org, v_dry, v_wet, T, cell): - self.critical_volume_body( - v_cr=v_cr.data, - kappa=kappa.data, - f_org=f_org.data, - v_dry=v_dry.data, - v_wet=v_wet.data, - T=T.data, - cell=cell.data, - ) + @numba.njit(**self.default_jit_flags) + def body(*, T_in, p_in, RH_in, water_vapour_mixing_ratio_in, a_w_ice_out): + for i in prange(T_in.shape[0]): # pylint: disable=not-an-iterable + pvi = ff.saturation_vapour_pressure__ice_Celsius( + T_in[i] - ff.constants.T0 + ) + pv = ff.state_variable_triplet__pv( + p_in[i], water_vapour_mixing_ratio_in[i] + ) + pvs = pv / RH_in[i] + a_w_ice_out[i] = pvi / pvs + + return body def a_w_ice(self, *, T, p, RH, water_vapour_mixing_ratio, a_w_ice): - self.a_w_ice_body( + self._a_w_ice_body( T_in=T.data, p_in=p.data, RH_in=RH.data, @@ -122,17 +102,39 @@ def a_w_ice(self, *, T, p, RH, water_vapour_mixing_ratio, a_w_ice): a_w_ice_out=a_w_ice.data, ) + @cached_property + def _volume_of_mass_body(self): + ff = self.formulae_flattened + + @numba.njit(**self.default_jit_flags) + def body(volume, mass): + for i in prange(volume.shape[0]): # pylint: disable=not-an-iterable + volume[i] = ff.particle_shape_and_density__mass_to_volume(mass[i]) + + return body + def volume_of_water_mass(self, volume, mass): - self.volume_of_mass_body(volume.data, mass.data) + self._volume_of_mass_body(volume.data, mass.data) + + @cached_property + def _mass_of_volume_body(self): + ff = self.formulae_flattened + + @numba.njit(**self.default_jit_flags) + def body(mass, volume): + for i in prange(volume.shape[0]): # pylint: disable=not-an-iterable + mass[i] = ff.particle_shape_and_density__volume_to_mass(volume[i]) + + return body def mass_of_water_volume(self, mass, volume): - self.mass_of_volume_body(mass.data, volume.data) + self._mass_of_volume_body(mass.data, volume.data) @cached_property def __air_density_body(self): formulae = self.formulae.flatten - @numba.njit(**{**conf.JIT_FLAGS, "fastmath": formulae.fastmath}) + @numba.njit(**self.default_jit_flags) def body(output, rhod, water_vapour_mixing_ratio): for i in numba.prange(output.shape[0]): # pylint: disable=not-an-iterable output[i] = ( @@ -150,7 +152,7 @@ def air_density(self, *, output, rhod, water_vapour_mixing_ratio): def __air_dynamic_viscosity_body(self): formulae = self.formulae.flatten - @numba.njit(**{**conf.JIT_FLAGS, "fastmath": formulae.fastmath}) + @numba.njit(**self.default_jit_flags) def body(output, temperature): for i in numba.prange(output.shape[0]): # pylint: disable=not-an-iterable output[i] = formulae.air_dynamic_viscosity__eta_air(temperature[i]) @@ -164,7 +166,7 @@ def air_dynamic_viscosity(self, *, output, temperature): def __reynolds_number_body(self): formulae = self.formulae.flatten - @numba.njit(**{**conf.JIT_FLAGS, "fastmath": formulae.fastmath}) + @numba.njit(**self.default_jit_flags) def body( # pylint: disable=too-many-arguments output, cell_id, @@ -194,3 +196,16 @@ def reynolds_number( radius.data, velocity_wrt_air.data, ) + + @cached_property + def _explicit_euler_body(self): + ff = self.formulae_flattened + + @numba.njit(**self.default_jit_flags) + def body(y, dt, dy_dt): + y[:] = ff.trivia__explicit_euler(y, dt, dy_dt) + + return body + + def explicit_euler(self, y, dt, dy_dt): + self._explicit_euler_body(y.data, dt, dy_dt) diff --git a/PySDM/backends/impl_numba/methods/terminal_velocity_methods.py b/PySDM/backends/impl_numba/methods/terminal_velocity_methods.py index 68d35f92d..ee0e14e6b 100644 --- a/PySDM/backends/impl_numba/methods/terminal_velocity_methods.py +++ b/PySDM/backends/impl_numba/methods/terminal_velocity_methods.py @@ -7,14 +7,13 @@ import numba from PySDM.backends.impl_common.backend_methods import BackendMethods -from PySDM.backends.impl_numba import conf class TerminalVelocityMethods(BackendMethods): @cached_property - def interpolation_body(self): - @numba.njit(**{**conf.JIT_FLAGS, "fastmath": self.formulae.fastmath}) - def interpolation_body(output, radius, factor, b, c): + def _interpolation_body(self): + @numba.njit(**self.default_jit_flags) + def body(output, radius, factor, b, c): for i in numba.prange(len(radius)): # pylint: disable=not-an-iterable if radius[i] < 0: output[i] = 0 @@ -23,38 +22,40 @@ def interpolation_body(output, radius, factor, b, c): r_rest = ((factor * radius[i]) % 1) / factor output[i] = b[r_id] + r_rest * c[r_id] - return interpolation_body + return body def interpolation(self, *, output, radius, factor, b, c): - return self.interpolation_body(output.data, radius.data, factor, b.data, c.data) + return self._interpolation_body( + output.data, radius.data, factor, b.data, c.data + ) @cached_property - def terminal_velocity_body(self): + def _terminal_velocity_body(self): v_term = self.formulae.terminal_velocity.v_term - @numba.njit(**{**conf.JIT_FLAGS, "fastmath": self.formulae.fastmath}) - def terminal_velocity_body(*, values, radius): + @numba.njit(**self.default_jit_flags) + def body(*, values, radius): for i in numba.prange(len(values)): # pylint: disable=not-an-iterable values[i] = v_term(radius[i]) - return terminal_velocity_body + return body def terminal_velocity(self, *, values, radius): - self.terminal_velocity_body(values=values, radius=radius) + self._terminal_velocity_body(values=values, radius=radius) @cached_property - def power_series_body(self): - @numba.njit(**{**conf.JIT_FLAGS, "fastmath": self.formulae.fastmath}) - def power_series_body(*, values, radius, num_terms, prefactors, powers): + def _power_series_body(self): + @numba.njit(**self.default_jit_flags) + def body(*, values, radius, num_terms, prefactors, powers): for i in numba.prange(len(values)): # pylint: disable=not-an-iterable values[i] = 0.0 for j in range(num_terms): values[i] = values[i] + prefactors[j] * radius[i] ** (powers[j] * 3) - return power_series_body + return body def power_series(self, *, values, radius, num_terms, prefactors, powers): - self.power_series_body( + self._power_series_body( values=values, radius=radius, num_terms=num_terms, diff --git a/PySDM/backends/impl_thrust_rtc/methods/physics_methods.py b/PySDM/backends/impl_thrust_rtc/methods/physics_methods.py index 9fa72bbdd..3ad60fdba 100644 --- a/PySDM/backends/impl_thrust_rtc/methods/physics_methods.py +++ b/PySDM/backends/impl_thrust_rtc/methods/physics_methods.py @@ -13,7 +13,7 @@ class PhysicsMethods(ThrustRTCBackendMethods): @cached_property - def _temperature_pressure_RH_body(self): + def _temperature_pressure_rh_body(self): return trtc.For( ("rhod", "thd", "water_vapour_mixing_ratio", "T", "p", "RH"), "i", @@ -108,10 +108,10 @@ def critical_volume(self, *, v_cr, kappa, f_org, v_dry, v_wet, T, cell): ) @nice_thrust(**NICE_THRUST_FLAGS) - def temperature_pressure_RH( + def temperature_pressure_rh( self, *, rhod, thd, water_vapour_mixing_ratio, T, p, RH ): - self._temperature_pressure_RH_body.launch_n( + self._temperature_pressure_rh_body.launch_n( T.shape[0], ( rhod.data, diff --git a/PySDM/backends/numba.py b/PySDM/backends/numba.py index 27793f993..7bbf53a4e 100644 --- a/PySDM/backends/numba.py +++ b/PySDM/backends/numba.py @@ -2,53 +2,54 @@ Multi-threaded CPU backend using LLVM-powered just-in-time compilation """ -from PySDM.backends.impl_numba.methods.chemistry_methods import ChemistryMethods -from PySDM.backends.impl_numba.methods.collisions_methods import CollisionsMethods -from PySDM.backends.impl_numba.methods.condensation_methods import CondensationMethods -from PySDM.backends.impl_numba.methods.displacement_methods import DisplacementMethods -from PySDM.backends.impl_numba.methods.freezing_methods import FreezingMethods -from PySDM.backends.impl_numba.methods.index_methods import IndexMethods -from PySDM.backends.impl_numba.methods.isotope_methods import IsotopeMethods -from PySDM.backends.impl_numba.methods.moments_methods import MomentsMethods -from PySDM.backends.impl_numba.methods.pair_methods import PairMethods -from PySDM.backends.impl_numba.methods.physics_methods import PhysicsMethods -from PySDM.backends.impl_numba.methods.terminal_velocity_methods import ( - TerminalVelocityMethods, -) +from PySDM.backends.impl_numba import methods from PySDM.backends.impl_numba.random import Random as ImportedRandom from PySDM.backends.impl_numba.storage import Storage as ImportedStorage from PySDM.formulae import Formulae +from PySDM.backends.impl_numba.conf import JIT_FLAGS class Numba( # pylint: disable=too-many-ancestors,duplicate-code - CollisionsMethods, - PairMethods, - IndexMethods, - PhysicsMethods, - CondensationMethods, - ChemistryMethods, - MomentsMethods, - FreezingMethods, - DisplacementMethods, - TerminalVelocityMethods, - IsotopeMethods, + methods.CollisionsMethods, + methods.FragmentationMethods, + methods.PairMethods, + methods.IndexMethods, + methods.PhysicsMethods, + methods.CondensationMethods, + methods.ChemistryMethods, + methods.MomentsMethods, + methods.FreezingMethods, + methods.DisplacementMethods, + methods.TerminalVelocityMethods, + methods.IsotopeMethods, ): Storage = ImportedStorage Random = ImportedRandom default_croupier = "local" - def __init__(self, formulae=None, double_precision=True): + def __init__(self, formulae=None, double_precision=True, override_jit_flags=None): if not double_precision: raise NotImplementedError() self.formulae = formulae or Formulae() - CollisionsMethods.__init__(self) - PairMethods.__init__(self) - IndexMethods.__init__(self) - PhysicsMethods.__init__(self) - CondensationMethods.__init__(self) - ChemistryMethods.__init__(self) - MomentsMethods.__init__(self) - FreezingMethods.__init__(self) - DisplacementMethods.__init__(self) - TerminalVelocityMethods.__init__(self) + self.formulae_flattened = self.formulae.flatten + + assert "fastmath" not in (override_jit_flags or {}) + self.default_jit_flags = { + **JIT_FLAGS, + **{"fastmath": self.formulae.fastmath}, + **(override_jit_flags or {}), + } + + methods.CollisionsMethods.__init__(self) + methods.FragmentationMethods.__init__(self) + methods.PairMethods.__init__(self) + methods.IndexMethods.__init__(self) + methods.PhysicsMethods.__init__(self) + methods.CondensationMethods.__init__(self) + methods.ChemistryMethods.__init__(self) + methods.MomentsMethods.__init__(self) + methods.FreezingMethods.__init__(self) + methods.DisplacementMethods.__init__(self) + methods.TerminalVelocityMethods.__init__(self) + methods.IsotopeMethods.__init__(self) diff --git a/PySDM/environments/impl/moist.py b/PySDM/environments/impl/moist.py index 70da1f049..9487a7c82 100644 --- a/PySDM/environments/impl/moist.py +++ b/PySDM/environments/impl/moist.py @@ -62,7 +62,7 @@ def sync(self): target["water_vapour_mixing_ratio"].ravel(self.get_water_vapour_mixing_ratio()) target["thd"].ravel(self.get_thd()) - self.particulator.backend.temperature_pressure_RH( + self.particulator.backend.temperature_pressure_rh( rhod=target["rhod"], thd=target["thd"], water_vapour_mixing_ratio=target["water_vapour_mixing_ratio"], diff --git a/PySDM/particulator.py b/PySDM/particulator.py index bce0fe714..1adcaa738 100644 --- a/PySDM/particulator.py +++ b/PySDM/particulator.py @@ -91,7 +91,7 @@ def normalize(self, prob, norm_factor): ) def update_TpRH(self): - self.backend.temperature_pressure_RH( + self.backend.temperature_pressure_rh( # input rhod=self.environment.get_predicted("rhod"), thd=self.environment.get_predicted("thd"), diff --git a/examples/PySDM_examples/Arabas_and_Shima_2017/simulation.py b/examples/PySDM_examples/Arabas_and_Shima_2017/simulation.py index a6572dacb..e0ed4b536 100644 --- a/examples/PySDM_examples/Arabas_and_Shima_2017/simulation.py +++ b/examples/PySDM_examples/Arabas_and_Shima_2017/simulation.py @@ -19,7 +19,14 @@ def __init__(self, settings, backend=CPU): self.n_substeps += 1 builder = Builder( - backend=backend(formulae=settings.formulae), + backend=backend( + formulae=settings.formulae, + **( + {"override_jit_flags": {"parallel": False}} + if backend == CPU + else {} + ) + ), n_sd=1, environment=Parcel( dt=dt_output / self.n_substeps, diff --git a/examples/PySDM_examples/Berry_1967/example_fig_6.py b/examples/PySDM_examples/Berry_1967/example_fig_6.py index b7ed85d40..0d9bf2673 100644 --- a/examples/PySDM_examples/Berry_1967/example_fig_6.py +++ b/examples/PySDM_examples/Berry_1967/example_fig_6.py @@ -5,7 +5,7 @@ from PySDM.backends import CPU from PySDM.physics import constants as const -backend = CPU +backend = CPU() um = const.si.um @@ -28,7 +28,7 @@ def print_collection_efficiency_portrait(params): pair[0] = r for i, _ in enumerate(x_values): pair[1] = x_values[i] * r - backend.linear_collection_efficiency_body( + backend._linear_collection_efficiency_body( params=full_params(params), output=Y_c[i : i + 1], radii=pair, @@ -71,7 +71,7 @@ def Y_c_portrait( pair[0] = radii[i] for j, __ in enumerate(p): pair[1] = p[j] * radii[i] - backend.linear_collection_efficiency_body( + backend._linear_collection_efficiency_body( params=full_params(params), output=Y_c[i : i + 1, j], radii=pair, diff --git a/examples/PySDM_examples/Jensen_and_Nugent_2017/simulation.py b/examples/PySDM_examples/Jensen_and_Nugent_2017/simulation.py index 92e1c0d7b..5ad0c9f22 100644 --- a/examples/PySDM_examples/Jensen_and_Nugent_2017/simulation.py +++ b/examples/PySDM_examples/Jensen_and_Nugent_2017/simulation.py @@ -48,7 +48,9 @@ def __init__( builder = Builder( n_sd=N_SD_NON_GCCN + n_gccn, - backend=CPU(formulae=settings.formulae), + backend=CPU( + formulae=settings.formulae, override_jit_flags={"parallel": False} + ), environment=env, ) diff --git a/examples/PySDM_examples/Kreidenweis_et_al_2003/simulation.py b/examples/PySDM_examples/Kreidenweis_et_al_2003/simulation.py index e855b17de..9ef5f6916 100644 --- a/examples/PySDM_examples/Kreidenweis_et_al_2003/simulation.py +++ b/examples/PySDM_examples/Kreidenweis_et_al_2003/simulation.py @@ -22,7 +22,11 @@ def __init__(self, settings, products=None): ) builder = Builder( - n_sd=settings.n_sd, backend=CPU(formulae=settings.formulae), environment=env + n_sd=settings.n_sd, + backend=CPU( + formulae=settings.formulae, override_jit_flags={"parallel": False} + ), + environment=env, ) attributes = env.init_attributes( diff --git a/examples/PySDM_examples/Lowe_et_al_2019/simulation.py b/examples/PySDM_examples/Lowe_et_al_2019/simulation.py index ccd255eab..c0ce346e9 100644 --- a/examples/PySDM_examples/Lowe_et_al_2019/simulation.py +++ b/examples/PySDM_examples/Lowe_et_al_2019/simulation.py @@ -22,7 +22,11 @@ def __init__(self, settings, products=None): ) n_sd = settings.n_sd_per_mode * len(settings.aerosol.modes) builder = Builder( - n_sd=n_sd, backend=CPU(formulae=settings.formulae), environment=env + n_sd=n_sd, + backend=CPU( + formulae=settings.formulae, override_jit_flags={"parallel": False} + ), + environment=env, ) attributes = { diff --git a/examples/PySDM_examples/Yang_et_al_2018/simulation.py b/examples/PySDM_examples/Yang_et_al_2018/simulation.py index 4fbacd261..6d5ca09b6 100644 --- a/examples/PySDM_examples/Yang_et_al_2018/simulation.py +++ b/examples/PySDM_examples/Yang_et_al_2018/simulation.py @@ -40,7 +40,11 @@ def __init__(self, settings, backend=CPU): z0=settings.z0, ) builder = Builder( - backend=backend(formulae=self.formulae), n_sd=settings.n_sd, environment=env + backend=backend( + formulae=self.formulae, override_jit_flags={"parallel": False} + ), + n_sd=settings.n_sd, + environment=env, ) environment = builder.particulator.environment diff --git a/tests/examples_tests/conftest.py b/tests/examples_tests/conftest.py index 9be8c2ef2..4268485e5 100644 --- a/tests/examples_tests/conftest.py +++ b/tests/examples_tests/conftest.py @@ -47,16 +47,16 @@ def findfiles(path, regex): "coagulation": ["Berry_1967", "Shima_et_al_2009"], "breakup": ["Bieli_et_al_2022", "deJong_Mackay_et_al_2023", "Srivastava_1982"], "multi-process_a": [ - "deJong_Azimi", "Arabas_et_al_2015", - "Bartman_2020_MasterThesis", + "Arabas_et_al_2023", + "deJong_Azimi", "Bulenok_2023_MasterThesis", + "Shipway_and_Hill_2012", ], "multi-process_b": [ - "Arabas_et_al_2023", + "Bartman_2020_MasterThesis", "Bartman_et_al_2021", "Morrison_and_Grabowski_2007", - "Shipway_and_Hill_2012", "Szumowski_et_al_1998", "utils", ], diff --git a/tests/unit_tests/attributes/test_area_radius.py b/tests/unit_tests/attributes/test_area_radius.py index c54cdbd1b..0c6d4f329 100644 --- a/tests/unit_tests/attributes/test_area_radius.py +++ b/tests/unit_tests/attributes/test_area_radius.py @@ -7,10 +7,10 @@ @pytest.mark.parametrize("volume", (np.asarray([44, 666]),)) -def test_radius(volume, backend_class): +def test_radius(volume, backend_instance): # arrange env = Box(dt=None, dv=None) - builder = Builder(backend=backend_class(), n_sd=volume.size, environment=env) + builder = Builder(backend=backend_instance, n_sd=volume.size, environment=env) builder.request_attribute("radius") particulator = builder.build( attributes={"volume": volume, "multiplicity": np.ones_like(volume)} @@ -25,10 +25,10 @@ def test_radius(volume, backend_class): @pytest.mark.parametrize("volume", (np.asarray([44, 666]),)) -def test_sqrt_radius(volume, backend_class): +def test_sqrt_radius(volume, backend_instance): # arrange env = Box(dt=None, dv=None) - builder = Builder(backend=backend_class(), n_sd=volume.size, environment=env) + builder = Builder(backend=backend_instance, n_sd=volume.size, environment=env) builder.request_attribute("radius") builder.request_attribute("square root of radius") particulator = builder.build( @@ -45,10 +45,10 @@ def test_sqrt_radius(volume, backend_class): @pytest.mark.parametrize("volume", (np.asarray([44, 666]),)) -def test_area(volume, backend_class): +def test_area(volume, backend_instance): # arrange env = Box(dv=None, dt=None) - builder = Builder(backend=backend_class(), n_sd=volume.size, environment=env) + builder = Builder(backend=backend_instance, n_sd=volume.size, environment=env) builder.request_attribute("area") particulator = builder.build( attributes={"volume": volume, "multiplicity": np.ones_like(volume)} diff --git a/tests/unit_tests/attributes/test_fall_velocity.py b/tests/unit_tests/attributes/test_fall_velocity.py index 44e4f4561..dd5bc9824 100644 --- a/tests/unit_tests/attributes/test_fall_velocity.py +++ b/tests/unit_tests/attributes/test_fall_velocity.py @@ -53,14 +53,14 @@ def default_attributes_fixture(request): return request.param -def test_fall_velocity_calculation(default_attributes, backend_class): +def test_fall_velocity_calculation(default_attributes, backend_instance): """ Test that fall velocity is the momentum divided by the mass. """ env = Box(dt=1, dv=1) builder = Builder( n_sd=len(default_attributes["multiplicity"]), - backend=backend_class(), + backend=backend_instance, environment=env, ) @@ -79,14 +79,14 @@ def test_fall_velocity_calculation(default_attributes, backend_class): ) -def test_conservation_of_momentum(default_attributes, backend_class): +def test_conservation_of_momentum(default_attributes, backend_instance): """ Test that conservation of momentum holds when many super-droplets coalesce """ env = Box(dt=1, dv=1) builder = Builder( n_sd=len(default_attributes["multiplicity"]), - backend=backend_class(), + backend=backend_instance, environment=env, ) @@ -122,14 +122,14 @@ def test_conservation_of_momentum(default_attributes, backend_class): assert np.isclose(total_final_momentum, total_initial_momentum) -def test_attribute_selection(backend_class): +def test_attribute_selection(backend_instance): """ Test that the correct velocity attribute is selected by the mapper. `PySDM.attributes.physics.relative_fall_velocity.RelativeFallVelocity` should only be selected when `PySDM.dynamics.RelaxedVelocity` dynamic exists. """ env = Box(dt=1, dv=1) - builder_no_relax = Builder(n_sd=1, backend=backend_class(), environment=env) + builder_no_relax = Builder(n_sd=1, backend=backend_instance, environment=env) builder_no_relax.request_attribute("relative fall velocity") # with no RelaxedVelocity, the builder should use TerminalVelocity @@ -137,7 +137,7 @@ def test_attribute_selection(backend_class): builder_no_relax.req_attr["relative fall velocity"], TerminalVelocity ) env = Box(dt=1, dv=1) - builder = Builder(n_sd=1, backend=backend_class(), environment=env) + builder = Builder(n_sd=1, backend=backend_instance, environment=env) builder.add_dynamic(RelaxedVelocity()) builder.request_attribute("relative fall velocity") @@ -146,6 +146,6 @@ def test_attribute_selection(backend_class): # requesting momentum with no dynamic issues a warning env = Box(dt=1, dv=1) - builder = Builder(n_sd=1, backend=backend_class(), environment=env) + builder = Builder(n_sd=1, backend=backend_instance, environment=env) with pytest.warns(UserWarning): builder.request_attribute("relative fall momentum") diff --git a/tests/unit_tests/attributes/test_isotopes.py b/tests/unit_tests/attributes/test_isotopes.py index 836968ca3..1f39945d2 100644 --- a/tests/unit_tests/attributes/test_isotopes.py +++ b/tests/unit_tests/attributes/test_isotopes.py @@ -22,12 +22,12 @@ def dummy_attrs(length): class TestIsotopes: @staticmethod @pytest.mark.parametrize("isotope", HEAVY_ISOTOPES) - def test_heavy_isotope_moles_attributes(backend_class, isotope): + def test_heavy_isotope_moles_attributes(backend_instance, isotope): # arrange values = [1, 2, 3] builder = Builder( n_sd=len(values), - backend=backend_class(), + backend=backend_instance, environment=Box(dt=np.nan, dv=np.nan), ) particulator = builder.build( diff --git a/tests/unit_tests/attributes/test_multiplicities.py b/tests/unit_tests/attributes/test_multiplicities.py index 43b3d875e..123383d52 100644 --- a/tests/unit_tests/attributes/test_multiplicities.py +++ b/tests/unit_tests/attributes/test_multiplicities.py @@ -27,11 +27,11 @@ def test_max_multiplicity_value(): ), ), ) - def test_max_multiplicity_assignable(backend_class, value): + def test_max_multiplicity_assignable(backend_instance, value): # arrange n_sd = 1 env = Box(dt=np.nan, dv=np.nan) - builder = Builder(n_sd=n_sd, backend=backend_class(), environment=env) + builder = Builder(n_sd=n_sd, backend=backend_instance, environment=env) # act particulator = builder.build( diff --git a/tests/unit_tests/backends/storage/test_basic_ops.py b/tests/unit_tests/backends/storage/test_basic_ops.py index 5af09f25c..ecd4b9e87 100644 --- a/tests/unit_tests/backends/storage/test_basic_ops.py +++ b/tests/unit_tests/backends/storage/test_basic_ops.py @@ -12,9 +12,9 @@ class TestBasicOps: ([1.0], [2], [3.0]), ], ) - def test_addition(backend_class, output, addend, expected): + def test_addition(backend_instance, output, addend, expected): # Arrange - backend = backend_class() + backend = backend_instance output = backend.Storage.from_ndarray(np.asarray(output)) if hasattr(addend, "__len__"): addend = backend.Storage.from_ndarray(np.asarray(addend)) diff --git a/tests/unit_tests/backends/storage/test_index.py b/tests/unit_tests/backends/storage/test_index.py index 2c89a053e..9e345a235 100644 --- a/tests/unit_tests/backends/storage/test_index.py +++ b/tests/unit_tests/backends/storage/test_index.py @@ -7,15 +7,14 @@ class TestIndex: # pylint: disable=too-few-public-methods @staticmethod - def test_remove_zero_n_or_flagged(backend_class): + def test_remove_zero_n_or_flagged(backend_instance): # Arrange - backend = backend_class() n_sd = 44 - idx = make_Index(backend).identity_index(n_sd) + idx = make_Index(backend_instance).identity_index(n_sd) data = np.ones(n_sd).astype(np.int64) data[0], data[n_sd // 2], data[-1] = 0, 0, 0 - data = backend.Storage.from_ndarray(data) - data = make_IndexedStorage(backend).indexed(storage=data, idx=idx) + data = backend_instance.Storage.from_ndarray(data) + data = make_IndexedStorage(backend_instance).indexed(storage=data, idx=idx) # Act idx.remove_zero_n_or_flagged(data) @@ -23,5 +22,5 @@ def test_remove_zero_n_or_flagged(backend_class): # Assert assert len(idx) == n_sd - 3 assert ( - backend.Storage.to_ndarray(data)[idx.to_ndarray()[: len(idx)]] > 0 + backend_instance.Storage.to_ndarray(data)[idx.to_ndarray()[: len(idx)]] > 0 ).all() diff --git a/tests/unit_tests/backends/test_collisions_methods.py b/tests/unit_tests/backends/test_collisions_methods.py index 7dca6e55f..7d54db948 100644 --- a/tests/unit_tests/backends/test_collisions_methods.py +++ b/tests/unit_tests/backends/test_collisions_methods.py @@ -67,9 +67,9 @@ def test_pair_indices(i, idx, is_first_in_pair, gamma, expected): ((4, 5, 4.5, 3, 0.1), (0, 1, 2, 3, 4, 5), 5), ), ) - def test_adaptive_sdm_end(backend_class, dt_left, cell_start, expected): + def test_adaptive_sdm_end(backend_instance, dt_left, cell_start, expected): # Arrange - backend = backend_class() + backend = backend_instance dt_left = backend.Storage.from_ndarray(np.asarray(dt_left)) cell_start = backend.Storage.from_ndarray(np.asarray(cell_start)) @@ -150,7 +150,7 @@ def test_adaptive_sdm_end(backend_class, dt_left, cell_start, expected): # pylint: disable=too-many-locals def test_scale_prob_for_adaptive_sdm_gamma( *, - backend_class, + backend_instance, gamma, idx, n, @@ -163,7 +163,7 @@ def test_scale_prob_for_adaptive_sdm_gamma( expected_n_substep, ): # Arrange - backend = backend_class() + backend = backend_instance _gamma = backend.Storage.from_ndarray(np.asarray(gamma)) _idx = make_Index(backend).from_ndarray(np.asarray(idx)) _n = make_IndexedStorage(backend).from_ndarray(_idx, np.asarray(n)) diff --git a/tests/unit_tests/backends/test_isotope_methods.py b/tests/unit_tests/backends/test_isotope_methods.py index f84ce3b06..05643be66 100644 --- a/tests/unit_tests/backends/test_isotope_methods.py +++ b/tests/unit_tests/backends/test_isotope_methods.py @@ -7,17 +7,17 @@ class TestIsotopeMethods: @staticmethod - def test_isotopic_fractionation(backend_class): + def test_isotopic_fractionation(backend_instance): # arrange - backend = backend_class() + backend = backend_instance # act backend.isotopic_fractionation() @staticmethod - def test_isotopic_delta(backend_class): + def test_isotopic_delta(backend_instance): # arrange - backend = backend_class() + backend = backend_instance arr2storage = backend.Storage.from_ndarray n_sd = 10 output = arr2storage(np.empty(n_sd)) diff --git a/tests/unit_tests/backends/test_pair_methods.py b/tests/unit_tests/backends/test_pair_methods.py index f43c82854..226f6af85 100644 --- a/tests/unit_tests/backends/test_pair_methods.py +++ b/tests/unit_tests/backends/test_pair_methods.py @@ -50,9 +50,9 @@ def test_sum_pair_body_out_of_bounds( idx = backend.Storage.from_ndarray(np.asarray(_idx)) sut = ( - backend.sum_pair_body + backend._sum_pair_body if "NUMBA_DISABLE_JIT" in os.environ - else backend.sum_pair_body.py_func + else backend._sum_pair_body.py_func ) # Act sut( @@ -78,9 +78,9 @@ def test_sum_pair_body_out_of_bounds( ), ), ) - def test_sum_pair(_data_in, _data_out, _idx, backend_class): + def test_sum_pair(_data_in, _data_out, _idx, backend_instance): # Arrange - backend = backend_class() + backend = backend_instance data_out = backend.Storage.from_ndarray(np.asarray(_data_out)) data_in = backend.Storage.from_ndarray(np.asarray(_data_in)) @@ -102,9 +102,9 @@ def test_sum_pair(_data_in, _data_out, _idx, backend_class): @staticmethod @pytest.mark.parametrize("length", (1, 2, 3, 4)) - def test_find_pairs_length(backend_class, length): + def test_find_pairs_length(backend_instance, length): # arrange - backend = backend_class() + backend = backend_instance n_sd = 4 cell_start = backend.Storage.from_ndarray(np.asarray([0, 0, 0, 0])) diff --git a/tests/unit_tests/backends/test_physics_methods.py b/tests/unit_tests/backends/test_physics_methods.py index bc6f69b4c..b1a546aa9 100644 --- a/tests/unit_tests/backends/test_physics_methods.py +++ b/tests/unit_tests/backends/test_physics_methods.py @@ -8,10 +8,10 @@ class TestPhysicsMethods: @staticmethod - def test_temperature_pressure_RH(backend_class): + def test_temperature_pressure_rh(backend_instance): # Arrange - backend = backend_class() - sut = backend.temperature_pressure_RH + backend = backend_instance + sut = backend.temperature_pressure_rh rhod = backend.Storage.from_ndarray(np.asarray((1, 1.1))) thd = backend.Storage.from_ndarray(np.asarray((300.0, 301))) water_vapour_mixing_ratio = backend.Storage.from_ndarray( diff --git a/tests/unit_tests/conftest.py b/tests/unit_tests/conftest.py index 3983e2127..523474ada 100644 --- a/tests/unit_tests/conftest.py +++ b/tests/unit_tests/conftest.py @@ -7,3 +7,8 @@ @pytest.fixture(params=(CPU, GPU)) def backend_class(request): return request.param + + +@pytest.fixture(params=(CPU(), GPU()), scope="session") +def backend_instance(request): + return request.param diff --git a/tests/unit_tests/dynamics/collisions/test_sdm_breakup.py b/tests/unit_tests/dynamics/collisions/test_sdm_breakup.py index 303717f3a..e37408cc7 100644 --- a/tests/unit_tests/dynamics/collisions/test_sdm_breakup.py +++ b/tests/unit_tests/dynamics/collisions/test_sdm_breakup.py @@ -94,9 +94,9 @@ def test_nonadaptive_same_results_regardless_of_dt(dt, backend_class): ), ], ) - def test_single_collision_bounce(params, backend_class): + def test_single_collision_bounce(params, backend_instance): # Arrange - backend = backend_class() + backend = backend_instance n_sd = 2 env = Box(dv=np.NaN, dt=np.NaN) builder = Builder(n_sd, backend, environment=env) @@ -172,12 +172,14 @@ def test_single_collision_bounce(params, backend_class): }, ], ) - def test_breakup_counters(params, backend_class): # pylint: disable=too-many-locals + def test_breakup_counters( + params, backend_instance + ): # pylint: disable=too-many-locals # Arrange n_init = params["n_init"] n_sd = len(n_init) env = Box(dv=np.NaN, dt=np.NaN) - builder = Builder(n_sd, backend_class(), environment=env) + builder = Builder(n_sd, backend_instance, environment=env) particulator = builder.build( attributes={ "multiplicity": np.asarray(n_init), @@ -706,13 +708,13 @@ def test_same_multiplicity_overflow_no_substeps( ) @pytest.mark.parametrize("flag", ("multiplicity", "v", "conserve", "deficit")) def test_noninteger_fragments( - params, flag, backend_class + params, flag, backend_instance ): # pylint: disable=too-many-locals # Arrange n_init = params["n_init"] n_sd = len(n_init) env = Box(dv=np.NaN, dt=np.NaN) - builder = Builder(n_sd, backend_class(), environment=env) + builder = Builder(n_sd, backend_instance, environment=env) particulator = builder.build( attributes={ "multiplicity": np.asarray(n_init), diff --git a/tests/unit_tests/dynamics/collisions/test_sdm_single_cell.py b/tests/unit_tests/dynamics/collisions/test_sdm_single_cell.py index a17c29203..6b9ed7d67 100644 --- a/tests/unit_tests/dynamics/collisions/test_sdm_single_cell.py +++ b/tests/unit_tests/dynamics/collisions/test_sdm_single_cell.py @@ -213,9 +213,9 @@ def test_multi_step(backend_class): np.testing.assert_approx_equal(actual=actual, desired=desired, significant=8) @staticmethod - def test_compute_gamma(backend_class): + def test_compute_gamma(backend_instance): # Arrange - backend = backend_class() + backend = backend_instance n = 87 prob = np.linspace(0, 3, n, endpoint=True) rand = np.linspace(0, 1, n, endpoint=False) @@ -268,7 +268,7 @@ def expected(p, r): ) def test_rnd_reuse(backend_class, optimized_random, adaptive): if backend_class is ThrustRTC: - return # TODO #330 + pytest.skip("# TODO #330") # Arrange n_sd = 256 @@ -277,7 +277,10 @@ def test_rnd_reuse(backend_class, optimized_random, adaptive): n_substeps = 5 particles, sut = get_dummy_particulator_and_coalescence( - backend_class, n_sd, optimized_random=optimized_random, substeps=n_substeps + backend_class, + n_sd, + optimized_random=optimized_random, + substeps=n_substeps, ) attributes = {"multiplicity": n, "volume": v} particles.build(attributes) diff --git a/tests/unit_tests/dynamics/test_isotopic_fractionation.py b/tests/unit_tests/dynamics/test_isotopic_fractionation.py index fa1b2df9f..b6db5844f 100644 --- a/tests/unit_tests/dynamics/test_isotopic_fractionation.py +++ b/tests/unit_tests/dynamics/test_isotopic_fractionation.py @@ -30,10 +30,10 @@ ), ), ) -def test_ensure_condensation_executed_before(backend_class, dynamics, context): +def test_ensure_condensation_executed_before(backend_instance, dynamics, context): # arrange builder = Builder( - n_sd=1, backend=backend_class(), environment=Box(dv=np.nan, dt=1 * si.s) + n_sd=1, backend=backend_instance, environment=Box(dv=np.nan, dt=1 * si.s) ) for dynamic in dynamics: builder.add_dynamic(dynamic) diff --git a/tests/unit_tests/dynamics/test_relaxed_velocity.py b/tests/unit_tests/dynamics/test_relaxed_velocity.py index f34057661..09a5f9af9 100644 --- a/tests/unit_tests/dynamics/test_relaxed_velocity.py +++ b/tests/unit_tests/dynamics/test_relaxed_velocity.py @@ -42,7 +42,7 @@ def constant_timescale_fixture(request): return request.param -def test_small_timescale(default_attributes, constant_timescale, backend_class): +def test_small_timescale(default_attributes, constant_timescale, backend_instance): """ When the fall velocity is initialized to 0 and relaxation is very quick, the velocity should quickly approach the terminal velocity @@ -51,7 +51,7 @@ def test_small_timescale(default_attributes, constant_timescale, backend_class): env = Box(dt=1, dv=1) builder = Builder( n_sd=len(default_attributes["multiplicity"]), - backend=backend_class(), + backend=backend_instance, environment=env, ) @@ -74,7 +74,7 @@ def test_small_timescale(default_attributes, constant_timescale, backend_class): ) -def test_large_timescale(default_attributes, constant_timescale, backend_class): +def test_large_timescale(default_attributes, constant_timescale, backend_instance): """ When the fall velocity is initialized to 0 and relaxation is very slow, the velocity should remain 0 @@ -83,7 +83,7 @@ def test_large_timescale(default_attributes, constant_timescale, backend_class): env = Box(dt=1, dv=1) builder = Builder( n_sd=len(default_attributes["multiplicity"]), - backend=backend_class(), + backend=backend_instance, environment=env, ) @@ -106,7 +106,7 @@ def test_large_timescale(default_attributes, constant_timescale, backend_class): ) -def test_behavior(default_attributes, constant_timescale, backend_class): +def test_behavior(default_attributes, constant_timescale, backend_instance): """ The fall velocity should approach the terminal velocity exponentially """ @@ -114,7 +114,7 @@ def test_behavior(default_attributes, constant_timescale, backend_class): env = Box(dt=1, dv=1) builder = Builder( n_sd=len(default_attributes["multiplicity"]), - backend=backend_class(), + backend=backend_instance, environment=env, ) @@ -153,7 +153,7 @@ def test_behavior(default_attributes, constant_timescale, backend_class): @pytest.mark.parametrize("c", [0.1, 10]) -def test_timescale(default_attributes, c, constant_timescale, backend_class): +def test_timescale(default_attributes, c, constant_timescale, backend_instance): """ The non-constant timescale should be proportional to the sqrt of the radius. The proportionality constant should be the parameter for the dynamic. @@ -163,7 +163,7 @@ def test_timescale(default_attributes, c, constant_timescale, backend_class): env = Box(dt=1, dv=1) builder = Builder( n_sd=len(default_attributes["multiplicity"]), - backend=backend_class(), + backend=backend_instance, environment=env, ) diff --git a/tests/unit_tests/initialisation/test_init_fall_momenta.py b/tests/unit_tests/initialisation/test_init_fall_momenta.py index 43d5e2cbc..488714367 100644 --- a/tests/unit_tests/initialisation/test_init_fall_momenta.py +++ b/tests/unit_tests/initialisation/test_init_fall_momenta.py @@ -36,13 +36,13 @@ def params_fixture(request): return request.param -def test_init_to_terminal_velocity(params, backend_class): +def test_init_to_terminal_velocity(params, backend_instance): """ Fall momenta correctly initialized to the terminal velocity * mass. """ env = Box(dt=1, dv=1) builder = Builder( - n_sd=len(params["multiplicity"]), backend=backend_class(), environment=env + n_sd=len(params["multiplicity"]), backend=backend_instance, environment=env ) builder.request_attribute("terminal velocity") particulator = builder.build( diff --git a/tests/unit_tests/initialisation/test_spectral_discretisation.py b/tests/unit_tests/initialisation/test_spectral_discretisation.py index 28014a2c0..009ffe2f5 100644 --- a/tests/unit_tests/initialisation/test_spectral_discretisation.py +++ b/tests/unit_tests/initialisation/test_spectral_discretisation.py @@ -24,10 +24,10 @@ pytest.param(spectral_sampling.UniformRandom(spectrum, m_range)), ), ) -def test_spectral_discretisation(discretisation, backend_class): +def test_spectral_discretisation(discretisation, backend_instance): # Arrange n_sd = 100000 - backend = backend_class() + backend = backend_instance # Act m, n = discretisation.sample(n_sd, backend=backend) diff --git a/tests/unit_tests/initialisation/test_spectro_glacial_discretisation.py b/tests/unit_tests/initialisation/test_spectro_glacial_discretisation.py index 839642054..00a72cc85 100644 --- a/tests/unit_tests/initialisation/test_spectro_glacial_discretisation.py +++ b/tests/unit_tests/initialisation/test_spectro_glacial_discretisation.py @@ -33,10 +33,10 @@ ), ), ) -def test_spectral_discretisation(discretisation, backend_class): +def test_spectral_discretisation(discretisation, backend_instance): # Arrange n_sd = 100000 - backend = backend_class() + backend = backend_instance # Act freezing_temperatures, immersed_surfaces, n = discretisation.sample( diff --git a/tests/unit_tests/products/test_collision_rates.py b/tests/unit_tests/products/test_collision_rates.py index 0841ade18..1d559ba37 100644 --- a/tests/unit_tests/products/test_collision_rates.py +++ b/tests/unit_tests/products/test_collision_rates.py @@ -69,17 +69,19 @@ class TestCollisionProducts: }, ], ) - def test_individual_dynamics_rates_nonadaptive(params, backend_class): - # TODO #744 - if backend_class.__name__ == "ThrustRTC" and params["enable_breakup"]: - return + def test_individual_dynamics_rates_nonadaptive(params, backend_instance): + if ( + backend_instance.__class__.__name__ == "ThrustRTC" + and params["enable_breakup"] + ): + pytest.skip("# TODO #744") # Arrange n_init = [5, 2] n_sd = len(n_init) env = Box(**ENV_ARGS) - builder = Builder(n_sd, backend_class(), environment=env) + builder = Builder(n_sd, backend_instance, environment=env) dynamic, products = _get_dynamics_and_products(params, adaptive=False) builder.add_dynamic(dynamic) diff --git a/tests/unit_tests/products/test_concentration_product.py b/tests/unit_tests/products/test_concentration_product.py index 1d8b56060..4356536ec 100644 --- a/tests/unit_tests/products/test_concentration_product.py +++ b/tests/unit_tests/products/test_concentration_product.py @@ -23,9 +23,9 @@ class TestParticleConcentration: @staticmethod @pytest.mark.parametrize("stp", (True, False)) - def test_stp(backend_class, stp): + def test_stp(backend_instance, stp): # arrange - builder = Builder(n_sd=N_SD, backend=backend_class(), environment=ENV) + builder = Builder(n_sd=N_SD, backend=backend_instance, environment=ENV) particulator = builder.build( attributes=ATTRIBUTES, products=(TotalParticleConcentration(stp=stp),) ) @@ -46,9 +46,9 @@ def test_stp(backend_class, stp): @staticmethod @pytest.mark.parametrize("specific", (True, False)) - def test_specific(backend_class, specific): + def test_specific(backend_instance, specific): # arrange - builder = Builder(n_sd=N_SD, backend=backend_class(), environment=ENV) + builder = Builder(n_sd=N_SD, backend=backend_instance, environment=ENV) particulator = builder.build( attributes=ATTRIBUTES, products=(ParticleConcentration(specific=specific),) )