From b6e808cfa0a61f8b60daa2dfd2fdd9cbde567f6d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?J=C3=A9rome=20Eertmans?= Date: Fri, 10 Jan 2025 18:20:59 +0100 Subject: [PATCH] feat(lib): implement Fresnel coefficients (#103) * feat(lib): implement Fresnel and UTD coefficients * chore(docs): add some docstring * fix(docs): typo * feat(lib): better Fresnel integrals impl. * fix(docs): rst syntax * fix(docs): escaping '+' ? * fix(docs): oops * feat(lib): implement Fresnel coef. * fix(docs): no :python: role :-( * chore(tests): added more tests for transition function F * chore(deps): bump jax to 0.4.32 for `jax.scipy.special.fresnel` See https://github.com/google/jax/pull/22843 * chore(lib): remove `differt.em.special` module * chore(lib): do not enforce complex dtype * wip * chore(lib): add antenna module * wip: radiation pattern * chore(ci): antenna dipole * wip: antenna stuff * wip testing antennas * fix(docs): typo * fixes * oops * fix(docs): typo * fmt and fixes * chore(fmt): imports * fix(lib): use `dict.setdefault` to allow overwriting shading * wip: docs * wip: example * wip: docs * fix(docs): bib ref. * chore(docs): more example to function's docstring * feat(lib): add utility to get sp-components * chore(docs): re-order text * chore(tests): add integration tests * fix(deps): add ref to deps * chore(tests): remove debugging NaNs * fix(lib): `t` was not correctly typed * chore(tests): mark integration tests as slow * fix(docs): typo in docstring and more tests * wip: sp directions * fix(lib): remove `normalize` arg * fix(docs): tmp * fix(docs): fresnel example * feat(lib): adding ITU materials * wip: docs * fix(tests): explicit dtype * fix: remove file and update motivation * fix(docs): typo * fix(docs): quotes and alignment * chore(docs): up * chore(lib): postpone UTD for a feature PR * chore(tests): add s and p rotation matrix (wip) * wip: testing sp-rotation matrix * wip: transition matrices * chore(docs): finally include coverage map example * chore(docs): cleanup --- .github/workflows/test.yml | 2 +- differt-core/pyproject.toml | 2 +- differt/src/differt/conftest.py | 5 + differt/src/differt/em/__init__.py | 57 +- differt/src/differt/em/_antenna.py | 662 + differt/src/differt/em/_constants.py | 11 +- differt/src/differt/em/_fresnel.py | 450 + differt/src/differt/em/_interaction_type.py | 13 + differt/src/differt/em/_material.py | 282 + differt/src/differt/em/_special.py | 227 - differt/src/differt/em/_utd.py | 283 +- differt/src/differt/em/_utils.py | 268 +- differt/src/differt/geometry/__init__.py | 2 + differt/src/differt/geometry/_paths.py | 64 +- .../src/differt/geometry/_triangle_mesh.py | 190 +- differt/src/differt/geometry/_utils.py | 84 +- differt/src/differt/plotting/_core.py | 10 +- differt/src/differt/rt/_image_method.py | 2 +- differt/src/differt/rt/_utils.py | 4 +- differt/src/differt/utils.py | 29 + differt/tests/benchmarks/test_rt.py | 3 - differt/tests/em/test_antenna.py | 133 + differt/tests/em/test_constants.py | 21 +- differt/tests/em/test_fresnel.py | 93 + differt/tests/em/test_interaction_type.py | 45 + differt/tests/em/test_material.py | 171 + differt/tests/em/test_special.py | 70 - differt/tests/em/test_utd.py | 86 +- differt/tests/em/test_utils.py | 73 +- differt/tests/geometry/test_paths.py | 47 + differt/tests/geometry/test_triangle_mesh.py | 11 + differt/tests/geometry/test_utils.py | 13 + differt/tests/rt/test_utils.py | 4 +- differt/tests/scene/test_triangle_scene.py | 31 +- differt/tests/test_integration.py | 211 + differt/tests/test_utils.py | 26 +- docs/source/_templates/autosummary/base.rst | 9 + docs/source/_templates/autosummary/class.rst | 11 +- docs/source/api_reference.rst | 8 + docs/source/conf.py | 15 + docs/source/index.rst | 1 + docs/source/motivations.md | 44 + docs/source/notebooks/multipath.ipynb | 40771 ++++++++-------- docs/source/notebooks/sampling_paths.ipynb | 1 + docs/source/reference/differt.em.rst | 105 +- docs/source/reference/differt.geometry.rst | 1 + docs/source/reference/differt.utils.rst | 1 + docs/source/references.bib | 85 +- pyproject.toml | 11 +- uv.lock | 777 +- 50 files changed, 24689 insertions(+), 20836 deletions(-) create mode 100644 differt/src/differt/em/_antenna.py create mode 100644 differt/src/differt/em/_fresnel.py create mode 100644 differt/src/differt/em/_interaction_type.py create mode 100644 differt/src/differt/em/_material.py delete mode 100644 differt/src/differt/em/_special.py create mode 100644 differt/tests/em/test_antenna.py create mode 100644 differt/tests/em/test_fresnel.py create mode 100644 differt/tests/em/test_interaction_type.py create mode 100644 differt/tests/em/test_material.py delete mode 100644 differt/tests/em/test_special.py create mode 100644 differt/tests/test_integration.py create mode 100644 docs/source/_templates/autosummary/base.rst create mode 100644 docs/source/motivations.md diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index d1bd29af..0ef44683 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -44,7 +44,7 @@ jobs: uses: dtolnay/rust-toolchain@stable - name: Run tests - run: uv run --python ${{ matrix.pyversion }} --frozen --no-dev --extra tests pytest + run: uv run --python ${{ matrix.pyversion }} --frozen --no-dev --extra tests-extended pytest - name: Upload to codecov.io uses: codecov/codecov-action@v5 diff --git a/differt-core/pyproject.toml b/differt-core/pyproject.toml index 5f82d35a..207daf8d 100644 --- a/differt-core/pyproject.toml +++ b/differt-core/pyproject.toml @@ -1,6 +1,6 @@ [build-system] build-backend = "maturin" -requires = ["maturin>=1.3,<2.0"] +requires = ["maturin>=1.6,<2"] [project] authors = [ diff --git a/differt/src/differt/conftest.py b/differt/src/differt/conftest.py index c525fddb..99eb04a8 100644 --- a/differt/src/differt/conftest.py +++ b/differt/src/differt/conftest.py @@ -1,6 +1,7 @@ import sys from typing import Any +import chex import jax import jax.numpy as jnp import matplotlib.pyplot as plt @@ -14,6 +15,10 @@ collect_ignore_glob = ["*", "**/*"] +def pytest_configure() -> None: + chex.set_n_cpu_devices(8) + + @pytest.fixture(autouse=True) def add_doctest_modules(doctest_namespace: dict[str, Any]) -> None: doctest_namespace["go"] = go diff --git a/differt/src/differt/em/__init__.py b/differt/src/differt/em/__init__.py index f0872f48..33863275 100644 --- a/differt/src/differt/em/__init__.py +++ b/differt/src/differt/em/__init__.py @@ -1,18 +1,59 @@ -"""Electromagnetic fields utilities.""" +"""Electromagnetic (EM) fields utilities.""" __all__ = ( + "Antenna", + "BaseAntenna", + "Dipole", "F", + "HWDipolePattern", + "InteractionType", + "L_i", + "Material", + "RadiationPattern", + "ShortDipole", + "ShortDipolePattern", "c", + "diffraction_coefficients", "epsilon_0", - "erf", - "erfc", - "fresnel", + "fresnel_coefficients", "lengths_to_delays", + "materials", "mu_0", "path_delays", + "pointing_vector", + "reflection_coefficients", + "refraction_coefficients", + "refractive_indices", + "sp_directions", + "sp_rotation_matrix", + "transition_matrices", + "z_0", ) -from ._constants import c, epsilon_0, mu_0 -from ._special import erf, erfc, fresnel -from ._utd import F -from ._utils import lengths_to_delays, path_delays +from ._antenna import ( + Antenna, + BaseAntenna, + Dipole, + HWDipolePattern, + RadiationPattern, + ShortDipole, + ShortDipolePattern, + pointing_vector, +) +from ._constants import c, epsilon_0, mu_0, z_0 +from ._fresnel import ( + fresnel_coefficients, + reflection_coefficients, + refraction_coefficients, + refractive_indices, +) +from ._interaction_type import InteractionType +from ._material import Material, materials +from ._utd import F, L_i, diffraction_coefficients +from ._utils import ( + lengths_to_delays, + path_delays, + sp_directions, + sp_rotation_matrix, + transition_matrices, +) diff --git a/differt/src/differt/em/_antenna.py b/differt/src/differt/em/_antenna.py new file mode 100644 index 00000000..e69f4075 --- /dev/null +++ b/differt/src/differt/em/_antenna.py @@ -0,0 +1,662 @@ +from abc import abstractmethod +from dataclasses import KW_ONLY +from typing import Any + +import equinox as eqx +import jax +import jax.numpy as jnp +from beartype import beartype as typechecker +from jaxtyping import Array, ArrayLike, Float, Inexact, jaxtyped + +from differt.geometry import normalize +from differt.plotting import PlotOutput, draw_surface +from differt.utils import dot, safe_divide + +from ._constants import c, epsilon_0, mu_0 + + +@eqx.filter_jit +def pointing_vector( + e: Inexact[Array, "*#batch 3"], + b: Inexact[Array, "*#batch 3"], +) -> Inexact[Array, "*batch"]: + r""" + Compute the pointing vector in vacuum at from electric :math:`\vec{E}` and magnetic :math:`\vec{B}` fields. + + Args: + e: The electrical field. + b: The magnetical field. + + Returns: + The pointing vector :math:`\vec{S}`. + + It can be either real of complex-valued. + """ + h = b / mu_0 + + return jnp.cross(e, h) + + +@jaxtyped(typechecker=typechecker) +class BaseAntenna(eqx.Module): + """An antenna class, base class for :class:`Antenna` and :class:`RadiationPattern`.""" + + frequency: Float[Array, " "] = eqx.field(converter=jnp.asarray) + """The frequency :math:`f` at which the antenna is operating.""" + _: KW_ONLY + center: Float[Array, "3"] = eqx.field( + converter=jnp.asarray, default_factory=lambda: jnp.array([0.0, 0.0, 0.0]) + ) + """The center position of the antenna, from which the fields are radiated. + + Default value is the origin. + """ + + @property + @jaxtyped(typechecker=typechecker) + def period(self) -> Float[Array, " "]: + """The period :math:`T = 1/f`.""" + return 1 / self.frequency + + @property + @jaxtyped(typechecker=typechecker) + def angular_frequency(self) -> Float[Array, " "]: + r"""The angular frequency :math:`\omega = 2 \pi f`.""" + return 2 * jnp.pi * self.frequency + + @property + @jaxtyped(typechecker=typechecker) + def wavelength(self) -> Float[Array, " "]: + r"""The wavelength :math:`\lambda = c / f`.""" + return c * self.period + + @property + @jaxtyped(typechecker=typechecker) + def wavenumber(self) -> Float[Array, " "]: + r"""The wavenumber :math:`k = \omega / c`.""" + return self.angular_frequency / c + + +@jaxtyped(typechecker=typechecker) +class Antenna(BaseAntenna): + """An antenna class, must be subclassed.""" + + @property + @abstractmethod + def average_power(self) -> Float[Array, " "]: # TODO: provide default impl. + """The time-average power radiated by this antenna.""" + + @abstractmethod + def fields( + self, r: Float[Array, "*#batch 3"], t: Float[Array, "*#batch"] | None = None + ) -> tuple[Inexact[Array, "*batch 3"], Inexact[Array, "*batch 3"]]: + r""" + Compute electric and magnetic fields in vacuum at given position and (optional) time. + + Args: + r: The array of positions. + t: The array of time instants. + + If not provided, initial time instant + is assumed. + + Returns: + The electric :math:`\vec{E}` and magnetical :math:`\vec{B}` fields. + + Fields can be either real or complex-valued. + """ + + @eqx.filter_jit + @jaxtyped(typechecker=typechecker) + def pointing_vector( + self, + r: Float[Array, "*#batch 3"], + t: Float[Array, "*#batch"] | None = None, + ) -> Inexact[Array, "*batch 3"]: + r""" + Compute the pointing vector in vacuum at given position and (optional) time. + + Args: + r: The array of positions. + t: The array of time instants. + + If not provided, initial time instant + is assumed. + + Returns: + The pointing vector :math:`\vec{S}`. + + It can be either real of complex-valued. + """ + e, b = self.fields(r, t) + return pointing_vector(e, b) + + @jaxtyped(typechecker=typechecker) + def directivity( + self, + num_points: int = int(1e2), + ) -> tuple[ + Float[Array, " 2*{num_points}"], + Float[Array, " {num_points}"], + Float[Array, "2*{num_points} {num_points}"], + ]: + """ + Compute an estimate of the antenna directivity for azimutal and elevation angles. + + .. note:: + + Subclasses may provide a more accurate or exact + implementation. + + Args: + num_points: The number of points to sample along the elevation axis. + + Twice this number of points are sampled on the aximutal axis. + + Returns: + Azimutal and elevation angles, as well as corresponding directivity values. + + .. seealso:: + + :meth:`directive_gain` + """ + u, du = jnp.linspace(0, 2 * jnp.pi, num_points * 2, retstep=True) + v, dv = jnp.linspace(0, jnp.pi, num_points, retstep=True) + x = jnp.outer(jnp.cos(u), jnp.sin(v)) + y = jnp.outer(jnp.sin(u), jnp.sin(v)) + z = jnp.outer(jnp.ones_like(u), jnp.cos(v)) + + r = self.center + jnp.stack((x, y, z), axis=-1) + + s = self.pointing_vector(r) + + p = jnp.linalg.norm(s, axis=-1) + + ds = du * dv + + # Power per unit solid angle + U = p / ds # noqa: N806 + p_tot = jnp.sum(p * jnp.sin(v)) / (4 * jnp.pi) + + return u, v, U / p_tot + + @jaxtyped(typechecker=typechecker) + def directive_gain( + self, + num_points: int = int(1e2), + ) -> Float[Array, " "]: + """ + Compute an estimate of the antenna directive gain. + + .. note:: + + Subclasses may provide a more accurate or exact + implementation. + + Args: + num_points: The number of points used for the estimate. + + Returns: + The antenna directive gain. + + .. seealso:: + + :meth:`directivity` + """ + return self.directivity(num_points=num_points)[-1].max() + + def plot_radiation_pattern( + self, + num_points: int = int(1e2), + distance: Float[ArrayLike, " "] = 1.0, + num_wavelengths: Float[ArrayLike, " "] | None = None, + **kwargs: Any, + ) -> PlotOutput: + """ + Plot the radiation pattern (normalized power) of this antenna. + + The power is computed on points on an sphere around the antenna. + + Args: + num_points: The number of points to sample along the elevation axis. + + Twice this number of points are sampled on the aximutal axis. + distance: The distance from the antenna at which power samples + are evaluated. + num_wavelengths: If provided, supersedes ``distance`` by setting + the distance relatively to the :attr:`wavelength`. + kwargs: Keyword arguments passed to + :func:`draw_surface`. + + Returns: + The resulting plot output. + """ + if num_wavelengths is not None: + distance = jnp.asarray(num_wavelengths) * self.wavelength + else: + distance = jnp.asarray(distance) + + u = jnp.linspace(0, 2 * jnp.pi, num_points * 2) + v = jnp.linspace(0, jnp.pi, num_points) + x = jnp.outer(jnp.cos(u), jnp.sin(v)) + y = jnp.outer(jnp.sin(u), jnp.sin(v)) + z = jnp.outer(jnp.ones_like(u), jnp.cos(v)) + + r = self.center + distance * jnp.stack((x, y, z), axis=-1) + + s = self.pointing_vector(r) + + p = jnp.linalg.norm(s, axis=-1, keepdims=True) + + gain = p / p.max() + + r *= gain + gain = jnp.squeeze(gain, axis=-1) + + return draw_surface( + x=r[..., 0], y=r[..., 1], z=r[..., 2], colors=gain, **kwargs + ) + + +@jaxtyped(typechecker=typechecker) +class Dipole(Antenna): + r""" + A simple electrical (or Hertzian) dipole. + + Equations were obtained from :cite:`dipole,dipole-moment,dipole-antenna,directivity`, and assume + a constant current across the dipole length. + + Args: + frequency: The frequency at which the antenna is operating. + num_wavelengths: The length of the dipole, relative to the wavelength. + length: The absolute length of the dipole, supersedes ``num_wavelengths``. + moment: The dipole moment. + + By default, the dipole is aligned with the z-axis. + current: The current (in A) flowing in the dipole. + + If this is provided, which is the default, the only the direction of the moment + vector is used, and its insensity is set to match the dipole moment with + specified current. + charge: The dipole charge (in Coulomb), assuming opposite charges on either ends of the dipole. + + If this is provided, this takes precedence over ``current``. + center: The center position of the antenna, from which the fields are radiated. + + Examples: + The following example shows how to plot the radiation + pattern (antenna power) at 1 meter. + + .. plotly:: + :fig-vars: fig + + >>> from differt.em import Dipole + >>> + >>> ant = Dipole(frequency=1e9) + >>> fig = ant.plot_radiation_pattern(backend="plotly") + >>> fig # doctest: +SKIP + + The second example shows how to plot the radiation + pattern (antenna power) at 1 meter, but only + in the x-z plane, for multiple dipole lengths. + + .. plot:: + + >>> from differt.em import Dipole + >>> + >>> theta = jnp.linspace(0, 2 * jnp.pi, 200) + >>> r = jnp.stack( + ... (jnp.cos(theta), jnp.zeros_like(theta), jnp.sin(theta)), axis=-1 + ... ) + >>> fig = plt.figure() + >>> ax = fig.add_subplot( + ... projection="polar", facecolor="lightgoldenrodyellow" + ... ) + >>> for ratio in [0.5, 1.0, 1.25, 1.5, 2.0]: + ... ant = Dipole(1e9, ratio) + ... power = jnp.linalg.norm(ant.pointing_vector(r), axis=-1) + ... _ = ax.plot(theta, power, label=rf"$\ell/\lambda = {ratio:1.2f}$") + >>> + >>> ax.tick_params(grid_color="palegoldenrod") + >>> ax.set_rscale("log") + >>> angle = jnp.deg2rad(-10) + >>> ax.legend( # doctest: +SKIP + ... loc="upper left", + ... bbox_to_anchor=(0.5 + jnp.cos(angle) / 2, 0.5 + jnp.sin(angle) / 2), + ... ) + >>> plt.show() # doctest: +SKIP + """ + + length: Float[Array, " "] = eqx.field(converter=jnp.asarray) + """Dipole length (in meter).""" + moment: Float[Array, "3"] = eqx.field(converter=jnp.asarray) + """Dipole moment (in Coulomb-meter).""" + + @jaxtyped(typechecker=typechecker) + def __init__( + self, + frequency: Float[ArrayLike, " "], + num_wavelengths: Float[ArrayLike, " "] = 0.5, + *, + length: Float[ArrayLike, " "] | None = None, + moment: Float[ArrayLike, "3"] | None = jnp.array([0.0, 0.0, 1.0]), + current: Float[ArrayLike, " "] | None = 1.0, + charge: Float[ArrayLike, " "] | None = None, + center: Float[Array, "3"] = jnp.array([0.0, 0.0, 0.0]), + ) -> None: + super().__init__(jnp.asarray(frequency), center=center) + + if length is not None: + self.length = jnp.asarray(length) + else: + self.length = jnp.asarray(num_wavelengths) * self.wavelength + + moment = jnp.array(moment) + + if charge is not None: + moment *= jnp.asarray(charge) * self.length / jnp.linalg.norm(moment) + elif current is not None: + moment *= ( + jnp.asarray(current) + * self.length + / (jnp.linalg.norm(moment) * self.angular_frequency) + ) + + self.moment = moment # type: ignore[reportAttributeAccessIssue] + + @property + def average_power(self) -> Float[Array, " "]: + p_0 = jnp.linalg.norm(self.moment) + + # Equivalent to mu_0 * self.angular_frequency**4 * p_0**2 / (12 * jnp.pi * c) + # but avoids overflow + + r = mu_0 * self.angular_frequency + t = self.angular_frequency * p_0 + r *= t + r *= t + r *= self.angular_frequency / (12 * jnp.pi * c) + + return r + + @eqx.filter_jit + @jaxtyped(typechecker=typechecker) + def fields( + self, r: Float[Array, "*#batch 3"], t: Float[Array, "*#batch"] | None = None + ) -> tuple[Inexact[Array, "*batch 3"], Inexact[Array, "*batch 3"]]: + r_hat, r = normalize(r - self.center, keepdims=True) + p = self.moment + w = self.angular_frequency + k = self.wavenumber + k_k = k * k + r_inv = 1 / r + j_k_r = 1j * k * r + + factor = 1 / (4 * jnp.pi * epsilon_0) + + r_x_p = jnp.cross(r_hat, p) + r_d_p = jnp.sum(r_hat * p, axis=-1, keepdims=True) + + e = ( + factor + * ( + k_k * jnp.cross(r_x_p, r_hat) + + r_inv * r_inv * (r_inv - 1j * k) * (3 * r_hat * r_d_p - p) + ) + * r_inv + ) + b = (factor * k_k / c) * r_x_p * (1 - 1 / j_k_r) * r_inv + + exp = ( + jnp.exp(j_k_r - 1j * w * t[..., None]) if t is not None else jnp.exp(j_k_r) + ) + + e *= exp + b *= exp + + return e, b + + @jaxtyped(typechecker=typechecker) + def directivity( + self, + num_points: int = int(1e2), + ) -> tuple[ + Float[Array, " 2*{num_points}"], + Float[Array, " {num_points}"], + Float[Array, "2*{num_points} {num_points}"], + ]: + u = jnp.linspace(0, 2 * jnp.pi, num_points * 2) + v = jnp.linspace(0, jnp.pi, num_points) + x = jnp.outer(jnp.cos(u), jnp.sin(v)) + y = jnp.outer(jnp.sin(u), jnp.sin(v)) + z = jnp.outer(jnp.ones_like(u), jnp.cos(v)) + + r = jnp.stack((x, y, z), axis=-1) + + p = self.moment / jnp.linalg.norm(self.moment) + + sin_theta = jnp.cross(r, p) + + return u, v, 1.5 * jax.lax.integer_pow(sin_theta, 2) + + @jaxtyped(typechecker=typechecker) + def directive_gain( # noqa: PLR6301 + self, + num_points: int = int(1e2), # noqa: ARG002 + ) -> Float[Array, " "]: + return jnp.array(1.5) + + +class ShortDipole(Dipole): + """Short dipole. + + Like :class:`Dipole`, but accounts for the fact that the current is not constant across the dipole length, + which leads to more realistic results. + + However, fields are only derived for far field. + + Warning: + Not implemented yed. + """ + + @eqx.filter_jit + @jaxtyped(typechecker=typechecker) + def fields( + self, r: Float[Array, "*#batch 3"], t: Float[Array, "*#batch"] | None = None + ) -> tuple[Inexact[Array, "*batch 3"], Inexact[Array, "*batch 3"]]: + raise NotImplementedError + + @jaxtyped(typechecker=typechecker) + def directivity( + self, + num_points: int = int(1e2), + ) -> tuple[ + Float[Array, " 2*{num_points}"], + Float[Array, " {num_points}"], + Float[Array, "2*{num_points} {num_points}"], + ]: + # Bypass Dipole's specialized implementation + return Antenna.directivity(self, num_points=num_points) + + @jaxtyped(typechecker=typechecker) + def directive_gain( + self, + num_points: int = int(1e2), + ) -> Float[Array, " "]: + # Bypass Dipole's specialized implementation + return Antenna.directive_gain(self, num_points=num_points) + + +@jaxtyped(typechecker=typechecker) +class RadiationPattern(BaseAntenna): + """An antenna radiation pattern class, must be subclassed.""" + + @abstractmethod + def polarization_vectors( + self, + r: Float[Array, "*#batch 3"], + ) -> tuple[Float[Array, "*batch 3"], Float[Array, "*batch 3"]]: + r""" + Compute s and p polarization vectors. + + Args: + r: The array of positions. + + Returns: + The electric :math:`\vec{E}` and magnetical :math:`\vec{B}` fields. + + Fields can be either real or complex-valued. + """ + + @jaxtyped(typechecker=typechecker) + def directivity( + self, + num_points: int = int(1e2), + ) -> tuple[ + Float[Array, " 2*{num_points}"], + Float[Array, " {num_points}"], + Float[Array, "2*{num_points} {num_points}"], + ]: + """ + Compute an estimate of the antenna directivity for azimutal and elevation angles. + + .. note:: + + Subclasses may provide a more accurate or exact + implementation. + + Args: + num_points: The number of points to sample along the elevation axis. + + Twice this number of points are sampled on the aximutal axis. + + Returns: + Azimutal and elevation angles, as well as corresponding directivity values. + + .. seealso:: + + :meth:`directive_gain` + """ + u, du = jnp.linspace(0, 2 * jnp.pi, num_points * 2, retstep=True) + v, dv = jnp.linspace(0, jnp.pi, num_points, retstep=True) + x = jnp.outer(jnp.cos(u), jnp.sin(v)) + y = jnp.outer(jnp.sin(u), jnp.sin(v)) + z = jnp.outer(jnp.ones_like(u), jnp.cos(v)) + + r = self.center + jnp.stack((x, y, z), axis=-1) + + s, p = self.polarization_vectors(r) + + g = dot(s) + dot(p) + + # TODO: check if this is correct + + return u, v, g + + @jaxtyped(typechecker=typechecker) + def directive_gain( + self, + num_points: int = int(1e2), + ) -> Float[Array, " "]: + """ + Compute an estimate of the antenna directive gain. + + .. note:: + + Subclasses may provide a more accurate or exact + implementation. + + Args: + num_points: The number of points used for the estimate. + + Returns: + The antenna directive gain. + + .. seealso:: + + :meth:`directivity` + """ + return self.directivity(num_points=num_points)[-1].max() + + def plot_radiation_pattern( + self, + num_points: int = int(1e2), + distance: Float[ArrayLike, " "] = 1.0, + num_wavelengths: Float[ArrayLike, " "] | None = None, + **kwargs: Any, + ) -> PlotOutput: + """ + Plot the radiation pattern (normalized power) of this antenna. + + The power is computed on points on an sphere around the antenna. + + Args: + num_points: The number of points to sample along the elevation axis. + + Twice this number of points are sampled on the aximutal axis. + distance: The distance from the antenna at which power samples + are evaluated. + num_wavelengths: If provided, supersedes ``distance`` by setting + the distance relatively to the :attr:`wavelength`. + kwargs: Keyword arguments passed to + :func:`draw_surface`. + + Returns: + The resulting plot output. + """ + if num_wavelengths is not None: + distance = jnp.asarray(num_wavelengths) * self.wavelength + else: + distance = jnp.asarray(distance) + + u = jnp.linspace(0, 2 * jnp.pi, num_points * 2) + v = jnp.linspace(0, jnp.pi, num_points) + x = jnp.outer(jnp.cos(u), jnp.sin(v)) + y = jnp.outer(jnp.sin(u), jnp.sin(v)) + z = jnp.outer(jnp.ones_like(u), jnp.cos(v)) + + r = self.center + distance * jnp.stack((x, y, z), axis=-1) + + s = self.pointing_vector(r) + + p = jnp.linalg.norm(s, axis=-1, keepdims=True) + + gain = p / p.max() + + r *= gain + gain = jnp.squeeze(gain, axis=-1) + + return draw_surface( + x=r[..., 0], y=r[..., 1], z=r[..., 2], colors=gain, **kwargs + ) + + +@jaxtyped(typechecker=typechecker) +class HWDipolePattern(RadiationPattern): + """An half-wave dipole radiation pattern.""" + + direction: Float[Array, "3"] = eqx.field(converter=jnp.asarray) + """The dipole direction.""" + + def polarization_vectors( + self, + r: Float[Array, "*#batch 3"], + ) -> tuple[Float[Array, "*batch 3"], Float[Array, "*batch 3"]]: + r_hat, r = normalize(r - self.center, keepdims=True) + + cos_theta = dot(r_hat, self.direction) + sin_theta = jnp.sqrt(1 - cos_theta**2) + + d = 1.640922376984585 # Directive gain: 4 / Cin(2*pi) + + cos_theta = dot() + sin_theta = jnp.sin() + d = safe_divide(jnp.cos(0.5 * jnp.pi * cos_theta), sin_theta) + + +@jaxtyped(typechecker=typechecker) +class ShortDipolePattern(RadiationPattern): + """An short dipole radiation pattern.""" + + direction: Float[Array, "3"] = eqx.field(converter=jnp.asarray) + """The dipole direction.""" diff --git a/differt/src/differt/em/_constants.py b/differt/src/differt/em/_constants.py index 9d9a500e..7f8c8b89 100644 --- a/differt/src/differt/em/_constants.py +++ b/differt/src/differt/em/_constants.py @@ -1,10 +1,11 @@ -"""Physical constants used for EM fields computation.""" - -c = 299792458.0 +c: float = 299792458.0 """The speed of light in vacuum.""" -mu_0 = 1.25663706212e-06 +mu_0: float = 1.25663706212e-06 r"""The vacuum permeability :math:`\mu_0`.""" -epsilon_0 = 8.8541878128e-12 +epsilon_0: float = 8.8541878128e-12 r"""The vacuum permittivity :math:`\epsilon_0`.""" + +z_0: float = 376.73031341259 +r"""The impedance of free space :math:`Z_0`.""" diff --git a/differt/src/differt/em/_fresnel.py b/differt/src/differt/em/_fresnel.py new file mode 100644 index 00000000..803e3aa6 --- /dev/null +++ b/differt/src/differt/em/_fresnel.py @@ -0,0 +1,450 @@ +import equinox as eqx +import jax +import jax.numpy as jnp +from beartype import beartype as typechecker +from jaxtyping import Array, ArrayLike, Float, Inexact, jaxtyped + +from differt.utils import safe_divide + + +@eqx.filter_jit +@jaxtyped(typechecker=typechecker) +def refractive_indices( + epsilon_r: Inexact[ArrayLike, " *#batch"], + mu_r: Inexact[ArrayLike, " *#batch"] | None = None, +) -> Inexact[Array, " *batch"]: + r""" + Compute the refractive indices corresponding to relative permittivities and relative permeabilities. + + The refractive index :math:`n` is simply defined as + + .. math:: + n = \sqrt{\epsilon_r\mu_r}, + + where :math:`\epsilon_r` is the relative permittivity, and :math:`\mu_r` is the relative permeability. + + Args: + epsilon_r: The relative permittivities. + mu_r: The relative permeabilities. If not provided, + a value of 1 is used. + + Returns: + The array of refractive indices. + + The output dtype will only be complex if any of the provided arguments + has a complex dtype. + + .. seealso:: + + :func:`fresnel_coefficients` + + :func:`reflection_coefficients` + + :func:`refraction_coefficients` + """ + return jnp.sqrt(epsilon_r if mu_r is None else epsilon_r * mu_r) + + +@jax.jit +@jaxtyped(typechecker=typechecker) +def fresnel_coefficients( + n_r: Inexact[ArrayLike, " *#batch"], + cos_theta_i: Float[Array, " *#batch"], +) -> tuple[ + tuple[Inexact[Array, " *batch"], Inexact[Array, " *batch"]], + tuple[Inexact[Array, " *batch"], Inexact[Array, " *batch"]], +]: + r""" + Compute the Fresnel reflection and refraction coefficients at an interface. + + The Snell's law describes the relationship between the angles of incidence + and refraction: + + .. math:: + n_i\sin\theta_i = n_t\sin\theta_t, + + where :math:`n` is the refraction index, :math:`\theta` is the angle of between the ray path + and the normal to the interface, and :math:`i` and :math:`t` indicate, + respectively, the first (i.e., incidence) and the second (i.e., transmission) + media. + + The s and p reflection coefficients are: + + .. math:: + r_s = \frac{n_i\cos\theta_i - n_t\cos\theta_t}{n_i\cos\theta_i + n_t\cos\theta_t}, + + and + + .. math:: + r_p = \frac{n_t\cos\theta_i - n_i\cos\theta_t}{n_t\cos\theta_i + n_i\cos\theta_t}. + + The s and p refraction coefficients are: + + .. math:: + t_s = \frac{2n_i\cos\theta_i}{n_i\cos\theta_i + n_t\cos\theta_t}, + + and + + .. math:: + t_p = \frac{2n_i\cos\theta_i}{n_t\cos\theta_i + n_i\cos\theta_t}. + + Then, we define :math:`n_r \triangleq \frac{n_t}{n_i}` and rewrite the four coefficients as: + + .. math:: + r_s &= \frac{\cos\theta_i - n_r\cos\theta_t}{\cos\theta_i + n_r\cos\theta_t},\\ + r_p &= \frac{n_r^2\cos\theta_i - n_r\cos\theta_t}{n_r^2\cos\theta_i + n_r\cos\theta_t},\\ + t_s &= \frac{2\cos\theta_i}{\cos\theta_i + n_r\cos\theta_t},\\ + t_p &= \frac{2n_r\cos\theta_i}{n_r^2\cos\theta_i + n_r\cos\theta_t}, + + where :math:`n_t\cos\theta_t` is obtained from: + + .. math:: + n_r\cos\theta_t = \sqrt{n_r^2 + \cos^2\theta_i - 1}. + + Args: + n_r: The relative refractive indices. + + This is the ratios of the refractive indices of the second + media over the refractive indices of the first media. + cos_theta_i: The (cosine of the) angles of incidence (or reflection). + + Returns: + The reflection and refraction coefficients for s and p polarizations. + + The output dtype will only be complex if any of the provided arguments + has a complex dtype. + + .. seealso:: + + :func:`reflection_coefficients` + + :func:`refraction_coefficients` + + :func:`refractive_indices` + + Examples: + .. plot:: + + The following example reproduces the air-to-glass Fresnel coefficient. + The Brewster angle (defined by :math:`r_p=0`) is indicated by the vertical + red line. + + >>> from differt.em import fresnel_coefficients + >>> + >>> n = 1.5 # Air to glass + >>> theta = jnp.linspace(0, jnp.pi / 2) + >>> cos_theta = jnp.cos(theta) + >>> (r_s, r_p), (t_s, t_p) = fresnel_coefficients(n, cos_theta) + >>> theta_d = jnp.rad2deg(theta) + >>> theta_b = jnp.rad2deg(jnp.arctan(n)) + >>> plt.plot(theta_d, r_s, "b:", label=r"$r_s$") # doctest: +SKIP + >>> plt.plot(theta_d, r_p, "r:", label=r"$r_p$") # doctest: +SKIP + >>> plt.plot(theta_d, t_s, "b-", label=r"$t_s$") # doctest: +SKIP + >>> plt.plot(theta_d, t_p, "r-", label=r"$t_p$") # doctest: +SKIP + >>> plt.axvline(theta_b, color="r", linestyle="--") # doctest: +SKIP + >>> plt.xlabel("Angle of incidence (°)") # doctest: +SKIP + >>> plt.ylabel("Amplitude") # doctest: +SKIP + >>> plt.xlim(0, 90) # doctest: +SKIP + >>> plt.ylim(-1.0, 1.0) # doctest: +SKIP + >>> plt.title("Fresnel coefficients") # doctest: +SKIP + >>> plt.legend() # doctest: +SKIP + >>> plt.tight_layout() # doctest: +SKIP + + .. plot:: + + The following example produces the same but glass-to-air interface. + The critical angle (total internal reflection) is indicated by the vertical + black line. + + >>> from differt.em import fresnel_coefficients + >>> + >>> n = 1/ 1.5 # Glass to air + >>> theta = jnp.linspace(0, jnp.pi / 2, 300) + >>> cos_theta = jnp.cos(theta) + >>> (r_s, r_p), (t_s, t_p) = fresnel_coefficients(n, cos_theta) + >>> theta_d = jnp.rad2deg(theta) + >>> theta_b = jnp.rad2deg(jnp.arctan(n)) + >>> theta_c = jnp.rad2deg(jnp.arcsin(n)) + >>> plt.plot(theta_d, r_s, "b:", label=r"$r_s$") # doctest: +SKIP + >>> plt.plot(theta_d, r_p, "r:", label=r"$r_p$") # doctest: +SKIP + >>> plt.plot(theta_d, t_s, "b-", label=r"$t_s$") # doctest: +SKIP + >>> plt.plot(theta_d, t_p, "r-", label=r"$t_p$") # doctest: +SKIP + >>> plt.axvline(theta_b, color="r", linestyle="--") # doctest: +SKIP + >>> plt.axvline(theta_c, color="k", linestyle="--") # doctest: +SKIP + >>> plt.xlabel("Angle of incidence (°)") # doctest: +SKIP + >>> plt.ylabel("Amplitude") # doctest: +SKIP + >>> plt.xlim(0, 90) # doctest: +SKIP + >>> plt.ylim(-0.5, 3.0) # doctest: +SKIP + >>> plt.title("Fresnel coefficients") # doctest: +SKIP + >>> plt.legend() # doctest: +SKIP + >>> plt.tight_layout() # doctest: +SKIP + """ + n_r_squared = jax.lax.integer_pow(n_r, 2) + cos_theta_i_squared = jax.lax.integer_pow(cos_theta_i, 2) + n_r_squared_cos_theta_i = n_r_squared * cos_theta_i + n_r_cos_theta_t = jnp.sqrt(n_r_squared + cos_theta_i_squared - 1) + two_cos_theta_i = 2 * cos_theta_i + + r_s = safe_divide( + cos_theta_i - n_r_cos_theta_t, + cos_theta_i + n_r_cos_theta_t, + ) + t_s = safe_divide( + two_cos_theta_i, + cos_theta_i + n_r_cos_theta_t, + ) + r_p = safe_divide( + n_r_squared_cos_theta_i - n_r_cos_theta_t, + n_r_squared_cos_theta_i + n_r_cos_theta_t, + ) + t_p = safe_divide( + n_r * two_cos_theta_i, + n_r_squared_cos_theta_i + n_r_cos_theta_t, + ) + + return (r_s, r_p), (t_s, t_p) + + +@jax.jit +@jaxtyped(typechecker=typechecker) +def reflection_coefficients( + n_r: Inexact[ArrayLike, " *#batch"], + cos_theta_i: Float[Array, " *#batch"], +) -> tuple[Inexact[Array, " *batch"], Inexact[Array, " *batch"]]: + r""" + Compute the Fresnel reflection coefficients at an interface. + + Args: + n_r: The relative refractive indices. + + This is the ratios of the refractive indices of the second + media over the refractive indices of the first media. + cos_theta_i: The (cosine of the) angles of incidence (or reflection). + + Returns: + The reflection coefficients for s and p polarizations. + + The output dtype will only be complex if any of the provided arguments + has a complex dtype. + + .. seealso:: + + :func:`fresnel_coefficients` + + :func:`refraction_coefficients` + + :func:`refractive_indices` + + Examples: + .. plot:: + :context: reset + + The following example show how to compute interference + patterns from line of sight and reflection on a glass + ground. + + >>> from differt.em import ( + ... c, + ... reflection_coefficients, + ... Dipole, + ... pointing_vector, + ... sp_directions, + ... ) + >>> from differt.geometry import normalize + >>> from differt.rt import image_method + >>> from differt.utils import dot + + The first step is to define the antenna and the geometry of the scene. + Here, we place a dipole antenna above the origin, and generate a + ``num_positions`` number of positions along the horizontal line, + where we will evaluate the EM fields. + + >>> tx_position = jnp.array([0.0, 2.0, 0.0]) + >>> rx_position = jnp.array([0.0, 2.0, 0.0]) + >>> num_positions = 1000 + >>> # [num_positions 3] + >>> x = jnp.logspace(0, 3, num_positions) # From close to very far + >>> rx_positions = ( + ... jnp.tile(rx_position, (num_positions, 1)).at[..., 0].add(x) + ... ) + >>> ant = Dipole(2.4e9) # 2.4 GHz + >>> plt.xscale("symlog", linthresh=1e-1) # doctest: +SKIP + >>> plt.plot( + ... [tx_position[0]], + ... [tx_position[1]], + ... "o", + ... label="TX", + ... ) # doctest: +SKIP + >>> plt.plot( + ... rx_positions[::50, 0], + ... rx_positions[::50, 1], + ... "o", + ... label="RXs", + ... ) # doctest: +SKIP + >>> plt.axhline(color="k", label="Ground") # doctest: +SKIP + >>> plt.xlabel("x-axis (m)") # doctest: +SKIP + >>> plt.ylabel("y-axis (m)") # doctest: +SKIP + >>> plt.legend() # doctest: +SKIP + >>> plt.tight_layout() # doctest: +SKIP + + .. plot:: + :context: close-figs + + Next, we compute the EM fields from the direct (line-of-sight) path. + + >>> # [num_positions 3] + >>> E_los, B_los = ant.fields(rx_positions - tx_position) + >>> # [num_positions] + >>> P_los = jnp.linalg.norm(pointing_vector(E_los, B_los), axis=-1) + >>> plt.semilogx( + ... x, + ... 10 * jnp.log10(P_los / ant.average_power), + ... label=r"$P_\text{los}$", + ... ) # doctest: +SKIP + + After, the :func:`image_method` + function is used to compute the reflection points. + + >>> ground_vertex = jnp.array([0.0, 0.0, 0.0]) + >>> ground_normal = jnp.array([0.0, 1.0, 0.0]) + >>> # [num_positions 3] + >>> reflection_points = image_method( + ... tx_position, + ... rx_positions, + ... ground_vertex[None, ...], + ... ground_normal[None, ...], + ... ).squeeze(axis=-2) # Squeeze because only one reflection + >>> # [num_positions 3], [num_positions 1] + >>> k_i, s_i = normalize(reflection_points - tx_position, keepdims=True) + >>> k_r, s_r = normalize(rx_positions - reflection_points, keepdims=True) + >>> # [num_positions 1] + >>> l = jnp.linalg.norm(rx_positions - tx_position, axis=-1, keepdims=True) + >>> tau = (s_i + s_r - l) / c # Delay between two paths + >>> tau = tau.squeeze(axis=-1) + + We then compute the EM fields at those points, and use the Fresnel + reflection coefficients to compute the reflected fields. + + >>> # [num_positions 3] + >>> E_i, B_i = ant.fields(reflection_points - tx_position, t=-tau) + >>> # [num_positions 1] + >>> cos_theta = dot(ground_normal, -k_i, keepdims=True) + >>> n_r = 1.5 # Air to glass + >>> # [num_positions 1] + >>> r_s, r_p = reflection_coefficients(n_r, cos_theta) + + To apply the coefficients correctly, we must determine the polarization + directions of both the incident and the reflected fields. + + .. important:: + + Reflection coefficients are returned based on s and p directions. + As a result, we need to first determine those local directions, and + apply the corresponding reflection coefficients to the projection + of the fields onto those directions + :cite:`utd-mcnamara{eq. 3.3-3.8 and 3.39, p. 70 and 77}`. + + >>> # [num_positions 3] + >>> (e_i_s, e_i_p), (e_r_s, e_r_p) = sp_directions(k_i, k_r, ground_normal) + + We then transform XYZ-components into local s and p components. + + >>> # [num_positions 1] + >>> E_i_s = dot(E_i, e_i_s, keepdims=True) + >>> E_i_p = dot(E_i, e_i_p, keepdims=True) + >>> B_i_s = dot(B_i, e_i_s, keepdims=True) + >>> B_i_p = dot(B_i, e_i_p, keepdims=True) + + Then, we apply reflection coefficients to the local s and p components. + + >>> # [num_positions 1] + >>> E_r_s = r_s * E_i_s + >>> E_r_p = r_p * E_i_p + >>> B_r_s = r_s * B_i_s + >>> B_r_p = r_p * B_i_p + + And we project back to XYZ-components. + + >>> E_r = E_r_s * e_r_s + E_r_p * e_r_p + >>> B_r = B_r_s * e_r_s + B_r_p * e_r_p + + Finally, we apply the spreading factor and phase shift due to the propagation + from the reflection points to the receiver :cite:`utd-mcnamara{eq. 3.1, p. 63}`. + + >>> spreading_factor = s_i / ( + ... s_i + s_r + ... ) # We assume that the radii of curvature are equal to 's_i' + >>> phase_shift = jnp.exp(1j * s_r * ant.wavenumber) + >>> E_r *= spreading_factor * phase_shift + >>> B_r *= spreading_factor * phase_shift + >>> P_r = jnp.linalg.norm(pointing_vector(E_r, B_r), axis=-1) + >>> plt.semilogx( + ... x, + ... 10 * jnp.log10(P_r / ant.average_power), + ... "--", + ... label=r"$P_\text{reflection}$", + ... ) # doctest: +SKIP + + We also plot the total field, to better observe the interference pattern. + + >>> E_tot = E_los + E_r + >>> B_tot = B_los + B_r + >>> P_tot = jnp.linalg.norm(pointing_vector(E_tot, B_tot), axis=-1) + >>> plt.semilogx( + ... x, + ... 10 * jnp.log10(P_tot / ant.average_power), + ... "-.", + ... label=r"$P_\text{total}$", + ... ) # doctest: +SKIP + >>> plt.xlabel("Distance to transmitter on x-axis (m)") # doctest: +SKIP + >>> plt.ylabel("Loss (dB)") # doctest: +SKIP + >>> plt.legend() # doctest: +SKIP + >>> plt.tight_layout() # doctest: +SKIP + + From the above figure, it is clear that the ground-reflection creates an interference + pattern in the received power. Moreover, we can clearly observe the Brewster angle + at a distance of 6 m. This can verified by computing the Brewster angle from the + relative refractive index, and matching it to the corresponding distance. + + >>> brewster_angle = jnp.arctan(n_r) + >>> print(f"Brewster angle: {jnp.rad2deg(brewster_angle):.1f}°") + Brewster angle: 56.3° + >>> cos_distance = jnp.abs(jnp.cos(brewster_angle) - cos_theta) + >>> distance = x[jnp.argmin(cos_distance)] + >>> print(f"Corresponding distance: {distance:.1f} m") + Corresponding distance: 6.0 m + """ + return fresnel_coefficients(n_r, cos_theta_i)[0] + + +@eqx.filter_jit +@jaxtyped(typechecker=typechecker) +def refraction_coefficients( + n_r: Inexact[ArrayLike, " *#batch"], + cos_theta_i: Float[Array, " *#batch"], +) -> tuple[Inexact[Array, " *batch"], Inexact[Array, " *batch"]]: + """ + Compute the Fresnel refraction coefficients at an interface. + + Args: + n_r: The relative refractive indices. + + This is the ratios of the refractive indices of the second + media over the refractive indices of the first media. + cos_theta_i: The (cosine of the) angles of incidence (or reflection). + + Returns: + The refraction coefficients for s and p polarizations. + + The output dtype will only be complex if any of the provided arguments + has a complex dtype. + + .. seealso:: + + :func:`fresnel_coefficients` + + :func:`reflection_coefficients` + + :func:`refractive_indices` + """ + return fresnel_coefficients(n_r, cos_theta_i)[1] diff --git a/differt/src/differt/em/_interaction_type.py b/differt/src/differt/em/_interaction_type.py new file mode 100644 index 00000000..2d1cf600 --- /dev/null +++ b/differt/src/differt/em/_interaction_type.py @@ -0,0 +1,13 @@ +from enum import IntEnum, unique + + +@unique +class InteractionType(IntEnum): + """Enumeration of interaction types.""" + + REFLECTION = 0 + """Specular reflection on an surface.""" + DIFFRACTION = 1 + """Diffraction on an edge.""" + SCATTERING = 2 + """Scattering on a surface.""" diff --git a/differt/src/differt/em/_material.py b/differt/src/differt/em/_material.py new file mode 100644 index 00000000..d4e6952c --- /dev/null +++ b/differt/src/differt/em/_material.py @@ -0,0 +1,282 @@ +# ruff: noqa: FURB152 +import operator +import sys +from collections.abc import Callable +from functools import partial + +import equinox as eqx +import jax +import jax.numpy as jnp +from beartype import beartype as typechecker +from jaxtyping import Array, ArrayLike, Float, jaxtyped + +if sys.version_info >= (3, 11): + from typing import Self +else: + from typing_extensions import Self + + +@jaxtyped(typechecker=typechecker) +class Material(eqx.Module): + """A class representing a material and it electrical properties.""" + + name: str = eqx.field(static=True) + """ + The name of the material. + """ + properties: Callable[ + [Float[ArrayLike, " *batch"]], + tuple[Float[Array, " *batch"], Float[Array, " *batch"]], + ] = eqx.field(static=True) + """ + The callable that computes the electrical properties of the material at the given frequency. + + The signature of the callable must be as follows. + + Args: + frequency: The frequency at which to compute the electrical properties. + + Returns: + A tuple containing the relative permittivity and conductivity of the material. + """ + aliases: tuple[str, ...] = eqx.field(default=(), static=True) + """ + A tuple of name aliases for the material. + """ + + @eqx.filter_jit + def relative_permittivity( + self, frequency: Float[ArrayLike, " *batch"] + ) -> Float[Array, " *batch"]: + """ + Compute the relative permittivity of the material at the given frequency. + + Args: + frequency: The frequency at which to compute the relative permittivity. + + Returns: + The relative permittivity of the material. + """ + return self.properties(frequency)[0] + + @eqx.filter_jit + def conductivity( + self, frequency: Float[ArrayLike, " *batch"] + ) -> Float[Array, " *batch"]: + """ + Compute the conductivity of the material at the given frequency. + + Args: + frequency: The frequency at which to compute the conductivity. + + Returns: + The conductivity of the material. + """ + return self.properties(frequency)[1] + + @classmethod + def from_itu_properties( + cls, + name: str, + *itu_properties: tuple[ + Float[ArrayLike, " "], + Float[ArrayLike, " "], + Float[ArrayLike, " "], + Float[ArrayLike, " "], + tuple[Float[ArrayLike, " "], Float[ArrayLike, " "]] | None, + ], + ) -> Self: + r""" + Create a material from ITU properties. + + The ITU-R Recommendation P.2040-3 :cite:`itu-r-2040` defines the electrical properties of a material + using 4 real-valued coefficients: **a**, **b**, **c**, and **c**. The :data:`materials` mapping + is already populated with values from :cite:`itu-r-2040{Tab. 3}`. + + Args: + name: The name of the material. + itu_properties: The list of material properties and corresponding frequency range. + + Each tuple must contain: + + * **a** (:class:`Float[ArrayLike, '']`): + The first coefficient for the real part of the relative permittivity. + * **b** (:class:`Float[ArrayLike, '']`): + The second coefficient for the real part of the relative permittivity. + * **c** (:class:`Float[ArrayLike, '']`): + The first coefficient for the conductivity. + * **d** (:class:`Float[ArrayLike, '']`): + The second coefficient for the conductivity. + * **frequency_range** + (:class:`tuple`\[:class:`Float[ArrayLike, '']`, + :class:`Float[ArrayLike, '']`\]): + The frequency range (in GHz) for which the electrical + properties are assumed to be correct. + + This parameter must either be an ordered 2-tuple of min. and max. frequencies, + or can be :data:`None`, in which case only one frequency range is allowed as + it will match all frequencies. + + Returns: + A new material. + + Raises: + ValueError: If you passed more than one frequency range and at least one was :data:`None`. + """ + f_ranges = [] + branches = [] + + dtype = jnp.result_type(*[x for prop in itu_properties for x in prop[:-1]]) + + aliases = ("itu_" + name.lower().replace(" ", "_"),) + + @partial(jax.jit, inline=True, static_argnums=(1, 2, 3, 4)) + @jaxtyped(typechecker=typechecker) + def callback( + f: Float[ArrayLike, " *batch"], + a: Float[ArrayLike, " "], + b: Float[ArrayLike, " "], + c: Float[ArrayLike, " "], + d: Float[ArrayLike, " "], + ) -> tuple[Float[Array, "*batch"], Float[Array, "*batch"]]: + f_ghz = jnp.asarray(f) / 1e9 + + if b == 0: + rel_perm = jnp.full_like(f_ghz, a, dtype=dtype) + else: + rel_perm = jnp.asarray(a * f_ghz**b, dtype=dtype) + + if d == 0: + cond = jnp.full_like(f_ghz, c, dtype=dtype) + else: + cond = jnp.asarray(c * f_ghz**d, dtype=dtype) + + return rel_perm, cond + + if any(prop[-1] is None for prop in itu_properties): + if len(itu_properties) != 1: + msg = "Only one frequency range can be used if 'None' is passed, as it will match any frequency." + raise ValueError(msg) + a, b, c, d, _ = itu_properties[0] + return cls( + name=name, + properties=partial(callback, a=a, b=b, c=c, d=d), + aliases=aliases, + ) + + props = sorted(itu_properties, key=operator.itemgetter(-1)) + + for a, b, c, d, f_range in props: + f_ranges.append(f_range) + branches.append(partial(callback, a=a, b=b, c=c, d=d)) + + # This callbacks is used when frequency is outside of range + branches.append( + lambda f: ( + -jnp.ones_like(f, dtype=dtype), + -jnp.ones_like(f, dtype=dtype), + ) + ) + i_range = jnp.arange(len(f_ranges)) + i_outside = len(branches) - 1 + + # NOTE: + # Checking f >= f_min_ghz * 1e9 + # leads to more accutate check than + # doing f / 1e9 >= f_min_ghz, + # hence we pre-multiply frequency ranges to be in Hz. + f_ranges = jnp.asarray(f_ranges) * 1e9 + f_min = f_ranges[:, 0] + f_max = f_ranges[:, 1] + + @jax.jit + def properties( + f: Float[ArrayLike, "*batch"], + ) -> tuple[Float[Array, "*batch"], Float[Array, "*batch"]]: + f = jnp.asarray(f) + + if jnp.ndim(f) == 0: + where = (f_min <= f) & (f <= f_max) + indices = jnp.min(i_range, where=where, initial=i_outside) + return jax.lax.switch( + indices, + branches, + f, + ) + + batch = f.shape + f = f.ravel() + + where = (f_min <= f[..., None]) & (f[..., None] <= f_max) + indices = jnp.min( + jnp.broadcast_to(i_range, where.shape), + where=where, + initial=i_outside, + axis=-1, + ) + + rel_perm, cond = jax.vmap( + lambda freq, i: jax.lax.switch( + i, + branches, + freq, + ), + )(f, indices) + + return rel_perm.reshape(batch), cond.reshape(batch) + + return cls( + name=name, + properties=properties, + aliases=aliases, + ) + + +# ITU-R P.2024-3 materials from Table 3. +_materials = [ + Material.from_itu_properties("Vacuum", (1.0, 0.0, 0.0, 0.0, None)), + Material.from_itu_properties("Concrete", (5.24, 0.0, 0.0462, 0.7822, (1, 100))), + Material.from_itu_properties("Brick", (3.91, 0.0, 0.0238, 0.16, (1, 40))), + Material.from_itu_properties("Plasterboard", (2.73, 0.0, 0.0085, 0.9395, (1, 100))), + Material.from_itu_properties("Wood", (1.99, 0.0, 0.0047, 1.0718, (0.001, 100))), + Material.from_itu_properties( + "Glass", + (6.31, 0.0, 0.0036, 1.3394, (0.1, 100)), + (5.79, 0.0, 0.0004, 1.658, (220, 450)), + ), + Material.from_itu_properties( + "Ceiling board", + (1.48, 0.0, 0.0011, 1.0750, (1, 100)), + (1.52, 0.0, 0.0029, 1.029, (220, 450)), + ), + Material.from_itu_properties("Chipboard", (2.58, 0.0, 0.0217, 0.7800, (1, 100))), + Material.from_itu_properties( + "Plywood", + ( + 2.71, + 0.0, + 0.33, + 0.0, + (1, 40), + ), + ), + Material.from_itu_properties("Marble", (7.074, 0.0, 0.0055, 0.9262, (1, 60))), + Material.from_itu_properties("Floorboard", (3.66, 0.0, 0.0044, 1.3515, (50, 100))), + Material.from_itu_properties("Metal", (1.0, 0.0, 1e7, 0.0, (1, 100))), + Material.from_itu_properties("Very dry ground", (3.0, 0.0, 0.00015, 2.52, (1, 10))), + Material.from_itu_properties( + "Medium dry ground", (15.0, -0.1, 0.035, 1.63, (1, 10)) + ), + Material.from_itu_properties("Wet ground", (30.0, -0.4, 0.15, 1.30, (1, 10))), +] + +materials: dict[str, Material] = { + name: material + for material in _materials + for name in (material.name, *material.aliases) +} +"""A dictionary mapping material names and corresponding object instances. + +Some materials, like ITU-R materials, have aliases to match the naming convention of Sionna.""" + +del _materials diff --git a/differt/src/differt/em/_special.py b/differt/src/differt/em/_special.py deleted file mode 100644 index b5566d2e..00000000 --- a/differt/src/differt/em/_special.py +++ /dev/null @@ -1,227 +0,0 @@ -import jax -import jax.numpy as jnp -from beartype import beartype as typechecker -from jax.scipy.special import erf as erfx -from jaxtyping import Array, Inexact, jaxtyped - - -@jax.jit -@jaxtyped(typechecker=typechecker) -def erf(z: Inexact[Array, " *batch"]) -> Inexact[Array, " *batch"]: - r""" - Evaluate the error function at the given points. - - The current implementation is written using - the real-valued error function :func:`jax.scipy.special.erf` - and the approximation as detailed in :cite:`erf-complex`. - - The output type (real or complex) is determined by the - input type. - - Warning: - Currently, we observe that - this function and :data:`scipy.special.erf` - starts to diverge for :math:`|z| > 6`. If you know - how to avoid this problem, please contact us! - - Args: - z: The array of real or complex points to evaluate. - - Returns: - The values of the error function at the given point. - - Notes: - Regarding performances, there are two possible outputs: - - 1. If ``z`` is real, then this function compiles to - :func:`jax.scipy.special.erf`, and will therefore - have the same performances (when JIT compilation - is done). Compared to the SciPy equivalent, we measured - that our implementation is **~ 10 times faster**. - 2. If ``z`` is complex, then our implementation is - **~ 3 times faster** than - :data:`scipy.special.erf`. - - Those results were measured on centered random uniform arrays - with :math:`10^5` elements. - - Examples: - The following plots the error function for real-valued inputs. - - .. plot:: - - >>> from differt.em import erf - >>> - >>> x = jnp.linspace(-3.0, +3.0) - >>> y = erf(x) - >>> plt.plot(x, y.real) # doctest: +SKIP - >>> plt.xlabel("$x$") # doctest: +SKIP - >>> plt.ylabel(r"$\text{erf}(x)$") # doctest: +SKIP - - The following plots the error function for complex-valued inputs. - - .. plotly:: - - >>> from differt.em import erf - >>> from scipy.special import erf - >>> - >>> x = y = jnp.linspace(-2.0, +2.0, 200) - >>> a, b = jnp.meshgrid(x, y) - >>> z = erf(a + 1j * b) - >>> fig = go.Figure( - ... data=[ - ... go.Surface( - ... x=x, - ... y=y, - ... z=jnp.abs(z), - ... colorscale="phase", - ... surfacecolor=jnp.angle(z), - ... colorbar=dict(title="Arg(erf(z))"), - ... ) - ... ] - ... ) - >>> fig.update_layout( - ... scene=dict( - ... xaxis=dict(title="Re(z)"), - ... yaxis=dict(title="Im(z)"), - ... zaxis=dict(title="Abs(erf(z))"), - ... ) - ... ) # doctest: +SKIP - >>> fig # doctest: +SKIP - """ - # TODO: remove this function as it is not needed anymore - if jnp.issubdtype(z.dtype, jnp.floating): - return erfx(z) - - if jnp.issubdtype(z.dtype, jnp.complex128): # double precision - N = 13 # noqa: N806 - M = 14 # noqa: N806 - else: # single precision - N = 9 # noqa: N806 - M = 10 # noqa: N806 - - r = z.real - i = jnp.abs(z.imag) - r_squared = r * r - - exp_r_squared = jnp.exp(-r_squared) - exp_2j_r_i = jnp.exp(-2j * r * i) - - f_sum = jnp.zeros_like(z) - g_sum = jnp.zeros_like(z) - h_sum = jnp.zeros_like(z) - - for n in range(1, N + 1): - n_squared = n * n - n_squared_over_four = n_squared / 4 - den = 1 / (n_squared_over_four + r_squared) - exp_f = jnp.exp(-n_squared_over_four) - exp_g = jnp.exp(+n * i - n_squared_over_four) - exp_h = jnp.exp(-n * i - n_squared_over_four) - - f_sum += exp_f * den - g_sum += exp_g * (r - 1j * n / 2) * den - h_sum += exp_h * (r + 1j * n / 2) * den - - for n in range(N + 1, N + M + 1): - n_squared = n * n - n_squared_over_four = n_squared / 4 - exp_g = jnp.exp(+n * i - n_squared_over_four) - - g_sum += exp_g * (r - 1j * n / 2) / (n_squared_over_four + r_squared) - - r_non_zero = jnp.where(r == 0.0, 1.0, r) - e = jnp.where( - r == 0.0, - 1j * i / jnp.pi, - (exp_r_squared * (1 - exp_2j_r_i)) / (2 * jnp.pi * r_non_zero), - ) # Fixes limit r -> 0 - f = r * exp_r_squared * f_sum / jnp.pi - g = exp_r_squared * g_sum / (2 * jnp.pi) - h = exp_r_squared * h_sum / (2 * jnp.pi) - - res = erfx(r) + e + f - exp_2j_r_i * (g + h) - return jnp.where(z.imag < 0, jnp.conj(res), res) - - -@jax.jit -@jaxtyped(typechecker=typechecker) -def erfc(z: Inexact[Array, " *batch"]) -> Inexact[Array, " *batch"]: - r""" - Evaluate the complementary error function at the given points. - - The output type (real or complex) is determined by the - input type. - - See :func:`erf` for more details. - - Args: - z: The array of real or complex points to evaluate. - - Returns: - The values of the complementary error function at the given point. - - Examples: - The following plots the complementary error function for real-valued inputs. - - .. plot:: - - >>> from differt.em import erfc - >>> - >>> x = jnp.linspace(-3.0, +3.0) - >>> y = erfc(x) - >>> plt.plot(x, y.real) # doctest: +SKIP - >>> plt.xlabel("$x$") # doctest: +SKIP - >>> plt.ylabel(r"$\text{erfc}(x)$") # doctest: +SKIP - """ - return 1.0 - erf(z) - - -@jax.jit -@jaxtyped(typechecker=typechecker) -def fresnel( - z: Inexact[Array, " *batch"], -) -> tuple[Inexact[Array, " *batch"], Inexact[Array, " *batch"]]: - """ - Evaluate the two Fresnel integrals at the given points. - - This current implementation is written using - the error function :func:`erf` - see :cite:`fresnel-integrals`. - - The output type (real or complex) is determined by the - input type. - - Args: - z: The array of real or complex points to evaluate. - - Returns: - A tuple of two arrays, one for each of the Fresnel integrals. - - Examples: - The following plots the Fresnel for real-valued inputs. - - .. plot:: - - >>> from differt.em import fresnel - >>> - >>> t = jnp.linspace(0.0, 5.0, 200) - >>> s, c = fresnel(t) - >>> plt.plot(t, s.real, label=r"$y=S(x)$") # doctest: +SKIP - >>> plt.plot(t, c.real, "--", label=r"$y=C(x)$") # doctest: +SKIP - >>> plt.xlabel("$x$") # doctest: +SKIP - >>> plt.ylabel("$y$") # doctest: +SKIP - >>> plt.legend() # doctest: +SKIP - """ - # Constant factors - sqrtpi_2_4 = 0.31332853432887503 # 0.25 * jnp.sqrt(0.5 * jnp.pi) - sqrt2 = 0.7071067811865476 # jnp.sqrt(0.5) - - # Erf function evaluations - ep = erf((1 + 1j) * sqrt2 * z) - em = erf((1 - 1j) * sqrt2 * z) - - s = sqrtpi_2_4 * (1 + 1j) * (ep - 1j * em) - c = sqrtpi_2_4 * (1 - 1j) * (ep + 1j * em) - - return s, c diff --git a/differt/src/differt/em/_utd.py b/differt/src/differt/em/_utd.py index f7d5f3cf..8be688d8 100644 --- a/differt/src/differt/em/_utd.py +++ b/differt/src/differt/em/_utd.py @@ -1,47 +1,199 @@ +# ruff: noqa: N802, N806 +from functools import partial +from typing import Any, Literal, overload + +import equinox as eqx import jax import jax.numpy as jnp +import jax.scipy.special as jsp from beartype import beartype as typechecker -from jaxtyping import Array, Complex, Inexact, jaxtyped +from jaxtyping import Array, Complex, Float, jaxtyped -from ._special import erfc +from differt.utils import dot -@jax.jit +@partial(jax.jit, inline=True) +@jaxtyped(typechecker=typechecker) +def _cot(x: Float[Array, " *batch"]) -> Float[Array, " *batch"]: + return 1 / jnp.tan(x) + + +@partial(jax.jit, inline=True) @jaxtyped(typechecker=typechecker) -def F(z: Inexact[Array, " *batch"]) -> Complex[Array, " *batch"]: # noqa: N802 +def _sign(x: Float[Array, " *batch"]) -> Float[Array, " *batch"]: + ones = jnp.ones_like(x) + return jnp.where(x >= 0, ones, -ones) + + +@partial(jax.jit, inline=True, static_argnames=("mode")) +@jaxtyped(typechecker=typechecker) +def _N( + beta: Float[Array, " *#batch"], n: Float[Array, " *#batch"], mode: Literal["+", "-"] +) -> Float[Array, " *batch"]: + if mode == "+": + return jnp.round((beta + jnp.pi) / (2 * n * jnp.pi)) + return jnp.round((beta + jnp.pi) / (2 * n * jnp.pi)) + + +@partial(jax.jit, inline=True, static_argnames=("mode")) +@jaxtyped(typechecker=typechecker) +def _a( + beta: Float[Array, " *#batch"], n: Float[Array, " *#batch"], mode: Literal["+", "-"] +) -> Float[Array, " *batch"]: + N = _N(beta, n, mode) + return 2.0 * jax.lax.integer_pow(jnp.cos(0.5 * (2 * n * jnp.pi * N - beta)), 2) + + +@overload +def L_i( + s_d: Float[Array, " *#batch"], + sin_2_beta_0: Float[Array, " *#batch"], + rho_1_i: None = None, + rho_2_i: None = None, + rho_e_i: None = None, + s_i: None = None, +) -> Float[Array, " *batch"]: ... + + +@overload +def L_i( + s_d: Float[Array, " *#batch"], + sin_2_beta_0: Float[Array, " *#batch"], + rho_1_i: None = None, + rho_2_i: None = None, + rho_e_i: None = None, + s_i: Float[Array, " *#batch"] | None = None, +) -> Float[Array, " *batch"]: ... + + +@overload +def L_i( + s_d: Float[Array, " *#batch"], + sin_2_beta_0: Float[Array, " *#batch"], + rho_1_i: Float[Array, " *#batch"], + rho_2_i: Float[Array, " *#batch"], + rho_e_i: Float[Array, " *#batch"], + s_i: None = None, +) -> Float[Array, " *batch"]: ... + + +@eqx.filter_jit +@jaxtyped(typechecker=typechecker) +def L_i( # noqa: PLR0917 + s_d: Float[Array, " *#batch"], + sin_2_beta_0: Float[Array, " *#batch"], + rho_1_i: Float[Array, " *#batch"] | None = None, + rho_2_i: Float[Array, " *#batch"] | None = None, + rho_e_i: Float[Array, " *#batch"] | None = None, + s_i: Float[Array, " *#batch"] | None = None, +) -> Float[Array, " *batch"]: r""" - Evaluate the transition function :cite:`utd-mcnamara{p. 184}` at the given points. + Compute the distance parameter associated with the incident shadow boundaries. + + .. note:: + + This function can also be used to compute the distance parameters + associated with the reflection shadow boundaries for the o- and n-faces, + by passing the corresponding radii of curvature, see + :cite:`utd-mcnamara{eq. 6.28, p. 270}`. - The transition function is defined as follows: + Its general expression is given by :cite:`utd-mcnamara{eq. 6.25, p. 270}`: .. math:: - F(z) = 2j \sqrt{z} e^{j z} \int\limits_\sqrt{z}^\infty e^{-j u^2} \text{d}u, + L_i = \frac{(\rho_e^i + s)\rho_1^i\rho_2^i}{\rho_e^i(\rho_1^i + s)(\rho_2^i + s)}\sin^2\beta_0, - where :math:`j^2 = -1`. + where :math:`s^d` is the distance from the point of diffraction (:math:`Q_d`) to the observer + point (:math:`P`), + :math:`\rho_1^i` is the principal radius of curvature of the incident wavewront at :math:`Q_d` + in the plane of incidence, + :math:`\rho_2^i` is the principal radius of curvature of the incident wavewront at :math:`Q_d` + in the place transverse to the plane of incidence, + :math:`\rho_e^i` is radius of curvature of the incident wavefront in the edge-fixed + plane of incidence., and :math:`\beta_0` is the angle of diffraction. - As detailed in :cite:`utd-mcnamara{p. 164}`, the integral can be expressed in - terms of Fresnel integrals (:math:`C(z)` and :math:`S(z)`), so that: + By default, when :math:`\rho_1^i`, :math:`\rho_e^i`, and :math:`\rho_2^i` are not provided, + a plane wave incidence is assumed and the expression simplifies to + :cite:`utd-mcnamara{eq. 6.27, p. 270}`: + + .. math:: + L_i = s^d\sin^2\beta_0. + + For spherical wavefront, you can pass :math:`s^i` (**s\_i**), the radius of curvature of + the spherical wavefront, where :math:`s^i = \rho_1^i = \rho_2^2 = \rho_e^i`, + and the expression will be simplified to + :cite:`utd-mcnamara{eq. 6.26, p. 270}`: + + .. math:: + L_i = \frac{s^ds^i}{s^d + s^i}\sin^2\beta_0. + + Args: + s_d: The distance from :math:`Q_d` to :math:`P`. + sin_2_beta_0: The squared sine of the angle of diffraction. + rho_1_i: The principal radius of curvature of the incident wavefront + in the plane of incidence. + rho_2_i: The principal radius of curvature of the incident wavefront + in the plane transverse to the plane of incidence. + rho_e_i: The radius of curvature of the incident wavefront in the edge-fixed + plane of incidence. + s_i: The radius of curvature of the incident spherical wavefront. + + If this is set, other radius parameters must be set to 'None'. + + Returns: + The values of the distance parameter :math:`L_i`. + + Raises: + ValueError: If 's_i' was provided along at least one of the other radius parameters, + or if one or the three 'rho' parameters was not provided. + """ + radii = (rho_1_i, rho_2_i, rho_e_i) + all_none = all(x is None for x in radii) + all_set = all(x is not None for x in radii) + if s_i is not None and any(x is not None for x in radii): + msg = "If 's_i' is provided, then 'rho_1_i', 'rho_2_i', and 'rho_e_i' must be left to 'None'." + raise ValueError(msg) + if (not all_none) and (not all_set): + msg = "All three of 'rho_1_i', 'rho_2_i', and 'rho_e_i' must be provided, or left to 'None'." + raise ValueError(msg) + + if s_i is not None: + return (s_d * s_i) * sin_2_beta_0 / (s_d + s_i) + if all_none: + return s_d * sin_2_beta_0 + return ( + (s_d * (rho_e_i + s_d) * rho_1_i * rho_2_i) + / (rho_e_i * (rho_1_i + s_d) * (rho_2_i + s_d)) + ) * sin_2_beta_0 + + +@jax.jit +@jaxtyped(typechecker=typechecker) +def F(z: Float[Array, " *batch"]) -> Complex[Array, " *batch"]: + r""" + Evaluate the transition function at the given points. + + The transition function is defined as follows :cite:`utd-mcnamara{eq. 4.72, p. 184}`: .. math:: - C(z) - j S(z) = \int\limits_\sqrt{z}^\infty e^{-j u^2} \text{d}u. + F(x) = 2j \sqrt{x} e^{j x} \int\limits_\sqrt{x}^\infty e^{-j u^2} \text{d}u, - Because JAX does not provide a XLA implementation of - :data:`scipy.special.fresnel`, this implementation relies on a - custom complex-valued implementation of the error function and - the fact that: + where :math:`j^2 = -1`. + + As detailed in :cite:`utd-mcnamara{p. 164}`, the integral can be expressed in + terms of Fresnel integrals (:math:`C(x)` and :math:`S(x)`), so that: .. math:: - C(z) - j S(z) = \sqrt{\frac{\pi}{2}}\frac{1-j}{2}\text{erf}\left(\frac{1+j}{\sqrt{2}}z\right). + C(x) - j S(x) = \int\limits_\sqrt{x}^\infty e^{-j u^2} \text{d}u. - As a result, we can further simplify :math:`F(z)` to: + Thus, the transition function can be rewritten as: .. math:: - F(z) = \sqrt{\frac{\pi}{2}} \sqrt{z} e^{j z} (1 - j) \text{erfc}\left(\frac{1+j}{\sqrt{2}}z\right), + 2j \sqrt{z} e^{j z} \Big(\sqrt{\frac{\pi}{2}}\frac{1 - j}{2} - C(\sqrt{z}) + j S(\sqrt{z})\Big). - where :math:`\text{erfc}` is the complementary error function. + With Fresnel integrals computed by :data:`jax.scipy.special.fresnel`. Args: - z: The array of real or complex points to evaluate. + z: The array of real points to evaluate. Returns: The values of the transition function at the given point. @@ -50,7 +202,7 @@ def F(z: Inexact[Array, " *batch"]) -> Complex[Array, " *batch"]: # noqa: N802 .. plot:: The following example reproduces the same plot as in - :cite:`utd-mcnamara{fig. 4.16}`. + :cite:`utd-mcnamara{fig. 4.16, p. 185}`. >>> from differt.em import F >>> @@ -73,10 +225,87 @@ def F(z: Inexact[Array, " *batch"]) -> Complex[Array, " *batch"]: # noqa: N802 factor = jnp.sqrt(jnp.pi / 2) sqrt_z = jnp.sqrt(z) - return ( - (1 + 1j) - * factor - * sqrt_z - * jnp.exp(1j * z) - * erfc((1 + 1j) * sqrt_z / jnp.sqrt(2)) + s, c = jsp.fresnel(sqrt_z / factor) + return 2j * sqrt_z * jnp.exp(1j * z) * (factor * ((1 - 1j) / 2 - c + 1j * s)) + + +@jax.jit +@jaxtyped(typechecker=typechecker) +def diffraction_coefficients( + *_args: Any, +) -> None: + """ + Compute the diffraction coefficients based on the Uniform Theory of Diffraction. + + Warning: + This function is not yet implemented, as we are still thinking of the + best API for it. If you want to get involved in the implementation of UTD coefficients, + please reach out to us on GitHub! + + The implementation closely follows what is described + in :cite:`utd-mcnamara{p. 268-273}`. + + Unlike :func:`fresnel_coefficients`, diffraction + coefficients depend on the radii of curvature of the incident wave. + + Args: + sin_beta_0: ... + sin_beta: ... + sin_phi: ... + rho_1_i: ... + rho_1_i: ... + rho_e_i: ... + + Returns: + The soft and hard diffraction coefficients. + + Raises: + NotImplementedError: The function is not yet implemented. + """ + # ruff: noqa: ERA001, F821, F841 + raise NotImplementedError + + # Ensure input vectors are normalized + incident_ray = incident_ray / jnp.linalg.norm(incident_ray) + diffracted_ray = diffracted_ray / jnp.linalg.norm(diffracted_ray) + edge_vector = edge_vector / jnp.linalg.norm(edge_vector) + + # Compute relevant angles + beta_0 = jnp.arccos(jnp.dot(incident_ray, edge_vector)) + beta = jnp.arccos(jnp.dot(diffracted_ray, edge_vector)) + phi = jnp.arccos(jnp.dot(-incident_ray, diffracted_ray)) + + # Compute L parameters (distance parameters) + L = r * jnp.sin(beta) ** 2 / (r + r_prime) + L_prime = r_prime * jnp.sin(beta_0) ** 2 / (r + r_prime) + + phi_i = jnp.pi - (jnp.pi - jnp.arccos(dot(-s_t_i, t_o))) * _sign(dot(-s_t_i, n_o)) + phi_d = jnp.pi - (jnp.pi - jnp.arccos(dot(+s_t_d, t_o))) * _sign(dot(+s_d_i, n_o)) + + # Compute the angle differences + phi_1 = phi_d - phi_i + phi_2 = phi_d + phi_i + + # Compute the diffraction coefficients (without common mul. factor) + D_1 = _cot((jnp.pi + phi_1) / (2 * n)) * F(k * L_i * _a(phi_1, "+")) + D_2 = _cot((jnp.pi - phi_1) / (2 * n)) * F(k * L_i * _a(phi_1, "-")) + D_3 = _cot((jnp.pi + phi_2) / (2 * n)) * F(k * L_r_n * _a(phi_2, "+")) + D_4 = _cot((jnp.pi - phi_2) / (2 * n)) * F(k * L_r_o * _a(phi_2, "-")) + + factor = -jnp.exp(-1j * jnp.pi / 4) / ( + 2 * n * jnp.sqrt(2 * jnp.pi * k) * sin_beta_0 ) + + # Apply the Keller cone condition + # D_s = jnp.where(jnp.abs(jnp.sin(beta) - jnp.sin(beta_0)) < 1e-6, D_s, 0) + # D_h = jnp.where(jnp.abs(jnp.sin(beta) - jnp.sin(beta_0)) < 1e-6, D_h, 0) + + # TODO: below are assuming perfectly conducting surfaces + + D_12 = D_1 + D_2 + D_34 = D_3 + D_4 + + D_s = (D_12 - D_34) * factor + D_h = (D_12 + D_34) * factor + + return D_s, D_h diff --git a/differt/src/differt/em/_utils.py b/differt/src/differt/em/_utils.py index b9c6a377..14b95b5c 100644 --- a/differt/src/differt/em/_utils.py +++ b/differt/src/differt/em/_utils.py @@ -1,12 +1,15 @@ from typing import Any import jax +import jax.numpy as jnp from beartype import beartype as typechecker -from jaxtyping import Array, ArrayLike, Float, jaxtyped +from jaxtyping import Array, ArrayLike, Float, Int, jaxtyped -from differt.geometry import path_lengths +from differt.geometry import normalize, path_lengths, perpendicular_vectors +from differt.utils import dot from ._constants import c +from ._interaction_type import InteractionType @jax.jit @@ -81,3 +84,264 @@ def path_delays( lengths = path_lengths(paths) return lengths_to_delays(lengths, **kwargs) + + +@jax.jit +@jaxtyped(typechecker=typechecker) +def sp_directions( + k_i: Float[Array, "*#batch 3"], + k_r: Float[Array, "*#batch 3"], + normals: Float[Array, "*#batch 3"], +) -> tuple[ + tuple[Float[Array, "*batch 3"], Float[Array, "*batch 3"]], + tuple[Float[Array, "*batch 3"], Float[Array, "*batch 3"]], +]: + """ + Compute the directions of the local s and p components, before and after reflection, relative to the propagation direction and a local normal. + + Args: + k_i: The array of propagation direction of incident fields. + + Each vector must have a unit length. + k_r: The array of propagation direction of reflected fields. + + Each vector must have a unit length. + normals: The array of local normals. + + Each vector must have a unit length. + + Returns: + The array of s and p directions, before and after reflection. + + Examples: + The following example shows how to compute and display the direction of + the s and p components before and after reflection on a spherical surface. + + .. plotly:: + + >>> import plotly.graph_objects as go + >>> from differt.geometry import normalize, spherical_to_cartesian + >>> from differt.em import ( + ... sp_directions, + ... ) + + We generate a grid of points on a spherical surface. + + >>> u, v = jnp.meshgrid( + ... jnp.linspace(3 * jnp.pi / 8, 5 * jnp.pi / 8, 100), + ... jnp.linspace(-jnp.pi / 8, jnp.pi / 8), + ... ) + >>> pa = jnp.stack((u, v), axis=-1) + >>> xyz = spherical_to_cartesian(pa) + >>> fig = go.Figure() + >>> fig.add_trace( + ... go.Surface( + ... x=xyz[..., 0], + ... y=xyz[..., 1], + ... z=xyz[..., 2], + ... colorscale=["blue", "blue"], + ... opacity=0.5, + ... showscale=False, + ... ) + ... ) # doctest: +SKIP + + Plotly does not provide a nice function to draw 3D vectors, so we create one. + + >>> def add_vector(fig, orig, dest, color="red", name=None, dashed=False): + ... dir = dest - orig + ... end = orig + 0.9 * dir + ... fig.add_trace( + ... go.Scatter3d( + ... x=[orig[0], end[0]], + ... y=[orig[1], end[1]], + ... z=[orig[2], end[2]], + ... mode="lines", + ... line_color=color, + ... showlegend=False, + ... line_dash="dashdot" if dashed else None, + ... legendgroup=name, + ... ) + ... ) + ... dir = 0.1 * dir + ... fig.add_trace( + ... go.Cone( + ... x=[end[0]], + ... y=[end[1]], + ... z=[end[2]], + ... u=[dir[0]], + ... v=[dir[1]], + ... w=[dir[2]], + ... colorscale=[color, color], + ... sizemode="raw", + ... showscale=False, + ... showlegend=True, + ... name=name, + ... hoverinfo="name", + ... opacity=0.5 if dashed else None, + ... legendgroup=name, + ... ) + ... ) + + We then place TX and RX points, and determine direction vectors, + as well as direction of local s and p components. + + >>> reflection_point = jnp.array([1.0, 0.0, 0.0]) + >>> angle = jnp.pi / 4 + >>> cos = jnp.cos(angle) + >>> sin = jnp.sin(angle) + >>> normal = jnp.array([1.0, 0.0, 0.0]) + >>> tx = reflection_point + jnp.array([cos, 0.0, +sin]) + >>> rx = reflection_point + jnp.array([cos, 0.0, -sin]) + >>> k_i = reflection_point - tx + >>> k_r = rx - reflection_point + >>> (e_i_s, e_i_p), (e_r_s, e_r_p) = sp_directions(k_i, k_r, normal) + + Finally, we draw all the vectors and markers. + + >>> fig.add_trace( + ... go.Scatter3d( + ... x=[tx[0], rx[0]], + ... y=[tx[1], rx[1]], + ... z=[tx[2], rx[2]], + ... mode="markers+text", + ... text=["TX", "RX"], + ... marker_color="black", + ... showlegend=False, + ... ) + ... ) # doctest: +SKIP + >>> add_vector(fig, tx, reflection_point, color="magenta", name="incident") + >>> add_vector(fig, reflection_point, rx, color="magenta", name="reflected") + >>> add_vector( + ... fig, + ... reflection_point, + ... reflection_point + normal, + ... color="blue", + ... name="normal", + ... ) + >>> add_vector( + ... fig, + ... reflection_point, + ... reflection_point + 0.5 * e_i_s, + ... color="orange", + ... name="s-component (incident)", + ... ) + >>> add_vector( + ... fig, + ... reflection_point, + ... reflection_point + 0.5 * e_i_p, + ... color="orange", + ... name="p-component (incident)", + ... ) + + We do the same, but for the reflected field. + + >>> add_vector( + ... fig, + ... reflection_point, + ... reflection_point + 0.5 * e_r_s, + ... color="green", + ... name="s-component (reflected)", + ... dashed=True, + ... ) + >>> add_vector( + ... fig, + ... reflection_point, + ... reflection_point + 0.5 * e_r_p, + ... color="green", + ... name="p-component (reflected)", + ... ) + >>> fig # doctest: +SKIP + """ + e_i_s, e_i_s_norm = normalize(jnp.cross(k_i, normals), keepdims=True) + # Alternative vectors if normal is parallel to k_i + normal_incidence = e_i_s_norm == 0.0 + e_i_s: Array = jnp.where( + normal_incidence, + perpendicular_vectors(k_i), + e_i_s, + ) # type: ignore[reportTypeAssignment] + e_i_p = normalize(jnp.cross(e_i_s, k_i))[0] + e_r_s = e_i_s + e_r_p = normalize(jnp.cross(e_r_s, k_r))[0] + + return (e_i_s, e_i_p), (e_r_s, e_r_p) + + +@jax.jit +@jaxtyped(typechecker=typechecker) +def sp_rotation_matrix( + e_a_s: Float[Array, "*#batch 3"], + e_a_p: Float[Array, "*#batch 3"], + e_b_s: Float[Array, "*#batch 3"], + e_b_p: Float[Array, "*#batch 3"], +) -> Float[Array, "*batch 2 2"]: + """ + Return the rotation matrix to convert the s and p components from one base to another. + + All input vectors must have a unit length, and the direction of propagation must be the same. + The latter is equivalent to ensuring that all four vectors are coplanar. + + Args: + e_a_s: The array of s component directions of the incident field. + e_a_p: The array of p component directions of the incident field. + e_b_s: The array of s component directions of the reflected field. + e_b_p: The array of p component directions of the reflected field. + + Returns: + The array of rotation matrices. + """ + r11 = dot(e_b_s, e_a_s, keepdims=True) + r12 = dot(e_b_s, e_a_p, keepdims=True) + r21 = dot(e_b_p, e_a_s, keepdims=True) + r22 = dot(e_b_p, e_a_p, keepdims=True) + + r11, r12, r21, r22 = jnp.broadcast_arrays(r11, r12, r21, r22) + + batch = r11.shape[:-1] + + return jnp.concatenate((r11, r12, r21, r22), axis=-1).reshape(*batch, 2, 2) + + +@jax.jit +@jaxtyped(typechecker=typechecker) +def transition_matrices( + vertices: Float[Array, "*batch path_length 3"], + objects: Float[Array, "*batch path_length"], + interaction_types: Int[Array, "*batch path_length-2"], + object_normals: Float[Array, "*batch path_length 3"], +) -> Float[Array, "*batch 2 2"]: + """ + Compute the transition matrix, ... + + Args: + k_i: The array of propagation direction of incident fields. + + Each vector must have a unit length. + k_r: The array of propagation direction of reflected fields. + + Each vector must have a unit length. + normals: The array of local normals. + + Each vector must have a unit length. + + Returns: + The array of s and p directions, before and after reflection. + """ + if any(x.dtype == jnp.float64 for x in (vertices, object_normals)): + cdtype = jnp.complex128 + else: + cdtype = jnp.complex64 + + # [*batch 2 2] + mat = jnp.zeros((vertices.shape[:-2], 2, 2), dtype=cdtype) + + v = jnp.diff(vertices, axis=-2) + k, s = normalize(v) + k_i, s_i = k[..., :-1, :], s[..., :-1, :] + k_r, s_r = k[..., +1:, :], s[..., +1:, :] + + mat_r = ... + + mat = jnp.where(interaction_types == InteractionType.REFLECTION, mat_r, mat) + + return mat diff --git a/differt/src/differt/geometry/__init__.py b/differt/src/differt/geometry/__init__.py index d8b6ed53..9811559e 100644 --- a/differt/src/differt/geometry/__init__.py +++ b/differt/src/differt/geometry/__init__.py @@ -12,6 +12,7 @@ "normalize", "orthogonal_basis", "path_lengths", + "perpendicular_vectors", "rotation_matrix_along_axis", "rotation_matrix_along_x_axis", "rotation_matrix_along_x_axis", @@ -35,6 +36,7 @@ normalize, orthogonal_basis, path_lengths, + perpendicular_vectors, rotation_matrix_along_axis, rotation_matrix_along_x_axis, rotation_matrix_along_y_axis, diff --git a/differt/src/differt/geometry/_paths.py b/differt/src/differt/geometry/_paths.py index c4ab9284..bcfdc085 100644 --- a/differt/src/differt/geometry/_paths.py +++ b/differt/src/differt/geometry/_paths.py @@ -1,6 +1,6 @@ import sys import warnings -from collections.abc import Callable, Iterator +from collections.abc import Callable, Iterator, Sequence from dataclasses import KW_ONLY from typing import Any @@ -103,6 +103,13 @@ class Paths(eqx.Module): batch ``*batch`` dimensions, which would not be possible if we were to directly store valid paths. """ + interaction_types: Int[Array, "*batch path_length-2"] | None = eqx.field( + converter=lambda x: jnp.asarray(x) if x is not None else None, default=None + ) + """An optional array to indicate the type of each interaction. + + If not specified, :attr:`InteractionType.REFLECTION` is assumed. + """ @jaxtyped( typechecker=None @@ -120,11 +127,62 @@ def reshape(self, *batch: int) -> Self: vertices = self.vertices.reshape(*batch, self.path_length, 3) objects = self.objects.reshape(*batch, self.path_length) mask = self.mask.reshape(*batch) if self.mask is not None else None + interaction_types = ( + self.interaction_types.reshape(*batch, self.path_length - 2) + if self.interaction_types is not None + else None + ) + + return eqx.tree_at( + lambda p: (p.vertices, p.objects, p.mask, p.interaction_types), + self, + (vertices, objects, mask, interaction_types), + is_leaf=lambda x: x is None, + ) + + @jaxtyped( + typechecker=None + ) # typing.Self is (currently) not compatible with jaxtyping and beartype + def squeeze(self, axis: int | Sequence[int] | None = None) -> Self: + """ + Return a copy by squeezing one or more axes of paths' batch dimensions. + + Args: + axis: See :func:`jax.numpy.squeeze` for allowed values. + + Returns: + A new paths instance with squeezed batch dimensions. + + Raises: + ValueError: If one of the provided axes is out-of-bounds, + or if trying to squeeze a 0-dimensional batch. + """ + ndim = self.vertices.ndim - 2 + if axis is not None and ndim == 0: + msg = "Cannot squeeze a 0-dimensional batch!" + raise ValueError(msg) + if isinstance(axis, int): + axis = (axis,) + if isinstance(axis, Sequence): + axis = tuple(a + ndim if a < 0 else a for a in axis) + + if any(ax >= ndim or ax < 0 for ax in axis): + msg = "One of the provided axes is out-of-bounds!" + raise ValueError(msg) + + mask = self.mask.squeeze(axis) if self.mask is not None else None + vertices = self.vertices.squeeze(axis) + objects = self.objects.squeeze(axis) + interaction_types = ( + self.interaction_types.squeeze(axis) + if self.interaction_types is not None + else None + ) return eqx.tree_at( - lambda p: (p.vertices, p.objects, p.mask), + lambda p: (p.vertices, p.objects, p.mask, p.interaction_types), self, - (vertices, objects, mask), + (vertices, objects, mask, interaction_types), is_leaf=lambda x: x is None, ) diff --git a/differt/src/differt/geometry/_triangle_mesh.py b/differt/src/differt/geometry/_triangle_mesh.py index 6138114e..e13d9819 100644 --- a/differt/src/differt/geometry/_triangle_mesh.py +++ b/differt/src/differt/geometry/_triangle_mesh.py @@ -119,12 +119,19 @@ class TriangleMesh(eqx.Module): If the present mesh contains multiple objects, usually as a result of appending multiple meshes together, this array contain start end end indices for each sub mesh. + + .. important:: + + The object indices must cover exactly all triangles in this mesh, + and be sorted in ascending order. Otherwise, some methods, like + the random object coloring with :meth:`set_face_colors`, may not + work as expected. """ assume_quads: bool = eqx.field(default=False) """Flag indicating whether triangles can be paired into quadrilaterals. Setting this to :data:`True` will not check anything, except that - :attr:`num_triangles` should is even, but each two consecutive + :attr:`num_triangles` is even, but each two consecutive triangles are assumed to represent a quadrilateral surface. """ @@ -359,13 +366,33 @@ def empty(cls) -> Self: """ return cls(vertices=jnp.empty((0, 3)), triangles=jnp.empty((0, 3), dtype=int)) - @jax.jit + @overload + def set_face_colors( + self, + colors: Float[Array, "#{self.num_triangles} 3"] | Float[Array, "3"], + *, + key: None = None, + ) -> Self: ... + + @overload + def set_face_colors( + self, + colors: None, + *, + key: PRNGKeyArray, + ) -> Self: ... + + @eqx.filter_jit @jaxtyped( typechecker=None ) # typing.Self is (currently) not compatible with jaxtyping and beartype def set_face_colors( self, - colors: Float[Array, "#{self.num_triangles} 3"] | Float[Array, "3"], + colors: Float[Array, "#{self.num_triangles} 3"] + | Float[Array, "3"] + | None = None, + *, + key: PRNGKeyArray | None = None, ) -> Self: """ Return a copy of this mesh, with new face colors. @@ -374,10 +401,135 @@ def set_face_colors( colors: The array of RGB colors. If one color is provided, it will be applied to all triangles. + This or ``key`` must be specified. + key: If provided, colors will be randomly generated. + + If :attr:`object_bounds` is not :data:`None`, then triangles + within the same object will share the same color. Otherwise, + a random color is generated for each triangle + (or quadrilateral if :attr:`assume_quads` is :data:`True`). + Returns: A new mesh with updated face colors. + + Raises: + ValueError: If ``colors`` or ``key`` is not specified. + + Examples: + The following example shows how this function paints the mesh, + for different argument types. + + First, we load a scene from Sionna :cite:`sionna`, that is + already colored, and extract the mesh from it. + + .. plotly:: + :context: + + >>> from differt.scene import ( + ... TriangleScene, + ... download_sionna_scenes, + ... get_sionna_scene, + ... ) + >>> + >>> download_sionna_scenes() # doctest: +SKIP + >>> file = get_sionna_scene("simple_street_canyon") + >>> mesh = TriangleScene.load_xml(file).mesh + >>> fig = mesh.plot(backend="plotly") + >>> fig # doctest: +SKIP + + Then, we could set the same color to all triangles. + + .. plotly:: + :context: + + >>> fig = mesh.set_face_colors(jnp.array([0.8, 0.2, 0.0])).plot( + ... backend="plotly" + ... ) + >>> fig # doctest: +SKIP + + We could also manually specify a different color for each triangle, but it can + become tedious as the number of triangles gets larger. Another option is to rely + on automatic random coloring, using the ``key`` argument. + + As our mesh is a collection of 7 distinct objects, as this was loaded from a Sionna + XML file, this utility will automatically detect it and color each + object differently. + + .. plotly:: + :context: + + >>> mesh.object_bounds + Array([[ 0, 12], + [12, 24], + [24, 36], + [36, 48], + [48, 60], + [60, 72], + [72, 74]], dtype=int32) + >>> fig = mesh.set_face_colors(key=jax.random.key(1234)).plot( + ... backend="plotly" + ... ) + >>> fig # doctest: +SKIP + + If you prefer to have per-triangle coloring, you can perform surgery on the mesh + to remove its :attr:`object_bounds` attribute. + + .. plotly:: + :context: + + >>> import equinox as eqx + >>> + >>> mesh = eqx.tree_at(lambda m: m.object_bounds, mesh, None) + >>> fig = mesh.set_face_colors(key=jax.random.key(1234)).plot( + ... backend="plotly" + ... ) + >>> fig # doctest: +SKIP + + Finally, you can also set :attr:`assume_quads` to :data:`True` to color quadrilaterals + instead. + + .. plotly:: + :context: + + >>> fig = ( + ... mesh.set_assume_quads() + ... .set_face_colors(key=jax.random.key(1234)) + ... .plot(backend="plotly") + ... ) + >>> fig # doctest: +SKIP """ - face_colors = jnp.broadcast_to(colors.reshape(-1, 3), self.triangles.shape) + if (colors is None) == (key is None): + msg = "You must specify one of 'colors' or `key`, not both." + raise ValueError(msg) + + if key is not None: + if self.object_bounds is not None: + object_colors = jax.random.uniform( + key, (self.object_bounds.shape[0], 3) + ) + repeats = jnp.diff(self.object_bounds, axis=-1) + colors = jnp.repeat( + object_colors, + repeats, + axis=0, + total_repeat_length=self.num_triangles, + ) + elif self.assume_quads: + quad_colors = jax.random.uniform(key, (self.num_quads, 3)) + repeats = jnp.full(self.num_objects, 2) + colors = jnp.repeat( + quad_colors, repeats, axis=0, total_repeat_length=self.num_triangles + ) + else: + colors = jax.random.uniform(key, (self.num_triangles, 3)) + + return self.set_face_colors(colors=colors) + + # TODO: understand why pyright cannot determine that colors is not None + face_colors = jnp.broadcast_to( + colors.reshape(-1, 3), # type: ignore[reportOptionalMemberAccess] + self.triangles.shape, + ) return eqx.tree_at( lambda m: m.face_colors, self, @@ -472,7 +624,6 @@ def plane( u, v = orthogonal_basis( normal, - normalize=True, ) s = 0.5 * side_length @@ -517,6 +668,35 @@ def box( Returns: A new box mesh. + + Examples: + The following example shows how to create a cube. + + .. plotly:: + + >>> from differt.geometry import TriangleMesh + >>> mesh = ( + ... TriangleMesh.box(with_top=True) + ... .set_assume_quads() + ... .set_face_colors(key=jax.random.key(1234)) + ... ) + >>> fig = mesh.plot(opacity=0.5, backend="plotly") + >>> fig # doctest: +SKIP + + The second example shows how to create a corridor-like + mesh, without the ceiling face. + + .. plotly:: + + >>> from differt.geometry import TriangleMesh + >>> mesh = ( + ... TriangleMesh.box(length=10.0, width=3.0, height=2.0) + ... .set_assume_quads() + ... .set_face_colors(key=jax.random.key(1234)) + ... ) + >>> fig = mesh.plot(opacity=0.5, backend="plotly") + >>> fig = fig.update_scenes(aspectmode="data") + >>> fig # doctest: +SKIP """ dx = jnp.array([length * 0.5, 0.0, 0.0]) dy = jnp.array([0.0, width * 0.5, 0.0]) diff --git a/differt/src/differt/geometry/_utils.py b/differt/src/differt/geometry/_utils.py index c6e8705d..362b8b1a 100644 --- a/differt/src/differt/geometry/_utils.py +++ b/differt/src/differt/geometry/_utils.py @@ -57,20 +57,19 @@ def normalize( def normalize( vector: Float[Array, "*batch 3"], keepdims: bool = False, -) -> ( - tuple[Float[Array, "*batch 3"], Float[Array, " *batch"]] - | tuple[Float[Array, "*batch 3"], Float[Array, " *batch 1"]] -): +) -> tuple[ + Float[Array, "*batch 3"], Float[Array, " *batch"] | Float[Array, " *batch 1"] +]: """ Normalize vectors and also return their length. This function avoids division by zero by checking vectors - with zero-length, and returning unit length instead. + with zero-length, dividing by one instead. Args: vector: An array of vectors. keepdims: If set to :data:`True`, the array of lengths - will have the same number of dimensions are the input. + will have the same number of dimensions as the input. Returns: The normalized vector and their length. @@ -89,19 +88,55 @@ def normalize( Array(1.7320508, dtype=float32)) >>> zero = jnp.array([0.0, 0.0, 0.0]) >>> normalize(zero) # Special behavior at 0. - (Array([0., 0., 0.], dtype=float32), Array(1., dtype=float32)) + (Array([0., 0., 0.], dtype=float32), Array(0., dtype=float32)) """ - length: Array = jnp.linalg.norm(vector, axis=-1, keepdims=True) - length = jnp.where(length == 0.0, jnp.ones_like(length), length) + length = jnp.linalg.norm(vector, axis=-1, keepdims=True) + + return vector / jnp.where(length == 0.0, jnp.ones_like(length), length), ( + length if keepdims else jnp.squeeze(length, axis=-1) + ) + + +@partial(jax.jit, inline=True) +@jaxtyped(typechecker=typechecker) +def perpendicular_vectors(u: Float[Array, "*batch 3"]) -> Float[Array, "*batch 3"]: + """ + Generate a vector perpendicular to the input vectors. + + Args: + u: The array of input vectors. + + Returns: + An array of vectors perpendicular to the input vectors. - return vector / length, (length if keepdims else jnp.squeeze(length, axis=-1)) + Examples: + The following example shows how this function works on basic input vectors. + + >>> from differt.geometry import ( + ... perpendicular_vectors, + ... ) + >>> + >>> u = jnp.array([1.0, 0.0, 0.0]) + >>> perpendicular_vectors(u) + Array([ 0., -0., 1.], dtype=float32) + >>> u = jnp.array([1.0, 1.0, 1.0]) + >>> perpendicular_vectors(u) + Array([ 0.8164966, -0.4082483, -0.4082483], dtype=float32) + """ + z = jnp.zeros_like(u[..., 0]) + v = jnp.where( + (jnp.abs(u[..., 0]) > jnp.abs(u[..., 1]))[..., None], + jnp.stack((-u[..., 1], u[..., 0], z), axis=-1), + jnp.stack((z, -u[..., 2], u[..., 1]), axis=-1), + ) + w = jnp.cross(u, v) + return w / jnp.linalg.norm(w, axis=-1, keepdims=True) -@partial(jax.jit, static_argnames=("normalize",), inline=True) +@partial(jax.jit, inline=True) @jaxtyped(typechecker=typechecker) def orthogonal_basis( u: Float[Array, "*batch 3"], - normalize: bool = True, ) -> tuple[Float[Array, "*batch 3"], Float[Array, "*batch 3"]]: """ Generate ``v`` and ``w``, two other arrays of unit vectors that form with input ``u`` an orthogonal basis. @@ -109,12 +144,6 @@ def orthogonal_basis( Args: u: The first direction of the orthogonal basis. It must have a unit length. - normalize: Whether the output vectors should be normalized. - - This may be needed, especially for vector ``v``, - as floating-point error can accumulate so much - that the vector lengths may diverge from the unit - length by 10% or even more! Returns: A pair of unit vectors, ``v`` and ``w``. @@ -129,19 +158,15 @@ def orthogonal_basis( >>> >>> u = jnp.array([1.0, 0.0, 0.0]) >>> orthogonal_basis(u) - (Array([ 0., -1., 0.], dtype=float32), Array([ 0., 0., -1.], dtype=float32)) + (Array([-0., 1., 0.], dtype=float32), Array([ 0., -0., 1.], dtype=float32)) >>> u, _ = normalize(jnp.array([1.0, 1.0, 1.0])) >>> orthogonal_basis(u) - (Array([ 0.4082483, -0.8164966, 0.4082483], dtype=float32), - Array([ 0.7071068, 0. , -0.7071068], dtype=float32)) + (Array([-0. , -0.7071068, 0.7071068], dtype=float32), + Array([ 0.8164966, -0.4082483, -0.4082483], dtype=float32)) """ - vp = jnp.stack((u[..., 2], -u[..., 0], u[..., 1]), axis=-1) - w = jnp.cross(u, vp, axis=-1) - v = jnp.cross(w, u, axis=-1) - - if normalize: - v = v / jnp.linalg.norm(v, axis=-1, keepdims=True) - w = w / jnp.linalg.norm(w, axis=-1, keepdims=True) + w = perpendicular_vectors(u) + v = jnp.cross(w, u) + v = v / jnp.linalg.norm(v, axis=-1, keepdims=True) return v, w @@ -388,6 +413,9 @@ def fibonacci_lattice( Args: n: The size of the lattice. dtype: The float dtype of the vertices. + grid: Whether to return a grid of shape ``{n} {n} 3`` + instead. This is mainly useful if you need to plot + a surface that is generated from a lattice. Unused if ``frustum`` is passed. frustum: The spatial region where to sample points. diff --git a/differt/src/differt/plotting/_core.py b/differt/src/differt/plotting/_core.py index 5622a718..dde4dcf7 100644 --- a/differt/src/differt/plotting/_core.py +++ b/differt/src/differt/plotting/_core.py @@ -99,9 +99,11 @@ def _( canvas, view = process_vispy_kwargs(kwargs) + kwargs.setdefault("shading", "flat") + vertices = np.asarray(vertices) triangles = np.asarray(triangles) - view.add(Mesh(vertices=vertices, faces=triangles, shading="flat", **kwargs)) + view.add(Mesh(vertices=vertices, faces=triangles, **kwargs)) view.camera.set_range() return canvas @@ -607,7 +609,7 @@ def draw_image( Matplotlib backend requires ``data`` to be either RGB or RGBA array. Examples: - The following example shows how plot a 2-D image, + The following example shows how to plot a 2-D image, without and with axis scaling. .. plotly:: @@ -766,7 +768,7 @@ def draw_contour( # noqa: PLR0917 a 2D figure instead. Examples: - The following example shows how plot a 2-D contour, + The following example shows how to plot a 2-D contour, without and with axis scaling, and filling. .. plotly:: @@ -963,7 +965,7 @@ def draw_surface( VisPy currently does not support colors. Examples: - The following example shows how plot a 3-D surface, + The following example shows how to plot a 3-D surface, without and with custom coloring. .. plotly:: diff --git a/differt/src/differt/rt/_image_method.py b/differt/src/differt/rt/_image_method.py index 31d0e253..d3430c1b 100644 --- a/differt/src/differt/rt/_image_method.py +++ b/differt/src/differt/rt/_image_method.py @@ -199,7 +199,7 @@ def image_method( intersections with mirrors are computed backward, `i.e.`, from last mirror to first, by joining the UE, then the intersections points, with the images of the BS. Finally, the valid path can be obtained by joining BS, the intermediary - intersection points, and the UE :cite:`mpt-eucap2023{fig. 5}`. + intersection points, and the UE :cite:`mpt-eucap2023{fig. 5, p. 3}`. Next, we show how to reproduce the above results using :func:`image_method`. diff --git a/differt/src/differt/rt/_utils.py b/differt/src/differt/rt/_utils.py index 9878d0a0..c3e2cc2d 100644 --- a/differt/src/differt/rt/_utils.py +++ b/differt/src/differt/rt/_utils.py @@ -305,7 +305,7 @@ def rays_intersect_any_triangle( checking if at least one of the triangles is intersect. A triangle is considered to be intersected if - ``t < hit_threshold & hit`` evaluates to :data:`True`. + ``t < (1 - hit_tol) & hit`` evaluates to :data:`True`. Args: ray_origins: An array of origin vertices. @@ -330,7 +330,7 @@ def rays_intersect_any_triangle( dtype = jnp.result_type(ray_origins, ray_directions, triangle_vertices) hit_tol = 10.0 * jnp.finfo(dtype).eps - hit_threshold = 1.0 - hit_tol + hit_threshold = 1.0 - jnp.asarray(hit_tol) # Put 'num_triangles' axis as leading axis triangle_vertices = jnp.moveaxis(triangle_vertices, -3, 0) diff --git a/differt/src/differt/utils.py b/differt/src/differt/utils.py index cccd8387..842f13ee 100644 --- a/differt/src/differt/utils.py +++ b/differt/src/differt/utils.py @@ -341,3 +341,32 @@ def sample_points_in_bounding_box( r = jax.random.uniform(key, shape=(*shape, 3)) return r * scale + amin + + +@jax.jit +@jaxtyped(typechecker=typechecker) +def safe_divide( + num: Num[Array, " *#batch"], den: Num[Array, " *#batch"] +) -> Num[Array, " *batch"]: + """ + Compute the elementwise division, but returns 0 when ``den`` is zero. + + Args: + num: The numerator. + den: The denominator. + + Returns: + The result of ``num / dev``, except that division by zero returns 0. + + Examples: + The following examples shows how division by zero is handled. + + >>> from differt.utils import safe_divide + >>> + >>> x = jnp.array([1, 2, 3, 4, 5]) + >>> y = jnp.array([0, 1, 2, 0, 2]) + >>> safe_divide(x, y) + Array([0. , 2. , 1.5, 0. , 2.5], dtype=float32) + """ + # TODO: add :python: rst role for x / y in docs + return jnp.where(den == 0, 0, num / den) diff --git a/differt/tests/benchmarks/test_rt.py b/differt/tests/benchmarks/test_rt.py index 582491f8..230edd35 100644 --- a/differt/tests/benchmarks/test_rt.py +++ b/differt/tests/benchmarks/test_rt.py @@ -75,7 +75,6 @@ def test_compute_paths_in_simple_street_canyon_scene( scene = simple_street_canyon_scene.set_assume_quads(assume_quads) if chunk_size: - @jax.debug_nans(False) # noqa: FBT003 def bench_fun() -> None: for path in scene.compute_paths( order, @@ -85,7 +84,6 @@ def bench_fun() -> None: else: - @jax.debug_nans(False) # noqa: FBT003 def bench_fun() -> None: scene.compute_paths( order, @@ -104,7 +102,6 @@ def test_compile_compute_paths( path_candidates = generate_all_path_candidates(scene.mesh.num_triangles, 2) @jax.jit - @jax.debug_nans(False) # noqa: FBT003 def fun(path_candidates: Array) -> Paths: return scene.compute_paths(path_candidates=path_candidates) diff --git a/differt/tests/em/test_antenna.py b/differt/tests/em/test_antenna.py new file mode 100644 index 00000000..c740fa65 --- /dev/null +++ b/differt/tests/em/test_antenna.py @@ -0,0 +1,133 @@ +from contextlib import AbstractContextManager +from contextlib import nullcontext as does_not_raise + +import chex +import jax.numpy as jnp +import pytest + +from differt.em import c, mu_0 +from differt.em._antenna import Antenna, Dipole + + +@pytest.fixture +def antenna() -> Dipole: + return Dipole( + frequency=1e9, + ) + + +class TestAntenna: + def test_frequency(self, antenna: Antenna) -> None: + chex.assert_trees_all_equal(antenna.frequency, 1e9) + + def test_center(self, antenna: Antenna) -> None: + chex.assert_trees_all_equal(antenna.center, jnp.zeros(3)) + + def test_period(self, antenna: Antenna) -> None: + chex.assert_trees_all_close(antenna.period, 1 / 1e9) + + def test_angular_frequency(self, antenna: Antenna) -> None: + chex.assert_trees_all_close(antenna.angular_frequency, 2 * jnp.pi * 1e9) + + def test_wavelength(self, antenna: Antenna) -> None: + chex.assert_trees_all_close(antenna.wavelength, c / 1e9) + + def test_wavenumber(self, antenna: Antenna) -> None: + chex.assert_trees_all_close(antenna.wavenumber, 2 * jnp.pi * 1e9 / c) + + def test_abstract(self) -> None: + with pytest.raises( + TypeError, + match="Can't instantiate abstract class Antenna", + ): + _ = Antenna(frequency=1e9) # type: ignore[reportAbstractUsage] + + @pytest.mark.parametrize("num_wavelengths", [None, 10.0]) + @pytest.mark.parametrize( + ("backend", "expectation"), + [ + ( + "vispy", + pytest.warns( + UserWarning, + match="VisPy does not currently support coloring like we would like", + ), + ), + ( + "matplotlib", + pytest.warns( + UserWarning, + match="Matplotlib requires 'colors' to be RGB or RGBA values", + ), + ), + ("plotly", does_not_raise()), + ], + ) + def test_plot_radiation_pattern( + self, + num_wavelengths: float | None, + backend: str, + expectation: AbstractContextManager[Exception], + antenna: Antenna, + ) -> None: + with expectation: + _ = antenna.plot_radiation_pattern( + num_wavelengths=num_wavelengths, backend=backend + ) + + +class TestDipole: + def test_init(self) -> None: + dipole = Dipole( + 1e9, + current=2.0, + length=4.0, + ) + chex.assert_trees_all_close( + jnp.linalg.norm(dipole.moment), (2.0 * 4.0 / dipole.angular_frequency) + ) + dipole = Dipole( + 1e9, + current=None, + ) + chex.assert_trees_all_close(jnp.linalg.norm(dipole.moment), 1.0) + dipole = Dipole(1e9, charge=3.0, length=2.0) + chex.assert_trees_all_close( + jnp.linalg.norm(dipole.moment), + 3.0 * 2.0, + ) + + def test_average_power(self) -> None: + f = 1e9 + w = 2 * jnp.pi * f + p_0 = 1.0 + dipole = Dipole( + frequency=f, + ) + p_0 = jnp.linalg.norm(dipole.moment) + chex.assert_trees_all_close( + dipole.average_power, mu_0 * w**4 * p_0**2 / (12 * jnp.pi * c) + ) + + @pytest.mark.parametrize( + ("ratio", "expected_gain"), + [(0.5, 1.5), (1.0, 1.5), (1.25, 1.5), (1.5, 1.5), (2.0, 1.5)], + ) + def test_directivity(self, ratio: float, expected_gain: float) -> None: + f = 1e9 + dipole = Dipole( + frequency=f, + num_wavelengths=ratio, + ) + directive_gain = dipole.directive_gain(1000) + chex.assert_trees_all_close(directive_gain, expected_gain) + + +class TestShortDipole: + @pytest.mark.skip + @pytest.mark.parametrize( + ("ratio", "expected_gain_dbi"), + [(0.5, 2.15), (1.0, 4.0), (1.25, 5.2), (1.5, 3.5), (2.0, 4.3)], + ) + def test_directivity(self, ratio: float, expected_gain_dbi: float) -> None: + pass diff --git a/differt/tests/em/test_constants.py b/differt/tests/em/test_constants.py index f1e6fa0e..c505e9d3 100644 --- a/differt/tests/em/test_constants.py +++ b/differt/tests/em/test_constants.py @@ -4,8 +4,21 @@ from differt.em import _constants -@pytest.mark.parametrize("constant_name", ["c", "epsilon_0", "mu_0"]) -def test_constants(constant_name: str) -> None: +@pytest.mark.parametrize( + ("constant_name", "value"), + [ + ("c", None), + ("epsilon_0", None), + ("mu_0", None), + ( + "z_0", + scipy.constants.physical_constants["characteristic impedance of vacuum"][0], + ), + ], +) +def test_constants(constant_name: str, value: float | None) -> None: got = getattr(_constants, constant_name) - expected = getattr(scipy.constants, constant_name) - assert got == expected + if value: + assert abs(got - value) < 1e-6 + else: + assert got == getattr(scipy.constants, constant_name) diff --git a/differt/tests/em/test_fresnel.py b/differt/tests/em/test_fresnel.py new file mode 100644 index 00000000..0f5c6a37 --- /dev/null +++ b/differt/tests/em/test_fresnel.py @@ -0,0 +1,93 @@ +import chex +import jax +import jax.experimental +import jax.numpy as jnp +import pytest +from jaxtyping import PRNGKeyArray + +from differt.em import materials +from differt.em._fresnel import ( + fresnel_coefficients, + reflection_coefficients, + refraction_coefficients, + refractive_indices, +) + + +@pytest.mark.parametrize( + ("mat_name", "expected"), + [ + ("Vacuum", 1.0), + ("Glass", 2.511971), + ], +) +@jax.experimental.enable_x64() +def test_refractive_indices(mat_name: str, expected: float) -> None: + frequency = 1e9 # Hz + mat = materials[mat_name] + eta = mat.relative_permittivity(frequency) + got = refractive_indices(eta) + chex.assert_trees_all_close(got, expected) + + +def test_fresnel_coefficients(key: PRNGKeyArray) -> None: + key_n_1, key_n_2 = jax.random.split(key, 2) + + n_1 = jax.random.uniform(key_n_1, (100,), minval=0.01, maxval=2.0) + n_2 = jax.random.uniform(key_n_2, (100,), minval=0.01, maxval=2.0) + + n_r = n_2 / n_1 + theta_i = jnp.linspace(0, jnp.pi / 2) + cos_theta_i = jnp.cos(theta_i) + n_r = n_r[..., None] + cos_theta_i = cos_theta_i[None, ...] + + (r_s, r_p), (t_s, t_p) = fresnel_coefficients(n_r, cos_theta_i) + + theta_c = jnp.arcsin(jnp.minimum(n_r, 1.0)) + + for array in (r_s, r_p, t_s, t_p): + chex.assert_tree_all_finite(jnp.where(theta_i <= theta_c, array, 0.0)) + + chex.assert_trees_all_equal((r_s, r_p), reflection_coefficients(n_r, cos_theta_i)) + + chex.assert_trees_all_equal((t_s, t_p), refraction_coefficients(n_r, cos_theta_i)) + + chex.assert_trees_all_close(t_s, r_s + 1, atol=1e-6) + chex.assert_trees_all_close(n_r * t_p, r_p + 1, atol=1e-6) + + +def test_reflection_coefficients() -> None: + n_r = jnp.array(1.5) # Glass + + # 1. Normal incidence + cos_theta_i = jnp.array(1.0) + + got_r_s, got_r_p = reflection_coefficients(n_r, cos_theta_i) + + chex.assert_trees_all_equal(got_r_s, -got_r_p) + + # 2. 45-degree incidence + cos_theta_i = jnp.cos(jnp.pi / 2) + + got_r_s, got_r_p = reflection_coefficients(n_r, cos_theta_i) + chex.assert_trees_all_close(got_r_s**2, -got_r_p) + + # 3. Brewster's angle + theta_b = jnp.arctan(n_r) + cos_theta_i = jnp.cos(theta_b) + + _, got_r_p = reflection_coefficients(n_r, cos_theta_i) + + chex.assert_trees_all_equal(got_r_p, 0 + 0j) + + # 4. Total reflection + n_r = 1 / jnp.array(1.5) + theta_i = jnp.arcsin(n_r / 1.0) + + cos_theta_i = jnp.cos(theta_i) + + got_r_s, got_r_p = reflection_coefficients(n_r, cos_theta_i) + + chex.assert_trees_all_equal(got_r_s, 1 + 0j) + chex.assert_trees_all_equal(got_r_p, 1 + 0j) diff --git a/differt/tests/em/test_interaction_type.py b/differt/tests/em/test_interaction_type.py new file mode 100644 index 00000000..fbd95574 --- /dev/null +++ b/differt/tests/em/test_interaction_type.py @@ -0,0 +1,45 @@ +import chex +import jax +import jax.experimental +import jax.numpy as jnp +import pytest +from jaxtyping import DTypeLike + +from differt.em._interaction_type import InteractionType + + +class TestInteractionType: + @pytest.mark.parametrize("dtype", [jnp.int32, jnp.int64]) + def test_array(self, dtype: DTypeLike) -> None: + with jax.experimental.enable_x64(dtype == jnp.int64): + arr = jnp.array(list(InteractionType), dtype=dtype) + assert arr.dtype == dtype + + for i_type in InteractionType: + assert jnp.where(arr == i_type, 1, 0).sum() == 1 + + arr = jnp.array([0, 1, 2, *list(InteractionType)], dtype=dtype) + assert arr.dtype == dtype + + def test_values(self) -> None: + # This is important to avoid breaking changes + assert InteractionType.REFLECTION == 0 + assert InteractionType.DIFFRACTION == 1 + assert InteractionType.SCATTERING == 2 + + def test_where(self) -> None: + interaction_types = jnp.array([0, 1, 2, 0, 1, 2]) + x = jnp.array([1, 2, 3, 1, 2, 3]) + + chex.assert_trees_all_equal( + jnp.where(interaction_types == InteractionType.REFLECTION, x, 0), + jnp.array([1, 0, 0, 1, 0, 0]), + ) + chex.assert_trees_all_equal( + jnp.where(interaction_types == InteractionType.DIFFRACTION, x, 0), + jnp.array([0, 2, 0, 0, 2, 0]), + ) + chex.assert_trees_all_equal( + jnp.where(interaction_types == InteractionType.SCATTERING, x, 0), + jnp.array([0, 0, 3, 0, 0, 3]), + ) diff --git a/differt/tests/em/test_material.py b/differt/tests/em/test_material.py new file mode 100644 index 00000000..740b226e --- /dev/null +++ b/differt/tests/em/test_material.py @@ -0,0 +1,171 @@ +# ruff: noqa: FURB152 +from typing import ClassVar + +import chex +import jax +import jax.experimental +import jax.numpy as jnp +import pytest +from jaxtyping import PRNGKeyArray + +from differt.em._material import Material, materials + + +class TestITU: + materials: ClassVar[dict[str, Material]] = { + name: material for name, material in materials.items() if name.startswith("itu") + } + + def test_constructor(self) -> None: + with pytest.raises( + ValueError, + match="Only one frequency range can be used if 'None' is passed, as it will match any frequency", + ): + _ = Material.from_itu_properties( + "test", (0, 0, 0, 0, None), (0, 0, 0, 0, None) + ) + + def test_num_materials(self) -> None: + assert len(self.materials) == 15 + + def test_vacuum(self, key: PRNGKeyArray) -> None: + mat = self.materials["itu_vacuum"] + + rel_perm, cond = mat.properties(1e9) + + chex.assert_trees_all_equal_shapes_and_dtypes(jnp.array(1e9), rel_perm, cond) + chex.assert_trees_all_close(rel_perm, 1.0) + chex.assert_trees_all_close(cond, 0.0) + + f = jax.random.randint(key, (10000, 30), 0, 100e9).astype(float) + + rel_perm, cond = mat.relative_permittivity(f), mat.conductivity(f) + + chex.assert_trees_all_equal_shapes_and_dtypes(f, rel_perm, cond) + chex.assert_trees_all_close(rel_perm, 1.0) + chex.assert_trees_all_close(cond, 0.0) + + def test_concrete(self) -> None: + mat = self.materials["itu_concrete"] + + f = jnp.array([0.1e9, 1e9, 10e9, 100e9, 1000e9]) + + got_rel_perm, got_cond = mat.relative_permittivity(f), mat.conductivity(f) + + expected_rel_perm = jnp.array([-1.0, 5.24, 5.24, 5.24, -1.0]) + expected_cond = jnp.array([-1.0, 0.0462, 0.279796, 1.694501, -1.0]) + chex.assert_trees_all_close(got_rel_perm, expected_rel_perm) + chex.assert_trees_all_close(got_cond, expected_cond) + + def test_concrete_scalar(self) -> None: + mat = self.materials["itu_concrete"] + + for f, expected_rel_perm, expected_cond in zip( + [0.1e9, 1e9, 10e9, 100e9, 1000e9], + [-1.0, 5.24, 5.24, 5.24, -1.0], + [-1.0, 0.0462, 0.279796, 1.694501, -1.0], + strict=False, + ): + got_rel_perm, got_cond = mat.properties(f) + chex.assert_trees_all_close(got_rel_perm, expected_rel_perm) + chex.assert_trees_all_close(got_cond, expected_cond) + + def test_glass(self) -> None: + mat = self.materials["itu_glass"] + + f = jnp.array([0.01e9, 0.1e9, 10e9, 100e9, 150e9, 220e9, 350e9, 450e9, 500e9]) + + got_rel_perm, got_cond = mat.relative_permittivity(f), mat.conductivity(f) + + expected_rel_perm = jnp.array([ + -1.0, + 6.31, + 6.31, + 6.31, + -1.0, + 5.79, + 5.79, + 5.79, + -1.0, + ]) + expected_cond = jnp.array([ + -1.0, + 1.647792e-04, + 7.865069e-02, + 1.718314, + -1.0, + 3.060531, + 6.608833, + 1.002504e01, + -1.0, + ]) + chex.assert_trees_all_close(got_rel_perm, expected_rel_perm) + chex.assert_trees_all_close(got_cond, expected_cond) + + def test_ceiling_board(self) -> None: + mat = self.materials["itu_ceiling_board"] + + f = jnp.array([0.1e9, 1e9, 10e9, 100e9, 150e9, 220e9, 350e9, 450e9, 500e9]) + + got_rel_perm, got_cond = mat.relative_permittivity(f), mat.conductivity(f) + + expected_rel_perm = jnp.array([ + -1.0, + 1.48, + 1.48, + 1.48, + -1.0, + 1.52, + 1.52, + 1.52, + -1.0, + ]) + expected_cond = jnp.array([ + -1.0, + 1.100000e-03, + 1.307353e-02, + 1.553792e-01, + -1.0, + 7.460210e-01, + 1.202940, + 1.557951, + -1.0, + ]) + chex.assert_trees_all_close(got_rel_perm, expected_rel_perm) + chex.assert_trees_all_close(got_cond, expected_cond) + + def test_plywood(self) -> None: + mat = self.materials["itu_plywood"] + + f = jnp.array([0.1e9, 1e9, 10e9, 40e9, 100e9]) + + got_rel_perm, got_cond = mat.relative_permittivity(f), mat.conductivity(f) + + expected_rel_perm = jnp.array([-1.0, 2.71, 2.71, 2.71, -1.0]) + expected_cond = jnp.array([-1.0, 0.33, 0.33, 0.33, -1.0]) + chex.assert_trees_all_close(got_rel_perm, expected_rel_perm) + chex.assert_trees_all_close(got_cond, expected_cond) + + def test_metal(self) -> None: + mat = self.materials["itu_metal"] + + f = jnp.array([0.1e9, 1e9, 10e9, 100e9, 1000e9]) + + got_rel_perm, got_cond = mat.relative_permittivity(f), mat.conductivity(f) + + expected_rel_perm = jnp.array([-1.0, 1.0, 1.0, 1.0, -1.0]) + expected_cond = jnp.array([-1.0, 1e7, 1e7, 1e7, -1.0]) + chex.assert_trees_all_close(got_rel_perm, expected_rel_perm) + chex.assert_trees_all_close(got_cond, expected_cond) + + def test_wet_ground(self) -> None: + mat = self.materials["itu_wet_ground"] + + f = jnp.array([0.1e9, 1e9, 10e9, 100e9]) + + got_rel_perm, got_cond = mat.relative_permittivity(f), mat.conductivity(f) + + expected_rel_perm = jnp.array([-1.0, 30.0, 11.943215, -1.0]) + expected_cond = jnp.array([-1.0, 0.15, 2.992893, -1.0]) + chex.assert_trees_all_close(got_rel_perm, expected_rel_perm) + chex.assert_trees_all_close(got_cond, expected_cond) diff --git a/differt/tests/em/test_special.py b/differt/tests/em/test_special.py deleted file mode 100644 index b73782af..00000000 --- a/differt/tests/em/test_special.py +++ /dev/null @@ -1,70 +0,0 @@ -from collections.abc import Iterator -from contextlib import contextmanager - -import chex -import jax -import jax.numpy as jnp -import numpy as np -import pytest -import scipy.special as sp -from chex import Array - -from differt.em._special import erf, erfc, fresnel - - -@contextmanager -def enable_double_precision(enable: bool) -> Iterator[None]: - enabled = jax.config.jax_enable_x64 # type: ignore[attr-defined] - try: - jax.config.update("jax_enable_x64", enable) - yield - finally: - jax.config.update("jax_enable_x64", enabled) - - -@pytest.mark.parametrize( - "double_precision", - [False, True], -) -def test_erf(double_precision: bool) -> None: - with enable_double_precision(double_precision): - t = jnp.linspace(-6.0, 6.0, 101) - a, b = jnp.meshgrid(t, t) - z = a + 1j * b - z = z.astype(dtype=jnp.complex128 if double_precision else jnp.complex64) - got = erf(z) - expected = jnp.asarray(sp.erf(np.asarray(z))) - chex.assert_trees_all_close( - got, - expected, - rtol=1e-12 if double_precision else 1e-4, - ) - - -@pytest.mark.parametrize( - "z", - [ - jnp.linspace(-5.0, 5.0, 101), - 1j * jnp.linspace(-5.0, 5.0, 101), - ], -) -def test_erfc(z: Array) -> None: - got = erfc(z) - expected = jnp.asarray(erfc(np.asarray(z))) - chex.assert_trees_all_close(got, expected) - - -@pytest.mark.parametrize( - "z", - [ - jnp.linspace(-5.0, 5.0, 101), - 1j * jnp.linspace(-5.0, 5.0, 101), - ], -) -def test_fresnel(z: Array) -> None: - got_s, got_c = fresnel(z) - expected = fresnel(np.asarray(z)) - expected_s = jnp.asarray(expected[0]) - expected_c = jnp.asarray(expected[1]) - chex.assert_trees_all_close(got_s, expected_s) - chex.assert_trees_all_close(got_c, expected_c) diff --git a/differt/tests/em/test_utd.py b/differt/tests/em/test_utd.py index 7cb0b3d5..d9d57b82 100644 --- a/differt/tests/em/test_utd.py +++ b/differt/tests/em/test_utd.py @@ -1,10 +1,71 @@ # pyright: reportMissingTypeArgument=false import chex +import jax import jax.numpy as jnp import numpy as np +import pytest import scipy.special as sp +from jaxtyping import PRNGKeyArray -from differt.em._utd import F +from differt.em._utd import F, L_i, diffraction_coefficients + + +def test_L_i(key: PRNGKeyArray) -> None: # noqa: N802 + key_s_d, key_sin, key_1_i, key_2_i, key_e_i, key_s_i = jax.random.split(key, 6) + + s_d = jax.random.uniform(key_s_d, (100,), minval=10.0, maxval=100.0) + sin_2_beta_0 = jax.random.uniform(key_sin, (100,), minval=0.0, maxval=1.0) + rho_1_i = jax.random.uniform(key_1_i, (100,), minval=10.0, maxval=100.0) + rho_2_i = jax.random.uniform(key_2_i, (100,), minval=10.0, maxval=100.0) + rho_e_i = jax.random.uniform(key_e_i, (100,), minval=10.0, maxval=100.0) + s_i = jax.random.uniform(key_s_i, (100,), minval=10.0, maxval=100.0) + + got = L_i(s_d, sin_2_beta_0) + expected = s_d * sin_2_beta_0 + + chex.assert_trees_all_close(got, expected) + + got = L_i(s_d, sin_2_beta_0, s_i=s_i) + expected = L_i(s_d, sin_2_beta_0, rho_1_i=s_i, rho_2_i=s_i, rho_e_i=s_i) + + chex.assert_trees_all_close(got, expected) + + got = L_i(s_d, sin_2_beta_0, rho_1_i=rho_1_i, rho_2_i=rho_2_i, rho_e_i=rho_e_i) + expected = ( + s_d + * (rho_e_i + s_d) + * rho_1_i + * rho_2_i + * sin_2_beta_0 + / (rho_e_i * (rho_1_i + s_d) * (rho_2_i + s_d)) + ) + + chex.assert_trees_all_close(got, expected) + + with pytest.raises( + ValueError, + match="If 's_i' is provided, then 'rho_1_i', 'rho_2_i', and 'rho_e_i' must be left to 'None'", + ): + _ = L_i( # type: ignore[reportCallIsuee] + s_d, + sin_2_beta_0, + rho_1_i=rho_1_i, + rho_2_i=rho_2_i, + rho_e_i=rho_e_i, + s_i=s_i, # type: ignore[reportArgumentType] + ) + + with pytest.raises( + ValueError, + match="If 's_i' is provided, then 'rho_1_i', 'rho_2_i', and 'rho_e_i' must be left to 'None'", + ): + _ = L_i(s_d, sin_2_beta_0, rho_1_i=rho_1_i, s_i=s_i) # type: ignore[reportCallIssue] + + with pytest.raises( + ValueError, + match="All three of 'rho_1_i', 'rho_2_i', and 'rho_e_i' must be provided, or left to 'None'", + ): + _ = L_i(s_d, sin_2_beta_0, rho_1_i=rho_1_i) # type: ignore[reportCallIssue] def scipy_F(x: np.ndarray) -> np.ndarray: # noqa: N802 @@ -17,8 +78,29 @@ def scipy_F(x: np.ndarray) -> np.ndarray: # noqa: N802 def test_F() -> None: # noqa: N802 - x = jnp.logspace(-3, 1, 100) + # Test case 1: 0.001 to 10.0 + x = jnp.logspace(-3, 1, 1000) got = F(x) expected = jnp.asarray(scipy_F(np.asarray(x))) chex.assert_trees_all_close(got, expected, rtol=1e-5) + + # Test case 2: F(x), x -> 0 + info = jnp.finfo(float) + got = F(info.eps) + mag = jnp.abs(got) + angle = jnp.angle(got, deg=True) + + chex.assert_trees_all_close(mag, 0.0, atol=1e-7) + chex.assert_trees_all_close(angle, 45) + + # Test case 3: F(x), x -> +oo + got = F(1e6) + mag = jnp.abs(got) + + chex.assert_trees_all_close(mag, 1.0, atol=1e-4) + + +def test_diffraction_coefficients() -> None: + with pytest.raises(NotImplementedError): + _ = diffraction_coefficients() diff --git a/differt/tests/em/test_utils.py b/differt/tests/em/test_utils.py index cf52740c..2475ede3 100644 --- a/differt/tests/em/test_utils.py +++ b/differt/tests/em/test_utils.py @@ -1,3 +1,4 @@ +# ruff: noqa: N806 from contextlib import AbstractContextManager from contextlib import nullcontext as does_not_raise @@ -7,7 +8,13 @@ from jaxtyping import Array from differt.em._constants import c -from differt.em._utils import lengths_to_delays, path_delays +from differt.em._utils import ( + lengths_to_delays, + path_delays, + sp_directions, + sp_rotation_matrix, +) +from differt.geometry import rotation_matrix_along_z_axis from ..utils import random_inputs @@ -61,3 +68,67 @@ def test_path_delays_random_inputs( ) chex.assert_trees_all_close(got, expected) + + +def test_sp_directions() -> None: + cos = jnp.cos(jnp.pi / 6) + sin = jnp.sin(jnp.pi / 6) + k_i = jnp.array([[cos, -sin, 0.0], [0.0, -1.0, 0.0]]) + k_r = jnp.array([[cos, +sin, 0.0], [0.0, +1.0, 0.0]]) + normals = jnp.array([[0.0, 1.0, 0.0], [0.0, 1.0, 0.0]]) + got = sp_directions(k_i, k_r, normals) + + chex.assert_trees_all_close( + got[0][0], got[1][0], custom_message="s-components should be equal" + ) + + for comp, k in zip(got, (k_i, k_r), strict=True): + s = comp[0] + p = comp[1] + + chex.assert_trees_all_close(jnp.cross(p, s), k) + chex.assert_trees_all_close(jnp.cross(k, p), s) + chex.assert_trees_all_close(jnp.cross(s, k), p) + + expected_e_i_s = jnp.array([[0.0, 0.0, 1.0], [1.0, 0.0, 0.0]]) + expected_e_i_p = jnp.array([[+sin, cos, 0.0], [0.0, 0.0, -1.0]]) + expected_e_r_p = jnp.array([[-sin, cos, 0.0], [0.0, 0.0, 1.0]]) + + chex.assert_trees_all_close(got[0][0], expected_e_i_s) + chex.assert_trees_all_close(got[0][1], expected_e_i_p) + chex.assert_trees_all_close(got[1][1], expected_e_r_p) + + +def test_sp_rotation_matrix() -> None: + e_i_s = jnp.array([1.0, 0.0, 0.0]) + e_i_p = jnp.array([0.0, 1.0, 0.0]) + + e_r_s = jnp.array([+0.0, 1.0, 0.0]) + e_r_p = jnp.array([-1.0, 0.0, 0.0]) + + got_R = sp_rotation_matrix(e_i_s, e_i_p, e_r_s, e_r_p) + expected_R = rotation_matrix_along_z_axis(-jnp.pi / 2) + + chex.assert_trees_all_close(got_R, expected_R[:-1, :-1], atol=1e-7) + chex.assert_trees_all_close(got_R @ got_R.mT, jnp.eye(2)) + + e_r_s = jnp.array([+1.0, 1.0, 0.0]) * jnp.sqrt(2) / 2 + e_r_p = jnp.array([-1.0, 1.0, 0.0]) * jnp.sqrt(2) / 2 + + got_R = sp_rotation_matrix(e_i_s, e_i_p, e_r_s, e_r_p) + expected_R = rotation_matrix_along_z_axis(-jnp.pi / 4) + + chex.assert_trees_all_close(got_R, expected_R[:-1, :-1]) + chex.assert_trees_all_close(got_R @ got_R.mT, jnp.eye(2), atol=1e-7) + + # We test normal incidence + e_r_s = +e_i_s + e_r_p = -e_i_p + + got_R = sp_rotation_matrix(e_i_s, e_i_p, e_r_s, e_r_p) + expected_R = rotation_matrix_along_z_axis(0.0).at[1, 1].set(-1.0) + + # Improper rotation matrix, determinant should be -1.0 + chex.assert_trees_all_close(jnp.linalg.det(got_R), -1.0) + chex.assert_trees_all_close(got_R, expected_R[:-1, :-1]) + chex.assert_trees_all_close(got_R @ got_R.mT, jnp.eye(2)) diff --git a/differt/tests/geometry/test_paths.py b/differt/tests/geometry/test_paths.py index 0c5459d3..2029b3e1 100644 --- a/differt/tests/geometry/test_paths.py +++ b/differt/tests/geometry/test_paths.py @@ -1,4 +1,6 @@ import math +from contextlib import AbstractContextManager +from contextlib import nullcontext as does_not_raise import chex import equinox as eqx @@ -43,6 +45,51 @@ def random_paths( class TestPaths: + @pytest.mark.parametrize("with_mask", [False, True]) + @pytest.mark.parametrize( + ("batch", "axis", "expectation"), + [ + ((), None, does_not_raise()), + ( + (), + -1, + pytest.raises( + ValueError, match="Cannot squeeze a 0-dimensional batch!" + ), + ), + ((1,), -1, does_not_raise()), + ((1,), 0, does_not_raise()), + ((10, 1), -1, does_not_raise()), + ((10, 1), 1, does_not_raise()), + ((1, 1), (0, 1), does_not_raise()), + ( + (1,), + 1, + pytest.raises( + ValueError, match="One of the provided axes is out-of-bounds!" + ), + ), + ], + ) + def test_squeeze( + self, + with_mask: bool, + batch: tuple[int, ...], + axis: int | tuple[int, ...] | None, + expectation: AbstractContextManager[Exception], + key: PRNGKeyArray, + ) -> None: + path_length = 10 + num_objects = 30 + with expectation: + _ = random_paths( + path_length, + *batch, + num_objects=num_objects, + with_mask=with_mask, + key=key, + ).squeeze(axis=axis) + @pytest.mark.parametrize("path_length", [3, 5]) @pytest.mark.parametrize("batch", [(), (1,), (1, 2, 3, 4)]) @pytest.mark.parametrize("num_objects", [1, 10]) diff --git a/differt/tests/geometry/test_triangle_mesh.py b/differt/tests/geometry/test_triangle_mesh.py index 1df5b33f..a4d21b07 100644 --- a/differt/tests/geometry/test_triangle_mesh.py +++ b/differt/tests/geometry/test_triangle_mesh.py @@ -346,6 +346,17 @@ def test_set_face_colors( mesh = two_buildings_mesh.set_face_colors(colors) assert mesh.face_colors is not None + def test_set_face_colors_wrong_args( + self, + two_buildings_mesh: TriangleMesh, + key: PRNGKeyArray, + ) -> None: + colors = jax.random.uniform(key, (two_buildings_mesh.num_triangles, 3)) + with pytest.raises( + ValueError, match="You must specify one of 'colors' or `key`, not both" + ): + _ = two_buildings_mesh.set_face_colors(colors, key=key) # type: ignore[reportCallIssue] + def test_load_obj(self, two_buildings_obj_file: str) -> None: mesh = TriangleMesh.load_obj(two_buildings_obj_file) assert mesh.triangles.shape == (24, 3) diff --git a/differt/tests/geometry/test_utils.py b/differt/tests/geometry/test_utils.py index a35f60cc..38b37f04 100644 --- a/differt/tests/geometry/test_utils.py +++ b/differt/tests/geometry/test_utils.py @@ -1,6 +1,7 @@ from collections.abc import Callable from contextlib import AbstractContextManager from contextlib import nullcontext as does_not_raise +from itertools import product import chex import jax @@ -17,6 +18,7 @@ orthogonal_basis, pairwise_cross, path_lengths, + perpendicular_vectors, rotation_matrix_along_axis, rotation_matrix_along_x_axis, rotation_matrix_along_y_axis, @@ -82,6 +84,17 @@ def test_normalize_random_inputs( chex.assert_trees_all_close(u, nu * lu[..., None]) +def test_perpendicular_vectors() -> None: + all_vectors = list(product([0.0, 1.0, -1.0], repeat=3)) + # Drop [0, 0, 0] case + all_vectors = all_vectors[1:] + u = jnp.array(all_vectors) + v = perpendicular_vectors(u) + + chex.assert_trees_all_close(jnp.linalg.norm(v, axis=-1), 1.0) + chex.assert_trees_all_close(jnp.sum(u * v, axis=-1), 0.0) + + @pytest.mark.parametrize( "u", [ diff --git a/differt/tests/rt/test_utils.py b/differt/tests/rt/test_utils.py index 851d9fe8..3777b586 100644 --- a/differt/tests/rt/test_utils.py +++ b/differt/tests/rt/test_utils.py @@ -3,7 +3,6 @@ from contextlib import nullcontext as does_not_raise import chex -import jax import jax.numpy as jnp import pytest from jaxtyping import Array @@ -249,7 +248,7 @@ def test_rays_intersect_any_triangle( ) -> None: if hit_tol is None: dtype = jnp.result_type(ray_origins, ray_directions, triangle_vertices) - hit_tol = jnp.finfo(dtype).eps # type: ignore[reportAssigmentType] + hit_tol = jnp.finfo(dtype).eps hit_threshold = 1.0 - hit_tol # type: ignore[reportOperatorIssue] with expectation: @@ -332,7 +331,6 @@ def test_triangles_visible_from_vertices( ) @pytest.mark.parametrize("epsilon", [None, 1e-2]) @random_inputs("ray_origins", "ray_directions", "triangle_vertices") -@jax.debug_nans(False) # noqa: FBT003 def test_first_triangles_hit_by_rays( ray_origins: Array, ray_directions: Array, diff --git a/differt/tests/scene/test_triangle_scene.py b/differt/tests/scene/test_triangle_scene.py index e684249d..0690150d 100644 --- a/differt/tests/scene/test_triangle_scene.py +++ b/differt/tests/scene/test_triangle_scene.py @@ -162,8 +162,7 @@ def test_compute_paths_on_advanced_path_tracing_example( if assume_quads: expected_objects -= expected_objects % 2 - with jax.debug_nans(False): # noqa: FBT003 - got = scene.compute_paths(order, method=method, max_dist=1e-1) + got = scene.compute_paths(order, method=method, max_dist=1e-1) if method == "sbr": masked_vertices = got.masked_vertices @@ -257,8 +256,7 @@ def test_compute_paths_on_simple_street_canyon( if assume_quads: expected_objects -= expected_objects % 2 - with jax.debug_nans(False): # noqa: FBT003 - got = scene.compute_paths(order) + got = scene.compute_paths(order) chex.assert_trees_all_close( got.masked_vertices, expected_path_vertices, atol=1e-5 @@ -355,14 +353,13 @@ def test_compute_paths_on_empty_scene( ) with expectation: - with jax.debug_nans(False): # noqa: FBT003 - got = scene.compute_paths( # type: ignore[reportCallIssue] - order=order, - chunk_size=chunk_size, # type: ignore[reportArgumentType] - path_candidates=path_candidates, - parallel=parallel, - method=method, # type: ignore[reportArgumentType] - ) + got = scene.compute_paths( # type: ignore[reportCallIssue] + order=order, + chunk_size=chunk_size, # type: ignore[reportArgumentType] + path_candidates=path_candidates, + parallel=parallel, + method=method, # type: ignore[reportArgumentType] + ) paths = next(got) if isinstance(got, Iterator) else got @@ -382,8 +379,7 @@ def test_compute_paths_on_grid( scene = scene.with_transmitters_grid(m_tx, n_tx) scene = scene.with_receivers_grid(m_rx, n_rx) - with jax.debug_nans(False): # noqa: FBT003 - paths = scene.compute_paths(order=1) + paths = scene.compute_paths(order=1) if n_tx is None: n_tx = m_tx @@ -450,10 +446,9 @@ def test_compute_paths_parallel( num_rays = m_rx * n_rx with expectation: - with jax.debug_nans(False): # noqa: FBT003 - paths = scene.compute_paths( - order=1, method=method, num_rays=num_rays, parallel=True - ) + paths = scene.compute_paths( + order=1, method=method, num_rays=num_rays, parallel=True + ) # TODO: fix this when 'hybrid' is implemented num_path_candidates = ( diff --git a/differt/tests/test_integration.py b/differt/tests/test_integration.py new file mode 100644 index 00000000..1a2c84d4 --- /dev/null +++ b/differt/tests/test_integration.py @@ -0,0 +1,211 @@ +import chex +import equinox as eqx +import jax.numpy as jnp +import numpy as np +import pytest + +from differt.em import materials +from differt.geometry import ( + TriangleMesh, + assemble_paths, + fibonacci_lattice, +) +from differt.rt import ( + first_triangles_hit_by_rays, + rays_intersect_any_triangle, + rays_intersect_triangles, +) +from differt.scene import TriangleScene + + +@pytest.mark.slow +def test_ray_casting() -> None: + o3d = pytest.importorskip("open3d") + + knot_mesh = o3d.data.KnotMesh() + o3d_mesh = o3d.io.read_triangle_mesh(knot_mesh.path).translate([50, 20, 10]) + + o3d_mesh = o3d.t.geometry.TriangleMesh.from_legacy(o3d_mesh) + o3d_mesh = o3d_mesh.compute_vertex_normals() # This avoids a warning from Open3D + o3d_mesh = o3d_mesh.compute_triangle_normals() + + mesh = TriangleMesh( + vertices=o3d_mesh.vertex.positions.numpy(), + triangles=o3d_mesh.triangle.indices.numpy(), + ) + + chex.assert_trees_all_close( + mesh.bounding_box, + jnp.stack( + [ + o3d_mesh.get_min_bound().numpy(), + o3d_mesh.get_max_bound().numpy(), + ], + axis=0, + ), + ) + + chex.assert_trees_all_close( + mesh.normals, o3d_mesh.triangle.normals.numpy(), atol=1e-6 + ) + + scene = o3d.t.geometry.RaycastingScene() + scene.add_triangles(o3d_mesh) + + ray_directions = fibonacci_lattice(1_000) + ray_directions = fibonacci_lattice(50) + ray_origins = jnp.zeros_like(ray_directions) + + o3d_rays = o3d.core.Tensor( + np.concatenate((ray_origins, ray_directions), axis=-1), + dtype=o3d.core.Dtype.Float32, + ) + + triangle_vertices = mesh.triangle_vertices + + triangles, t_hit = first_triangles_hit_by_rays( + ray_origins, ray_directions, triangle_vertices + ) + hit = triangles != -1 + triangles = triangles.astype(jnp.uint32) + + ans = scene.cast_rays(o3d_rays, nthreads=1) # codespell:ignore ans + + chex.assert_trees_all_close( + t_hit, + ans["t_hit"].numpy(), # codespell:ignore ans + atol=1e-4, + ) + chex.assert_trees_all_equal( + jnp.where(hit, triangles, jnp.asarray(scene.INVALID_ID, dtype=jnp.uint32)), + ans["primitive_ids"].numpy(), # codespell:ignore ans + ) + + got_counts = rays_intersect_triangles( + ray_origins[..., None, :], ray_directions[..., None, :], triangle_vertices + )[1].sum(axis=-1) + + expected_counts = scene.count_intersections(o3d_rays, nthreads=1).numpy() + + chex.assert_trees_all_equal( + got_counts, + expected_counts, + ) + + scale = 100.0 + + got_hit = rays_intersect_any_triangle( + ray_origins, + scale * ray_directions, + triangle_vertices, + ) + + expected_hit = scene.test_occlusions(o3d_rays, tfar=scale, nthreads=1).numpy() + + chex.assert_trees_all_equal( + got_hit, + expected_hit, + ) + + +@pytest.mark.slow +def test_simple_street_canyon() -> None: + sionna = pytest.importorskip("sionna") + file = sionna.rt.scene.simple_street_canyon + + sionna_scene = sionna.rt.load_scene(file) + differt_scene = TriangleScene.load_xml(file) + + sionna_scene.tx_array = sionna.rt.PlanarArray( + num_rows=1, + num_cols=1, + vertical_spacing=0.5, + horizontal_spacing=0.5, + pattern="tr38901", + polarization="V", + ) + + sionna_scene.rx_array = sionna.rt.PlanarArray( + num_rows=1, + num_cols=1, + vertical_spacing=0.5, + horizontal_spacing=0.5, + pattern="dipole", + polarization="cross", + ) + + tx = sionna.rt.Transmitter(name="tx", position=[-33.0, 0.0, 32.0]) + + sionna_scene.add(tx) + + rx = sionna.rt.Receiver(name="rx", position=[20.0, 0.0, 2.0], orientation=[0, 0, 0]) + + sionna_scene.add(rx) + + tx.look_at(rx) + + differt_scene = eqx.tree_at( + lambda s: s.transmitters, + differt_scene, + replace=jnp.asarray(tx.position.numpy()), + ) + + differt_scene = eqx.tree_at( + lambda s: s.receivers, + differt_scene, + replace=jnp.asarray(rx.position.numpy()), + ) + + max_order = 4 + + sionna_paths = sionna_scene.compute_paths(max_depth=max_order, method="exhaustive") + sionna_path_objects = sionna_paths.objects.numpy() + sionna_path_vertices = sionna_paths.vertices.numpy() + + max_depth = sionna_path_objects.shape[0] # May differ from 'max_order' + + for order in range(max_depth + 1): + paths = differt_scene.compute_paths(order=order) + select = (sionna_path_objects == -1).sum(axis=0) == (max_depth - order) + vertices = sionna_path_vertices[:order, select, :] + vertices = np.moveaxis(vertices, 0, -2) + vertices = assemble_paths( + differt_scene.transmitters.reshape(1, 3), + jnp.asarray(vertices), + differt_scene.receivers.reshape(1, 3), + ) + chex.assert_trees_all_close( + paths.masked_vertices, + vertices, + atol=1e-5, + custom_message=f"Mismatch for paths {order = }.", + ) + + +def test_itu_materials() -> None: + sionna = pytest.importorskip("sionna") + sionna_scene = sionna.rt.scene.Scene("__empty__") + + for mat_name, differt_mat in materials.items(): + if not mat_name.startswith("itu_"): + continue + + if mat_name == "itu_vacuum": + sionna_mat = sionna_scene.get("vacuum") + else: + sionna_mat = sionna_scene.get(mat_name) + + for f in np.logspace(9 - 2, 9 + 3, 21): + sionna_scene.frequency = f + + chex.assert_trees_all_close( + differt_mat.relative_permittivity(f), + sionna_mat.relative_permittivity, + custom_message=f"Mismatch for {mat_name = } @ {f / 1e9} GHz.", + ) + + chex.assert_trees_all_close( + differt_mat.conductivity(f), + sionna_mat.conductivity, + custom_message=f"Mismatch for {mat_name = } @ {f / 1e9} GHz.", + ) diff --git a/differt/tests/test_utils.py b/differt/tests/test_utils.py index 4d71b017..3f56bb91 100644 --- a/differt/tests/test_utils.py +++ b/differt/tests/test_utils.py @@ -1,9 +1,16 @@ import chex +import jax import jax.numpy as jnp import pytest from jaxtyping import Array, PRNGKeyArray -from differt.utils import dot, minimize, sample_points_in_bounding_box, sorted_array2 +from differt.utils import ( + dot, + minimize, + safe_divide, + sample_points_in_bounding_box, + sorted_array2, +) from .utils import random_inputs @@ -135,3 +142,20 @@ def assert_in_bounds(a: Array, bounds: Array) -> None: assert_in_bounds(got, bounding_box) assert got.shape == (4, 5, 3) + + +def test_safe_divide(key: PRNGKeyArray) -> None: + key_x, key_y = jax.random.split(key, 2) + x = jax.random.uniform(key_x, (30, 20)) + y = jax.random.randint(key_y, (30, 20), minval=0, maxval=3) + + assert y.sum() > 0, "We need at least one division by zero" + + got = safe_divide(x, y) + + assert not jnp.all(jnp.isnan(got)), "We don't want any NaN" + + expected = jnp.where(y != 0, x / y, 0) + + chex.assert_trees_all_equal_shapes_and_dtypes(got, expected) + chex.assert_trees_all_equal(got, expected) diff --git a/docs/source/_templates/autosummary/base.rst b/docs/source/_templates/autosummary/base.rst new file mode 100644 index 00000000..a4ecd522 --- /dev/null +++ b/docs/source/_templates/autosummary/base.rst @@ -0,0 +1,9 @@ +{% if module.rsplit('.', 1)[1].startswith('_') -%} +{{ (module.rsplit('.', 1)[0] + '.' + objname) | escape | underline}} +{%- else -%} +{{ fullname | escape | underline}} +{%- endif %} + +.. currentmodule:: {{ module }} + +.. auto{{ objtype }}:: {{ objname }} diff --git a/docs/source/_templates/autosummary/class.rst b/docs/source/_templates/autosummary/class.rst index 3027baff..b2c8634c 100644 --- a/docs/source/_templates/autosummary/class.rst +++ b/docs/source/_templates/autosummary/class.rst @@ -1,9 +1,17 @@ +{% if module.rsplit('.', 1)[1].startswith('_') -%} +{{ (module.rsplit('.', 1)[0] + '.' + objname) | escape | underline}} +{%- else -%} {{ fullname | escape | underline}} +{%- endif %} .. currentmodule:: {{ module }} .. autoclass:: {{ objname }} :members: + :show-inheritance: + {% if objname == 'InteractionType' -%} + :member-order: bysource + {%- else -%} :inherited-members: {% block attributes %} @@ -22,10 +30,11 @@ .. rubric:: {{ _('Methods') }} .. autosummary:: - {% for item in methods if item != '__init__' %} + {% for item in methods if item != '__init__' %} ~{{ name }}.{{ item }} {%- endfor %} {% endif %} {% endblock %} .. rubric:: Detailed documentation + {%- endif %} diff --git a/docs/source/api_reference.rst b/docs/source/api_reference.rst index 890447b4..ae46d233 100644 --- a/docs/source/api_reference.rst +++ b/docs/source/api_reference.rst @@ -1,6 +1,14 @@ API Reference ============= +DiffeRT comes as two-package project: :mod:`differt`, the main Python module, that +contains most of the features, and :mod:`differt_core`, a lower-level Python module written +in Rust, for performance reason. The second module (:mod:`differt_core`) is a direct dependency +of the former (:mod:`differt`). However, you can also decide to install :mod:`differt_core` directly, +if you only needs its features. + +You can find the documentation for both packages by clicking on the links below. + .. toctree:: :maxdepth: 1 diff --git a/docs/source/conf.py b/docs/source/conf.py index 47f1ea09..9e626cf6 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -15,6 +15,7 @@ from pathlib import Path from typing import Any +import jaxtyping from docutils import nodes from sphinx.addnodes import pending_xref from sphinx.application import Sphinx @@ -77,6 +78,13 @@ linkcheck_report_timeouts_as_broken = False # Default value in Sphinx >= 8 +# -- MathJax settings + +mathjax3_config = { + "loader": {"load": ["[tex]/boldsymbol"]}, + "tex": {"packages": {"[+]": ["boldsymbol"]}}, +} + numfig = True # -- Intersphinx mapping @@ -299,6 +307,13 @@ def fix_reference( def setup(app: Sphinx) -> None: + # Patch to avoid expanding the ArrayLike union type, which takes a lot + # of space and is less readable. + class ArrayLike(jaxtyping.Array): + pass + + jaxtyping.ArrayLike = ArrayLike + download_sionna_scenes() # Put this here so that download does not occur during notebooks execution app.connect("autodoc-before-process-signature", fix_sionna_folder) diff --git a/docs/source/index.rst b/docs/source/index.rst index 8c875879..b2431c8b 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -20,6 +20,7 @@ Contents installation conventions + motivations .. toctree:: :caption: Basic Tutorials diff --git a/docs/source/motivations.md b/docs/source/motivations.md new file mode 100644 index 00000000..5749d6b2 --- /dev/null +++ b/docs/source/motivations.md @@ -0,0 +1,44 @@ +# Why use DiffeRT? + +Why should you use DiffeRT? For what purpose? + +Those are two good questions, and we will try to motivate in this document the reasons to use DiffeRT. + +## What is DiffeRT + +DiffeRT is a Python, array-oriented, Differentiable Ray Tracing (RT) library that aims to provide fast and easy-to-use tools to model the propagation of radio waves. The long-term objective of DiffeRT is to provide: + +- fast methods to load large scenes from various formats; +- a large set of performant RT utilities (e.g., ray launching, {func}`image_method`); +- the ability to easily compute electromagnetic fields and relevant metrics (e.g., power delay profile, angular spread); +- and the ability to differentiate any of the previous parts with respect to arbitrary input parameters for optimization or Machine Learning applications. + +## History + +The development of DiffeRT began around 2021 as a collection of unorganized code projects during a PhD program. Later, a 2D version of DiffeRT, [DiffeRT2d](https://github.com/jeertmans/DiffeRT2d), was created and published in an open-access journal in 2024 {cite}`differt2d`. + +While 2D RT is excellent for developing toy examples---especially when leveraging object-oriented programming---it often scales poorly to large scenes, limiting DiffeRT2d's use to fundamental research rather than realistic radio propagation scenarios. + +DiffeRT builds on some of the principles behind DiffeRT2d while prioritizing performance and scalability for any scene size. Most utilities provided by DiffeRT work directly on arrays to avoid unnecessary abstractions associated with object-oriented programming[^1]. + +[^1]: DiffeRT stills uses object-oriented programming in some places, but those classes are immutable dataclasses, and JAX-compatible PyTree, which makes them compatible with many of the JAX features. + +## DiffeRT vs. Sionna + +In terms of features, DiffeRT does not aim to match the extensive functionality of Sionna. Instead, DiffeRT focuses on RT-specific applications similar to what `sionna.rt` offers, but with four main differences: + +1. **Public lower-level RT Routines[^2]:** Many internal RT mechanisms in Sionna are hidden or undocumented, making it challenging to modify the pipeline. DiffeRT, on the other hand, ensures that most RT utilities are public and well-documented, enabling users to customize or replace parts of the RT algorithms without re-implementing or copy-pasting code. +2. **JAX Integration:** Unlike Sionna, which uses TensorFlow, DiffeRT leverages JAX for efficient array-based programming. JAX offers powerful features like automatic differentiation, just-in-time (JIT) compilation, and compatibility with GPU/TPU acceleration, making it highly suitable for optimization and Machine Learning tasks. +3. **Minimal Abstraction with Immutable Dataclasses:** Sionna internally represents scenes using Mitsuba, which, while powerful, imposes restrictions on the types of scenes it can handle. Moreover, Sionna's classes are relatively complex, with many hidden attributes. In contrast, DiffeRT uses immutable dataclasses that can be created using simple constructor or convenient class methods (e.g., for reading scenes from files). Following JAX principles, all classes are immutable PyTrees, ensuring compatibility with JAX while avoiding unnecessary memory allocations through JIT optimization. +4. **Lightweight and Broadcastable Design:** DiffeRT's design philosophy prioritizes transparency and usability for RT applications, avoiding the heavier abstractions often seen in other libraries. Classes aim to store as few attributes as possibles, and most utilities accept input arrays with arbitrary sized inputs, which makes it very easy, e.g., to compute the same operation for one receiving (RX) antenna, or on a two-dimensional grid of RXs. + +[^2]: There are some exceptions, like the internal machinery behind + {meth}`TriangleScene.compute_paths`, + but we then provide detailed tutorials to help the user understand and build their version of the function, + if they which to do so, e.g., with {ref}`advanced_path_tracing`. + +We acknowledge the work of Sionna, and would recommend users to try both tools, and use the one that best fits their needs! If you want to reuse scenes files from Sionna, check out the {meth}`TriangleScene.load_xml` method, as it supports reading the same file format as the one use by Sionna, i.e., the XML file format used by Mitsuba. + +## What's Next? + +If you have any question, remark, or recommendation regarding DiffeRT, or its comparison with Sionna, please feel free to reach out on [GitHub discussion](https://github.com/jeertmans/DiffeRT/discussions)! diff --git a/docs/source/notebooks/multipath.ipynb b/docs/source/notebooks/multipath.ipynb index cc55b9e1..78e7b305 100644 --- a/docs/source/notebooks/multipath.ipynb +++ b/docs/source/notebooks/multipath.ipynb @@ -85,18 +85,26 @@ "from plotly.colors import convert_to_RGB_255\n", "from plotly.subplots import make_subplots\n", "\n", + "from differt.em import (\n", + " Dipole,\n", + " materials,\n", + " pointing_vector,\n", + " reflection_coefficients,\n", + " sp_directions,\n", + ")\n", "from differt.geometry import (\n", " TriangleMesh,\n", " merge_cell_ids,\n", " min_distance_between_cells,\n", - " path_lengths,\n", + " normalize,\n", ")\n", "from differt.plotting import draw_image, draw_markers, reuse, set_defaults\n", "from differt.scene import (\n", " TriangleScene,\n", " download_sionna_scenes,\n", " get_sionna_scene,\n", - ")" + ")\n", + "from differt.utils import dot" ] }, { @@ -846,10 +854,10 @@ 32.356605529785156, 32.356605529785156, 32.356605529785156, - 63.478111267089844, - 63.478111267089844, - 63.478111267089844, - 63.478111267089844, + 63.47811126708984, + 63.47811126708984, + 63.47811126708984, + 63.47811126708984, 32.356605529785156, 32.356605529785156, 32.356605529785156, @@ -906,31 +914,31 @@ -8.613334655761719, 37.45787048339844, 37.45787048339844, - 9.571563720703125, - 9.571563720703125, - 9.571563720703125, - 9.571563720703125, + 9.571563720703123, + 9.571563720703123, + 9.571563720703123, + 9.571563720703123, 37.45787048339844, 37.45787048339844, 37.45787048339844, 37.45787048339844, - 9.571563720703125, + 9.571563720703123, 37.45787048339844, - 9.571563720703125, + 9.571563720703123, 37.45787048339844, 37.45787048339844, 37.45787048339844, - 9.571563720703125, - 9.571563720703125, - 9.571563720703125, - 9.571563720703125, + 9.571563720703123, + 9.571563720703123, + 9.571563720703123, + 9.571563720703123, 37.45787048339844, 37.45787048339844, 37.45787048339844, 37.45787048339844, - 9.571563720703125, + 9.571563720703123, 37.45787048339844, - 9.571563720703125, + 9.571563720703123, 37.45787048339844, 38.223602294921875, 38.223602294921875, @@ -1871,10 +1879,11 @@ } } }, + "image/png": "iVBORw0KGgoAAAANSUhEUgAABAIAAAFoCAYAAADNd1kuAAAAAXNSR0IArs4c6QAAIABJREFUeF7svQm4ZGV57/vWXHvsiW4GBRS7pRGZm6GlTaNoQ0QBDxgTEw0eOfeeqKB5zLlg0Dz3Gr3i1VwVh+Re8aSjJwJHOEDOFRsQ1EPrgVZpBrEbUZRB6HnY89413ef/Va3aq+ZVtav2VL+V7HT33qu+9X2/b23yvP/vff9vKJfL5YwLAhCAAAQgAAEIQAACEIAABCAAga4gEEII6Ip9ZpEQgAAEIAABCEAAAhCAAAQgAAFHACGAFwECEIAABCAAAQhAAAIQgAAEINBFBBACumizWSoEIAABCEAAAhCAAAQgAAEIQAAhgHcAAhCAAAQgAAEIQAACEIAABCDQRQQQArpos1kqBCAAAQhAAAIQgAAEIAABCEAAIYB3AAIQgAAEIAABCEAAAhCAAAQg0EUEEAK6aLNZKgQgAAEIQAACEIAABCAAAQhAACGAdwACEIAABCAAAQhAAAIQgAAEINBFBBACumizWSoEIAABCEAAAhCAAAQgAAEIQAAhgHcAAhCAAAQgAAEIQAACEIAABCDQRQQQArpos1kqBCAAAQhAAAIQgAAEIAABCEAAIYB3AAIQgAAEIAABCEAAAhCAAAQg0EUEEAK6aLNZKgQgAAEIQAACEIAABCAAAQhAACGAdwACEIAABCAAAQhAAAIQgAAEINBFBBACumizWSoEIAABCEAAAhCAAAQgAAEIQAAhgHcAAhCAAAQgAAEIQAACEIAABCDQRQQQArpos1kqBCAAAQhAAAIQgAAEIAABCEAAIYB3AAIQgAAEIAABCEAAAhCAAAQg0EUEEAK6aLNZKgQgAAEIQAACEIAABCAAAQhAACGAdwACEIAABCAAAQhAAAIQgAAEINBFBBACumizWSoEIAABCEAAAhCAAAQgAAEIQAAhgHcAAhCAAAQgAAEIQAACEIAABCDQRQQQArpos1kqBCAAAQhAAAIQgAAEIAABCEAAIYB3AAIQgAAEIAABCEAAAhCAAAQg0EUEEAK6aLNZKgQgAAEIQAACEIAABCAAAQhAACGAdwACEIAABCAAAQhAAAIQgAAEINBFBBACumizWSoEIAABCEAAAhCAAAQgAAEIQAAhgHcAAhCAAAQgAAEIQAACEIAABCDQRQQQArpos1kqBCAAAQhAAAIQgAAEIAABCEAAIYB3AAIQgAAEIAABCEAAAhCAAAQg0EUEEAK6aLNZKgQgAAEIQAACEIAABCAAAQhAACGAdwACEIAABCAAAQhAAAIQgAAEINBFBBACumizWSoEIAABCEAAAhCAAAQgAAEIQAAhgHcAAhCAAAQgAAEIQAACEIAABCDQRQQQArpos1kqBCAAAQhAAAIQgAAEIAABCEAAIYB3AAIQgAAEIAABCEAAAhCAAAQg0EUEEAK6aLNZKgQgAAEIQAACEIAABCAAAQhAACGAdwACEIAABCAAAQhAAAIQgAAEINBFBBACumizWSoEIAABCEAAAhCAAAQgAAEIQAAhgHcAAhCAAAQgAAEIQAACEIAABCDQRQQQArpos1kqBCAAAQhAAAIQgAAEIAABCEAAIYB3AAIQgAAEIAABCEAAAhCAAAQg0EUEEAK6aLNZKgQgAAEIQAACEIAABCAAAQhAACGAdwACEIAABCAAAQhAAAIQgAAEINBFBBACumizWSoEIAABCEAAAhCAAAQgAAEIQAAhgHcAAhCAAAQgAAEIQAACEIAABCDQRQQQArpos1kqBCAAAQhAAAIQgAAEIAABCEAAIYB3AAIQgAAEIAABCEAAAhCAAAQg0EUEEAK6aLNZKgQgAAEIQAACEIAABCAAAQhAACGAdwACEIAABCAAAQhAAAIQgAAEINBFBBACumizWSoEIAABCEAAAhCAAAQgAAEIQAAhgHcAAhCAAAQgAAEIQAACEIAABCDQRQQQArpos1kqBCAAAQhAAAIQgAAEIAABCEAAIYB3AAIQgAAEIAABCEAAAhCAAAQg0EUEEAK6aLNZKgQgAAEIQAACEIAABCAAAQhAACGAdwACEIAABCAAAQhAAAIQgAAEINBFBBACumizWSoEIAABCEAAAhCAAAQgAAEIQAAhgHcAAhCAAAQgAAEIQAACEIAABCDQRQQQArpos1kqBCAAAQhAAAIQgAAEIAABCEAAIYB3AAIQgAAEIAABCEAAAhCAAAQg0EUEEAK6aLNZKgQgAAEIQAACEIAABCAAAQhAACGAdwACEIAABCAAAQhAAAIQgAAEINBFBBACumizWSoEIAABCEAAAhCAAAQgAAEIQAAhgHcAAhCAAAQgAAEIQAACEIAABCDQRQQQArpos1kqBCAAAQhAAAIQgAAEIAABCEAAIYB3AAIQgAAEIAABCEAAAhCAAAQg0EUEEAK6aLNZKgQgAAEIQAACEIAABCAAAQhAACGAdwACEIAABCAAAQhAAAIQgAAEINBFBBACumizWSoEIAABCEAAAhCAAAQgAAEIQAAhgHcAAhCAAAQgAAEIQAACEIAABCDQRQQQArpos1kqBCAAAQhAAAIQgAAEIAABCEAAIYB3AAIQgAAEIAABCEAAAhCAAAQg0EUEEAK6aLNZKgQgAAEIQAACEIAABCAAAQhAACGAdwACEIAABCAAAQhAAAIQgAAEINBFBBACumizWSoEIAABCEAAAhCAAAQgAAEIQAAhgHcAAhCAAAQgAAEIQAACEIAABCDQRQQQArpos1kqBCAAAQhAAAIQgAAEIAABCEAAIYB3AAIQgAAEIAABCEAAAhCAAAQg0EUEEAK6aLNZKgQgAAEIQAACEIAABCAAAQhAACGAdwACEIAABCAAAQhAAAIQgAAEINBFBBACumizWSoEIAABCEAAAhCAAAQgAAEIQAAhgHcAAhCAAAQgAAEIQAACEIAABCDQRQQQArpos1kqBCAAAQhAAAIQgAAEIAABCEAAIYB3AAIQgAAEIAABCEAAAhCAAAQg0EUEEAK6aLNZKgQgAAEIQAACEIAABCAAAQhAACGAdwACEIAABCAAAQhAAAIQgAAEINBFBBACumizWSoEIAABCEAAAhCAAAQgAAEIQAAhgHcAAhCAAAQgAAEIQAACEIAABCDQRQQQArpos1kqBCAAAQhAAAIQgAAEIAABCEAAIYB3AAIQgAAEIAABCEAAAhCAAAQg0EUEEAK6aLNZKgQgAAEIQAACEIAABCAAAQhAACGAdwACEIAABCAAAQhAAAIQgAAEINBFBBACumizWSoEIAABCEAAAhCAAAQgAAEIQAAhgHcAAhCAAAQgAAEIQAACEIAABCDQRQQQArpos1kqBCAAAQhAAAIQgAAEIAABCEAAIYB3AAIQgAAEIAABCEAAAhCAAAQg0EUEEAK6aLNZKgQgAAEIQAACEIAABCAAAQhAACGAdwACEIAABCAAAQhAAAIQgAAEINBFBBACumizWSoEIAABCEAAAhCAAAQgAAEIQAAhgHcAAhCAAAQgAAEIQAACEIAABCDQRQQQArpos1kqBCAAAQhAAAIQgAAEIAABCEAAIYB3AAIQgAAEIAABCEAAAhCAAAQg0EUEEAK6aLNZKgQgAAEIQAACEIAABCAAAQhAACGAdwACEIAABCAAAQhAAAIQgAAEINBFBBACumizWSoEIAABCEAAAhCAAAQgAAEIQAAhgHcAAhCAAAQgAAEIQAACEIAABCDQRQQQArpos1kqBCAAAQhAAAIQgAAEIAABCEAAIYB3AAIQgAAEIAABCEAAAhCAAAQg0EUEEAK6aLNZKgQgAAEIQAACEIAABCAAAQhAACGAdwACEIAABCAAAQhAAAIQgAAEINBFBBACumizWSoEIAABCEAAAhCAAAQgAAEIQAAhgHcAAhCAAAQgAAEIQAACEIAABCDQRQQQArpos1kqBCAAAQhAAAIQgAAEIAABCEAAIYB3AAIQgAAEIAABCEAAAhCAAAQg0EUEEAK6aLNZKgQgAAEIQAACEIAABCAAAQhAACGAdwACEIAABCAAAQhAAAIQgAAEINBFBBACumizWSoEIAABCEAAAhCAAAQgAAEIQAAhgHcAAhCAAAQgAAEIQAACEIAABCDQRQQQArpos1kqBCAAAQhAAAIQgAAEIAABCEAAIYB3AAIQgAAEIAABCEAAAhCAAAQg0EUEEAK6aLNZKgQgAAEIQAACEIAABCAAAQhAACGAdwACEIAABCAAAQhAAAIQgAAEINBFBBACumizWSoEIAABCEAAAhCAAAQgAAEIQAAhgHcAAhCAAAQgAAEIQAACEIAABCDQRQQQArpos1kqBCAAAQhAAAIQgAAEIAABCEAAIYB3AAIQgAAEIAABCEAAAhCAAAQg0EUEEAK6aLNZKgQgAAEIQAACEIAABCAAAQhAACGAdwACEIAABCAAAQhAAAIQgAAEINBFBBACumizWSoEIAABCECgnQQmUxmLRcIWDofaOSxjQQACEIAABCDQYQIIAR0GzPAQgAAEIACBxUQgncnayHjaJlNZy+ZylsuZrRhMWC6Xs3gUUWAx7TVrgQAEIACBxUsAIWDx7i0rgwAEIAABCLSFgIL/qXTORsdTlsrkrCcRsd5ExGUDZHNmSgg4PJay8cmMxWNh64lH3BeZAm3BzyAQgAAEIACBthNACGg7UgaEAAQgAAEILA4CE1MZ09eYAvxo2AX/8VjEplL5702lsxYJh2zlkoRlsjkLhUI2Npm2qZSEgyyiwOJ4DVgFBCAAAQgsQgIIAYtwU1kSBCAAAQhAoFUC/tR/jaET/oGeqAv0FfxLGFA5gP9atTRp+4cm3beULZCIhS0SDlcVBfqS0VanxucgAAEIQAACEGgTAYSANoFkGAhAAAIQgMBCJVAt9T8ZV+p/yCamsjY6kXZCQK1LQsCeQxMlP1amQL6EIB/4e5kCRyxJ2L6hScoHFurLwrwhAAEIQGBREEAIWBTbyCIgAAEIQAACzRPwp/5HIyHrT0ZdBoBS+73U/yCjHrOix17aP17zVpUVaFyJAnrO0FiqonxAWQTRSDjI47gHAhCAAAQgAIEZEkAImCFAPg4BCEAAAhBYSARUuz82kXf916UAXQKArpGJfH1/vdP/ams9alnSdh0szQiodp+yBFYMxp3IIFEgk826echkUM/UXCQIyGgQUWAhvVXMFQIQgAAEFhoBhICFtmPMFwIQgAAEINAkgfLU/2njv/zpvwSAdKZ26n+1x3leABOTGRvsi1eUBlT7jJ7b1xO1g8NTeRHClylQTRSQWJCM0ZKwye3mdghAAAIQgEBDAggBDRFxAwQgAAEIQGBhElBwrzR8lQDoNF6mfzp1VwtAfU8n8c1c3hixaNhS6awbJxnXKX7ERsbTbkxlHNS6PPHg0Eiq4pZqooBuCodCNpHKkCnQzEZxLwQgAAEIQKABAYQAXhEIQAACEIDAIiLgpf4r4M9mzVKZrGv7pyB+eLy11P+86d/0GP7OARp3+UDcxqfywboCd/1dBoPl3QU0TiwSduJEvUuigMwK9UyVDGi88vIBiQ+6jwsCEIAABCAAgeYJIAQ0z4xPQAACEIAABOYVgWqp/0rBT0TD7pS+GeM/b2H+DIJ65oEKxgd6o7Z/qDTdX74DEiHUdUAdAyQKKCNBlwSJINfS/pgrWZC4IFFA4/k9BTTH3qTaFSIKBOHJPRCAAAQgAAGPAEIA7wIEIAABCEBggRLwgnwv9V+n6H3J/Cm6uv2l09nAQbeHoPz0v5F5oPMbSEasVrq/d7KvIF5zaqYkYeWShB0cmSr6F/gzBTxRwMs88ESB/mTMQqEFuqFMGwIQgAAEIDBLBBACZgk0j4EABCAAAQi0g4Df9V8Bvz9w1+m7AmPv+0HS8DUnBdF9yair92+2dWC9uv/iqUPIXEmATvh1uq8MhUZ+AvrsqqVJ23t4oqLEQD8rFwW8zIMjlyZt/9CkJRNkCrTjfWMMCEAAAhBYnAQQAhbnvrIqCEAAAhBYRAS81P/hsZQL8qORkGv5Jx+AWoF7kABdjvw9yajFIs37B3j1+ZFIyHkDVMsIKN8CtQ4cHU9bJCIPgLATB7yWhdVMBoO2JfSLAuFwyA6PporlCJQPLKJfBJYCAQhAAAJtI4AQ0DaUDAQBCEAAAhBoLwHv5NxL/Vfgrzp71dt7AbSEgWpXLSFAgbE3jlz/xyfSNpGq7fRfPrbG1Ry856o+X2Z+EikatSBUqv+B4aniZ5XCr0wEtQn02gf6U/1XDCYCtSX05iiBRMaFyg7wPAX8HgV6np6lNWA02N53ldEgAAEIQGBhEUAIWFj7xWwhAAEIQGCREyhP/depversvdN/CQCNAm4hKq/dd/9WAByTgeB0CUEQnOXGgTL784QAiQKan+4pNwcsH7uZVP9MJmehcMgODudNCINc5eJHrfIBCSmIAkGIcg8EIAABCCxWAggBi3VnWRcEIAABCCwYAuWp//7AW6f2zRjseYv23PyVVeCd4OvvasMX9PKLB7VaD2rstFr8TWaK4kM8GnFzLvcBOGZFj720f7zh4725q3RA4sLwWNqqlQ6UDzTYG3MChbIKyq96okAyFrGBvnymBZkCDbeHGyAAAQhAYBEQQAhYBJvIEiAAAQhAYGES8Kf+64Tac9iXEFAr8A6yUgWzah+o2n95CPhP8IN83p/+30g8kAGg19LPG9s7bZcPQCQcdvX6EgqaSfWXwBAK5bMMlMkgUaCRyaDmMiEjwgalDuWigKorlIEwNJYiUyDIC8I9EIAABCCw4AkgBCz4LWQBEIAABCCwkAhMpvKn8gqedXrtnbpLBPDaAQY5/S5fs7/2X+Om0jmLRUO2fyhYan1554Cg4kE1IcA/N61P5QgyN9QljwAF9zp9r3eVj+sXFzyTQXH0eySUexAEeS80vyV9MYtFw6a9wVMgCDXugQAEIACBhU4AIWCh7yDzhwAEIACBeU+gWuq/1/ZPgax30t0oOK620PIOAl4A76XXNxICgqT/V3uuvAtUEtDfE7WxiUzD1H3PyE/rDXK6v2wgXtPIUGNJOJG44PclUOvA3Yeqtxus95JIQDg4MuVaG/YmI6ZSgXK/AzwF5v2vGROEAAQgAIEmCCAENAGLWyEAAQhAAALNEFCAr5P20fGUS1f3gn99r1nDPv9zvTIC75RdBoLltf+NhADNRZ/P5vJCRFDvAK9sQN4FiWjY8u36pmx0or73gDM9TERcm0GvW0BPPOKe7z+F99YZ9HTfn+avzx4eS7nMiqCiiuZy9PIee/nAePEz+p7ECkSBZt527oUABCAAgYVEACFgIe0Wc4UABCAAgXlPoDz1f2lfzKXG61K9voLuVlL/9flpE8GITaUyroVgrQ4Culfp9f6MgPL0/6AdCLzPqdRAtfT+soEVg3F3kq57NJ7X/q98o2q1M/QCeZ3wq4WgZwx41LLmTvclNAz0xpywIF+CqXQmcKbCsv647T08WfXdaiQKqP2hBA0LmUXDISeMcEEAAhCAAATmOwGEgPm+Q8wPAhCAAATmPYFqqf8K/mV4pyB5Mp21odFUST17M4vyZxIENRH0CwH+8gEF6jqB99fW15qLv3tBreeqJeD+oUkXgCso9tL1y53+9TPFyBqn2lXuAaCAevfBiUDz1HgaX/P1G/7JrLBRGUItgaLWHMszBbS/XkaG1iBRoDcZRRRo5gXnXghAAAIQmHUCCAGzjpwHQgACEIDAYiGg03054uukX5dOpXuS0279ChAVKMrB/+BwMNM+j40/CG8lk0Cf12m9An5/sBqEfblvQL2yAQkBew+X1uX72/95WQIDPTFXdx+kBEHCxRGDCTfV8lr9WvOv1TrQK0NIxKY7GPizFuq1HKzHyssUOGJJwrIyZ8zksz28sgREgSBvGvdAAAIQgMBcEUAImCvyPBcCEIAABBYkAS/1f7xQh+4P2FU3r0DQH+wqqJUrfSPTPg9GK6f/z23dYkedvt4S/UvcqbhOwvMn8CHn0h+0FMHvG1B+ol9rs5TCv+vgRNUfe0G438tgpEZGgH8AT1CReDLdCjHs2Oqr2nrqmQt6Y5e3DVRmhFoTHh5NBWbkn2debEk4IWS69WPYlTggCizIX28mDQEIQKBrCCAEdM1Ws1AIQAACEGiVgFL/1e5PQaxO2L1WfQpwFezWS9fXz+VmXytY1pxaOf0f3vWCPXX7N+z5rVtsZNeLFu8ftNPe/b/Yhqs+aLnEgAuYl/bHbc+h6kG6x8IzHlQZgzIPgrQN1HwVQIuJTsRf2j/eEK2yE1S7r/KAel4CGkiChFLwlebvn2dvIupEDgkcEmL8J/teiUKQkgeN6YkC6nogcUfCR5C2hv6FSmyRb4IMEMt5io/WW0sUkHCgOeAp0PDV4QYIQAACEOgAAYSADkBlSAhAAAIQWBwEylP/vZR5BXEKtIMa/9U6NW/l9P/l7T+1Z7b8V/clg7piAFr4iwSB1115tZ185X+wY195ZE0hwBMfIpFKA8Bau+cXLJT1IB8EBeyHRqZcV4R6l9cFoJGXgMaQKJHNmQv0q13lJ/uj42knetQTW6qN453oD4+nnLDRyE+gfAyZMUoMqVXu4Iks9UQBlSxo7pYzRIHF8Z8NVgEBCEBgQRBACFgQ28QkIQABCEBgtghUS/3Xya9OonXa7E/5Djonfys8BZ+eAKBSgnGZ9zUIovWcX3//Ntu++R/c6X/FVUMQOPWy99ir3/F+Gzjq2OJHghgAlo/v1fzr+/5WgxpLa9NJuoLoeif91U7sq3kJqO1fowC7KH4U2vwN9EYtEYu4jI1apQPV9qrcKLDcsFDrUZZErdKKZrIQaokCYqgMB4kRsWjY+hJaC5kCQX+3uA8CEIAABFojgBDQGjc+BQEIQAACi4hAeeq/luY/rVctuU6ng6adl6NRWrwC1HhMp86hwCn4wy+/YNs3f8FefuynPgGgQXu6KqKA/APOev/f2NoNG93zlf4fxLTPY6D1VPMMUBC7fCDfes/vB6AuCTql9wfQ9bwE5KOgLAuv44AC42bq9jVPfT6VVs1/1KXja89k5ChhodZVzyhwWrDJj6eTf38pgieCNJuFoLn4RQEJGJq3RAfPaFCsKBlYRP+BYSkQgAAE5iEBhIB5uClMCQIQgAAEOk9ATu86ife7/rea+l9rtp6XgFLdFZB6beYarU7p/78s1P/Xvrd5QeDYs86317/3r+3o099QdwoKrD3PAM05nakeTVczQvQHuXqIF5A38knwAuREtJAqb429BLxFlLcm9EoHnDiQqRQlvM9JoJFg0Sgjo5rJoMYo9wdotK/lP/f8IyTMeK0OJTYpM4ALAhCAAAQg0EkCCAGdpMvYEIAABCAw7whUS/1Xrbs/WFc6eKun/1qwv/2e11pQZQC1at71maGXXzAJAL/Y/AV3+q8wX4Fig3C/wLfOXTUyBM646mMVgoBfAAhiGug5+tdqjej9XMG9rt2HStsM1no5dCK+b2iyJEugXheDWqUE1VL9/af6ek7QOXlz9a9J74gyF5o1GfTGUncEtZb0OkpovisGEs57gQsCEIAABCDQSQIIAZ2ky9gQgAAEIDAvCCj1Xy7zYxOZYoDvpZN7qfr1Tr6DLMJ/+u95CXjp9zqx1s/9LvjemBIAtv3nL9jvt37fpkaGLOT+Z1oE6IQgoGdr3P6jXmkSBE588yV2xKoVgUoWytsl5nK5Etf8aqwUPKuEQJeC5npBvYLhlUuSRZNDfzBfrVuAxpRPwcGRqZqZC7qnvPxAqf4qI2jUVaHW3uuZGkNZAc2aDHpjVitNUBcG8eKCAAQgAAEIdJIAQkAn6TI2BCAAAQjMGYFqqf8KBlWHrhNXndCrJjtIrXy9RZSf/lc7Sa92cv7UPbfZjntus5e2/7Ti9D+fCdAhQcBTAQqLkiCQGFhiGz/+JXvlGy6qudRqJoODfTHTaX+jdoCOUTLiBAPvFFwPqlbHrz1a1p/3HSi/aqX8N3uyr3GW9MWcOOOZP9YyBKwGpLwlpOePIJM/zzRR71atkgpvTM1b6/SyTzSfI5cl5+x3hgdDAAIQgED3EEAI6J69ZqUQgAAEuoJArdR/CQAK2BSozzT1v9naf38t/QuP/sTu/fRHbKiQ/h8upP+Xn/znv18mBhTS/OuXCwT3DvA76Skz4Mz3/03FO+IJAD989pDd+tgeO+PofjtnybitPu4YW7ak14kpoVBeYKlmEqgBy9359T1/UC+fBgkyCoglFCQTedGg1lWeJaA5NpvirywNZYNIEFJ9fq1sg2pzKE/p99+jdUlo8owPa5kW6p1YPpAoyUjQnCRQcEEAAhCAAAQ6TQAhoNOEGR8CEIAABDpOQKn/OtFVQOjVW7fb+E+L8DIKZEKnADhoOYE/3f0nN3/eHvnPX7BIqPqJf7il4F+za00A8DanXAjwCwA3P/KyvTw8VdzHWC5tJ0f32KUnrbDVJ7zG+gcGS5zwdaM//b/czM//Qvi7DahsIJXOGxNWK6Oo9iJ5vgZa/1Qq4074g5zuLxuIl7RuLDcErFe+UK/bQLkooHelN5E3LfS3nuzviTrxwb9OygI6/p8KHgABCEAAAgUCCAG8ChCAAAQgsCAJlKf+e+na+TZ9YXe67A+8ZrJIBZszzSg4ZkWPq2N/9qEt9l8+8hcWC4VMpeDeyb+qwj1zQP0ZILQvW1IAw0AXY1fvAOAJAZ4A8OCzh+zmh6cFgJLRff84PvOynbMqbG8+7bV25NGvcHNSUD3Qm/dF0B4o4NV+1DNL9D4n4z9d+pzf2K/W/nnZBjLtU82/Tvd1NWofuGpp0vYPTafle+OLvcodZOLnpfkrg8QvLqjbQD2hoHyu5WNOpDJubH+LRO/9pW3gTH5T+SwEIAABCAQlgBAQlBT3QQACEIDAvCDgpf4rUHTBcsicu7xOXdWTXaaAQ6OpGbn+a1x/TbwCwaAnzeWQvDKCJb0xlzr/1NYf27f/18udCODEgJBZNBwyF74G7hJQayvKxACfoKD2hbV4p1veAAAgAElEQVREAP1k3b//G9v4weutXACoCGr93/A97ojsQXvTyqkSQUAZFArOtTcSaA6PpfzVCFUXISFAtfUSD7zMC3221il/tdN5f9mBPlv++fIa/5o0Q2bKZvCn+WusFYMJ23s4WAeEaqKAxtS8JWZ5YpUEAM9QcV78ojEJCEAAAhBY1AQQAhb19rI4CEAAAouDgFL/RyfyAZ1nrFYt9V+BY5CT53pUdMKswFUB/Ez8BDwhwQtm1RteJ8DP/+oJ+8/vfbPFwnkRIP9n/u9eJkDwXasd+JeP0UgI6H/HB+3Ha/7UXh6uXZtfQwMoqUqQIHByz6hddNIqO2H1SdbTk3QBrlLjk7FITR8Bb2ydtouTxACvbED7oX2tdgpfnuJfMseQVc0S0Gm8Mha8MpIgvD3Dw3g04rI4DgxPBSpBqDa2947JByH/96iFXXZIsGaRQebLPRCAAAQgAIF6BBACeD8gAAEIQGBeEqjm+q/gWoF1XzLiTpZVo+83/pOJW08yarX62tdaaLtO/zW+ZyRXLiTopHtiMmOHDhy0z1+4xgkAsaIQkM8KiNSMA2v8oGEJgW7IFU7hq5cEeEyePv3P7ddnvNcUkYZC4YaqxLk9v7WJbMwenzyu1J2gMKdYLmWvi+yxPzvzSFuzeq2FIvG6PgLePGql7HvlBl66vlc2UOv+8r2e7u4gUSHrhIV6hoS13hWd5nvlB+VzCfqLpHdBXgj+Ugl1EKAsIChB7oMABCAAgZkSQAiYKUE+DwEIQAACbSVQnvqvwf2n9Ar8a5n0KfjWifKeQ5Wt56pN0j+ugjLVlXsZB80syt9FwGUkjKdtIpUtGWKgJ2rZnLng7/9682pLjQ650oCYSgS8rADF3yWfCn7i7/+YXPy9DIB8RUB9EUC37Dztz0xigEQAJwQ4QcBLU5iex6mJ521j305bGhlzj5zMxuyR8de4r8mcz/G+TBDYdEK/nfK6k52xoK5yHwEvsJeXwkv7x2viV7lBvhQk6jJEZLpX7/7ygbQkZSgoiFeWQiMvgfLP6/3y9re8Q0BQT4rytoHytDhiMNHMK8e9EIAABCAAgRkRQAiYET4+DAEIQAAC7SBQLfXfc+hXkKQWbwr61GKu0bVqacKlfNcK6P2n/xp3XAJAWdDe6Bnez8vT/1VKUOu5OknW/XKJ/3///E22+5mniqUB/uwApZ1PX9NH/kGSxl3g7jkBFIL/RiUB3rOePOvf229PurSQDRC2kHLV1dmgkCFwavIF29j3dFEAKGek7ACJAdv8goBv0prH+tQTdsaRCTv1jHOKxoKej4BXQqE/9xyasMnxQ3bgD09a//LjbGD58RVb4pUNqNZe4pGCc3kwBNA8zMsi0H54poDVvASqvQcK4qu1KvRnLHiCQDVfA923tF9i1URxeLUM1PvBBQEIQAACEJgtAggBs0Wa50AAAhCAQAmBfMs/maXl0/t1+U/W9e9WavRr1Yx7ZQPqHa/n1QvaG21VrfT/ep9TAKigU2UL3/qPl9tzj/7Unf47ESBsFnelAiGLlGQFBBECQiXeAtOegMFKArw5P7r+Gnv+NW8pBv5eicBpvS8WBIDap/T+dUsQkBjwyEQ+Q8AfmK9L/8peld3lbl911CtKBAEF9gM9MXfCv/PJrfb8Mw/b1MRhd2//suPs6NVvrBAEvD2VCKCa/yCp+tWMAstbGNbKEpBosaw/bnsP18448cbqiUcsm8vZZEr+Fukih2ptAyVeRbXxXBCAAAQgAIFZIoAQMEugeQwEIAABCOQJVEv9V0Cnk2DvVLhVh36N7z95l7Dgpf977QSDZBVU2yuvO4FS/FsZywsiU+msfffvPmwP3/Gv7jHKAJAIkBcC8sLAdFZAPSFgusWAd/BeIgJo8CDH44XFbj//o/bCmk1mKg0Ih+1V8QO2aeCXdlQ0H4y3cj0+cZw9MvYa251Z4j6+MfWorcwdKhmqr3/AmQqefOrpFp540XY+tsXGR0vv8T4QiSZs5fFn26rjz7FoLFmy17rHH9ArO0ACQfmpfKNgvt7Jvv/dCsLD614gs0OVIcjsUIKF3/RQ7+iRy5JBhuMeCEAAAhCAQNsIIAS0DSUDQQACEIBALQI6/R9R6nZqugbfn6LvBdZKz24idq36OAVfSrVWwK2ygpme/pcbCbaSSeBlELje8WMp2/LVG+1/fOPzxfmrlWDCZQaELF4tK6CiraAv595fM+BrERi0JMCbxGNv/Ji98NqLbGl0wjb1/9LWJl5u2wv93NQR9uOxtfa60V9bn02nxHsPiEWy9sqVZsnISKBnShBY/opT7VWrz7R436qKkhEJAvIQ8Ez9/Cf8QYP5alkCep+Clqj4F+Kfj1pcqjxE76VECs1H7ysXBCAAAQhAYDYJIATMJm2eBQEIQKCLCFRL/fdO1b32fArQlDbdikFfOUp/WYHSrA+PTrmWg61eXiq/SglaKVHQc5WN4GUQ6BT4iCUJZ2z3yC3/j933xU8Wp6ZYPh5RVsB0iUDjrICylgEzEAE0kd+/6UP26tOPt9OTL0wjC2JM0ATg0cmI7RtK2NjUdD38EQOTtrxvyiKRxoaG1R5Vq2zAu7e824DaODYbzHtjKIiXoKVyFrU3bPbyslMkiPUnoy5LIBIOm7IUuCAAAQhAAAKzSQAhYDZp8ywIQAACXUCgWuq/17qtNxm1sQkFUpmWe7CXI5xuCzd9+q/2bNV6zjfCP9P0f43vFyTKsxFWLknYvqFJ2/HDe+y7/9tVJdNRLKisACcIlGQF1CoPqBY85ixolwDv4bHBAVt1/jpbfspJrkuA62Xv/rdG1kEjiAF+LkFgeDzqBIB4tEpA3UJcLEFg6ZGvteXHnOrKBsqvcnPBZt+PfEeKhBMBvCC+GYNCzae8baAyFpYP0C0gwCvDLRCAAAQg0GYCCAFtBspwEIAABLqRQK3Uf+8EVGnqavmnwLjdp//e2P6yAjnJuzZ+E+lA29GO9H9vDJ04qyyhWgmBWs8pAP31Iw/Zt//qnRVz8zICEpG8cWD9rIDyaDkfUActCYgO9tsR551ly16/1gkA+U6B+jMvBuRbB5YJAnpAC0F6rU3IeXUgufyzK6P3QNtXKqiU+QiUjyDXf+2NgnCtU6KUgvtGJSl6lxOxsB0aSbkhyzMN9P41yhIo7zigzBh1EOCCAAQgAAEIzDYBhIDZJs7zIAABCCwSAl7q/+h4yrX38y4v+FdgPNP6/HJU5ZkFEheqBV9eTb7aCNa7/On/Eg1UqtCsUOFvc9iohEAnwhOTGdv13O/tK5evqxr3Jl2JQMhlB8g7IH/VywqYZh9EBIgO9NvASWts8HWvtfiSgYrg3y8KLGRBQNSWrHqtrTr+7GK3gXKjwOlskogTA+oF8+Wn+cWdCeUNKr0sgVqZBp5w4H8n1YFAvy9cEIAABCAAgdkmgBAw28R5HgQgAIEFTsDrt66TVO/yp+dLFGi2BrseEn+6vu5rFGy7sDlkduTSpO06WGlMp5/7a/e1jlY6CfgFD80pyBjyC8jmzLWV+9vTVlRdtkoEegpigASB8qyAagfnGihISUDfia+xpWefYfHBggAQzp/8e9kAXuDvZQaUf382MgRy2VxRnKgKqIWMBM9HYNXRJ5Sc6jcTzK9amrT9Q5N1haLyLIHytoF6njwGvEsZAuHpDV7g/2Vg+hCAAAQgsJAIIAQspN1irhCAAATmiECt1H+5qCu4DRqgNzN976S91ZaCqsc/ODJVzBgoT/+vlU3QaI7lBoDl7enqfX5pX8zkkyAfhevWHWeTI0NVb1c2QD4zwCyqYL14l5euX/oxf9vAagPGViyzFZveZPIDqAj8i2KAVxZQKBFwIoQnFHjeASXOAe0rE/DZBKTSmfxzC89vixiQM4v3LLE3/PFHG5aM+Ds86B3xykvqCUvlcyzvOFCtbaB+d44YxB+g0e8bP4cABCAAgc4QQAjoDFdGhQAEILDgCdRK/Veg1JOMWqLgvt5O4z9BKz9pb9VXQEG3eslnMjlTLXa+9VtrXQokIqimXCngrZQ7+MUDrVHp4d/6j5fbc4/+tOp7okNiCQFJmQeWZwWUtRJsJAI4pqtPsOVv2mDqpuAF2PrT7wdQ/L7nE+AJBC7DouwYvoVT+Zq/ED4RQJ4B+vIyEio+08pzC+MvO/K1tm7jewKbSOp0XyKUvjLZrJtKo1KTaoKAfk/6eqLmtQ30sgTUMlDvExcEIAABCEBgLgggBMwFdZ4JAQhAYB4TyGZzrs+5P/Xff5quGvpW0+lrLbv8tL4d4oJ3+i7jPp3sBkndL5/fTE0EPQHAEw/ceL3RhkKA5hEL58UAZQfEamQFBBEBNFZy9att2cbzLRqNWCwasYhfECgP/Gepc0C+liHg1YoAoKF9zzj+tefZKWdf4lLzh8dTDc0BvZlJA1k+EHeBvDI5mu024ISYRMSJUf62gbFImLKAgNvPbRCAAAQg0H4CCAHtZ8qIEIAABBY8gd2F2nqdossETcFQq6fp9WD4T/89caFZsz7/+P7AXV4FCrb2HKruE1BvXn7PgyCeBOVzqJU9II4rlyTdnP7t/7jGHv/ebXXfFZcVUMgMqNZBoGAM0PB9i7/mVbb0jW+wWCxi0WjUiQESBZwgUOgW4C8DKBmw1SC8fFYKyv1jBRECZvLssvHXrrvMXrX6TJfREY/lzQH9Nfz1IKrM5PBovluAhBy9V17ZQKNuA/qMjAYlAkiMEm+JAkv66BbQ8MXlBghAAAIQ6BgBhICOoWVgCEAAAguXgIJxBdVjE2l3+t9MHXyjVZcH6+Ny60/lU69bvfyBu1+wKG/X1mj8VgwAvTG1LqV6SwSoVz5wzIoee2n/uP34G5+3//GNz9edkoL/3kjImQcqK2D68v4eJJo2i736VTb4xvOcABCLTQsBEgP0JcO6sur/9tX/V1tho2nPRADQ86qMv+7Cv7JQYrnzCPDX8Kt8ZHQ8XfMdr2Y86S8bmEqp/WD935Hy91DviUoDuCAAAQhAAAJzRQAhYK7I81wIQAAC85iAygP2Hq7vkN7s9DvRVtBLvddcqjn3rxiMB0rlnokBoF/Y0ImvgsJ6WQ1yn997eMIe+++32r996tqGGGUY2Kd6dVez3/D2qjfEXn289a4/x+ISAdyXSgSiFnV/RiwSVmbAdIvC1p5S+1P5kUOW8yL0ekJAi2ssPr3K2JlsyH69a8BOWL3WTlhzkh159CvyM3Kn81F3Qq9ODtXS/qu1/fOe5QkK+rxXMqOOGf4sgWqfP2JJwvR9LghAAAIQgMBcEUAImCvyPBcCEIDAPCdwaGSqxCeglen6T8nb5S3QjJ+AOhoowJXnQfmlcRTASQRoxQDQPw+JEC8enLCkjOHi9fvCS5xQC8HnfvET+/JfXNIQq+LivmjIfUVaVAIiq1Za75s3utP/vBhQyAwoCAEuKyAUbmsWgBf8FzQAU3Q87WlQZ9ktCAFFoaFGnv5EKmy/29tffOiqo15hp55xTlEQ0A+8gN1lwkxmimUDg70xJxL42/5Vm32t1oHln9cWHr28p+G+cwMEIAABCECgkwQQAjpJl7EhAAEILGACMkZr1iXdW67XWSAWCbUUZNcKtDz3/6CBe7XT2JkaAJYLAM/tH7fbntxnv9w9ZqNTGXv3qSvtHWuXVxUeJEyofeDwWMqeefwx+8ZfvDnQGyLDQAkBiUhFAn+gz3tCgIJQZQRIDIjHVSKQFwXyfgGRfAl/C4F46ceqdBjIKVvf/Z/GVxPP94sN9cY/MBK33UPJimdLEDhhzVo79rgTLJ7It/LzykzULUA+AIlYOFBWiTe43g+JS/LWSGWyrsTm0EiqWHqgd3hpP/4AjV8E7oAABCAAgU4SQAjoJF3GhgAEILDACcg0MKh5nwIemQsq2G3X6b/wNUr/r4fYX989EwNAL0BUGziJG8oAyAsAe+2Hzx72TSEfmp68qteuWX+0reqPl5w063OuRV5q3B68/3578G//faA3RKP2R0PWGw1bpIlA2TuBjx6ZzwjQFY6ELaGsgHheEPDKBFQi4MoDmhh/+tbKD2koL+7Xmt3VSAgI+OyiAKC/BBAZXjrYY4fHa9fkx+JxW/u6023tyacVBQHNf6AnZv09URufytT1Eai2iV7ZgDICJKrJu0IGhUv74u6d5oIABCAAAQjMJQGEgLmkz7MhAAEIzHMCckqXs3q9y3+CqvpoBbtBxYNa45Yb7zUyY6s3P9Xk53vTV/cRaLQFym6QAKA5FQWAJwoCQNXANf/N3ljYPv6mY+1Na5YV/QtGhofsye3b7Nnf7HT3PPtPn2r0+OLPlRUgMSAeMCtAsXcml6+Dj/uEAA2o4F8ZAfFYzDnoe74BkXAhQG0QkJek/VdZgV8ECBKolwxR59mV5QbBMg2e3dNnk+nGwXe5ICBRS+KWgvi+ZN4HoJn2gTIFjEVlupkpdhvQWmXOyAUBCEAAAhCYSwIIAXNJn2dDAAIQaILAH17eazd+7Rbb+cxzdvkfv9E+dNXlTXy6tVvTmaztOTRZ8WH/6b8CTqVQl5uktfJE/6l90PT/Ws/xZxIokKvmE1Bvjn4DQbnKP39wwm71BIDyDPo6gsA5r+y3958yYM/+8udFAcB77vP/+mVLD/szCmrPKGhWgPZDPRhU1y4vgqgyNcqEACcOxGOWkBigEoFCuUBU5QF1sgIaCQAV2QQBTuu9FefFmvqGhcXCiICZAN7YT788YNlcc8H38a8+wc484yzrW3Zk0fyvlo9ArV2TH4TXFlP3qNxg+QBlAa38t4HPQAACEIBAewkgBLSXJ6NBAAIQ6AiBoZExu+IDn7SPX/Pntu70tfa5r/yrDQz02fUffk9HnucfdN/QpKvz1+UF6qpzV2tBCQBpHTvP8PJqqls9tfceX80AUKn8PcmoHRyeCjTL8g4CLx4qFQD8g1SEljUEgVguZevSO+yY7N6SObz0b/9iEy89F2hejn9YJQLhql4B2gUF/vmvnBMD8p+pFAL0/YhKBOIxJwRIEMhnBSjzob5pYFWXgmrr9kQAPSzAK5LNZl1ZgjMt1FVlTP+zg3oOpNJh+82eaaPAwLALN9YyFlStv+cjoKyZcp9Cryxl96GJ4s/UMlBZAlwQgAAEIACBuSaAEDDXO8DzIQABCFQh8ODWR+3OLVvtFUcd4YL9u7ZstZ89ttM+c/3V7m5lB1z115+zzV+8zl5x9MqOMtRJv2qklSatS+nxEgbamf6fyuSarsH2L7qeAaB+ppPZapkN/jE8AcDLRHj85RGXAfDU7rGGfEti1jrZAcdk9tpZ6V9Z3PLlFru23GZjv3+64fjeDRpapoG90ZBFC6fnirG9MgBPAJAYoO8rpFZJQfyoaY8A/8PkEZBIKCsgLwjo342zAqqYAVZbQRPZAPp4KpV2AobmUCszwBMCgooAGvfQWMxePjRzl/5ly49wxoInrD6pxEdAgb2MASfT2ZJ3uJpR5aqlCYtGaBsY+IXnRghAAAIQ6BgBhICOoWVgCEAAAs0TuOHGm23bYzvtnNPX2mUXnW/nnHGSG0TCwLduv882f+n64qC6V5cnDjT/tOCfkBig9mlTae+cOfhny++MRkIucFLt9UzT/4MaACoAUweEcvFCIkEyHnantEUB4KVhu/XxPfbUnvF8gX3Aq5XsgH0/udeGnnwk4BPyt8W8rIBwyAX7XglAvhwg/6VAWT+LhUIuIyBRQwhQwC0hQJkBiYT8AvJiQLhOVkBFeUAbsgG0rpGxCctkMtaTTFjRuFA/qLUFAbIM9PHdh5N2YLR96fi1jAV7E9ESHwFlC/jbDupdO3JZZeeCpjafmyEAAQhAAAJtIoAQ0CaQDAMBCECgHQSu+uiNdvnFG9yX/1IGwKY/+0923y2fL2YA7HjmObv2k1+x+2/9QjseXXeMQyNTrtZ5JpdO3JVO7ZnuzcRToHys8QZzWzYQt3H5GBRKHMrNCJXlcN/Te+3Wx3bbntGUhZSeHgoV/qwXjVYSCZodsDr9vB31yGYb/fkDTWHV+MoISBaFgLwnQN4bIFf4M58N4MoCwiGLDfRa/zveVvU5kWjYkgn5BcTz2QHKClAHAUXgNYLwkvKA8nuaLAnwJrVn32HLZLPW35e03p5EvkShFvqAIoA+/ty+Xhuban86vicIKEugf2CwyNbLBEjEIq585vBYyu2JxCaVBnBBAAIQgAAE5gMBhID5sAvMAQIQgECBgL8EwPv7hRvOtDdvONOqiQTnXfJXdsc3/96VEHTyUibAvsOVpoGNnuk/cddpvEz3vGC80Wer/by8fj9ohoKCMM1FtdyeE7yXAXDfjt12y/aXbM/ItABgoXA+PT1c+NMFxcGzAyri1xrlAvGhl+24+260gT881hSOaNisNxIuCf5ddkBBENDfo6FCWUADIUAPVllAMhm3pK9EoH5WgG9B/rUVuwQWahOaWNXLuw/YxGTKBgZ6rL+vx/kWlJgDemNVEQFKjAbLntmKUWAT03a3lvsIKOvliMGEKxdIRMPOS0MZAhIJuCAAAQhAAALzgQBCwHzYBeYAAQhAoEBAJ/9XXP13LrBfu+Z4O/v0tfa1zXfZOy/eYOtOO9Fu+Nw37Y6bP2WD/b0mA8FN7/6YPfy9f5wVfnsPTZhq+YNcXvq/gp+ZthSsdnrfrD+BWgAu7Y+7VG1lDyi74d5f7bJbfvGC7RmZsmLgX5IJ4GUESAQoCAJtFwPyksHK7d+1ox/ZbJHJkSB43T3JSD4C12mz8wko/j3nDvJjKgkIkBHgZqASiUTMepLxooFgvayAijZ+biL5qeeLEoIZBPoX+/sX9lg6nXGixJLBPuvvTTpDw+JV6BRQDsiZDLqWfJVB9kyNAgNvRuFGCQJnnX2unbj61XZwZMoZaeqVGeiJWX/BY6PZMbkfAhCAAAQg0AkCCAGdoMqYEIAABGZA4Jobvmzve9dFTgTQtW37DicAqATgxq9+x3b+5nlXOvDA1kdt7erjZqWNoOah4FklAvWu8pT9mZgK1jMADIrXP0Y4FDI5uN//q5ftX7c9Z3tGJstKAAoBv3q8FwN/f4lAa6UCmmvx0LyOkWDs8Mv2+s1/GnRpVtABijG3C78LooB+5pUFSAyoVxrgPTAaDbv6fFcmIL+AeLSmg3+tNoKtigCaw7PP7bKsDA5kiNiXtCWDvdaTiE8bB1Yhk85knMlgssZ9w+NRe/Fgb2Cm7bjxjzZeYK9ac0qJHwXdAtpBljEgAAEIQKCdBBAC2kmTsSAAAQh0gIBO/te//YP21I82u9FVMvDSrn1OBFDJgLwC1DlAWQKdvNKZbFXn/fL0f3/f9FbmE9QAsN7YGqOvJ2pqHaj6f2UBPPGH/fal+3fYriEJAGXBvlcCUEsAcOZ5ygrwWusFKxMI4hegtHbvOvWf3m7RqeBZAbUYtCIEaCwJAMoK8MSA+lkBpQxmIgLIG+B3z+0uLiccDtngYK97p+VZUH6J2VQqbZNTKedn0JtMVEVxYCRuu4dmx6AvkUjYposutmWrjkMEaOUXn89AAAIQgMCsEkAImFXcPAwCEIBAMAIK9hXkKxD6+ua7XCcBf8cAjeK1EFQZwc5nnnNZBB+86vJgD2jxLr9poHfa3puMOlM0BdzNpuz7p6FsAnUTUIzuBe/NTlMlABIAPENCv4ngp+5+1B5+7pBLgy8G9WHvlN8L8kMWKgT902aBzZcFNCsAeOtc+52rrXffb5pddsX9WqJXFhA0I0CDiE1vMm49PSoR0Fehg4D7YeljSgwDpwsCiiUCzSxicjJlL7y0r+QjsVjUlg72Wn9/z7RxoJkzFNT9EgGUQSBzQXU9qHZ1yiiw/FmDg4N28SWXWrx3qSvVcLhCZsv65bsQaQYF90IAAhCAAARmhQBCwKxg5iEQgAAEmiOgEoC7vv9Q0Svgug+/p+LE328e6IkCH7rq8oqOA809uf7dk6l8fb3f/X8m6f8KlhQoycBPIsLwWGstCj0TQQVhMmar1kXgviefty/+YIdFEj1m4ch0WUDJaX9BGGjRJLCxAJCv6a91rbnjo00bB1YbS0KAaxtY8AiI1+kaUP75aCx/wp4XA2IWi1XvIOCVBxQzATRQMAuJiimPjE7Yrj0HK77f2xN3mQF9PflT/al02iYn8iJAKp1xmQsD/T0Wi5ZmDWSyIdt9OGGHx9vXNrDWnkkEePtlV1ooNp2RIyFqxWDC5JXBBQEIQAACEJiPBBAC5uOuMCcIQAACVQhIHBgeGbPPXH+1++lb3/0x2/zljxc7BiiLQMaCnW4nmEpnawbbQTeuHfX/elYzXQR+u/uQfXDzjyzS02+RRJ+FItF8i8BqQkBHTAHriwBazwn/3yds6bNbg2KseV+5EBALmw386ZWBx1UHAQXhXieBohFfvbi2RRFAkzp0eNT2HRiqmJ+2QRkBg/09hUyAtE0VRACVD6i7wECfTAWnT92n0iF7fn+fpTKdd+hftWqVXbjpEgvH+4pzz4sAcYv6jQ4Dk+dGCEAAAhCAwOwQQAiYHc48BQIQgMCMCSjwV+DqBfrKCHjflZtcCYF3vfVP/8Zu+vtr7KQ1x8/4ebUGCGIaWOuz7RYAlI2g+QRtI/jO//vfbCwTsmjvgEWS/U4McHX/ygCYaYtA/6KrBMz1MgG8j77yf3zVVj12+4z3rlQIyHcQaEYIUJDd26OsgIT7Mxb1pbdXEwNmIAJosXv3D9nhodGq61bnAKX/y6NCxoDqLKDuDyphkEig7gKeUCFPgL3DCcvmOn8Sv3r1ajt/41ssnZvORkAEmPGrywAQgAAEIDBLBBACZgk0j4EABCAwEwJ/2LXPeQXoT3UM0Ne3b7/PdRDwMgQ0/g033uwyBDrpFVDLNLDe+srN+1opJ1CQFY+FXRmBPt+sJ4Hm8LH/8mP7xe/2uPKASM+gyw4IR2L5gu42XdU6BAQRAfT4Y391p638wZdnPBNPCMiXBzQvBGgCKgnQiXtvbzMwANwAACAASURBVMKS8di0e38HhIDdew/Z8Mh4zXVLDJAfgGesKLPHvt6EEwL6ehJubhIA9g1XNw2cMdCyAc5bv95OPu1sm5jKty50vCIhWzZAJkC7WTMeBCAAAQh0hgBCQGe4MioEIACBthJQ2r/KAgb6e13XABkHyhdg05/9J7vvls+7rgG69LOfPbazRBxo60QKgx0YnrKJqUzDob3Ufd3YqgGgBAB5EmisVgWAgd68geBXtjxu3/zhUy7wjyb7LdI7aJFkn4XC7TF0KxUBCv/KlVTR12V2fOZle/0L37e9P7zb0sOHG/KtdYN0jWmzwNaEAI3R25t0p/HKCoiohMK7/GLADLMBNOTzf9hrU1PpwOvNlwUknT9AJJq0Fw/02mS6PXvYaBLnnrveTjxlXYkxpnwulvbFTPPiggAEIAABCCwEAggBC2GXmCMEIND1BOQPcOGGM+3s09ea0v83f/E6GxjoK2YJfOXT1zpGyghQW8H3Xrmpo8zqlQe0ywBwpl0JJBz4TQ1lIPizZ160G25/xLFR8B/tXZIXA+LJlrICqp3+y15f33fxcRMigG5fmT1oG9PbLTM5Yft/eq+NPP14S/uo5ycioYJhoFnPGX9kPScd2fRYygqQ+KTT9wpn/uIimx624gO/f2GPS/kPeilDwIkA8QHbPzowK34AmtsfbbzAXrXmlBIRoC8ZtSV91bsWBF0P90EAAhCAAARmmwBCwGwT53kQgAAEWiAgP4CbCsH+tZ+4qaRdoH42PDzqhAGVBVTrMNDCIxt+ZPfBiZKAqF31//5xFLxLdGimLWG1LISR4SF79jc7befTT9vNv5t2dw/H4hbtW2rRnkELResHc/U7AhQ89KWCeMF/kyKAgC/JDttb0z8rsh/93U6XHZCdmmy4H/4bNBuVBQyccKL1n7vJwgNLLTb5dFNjOLEkZNbXmz95l19AuIUSCqXzK3W/3vWb373c1NwkUFh8lY2lB5v6XKs3JxIJu+Qdl1n/0iNL3sWB3pgrVeGCAAQgAAEILDQCCAELbceYLwQg0JUEZBR4zhknOQFg7Zrj7YGHfmEPf+8fiyzkHaBSgbvv/Ym7R6LAZ677QLFkoBPQDo1MuSC9EwKAygia9RGoVYbw22d22JPbt9noyLDDsOXActudmq4lVweBaP/SvHmgL/29cSvA6TvygW7O+9+mMwG8/enNjdvbUv+zZLtSw4ecGDDx0nOBtzE2sNSWr9tofSeeVvxMK0KAPhyLRW1wIG/Kp78HvSQAZLJZV1JQTwhIpcP27EsTFs4ctJBN19zXek7OopaLH2WZ0LRTf9A5tXKf2gNe/u/eZdlIT0nrR0SAVmjyGQhAAAIQmC8EEALmy04wDwhAAAJ1CCjlX2UBMgnUpX9fdtH5ThzwLmUGnHP6WvuLKzfZf7n9PrtTXgJfvK5jYoBMA3VSLyGglcDdm7dM/JTCLyPAVnwEggoA3vO2DQ/ajjFfECm/gN4leTFAJQIusb9WpUDpyfZ0gOsTAUymdq2/zldOPVj1w4efeNgO/vzHDbMD+k88zZat22gSA/xXdPKZQIF2+cNdC7++pPX39zqn/kan+/q8RIB0JuMM/ipKCsoeMJEK2+/29pvlMhbO7LdI5mBNeBIB0vHjzEKzk4ovEeCSS68oaQ+oyakUQCUBXBCAAAQgAIGFSgAhYKHuHPOGAAS6msDQyFj+5L+/t9gq8OQLrrKnfrS5yEVigX5+/Yff0zFWw2MpF7y3ckkA8Ez8ZkMA8Ob4/ETCfnh4ecmU5RcQG1hu0b4lFlIXAU8OcHF/9bR2f7b7dOA/MxFAT/vjqZ9an01URarsgN1bbrOp/bsrfh4dWGJHX/qXFQKAd2N06lkL5VKtbNV0VkBf0mLR+gGwsgAkAmTSWVMtfyMh4NBYzF4+1DM9r1zKIqmXLZwr7SKQiSyzbHRVS/Nv5UOrVq2yCzddUiICaM+X9cdN5oBcEIAABCAAgYVMACFgIe8ec4cABLqSgNoIbntspwvyf7Z9h11/zZ+7TAEJAf4OAl5XAb840G5g9UwDaz1rpp0Ems0AKJ/HSCZid+yrDCjDsYTFBldYpGfAQqFwPvyvIgSUfCufCFC4Zi4CaKC3pLbZ0txI3a3a95N7bejJvOlhOJ6wJaee57IA6l0zEQKcV0DBpb83mW/XV365UoBMXgRQtkg2m7VkImbxWP3T+wMjcds9pEyM0iuUHbNIeo+FcpOWiay0bLRUvGn3u+wf7+STT7bzN77FRiemDQy15JVLkhaN0Bmgk+wZGwIQgAAEZocAQsDscOYpEIAABNpCQF4AV3zgk3bHzZ9yKf8Pbn3UvvbPd9od3/x7U2mABAGvfEAP1Pfed+Ume/OGM9vy/PJBlPq9+9BEoFR4L4BXOYEEBBkBNnPNVADwP+v2vSttNFt5sh3p6bfYwAqLJPKGgtXEgOqd89ojAuiZG1OP2srcoYZoMvffbYcjZgNnV5YBVPvwTIQAjed1EOjvT1q8LCtAQb8TANKFbIBM1rXSU9vBRhkEuw8n7cBovPZ6cxmz0OydwJ+3fr2ddMrZNpWe9itQ+cuKwbhFI74Wig13iBsgAAEIQAAC85cAQsD83RtmBgEIQKCCwLbtO+zbd9xvXrvAHc88Z5/72i22+UvX211bttrXNt9l99/6heLnlD2g64NXXd4xmp5pYK0H+AWA4bF0SYAVZFLtFAC85z14aJm9MFl5Ci2b/FjfMosOLLdwoYtA/vC7zilwC90B6q17XfpX9qrsrpq3hA8fsuT//LElfvWEZQaX2NhFl1r62Fc1RDlTIUAP8GcFKNDPewF4pQD5TABlBUgYSCRipuyBaLR+EP/CgR4bmZidmv9GkM49d72deMq6im4YiACNyPFzCEAAAhBYaAQQAhbajjFfCEAAAgUCEgGu/eRXbO3q4+ydF29wp/5v/dO/cRkA771yk7vrxq9+x7UU9P7dCXiTqYztH5oqGVonqDL/U2s1uf/LA6CZFoAarBMCgDfJX4312c+Gq7eec34BS1bm/QKKJQI1hIA2iwCa3+syv3Nf1a7Eo484ESA8WdpOcOKMc2xi/UbLJauIG4WB2iEEKKhXK0GZB6r+35UCpCUAZNzf8yJAzukmEgF6enSKXl8IeHZPn02mZ+/Ev9bvwEUXX2xHvnJNyXsai4Rs2QCZAJ347wZjQgACEIDA3BJACJhb/jwdAhCAQMsEvn37fc4nYLC/1z771e/Yh6663M4+7US74uq/s/e+6yIbHhmzB7Y+ajd9+lo7afVxLT8nyAd3H5wodhBQBwAF8fNRAPDWciAVtf9+YGXNpYXjSYsvPXK6RKAiK8DfJSAIoeD3rM68YKdnnin5gLIA+r/7LQsPHS5+v1yaUHbA6KV/YplVR1V9WDuEAA3c25uw/r4e1y3C8wTwsgAk9kgIiMciriygpyfh2gfWu55+ecCyubmru08kEnbFlX9iR6zMvw9e1opEgBWDCVfiwAUBCEAAAhBYbAQQAhbbjrIeCECgKwlIFJCBoEoGZBL44E+2u3/LQ+CYI1fY4ECf3fT313SsleDEVL5VXG8yamMT6XmXAVDtpbhlz5E2lasdpMb6l7nMAGUI5HUALyAM7geQy2UtpE9XMder9aIek91rb0g/6X7slQHEn3qiZnVCeZg6dsEmmzzz3Irho1O/d8Z7M72UFaCMAAXILgsgm7VsIRMgW2if0JOMOxFA7QbDdYSATDZkv941MNMptfx5tQe89J3vslwk37VAnSyW9sec50Usih9Ay2D5IAQgAAEIzHsCCAHzfouYIAQgAIFKAgr2h0fHXVmArms+cZP7u7ICdKm9oEwF5R0gU0EJBd+6/b4S/4B2c52pAKDTZL+HwG+f2WFPbt9moyPD7Z6qG2/LgeW2O5WoObYEgPiyoy3amw9UXSxf0iWgzrRUMpBNWy6btXA03pQQsDJ70Damt1v8l4/nywB8WQD5iVQ+t/xbqVce77wDskuWFm+OTD1f0ZKvVbAK8iUEyAvAiQC5fCaAPAMkFOSzAeLWk6h/op5Kh+03e/pbncaMPicR4JJLryhpD6gBB3pjrqSFCwIQgAAEILCYCSAELObdZW0QgMCiJaCT/r/97Ddc/f/Q6Lidc/paO/v0te7f+lOmgp/76ndcNwHvuuHGm93PO2Uc2Mg0sHwzapkIdloA8Obx2Ei/PT5a/zRa3QMSK46xUCSomZ0EgKzl0inLZVIWikQtFFO7veCny0ce+p1dfN/nLPbbp2u/vzWy1Us6GiQSNr5+YzE7oJ1CgPwB8kLAtADgTVYmgcoIkBiQVEZAnWyI4fGovXgw36FhNq9Vq1bZ2y69wlLZUm8CRIDZ3AWeBQEIQAACc0kAIWAu6fNsCEAAAjMgoFP/nc8858oA5AUgMeCBh35h11/z50WvgIe/94/FJ3jmgv6uAjN4fMVHq5kGVht/rgUAb067puJ278EVDRHEZRw4sMJCDdL7nQCQSVkuPWXZTNqNG44lLBxLWqhBnbw3iTMf2mxnbv2XhnMq3hAwO0DeAeHwnrZlBNSbYDIpISBhvT1xSybiNblNpUP2/P4+S2WCiyTBwdS+86yzzrIzztlQ0b4SEaAddBkDAhCAAAQWCgGEgIWyU8wTAhCAQBUCP3tsp934lX8tnvyrhaC+VBKgDgIqFbj84g3FT6pc4LoPv8fOOeOkjvD0TAPnswDgzW0qG7I79q2q6xOge3Wqn1jxiqJxYPna5AOQy6TzAoDLBFBJQMaVBDghIC4hoL4r/nG/3mrrf/BVGzi8u/l9CZgdMPa2Cy13TGfT8KWVJBLKBshnBCg7wHkklF2HxmK2+3By1k0Cz1u/3k58/TpLZ3IlM1raHzeZXHJBAAIQgAAEuoUAQkC37DTrhAAEFiUBlQjs+M3zRW8A/fvOLVudaaAEga9tvqvEF0DlASod8IsD7QRzeDRloxP503Dvmi8ZANXWOZKJOK+A0Wz9mvBIz4Allh9dEtCrHn5aAJjK/73wpQwAZQKE44WMgEj18fsP7bLTt2621z55rwuXZ+RPHyA7YOr0U2xq3WvNOhT0KmtCGQESASQGJOLxCux7hxO2b7i2N0M730f/WOeeu95OPGVdSXtACRcrlyQtGpkR+U5NmXEhAAEIQAACHSOAENAxtAwMAQhAYPYJKPBXO8H3XrnJPVxZAe+8eEPRF+Cqj96YbzN4+tqOTC6dydqeQ3ln+vksAPgXH0gMCIUsvmSVqZNATv+TyRQyACQAqBxAWQB5IUCXsgFC8WSxNCAcrfQYWLvtdnvdz263gcO7nBFhXgjIn5+3HJYGyA7IDgzYxIVnW/YVR7T9HZAQIJNAdQuQoWAiPr1ulQK8eKDXJtOze/Ku9oAbN77JjjpuTYkIoPaHKwbjFo10vjRBZTz6veSCAAQgAAEIzBcCCAHzZSeYBwQgAIEZEvA6Bcgg0As61F3gqr/+nDMJ1CVPgcv/+I1OGOhUYDIykbZkLOyCLn8XAPf8px6zXzyydYYrbf/HVSaw5eAKO5iubQoow8DE8qPcw7PpqaIhoPMDkDCQ1VfWQtHotABQ8AgIxzRuPkrvO7TL3vKvH80LAIXvuj/bJQboIUGyA05dY5NvPMtClm0bUBkIFlsH9iQsHstnQkykwk4EmG0/AHUGePs7LrNwcolrCehd7RQB/rBrn117w5ftve+6qGqmjX4vN737Y3bhG8+yD/7lZR1r4dm2TWQgCEAAAhDoCgIIAV2xzSwSAhDoBgJqEaigw99CUMG+vvf1zXc5Q8GPf/g9tu2xne7vm794XUeCEgVc+4cmbSpdGmB6GQIPPPCgPfH49nm3JRIDtg0P2m8nap/cqkRAngH5LIBUIfjPmD/K1M89k0DPI0AZAr2Hd9uJ2263E392uy/oz8fsOpOeDTGgXCPIHLHcJt72R5YbqEzhb2WDnBBQKAtQeUAsGrW5KgVwIsBlV1ooVrqf7RQBPEb6HdPvnzw7yjNubvzqd0xiwYUbznSlOjL1lE9Hp4S4VvaNz0AAAhCAQPcRQAjovj1nxRCAwCIloLT/z1x/tQtKv/4vd7uA/46bP+UCjms+cZMrEXjzhjPd6iUMKDhx93fg8psG+ksEDo2kXKbAj3/wPXvx+d914MkzH1JiwI6xvuoDhUIFIUDBf/WTdBkD5oWAvFGg/nz1jh/ayQ/9i/Uf3j2dBVDMAPDKAnx/djg7wC8I5BIJmzznVEud+poZZweorWC+NECtA5O2b6TXDo+3R2RoZmclAlxy6RUWjpfuYyyicoCEa33YicsT3Xb+5nln2Clh4Jobvmz33fYPxcBf3h1rVx/nvrggAAEIQAACc0UAIWCuyPNcCEAAAm0mcN4lf2Vr1xzvgg95ACgQ8S6JBEpL9roFuDKCq//OPnPdBzrSQUBeAaMTGetLRqqWCExNTtoPvn+nHTywr80U2jPcYyP99vjoQGuDhWQUKBEgYQPjQ/b6bf/VXv3UD0rLAHxeANMlAWWCQIfFgPJQePL002xqw6kWyuU9Hlq5ItGwEwHiiR7bN7bc0tnZ9QPQnI899li78KK3W6rs2eoKoO4As3F5ngDKBrjr+w+5chwZdJYH/zL3lCB32cUbyBCYjY3hGRCAAAQgUCSAEMDLAAEIQGCRENCp/+UXnW83fu0WtyJ/6r8CkuGRsZIMgGrfazeKA0OTNpGqfnI+MjzkxIDRkeF2P7Yt47UsBoRCTgh4/fb/bq//+R01fACmjQFdWUCt7IDC98OFov+WzrEbGAh6tfNTJ59mYxdfauH0HotkDrbEMBqNWCi+wsYzSyznCh5m9zrrrLPsjHM22PhkpuTBA70xG+ip3xmiUzNVoP/Srn2uzeenr7/aTlpzvCvXef9Hb7SB/l4nDjzw0C/s//z4f+iYiWen1sa4EIAABCCwcAkgBCzcvWPmEIAABCoIqAZZl7ICdALptQmUaeCmP/tPdt8tny/6Aux45jn7xI03m8wFO3VNTGXswPBUzeHnuxjwm/Ee+8nQ0ibxhOz1T3zPTtn+b9O1/2XBvgL7olGgTwQoegUU759BF4EAHQT8BnqeEKDFhrJjFkm9bCErbQXZCEQudqSlw83yajRqsJ+rPeDaU9dZOuNzBTSzuRABFOjrtN/fplMeAirXUWtPzzdAf9e1bfsOV87jz+IJtmruggAEIAABCLRGACGgNW58CgIQgMC8JKCAf2Cgz3UHuOFz37T7b/1CcZ433Hiz+7vnC6CTyqs+8lm7/7Z/6OhahsfTNjyWqisG3HP3rZaaqi0YdHSCDQY/kIravQdX2FSu8Ql3Xzhtp/eP2PEvbLfMnbdaTNkBVToCVMsCKDUM7KwAoCX7RQD92y8EOCS5jIUz+wNlB+Qsaun4cWah2l0XOrmHEgFOPGVdSXtAPW8uRAA91+sU8PD3/rG4bAX/uiQO6PfO7xswW7+LndwDxoYABCAAgYVFACFgYe0Xs4UABCAQmIB8AcqzAtRKUO7lbz7/DPv2Hfc7B/P3Xrkp8Jit3thIDDi4f6/dc/dtrQ7f8c+NZCK25cByG83WTi9/TXLMiQD9kYxldu+y0c3/ZMlQyOIhrytAPriXT11J28BCx4Dp7+fva+lqMgvA/4wKIaDww1Bm2CJptTqsXuKRDfdbJrpqTkSARCJhmy662JatOq5CBJAfgHwB5uryTv3fd+Um5wMgIU7mnTILVEmA191D81Mmj0p3rv/we+ZqujwXAhCAAAS6jABCQJdtOMuFAAS6h4BSkz/71e+UZAW42uSPfNZlDfhFAj8Vz+is3aRUIqBSgVrXb5/ZYQ8/9EC7H9u28WqJAcoCOGdgyI5LlprsHf7c/+4C+t5QyBJODAg5L4CS9P9CtkD+e/koviURoM6H/D8qzwIIIgS4e3IpVyoQzo2X8MxEVlg2ekTbGDczkDoDvO3tl1q0Z2lJdoMYqzNAPNo4g6OZ57Vyr4J+ter0gnz5Abz13R+zzV/+uL3iqDw3L3vgK5/5CB4BrUDmMxCAAAQg0BIBhICWsPEhCEAAAguDwFv/9G+KJ49qaaYvnUaqdOBD739nSQ2zygpUTqA/FbjIvMxrN9iO1WazOdt7eLLi5NY/9hPbt9mT27e143EdGUNiwIOHltnBdD4F/rS+YZcFUO0a/ud/suyeXaYzaYkByg6I+LMB2tEVoIFqUCICuIC+NpZaGQH+T4TTB1y5gOQMZQHkIi12Vpjh7kgEePtlV1oo1lsyUiScbw8YFeh5eqm7R3nJgNducJ5OmWlBAAIQgMAiJIAQsAg3lSVBAAIQ8AgoHVmnkgrof7Z9h0tNfsXRK12wX24eqFICpSur9aACE9Uxe/e3i6jaCu4fmqopBiiQe/qXP7dHHv6f7Xpk28eZyoZs2/BgsQyg1gNGvrPZMi/83v1YBQV9oZD1hENOGMh3CZiBD4AGbTbWrSMCaLggQoBbTK7g9zBHfgCrVq2yCzddYuF4XxURIG7RyNxnAtR76fQ7KTFOJTl3b9lqd27ZWtLho+0vLANCAAIQgAAEqhBACOC1gAAEILBICSgbQB4AH/zLy1zkWW4MWN4+8OQLrrKnfrS5SOPrm+9yLuftdjKXGKDMAH+augQA1XP3JCKu9dsPf/hDe/pXjy/onRl/YItN/fzh4hoSygwIh6wnFLJoaAY+AB0QAZoSAuZwV1avXm0bLnirpbKltf/5TID5LwIInUoB1EHgru8/ZOeccZL7/ZQ4xwUBCEAAAhCYTQIIAbNJm2dBAAIQmEUC5bX+SknWCb/8AXQqqfR/tRlUCzNlDEg4uOnvr3F9zr1L3/vMdR9wAUs7L39bQQX/6vE+lcqaTAUz2fzR9Y9/8D178fnftfOxszrW5M8ftokHtpQ8syeUzwzIlwk0e6TfXBaARnckG2QCeBMMnBEwqxSnH3be+vV28mln28RUqWlhLBKyZQMLQwSYI3Q8FgIQgAAEIFBBACGAlwICEOhaAgqG1b/bf13+x28scfNeTHC89oHHHHVE0bxM35Np2QevutyJA7q89oL6u9fyrBNu5mOTGYtHQxUCQDEwnZy0H3z/Tjt4YN+C3Ib087+30VumMyy0CM88sC8sMWDaIDDQApvQDZRvkJMCEFAE0PMn1v+RTbxhY6CpzPZNag+49tR1ls6ULigZj9jSvpiF1XKBCwIQgAAEIACBwAQQAgKj4kYIQGAxElBbL12qib/2Eze5NHjVyC/GS74Aah/ogv3CKf/X/vlO+9njT9tNn77WhodHK3wDJJR8/V/ubnt5gMdXYsChkamauEeGh5wYMDoyvOC2JHv4kA3/05cq5q0K9v5QyCQGxIKWCASIc6dvyecC1OsQUA3mfBUCLrjgTXbc6tdX+Er0JaO2pC9v2sgFAQhAAAIQgEBzBBACmuPF3RCAwCIkoBT6K67+O7tww5mLvo+36v7Vs1yn/hI/VKvsNxK8+96fOPMylRAM9vfOSn/zRm0FF7IYMPSlGy03OVHxWyPzwAH5IhT8Amr+WjXRFSA/RmsigD4534SARCJhb9l0kb12zRqXNXJ4LFUUNwZ6Y66chAsCEIAABCAAgdYIIAS0xo1PQQACi4jADZ/9hv1h9/6OnXrPN1RK95dR2fDouCuDUFmABAH1O1dGhPfztQWvAGULSBTo1BWkraDEgHvuvtVSU7WzBzo1v5mM67UQrDZG3CcGhMv9AuoIAFV/pM8XDAGazQTw5jafhAC1B7z8373LspEeF/z390StPxm1Q6NTFovmPSW4IAABCEAAAhBonQBCQOvs+CQEILAICCgAVnp8u9vkzVc0noGgTADfd+Um18JMl9dO0OsaoPv0Pb9xYCfX1KitoJ59cP9eu+fu2zo5jbaPPfrfbrX0MztrjiufgMFw2JkHFgP8GiJATW2gDSKAJjhfhACJAG+/7EoLxUrFp2gkZCuXJF3rRS4IQAACEIAABGZGACFgZvz4NAQgsIAJKDVeLfW+8pmPLFpfAG97FNSr1l+eCDr1v2vLVpf2v/mL17nWZeoe4P79pevnbEertRUsn8xvn9lhDz/0wJzNsdkHT2z9kU3+5Ed1P6YuAoPhkMX9Ea4v2K0rAGjkGWYCeJObD0LAqlWr7MJNl1g43lfCTGiW9cdN5oBcEIAABCAAAQjMnABCwMwZMgIEILBACVzxgU+6mb/3XRcVV6AUeNXML5bLEwC2PbbTPv7h95SsTdkQ3/ruvU4IcAJBQRSot3ZlCqjtoDoNdOLytxWsNf4T27fZk9u3deLxbR8z9eudNnbnrXXHVaAvv4ABv19AIfqvffjt/0nzxoDVJjR62Z9YavWJbWcQdMCTTz7Z1p33RkvnStP+I+GQrRhMmDICuCAAAQhAAAIQaA8BhID2cGQUCEBgARKQcV75pdZ6l1+8YQGupnLKO555zq795FecD0CtNTUT2IuXjARVtC3xoFPeAcPjaRseS9Xdg4UiBmR277KRzf/U8H1SJ4El4ZDrJlD0C3DtBatd7RUBsoNLbORP3mfZJUsbzrNTN5y3fr2dfNrZNjGVLXlEXgSIWzQiQlwQgAAEIAABCLSLAEJAu0gyDgQgAIFFTODBrY/at26/r1g6IFFAWQadKiWoJwYkY2Eb7IvZAw88aE88vn3eUx/5zmbLvPD7hvNUI7wlkXwngaJjQIkYUCoL5FQT4MoCKi/3M9dDoP4pevqVx9vYxZfOqQhw7rnr7cRT1lW0B0QEaPjKcAMEIAABCECgZQIIAS2j44MQgAAEFjcBZRR4ZoFe9oQ6DHiXDAfrZRvMlE55W0EFhnKLj8fCJqFgfDJjP/7B9+zF538300d1/PNj37vLUr98rOFzekIhW+r3C6gqBBTC/BoiQCaXs7SZxUJm4TpCwHzwBPijjRfYq9acUiECxCIhWzZAJkDDF4YbIAABCEAAAi0SQAhoERwfgwAEILCYCcg88KqP3mj33fJ5VwZQnhGgtXuGg/ffwLx1OwAAFzBJREFU+oWOoPC3FZQA0JeM2uhE2okA3jU1OWk/+P6ddvDAvo7MoZ2DBjEO1Pm9jAMH65QIFLwBK6am76dzOUvlzFROX0sIUCnA+JsumlM/gEQiYVdc+ScW7VlaIQJI6FneH7dwGE+Adr5/jAUBCEAAAhDwE0AI4H2AAAQgAIEKAgryv/3de23tmuPtM9df7X6uDIByQ0F97zPXfcDOOeOkjlBUJwGlt0+mMk4AyGQrj8FHhoecGDA6MtyRObRz0CBigHzxl7sSgUJdvC8rIFenFGAqZzaZyzleEgHiVTIC5oMfgNoDXvrOd5lFe2Q3UXL1JiK2tD/eTuSMBQEIQAACEIBAFQIIAbwWEIAABCBQQUClADJOVEtBL9DX93b85nn7yqevLd5/zSdusndevKGjnRYULO46OF4RNPonvZDEgMmfPWwTD26p+9YlLeTEgJjXUlCH43VKASZzZlO5nGUKJQFqRVguBEyeeY5NrN9ouWRyzt54iQCXXHpFRXtATWigN+ZKP7ggAAEIQAACEOg8AYSAzjPmCRCAAAQWHIEbv/ode/P5Z9hLu/c7McBL/1cGgAJ/eQWoNeEVV/+d3XfbP5jaLnbyGpvM2KGRqbqPkBhwz923Wmqq/n2dnGfQsaeefMwmHthiucmJmh9ReYA6CRS7CJTdKacAlQEoC0DZAOmcWbRQElAuBIxfsMkmzzo36PQ6ct+qVavsbZdeYamsch5KL0SAjiBnUAhAAAIQgEBNAggBvBwQgAAEIFBB4IYbb3bB/iuOOsKVBKxdfZydc/paJw7c8LlvOhFAwf+H3v/OjmYD+CcWpK3gwf177Z67b1sQO6rWgqO3bK4pBqgwYHk4ZH3hytZ52VzOJs1sMpsXAVKWc8aAXkmAhAD9PTS4dM5bA2ozTj75ZDt/41tsdEI5C/NHBNB7rPd55zPP5ctgrvuA88TgggAEIAABCCx2AggBi32HWR8EIACBFgjIKPCmT1/rTAKVEfDSrn1F40ANNzQy1vEsgGrTPjyacoaB9a7fPrPDHn7ogRZWPfsfyR4+ZGovmBs6VPXhqpZfEQmbAntdniHgRCELQOUAqcInYz4hQCUFkWOPt4mLL5vT1oCa2nnr19tJp5xtU+lsxRrlByBfgLm49A4ro8XLcJFB5t/eeHMx+2Uu5sQzIQABCEAAArNFACFgtkjzHAhAAALzmIBORu++9ye27bGdtvlL19tb3/0xs1DIZQF88C8vs6//y9129ulr7fKLN8z5KvYNTdpUqjKo9E/sie3b7Mnt2+Z8rkEm0EgM6AuFbFlYFoDmsgAmXBZAXgDwztf1M1XXe5kA4fMvsMwbNgZ5fEfvOffc9XbiKesqTB6la6xckrSo2hvM0SWxS5ku13/4PcUZ6Ht63ztlfjlHS+WxEIAABCAAgQoCCAG8FBCAAAS6nIA6BOhLQb4X6OvfZ592YjFNWkKBrvmQNu1vK1hv6xaSGJCbmLCRWzZbds+uyv9HbWZLC0LARMEUsFpORExCwJKlFr3wYrM1a+f8rd606WI76rg1FSJAJByyFYNxi0YqSx5ma9Le6f8dN3+qmNmiDIFN7/7YrHhezNY6eQ4EIAABCECgFgGEAN4NCEAAAl1KQMGQ0v7dqf9Vly8oCmoruH9oqmo7Qf9Cfv7wQ/b0rx5fEGuTGDB2z12WfmZnxXyVPK8vZQHUaB7gRIDkn11loSVL53S96gzw9ndcZtGepfNSBBAcdbvQe//eKzcVWel3Qb8TyojhggAEIAABCCx2AggBi32HWR8EIACBGgTafcqv8VRC8MBDv5gV47V0Jmd7D0/UbSuopW998B577vfPLpj3YOKeu2zyyceamm983XnWo0yAOb4kAlz6zndZLtJTKWbMg0wAb1Iqfdn85Y87M0xdO3/zvF31kc+aMgTmQ9bLHG8jj4cABCAAgS4ggBDQBZvMEiEAAQjMBgF1F3jflZvcKevXN99ld27Zapu/eF1HA6t6bQW9FPSR0XG7/bv/1Q4e2DcbGNryjImtP7LJn/wo0FjJN19sibPPC3RvJ2+SCHDJpVdYON5X8ZhYROUACQuH584TwD8peQFcuOFM965KwLrqrz9XfHc7yYixIQABCEAAAvOFAELAfNkJ5gEBCEBgARP4w6597kT1/tv+obgKtSDU9Znrr+7oyqq1FRzoiVpPImL62fhkxkaGh+wH37/TRkeGOzqXdg7eSAxQa8D+91xl4TkuBdCajz32WLvwordbKlvZAUBdAdQdYD5dXtvA4eFR2/nbF+xDV12+4Mpj5hNP5gIBCEAAAguPAELAwtszZgwBCEBgzgkokPr2Hfe7ebz3irfawECfrX/7B0taDHonrZ3OCtAcvLaCXhaAugpIBMhkpyvqF6IYMPXkYzZ+z10V+x1ds9aVAswHEeCss86yM87Z4ASX8mugN2YSZebrJQFroL93TlphzlcmzAsCEIAABLqDAEJAd+wzq4QABCDQNgLfvv0++9bt97lUatVWq+Xg/bd+wZQBoJprv/HgjV/9jg2PjHU8K0CLm5jKuHZ0XhZAtQVLDLjn7lstNTXVNh6dHij9/O9t7L/darnJCfeoxPkXWHLDBZ1+bKDxz1u/3k58/TqTX8NCEwECLZCbIAABCEAAAouUAELAIt1YlgUBCECgUwRktHbTp6+1k9Yc7x5x3iV/ZXd88+9Nrn2qtfZnAGzbvsMZCM6WE/u+w5M2lc7WXfrB/Xvtnrtv6xSejoyb2b3Lxh/YYol151nstXPfGlCLPPfc9XbiKeuqdm6Y75kAHdkkBoUABCAAAQgsIAIIAQtos5gqBCAAgflA4IoPfNJu+sxHio7rEgYkBAz29zqTwAe2Pmr//KXr3b/v2rLV/fsrn752VqYetK3gb5/ZYQ8/9MCszGmxPSSRSNimiy62ZauOqyoCyA9AvgBcEIAABCAAAQjMXwIIAfN3b5gZBCAAgXlPYGhkzK79xE0lJ/4qB7jr+w/Z2Wec5EoHZsMjwA8qaFvBJ7Zvsye3b5v3jOfTBNUZ4O3vuMzCySUVbRtDIXOdAeLR8HyaMnOBAAQgAAEIQKAKAYQAXgsIQAACEPj/27u70DzLMw7gVz7exCaNtptmOsHtIEKl6Oycn+u04KyFuqVoPZggFuZOqmM7W1XYgVDWwQ4m+zhZBwXBKduwHvixrRWkHhTn5tzAD5wIomystXVN65pkreN+JVlM0yapvZPrTX4PlELf572e6/ld6cH7z/Pe92kLlPUCShhQVl2feJRF2MpigVetuuS0a3+SN55qW8GJdYUBM1duhgCDG6Ot0XPCmz5apLG7uUaDgwABAgQIEMgvIAjIPyMdEiBAII1A+XBfjgsvOK/5d9mPvYQAV16+ovk1gPJnrtYDmA5lqm0Fp3rPi3v3xOuvvDxduUX9egkB1n/9tmjv6j1JCNAVnR2eBFjUPyRungABAgRaSkAQ0FLj0iwBAgTmR6AEAGXRv3JsvmtwPAhoLhS4/cHma+XJgC33fGP8tfnp9ONXHdtWcLpentv1ZLzz9lvTnbYoXx8YGIjVa26K0eMnfu9/bLtGIcCi/NFw0wQIECDQwgKCgBYentYJECAwFwLl8f+y4N+GdaubfyYeK9dsis+ef27zqYDJr81FbzO5xv5DwzEyeuqdBEaGh2PX04/HwQP7Z1Jy0ZyzatUX44prvhL/GT52wj03OtpieZ8nARbND4MbJUCAAIEFJSAIWFDjdDMECBA4swLl0f+rLl8RmyetATB2lbI94HytAzDTO53pTgKHhw41w4Ajh4dmWnpBn3fdddfFxSuviLL44uTjrK6OWNbbiPZ2awIs6B8CN0eAAAECC1ZAELBgR+vGCBAgQGBMYDZhwLPP7IyhoUOLGu/6G9bEpZetikZnexwYGvnYNoG9Z3XGOb2NRe3j5gkQIECAQKsLCAJafYL6J0CAAIEZCZTfbP/r/aMnPbd83/1TfV1x4OD78atHHo7RkZEZ1V1IJ3V3d8fam9fF8v6Lmh/++5Z0xpLujnjv0EdhQF9Po/lvDgIECBAgQKC1BQQBrT0/3RMgQIDALASOHP1vlAUEJx/lw+45PY0or5fdBg6+ty+eeuKxWVRu/VPLzgAbbr09jncsiQ8nfBug2PQtacTw6LEoTwM4CBAgQIAAgdYXEAS0/gzdAQECBAjMQmDitoLlKYDyG+6uRvv4b73HSr35xquxd8/uWVRu3VNLCHDL4MZoa/RMeRPlSYmyLoCDAAECBAgQWBgCgoCFMUd3QYAAgZQCO595PoYOfxCD61bH2Uun/pA5H42XpwLKb7jLB9yyIn4JB6Y6/vrSC/G3l16Yjxbn7Jr9/f1x49r10d7Ve8I129oili8VAszZMFyIAAECBAjMkYAgYI6gXYYAAQKLSeDQ4Q/itru/39xxoG9pT+ze86f4ydbvxIqBi9IwHDv2Yfz7g5E4OnLqrQUXchgwMDAQq9fcFKPHT/xtf3la4tNnd0dnh50B0vzQaoQAAQIECJwhAUHAGYJUhgABAgT+L/Dwb34f7/5zf2y5947mP/7xL6/Ftx94KH67/cG48ILzUlDNdCeB0uyLe/fE66+8nKLvM9XENddeGyu/cOWUQchHIUBXdHa0n6nLqUOAAAECBAgkEhAEJBqGVggQILBQBCYHAeW+fr5jZzMc2Lrl7jS3OZsw4LldT8Y7b7+VpvdP0sjVV18bKy77UpSdFCYfQoBPIuu9BAgQIECgNQQEAa0xJ10SIECgpQReeOnVeOCHv4w/PPqj8b7Hvi6w9XvfjKtWXZLmfqbbVnCs0ZHh4dj19ONx8MD+NL2fTiPX37AmPn/xpc3tACcfjY62WN7nSYDTcfUeAgQIECDQSgKCgFaall4JECDQQgIPbNseF55/bmzetGG86/JUQDkm/luGWzrZtoKTezs8dKgZBhw5PJSh7Vn10N3dHeu/NhhLl31myhCg7AqwrLcR7e3WBJgVbOWTxxbcvHPj2uaVSqBW/m9tvmswLrn4c5WvrjwBAgQILFQBQcBCnaz7IkCAwDwLvPuPfc0FA3c8dN/4IoFTfWVgntscv/zEbQVP1VMrhgFle8ANt94exzuWxIcnPggQPd0dsWxpV5ZR6GOCwNiTNPds2hAb1q2On+3Y2VxzY8ePt3AiQIAAAQKnLSAIOG06byRAgACB6QSeff7Pcf8PfhF33n5z8+mAh3/9u9h637dS7R4w8R7KtoLl6YDpjhIGPPXEozE6MjLdqfP+egkBbhncGG2Nqbdv7OtpRN+SznnvUwMnFygf/O/ftj3uu/eO5v+nTItumhsBAgQItKaAIKA156ZrAgQItIzAa39/O8rjzeW487ab0uwacDLAfe8fjdEpFtGbfP7B9/bFU088lnoO/f39cePa9dHe1Ttln0KA1OP7WHPbfvpIlCdqyk4cY18TaJ3udUqAAAEC2QQEAdkmoh8CBAgQmFeB2ewk8OYbr8bePbvntd+TXXzlypXx5Ru+GkeOHpvylHN6G9F7licBUg5viqY2fXdb8ysBJQQY25azVXrXJwECBAjkExAE5JuJjggQIECAAAECBAgQIECAQDUBQUA1WoUJECBAgAABAgQIECBAgEA+AUFAvpnoiAABAgQIECBAgAABAgQIVBMQBFSjVZgAAQIECBAgQIAAAQIECOQTEATkm4mOCBAgQIAAAQIECBAgQIBANQFBQDVahQkQIECAAAECBAgQIECAQD4BQUC+meiIAAECBAgQIECAAAECBAhUExAEVKNVmAABAgQIECBAgAABAgQI5BMQBOSbiY4IECBAgAABAgQIECBAgEA1AUFANVqFCRAgQIAAAQIECBAgQIBAPgFBQL6Z6IgAAQIECBAgQIAAAQIECFQTEARUo1WYAAECBAgQIECAAAECBAjkExAE5JuJjggQIECAAAECBAgQIECAQDUBQUA1WoUJECBAgAABAgQIECBAgEA+AUFAvpnoiAABAgQIECBAgAABAgQIVBMQBFSjVZgAAQIECBAgQIAAAQIECOQTEATkm4mOCBAgQIAAAQIECBAgQIBANQFBQDVahQkQIECAAAECBAgQIECAQD4BQUC+meiIAAECBAgQIECAAAECBAhUExAEVKNVmAABAgQIECBAgAABAgQI5BMQBOSbiY4IECBAgAABAgQIECBAgEA1AUFANVqFCRAgQIAAAQIECBAgQIBAPgFBQL6Z6IgAAQIECBAgQIAAAQIECFQTEARUo1WYAAECBAgQIECAAAECBAjkExAE5JuJjggQIECAAAECBAgQIECAQDUBQUA1WoUJECBAgAABAgQIECBAgEA+AUFAvpnoiAABAgQIECBAgAABAgQIVBMQBFSjVZgAAQIECBAgQIAAAQIECOQTEATkm4mOCBAgQIAAAQIECBAgQIBANQFBQDVahQkQIECAAAECBAgQIECAQD4BQUC+meiIAAECBAgQIECAAAECBAhUExAEVKNVmAABAgQIECBAgAABAgQI5BMQBOSbiY4IECBAgAABAgQIECBAgEA1AUFANVqFCRAgQIAAAQIECBAgQIBAPgFBQL6Z6IgAAQIECBAgQIAAAQIECFQTEARUo1WYAAECBAgQIECAAAECBAjkExAE5JuJjggQIECAAAECBAgQIECAQDUBQUA1WoUJECBAgAABAgQIECBAgEA+AUFAvpnoiAABAgQIECBAgAABAgQIVBMQBFSjVZgAAQIECBAgQIAAAQIECOQTEATkm4mOCBAgQIAAAQIECBAgQIBANQFBQDVahQkQIECAAAECBAgQIECAQD4BQUC+meiIAAECBAgQIECAAAECBAhUExAEVKNVmAABAgQIECBAgAABAgQI5BMQBOSbiY4IECBAgAABAgQIECBAgEA1AUFANVqFCRAgQIAAAQIECBAgQIBAPgFBQL6Z6IgAAQIECBAgQIAAAQIECFQTEARUo1WYAAECBAgQIECAAAECBAjkExAE5JuJjggQIECAAAECBAgQIECAQDUBQUA1WoUJECBAgAABAgQIECBAgEA+AUFAvpnoiAABAgQIECBAgAABAgQIVBMQBFSjVZgAAQIECBAgQIAAAQIECOQTEATkm4mOCBAgQIAAAQIECBAgQIBANQFBQDVahQkQIECAAAECBAgQIECAQD4BQUC+meiIAAECBAgQIECAAAECBAhUExAEVKNVmAABAgQIECBAgAABAgQI5BMQBOSbiY4IECBAgAABAgQIECBAgEA1AUFANVqFCRAgQIAAAQIECBAgQIBAPgFBQL6Z6IgAAQIECBAgQIAAAQIECFQTEARUo1WYAAECBAgQIECAAAECBAjkExAE5JuJjggQIECAAAECBAgQIECAQDUBQUA1WoUJECBAgAABAgQIECBAgEA+AUFAvpnoiAABAgQIECBAgAABAgQIVBMQBFSjVZgAAQIECBAgQIAAAQIECOQTEATkm4mOCBAgQIAAAQIECBAgQIBANQFBQDVahQkQIECAAAECBAgQIECAQD4BQUC+meiIAAECBAgQIECAAAECBAhUExAEVKNVmAABAgQIECBAgAABAgQI5BMQBOSbiY4IECBAgAABAgQIECBAgEA1AUFANVqFCRAgQIAAAQIECBAgQIBAPgFBQL6Z6IgAAQIECBAgQIAAAQIECFQTEARUo1WYAAECBAgQIECAAAECBAjkExAE5JuJjggQIECAAAECBAgQIECAQDUBQUA1WoUJECBAgAABAgQIECBAgEA+AUFAvpnoiAABAgQIECBAgAABAgQIVBMQBFSjVZgAAQIECBAgQIAAAQIECOQTEATkm4mOCBAgQIAAAQIECBAgQIBANQFBQDVahQkQIECAAAECBAgQIECAQD4BQUC+meiIAAECBAgQIECAAAECBAhUExAEVKNVmAABAgQIECBAgAABAgQI5BP4HyDogZmp4qkVAAAAAElFTkSuQmCC", "text/html": [ - "