Skip to content

Commit

Permalink
fix(lib): t was not correctly typed
Browse files Browse the repository at this point in the history
  • Loading branch information
jeertmans committed Dec 2, 2024
1 parent fa02262 commit 33b0e6d
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 6 deletions.
12 changes: 7 additions & 5 deletions differt/src/differt/em/_antenna.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ def average_power(self) -> Float[Array, " "]: # TODO: provide default impl.

@abstractmethod
def fields(
self, r: Float[Array, "*#batch 3"], t: Float[Array, "*#batch 3"] | None = None
self, r: Float[Array, "*#batch 3"], t: Float[Array, "*#batch"] | None = None
) -> tuple[Inexact[Array, "*batch 3"], Inexact[Array, "*batch 3"]]:
r"""
Compute electric and magnetic fields in vacuum at given position and (optional) time.
Expand All @@ -107,7 +107,7 @@ def fields(
def pointing_vector(
self,
r: Float[Array, "*#batch 3"],
t: Float[Array, "*#batch 3"] | None = None,
t: Float[Array, "*#batch"] | None = None,
) -> Inexact[Array, "*batch 3"]:
r"""
Compute the pointing vector in vacuum at given position and (optional) time.
Expand Down Expand Up @@ -378,7 +378,7 @@ def average_power(self) -> Float[Array, " "]:
@eqx.filter_jit
@jaxtyped(typechecker=typechecker)
def fields(
self, r: Float[Array, "*#batch 3"], t: Float[Array, "*#batch 3"] | None = None
self, r: Float[Array, "*#batch 3"], t: Float[Array, "*#batch"] | None = None
) -> tuple[Inexact[Array, "*batch 3"], Inexact[Array, "*batch 3"]]:
r_hat, r = normalize(r - self.center, keepdims=True)
p = self.moment
Expand All @@ -403,7 +403,9 @@ def fields(
)
b = (factor * k_k / c) * r_x_p * (1 - 1 / j_k_r) * r_inv

exp = jnp.exp(j_k_r - 1j * w * t) if t is not None else jnp.exp(j_k_r)
exp = (
jnp.exp(j_k_r - 1j * w * t[..., None]) if t is not None else jnp.exp(j_k_r)
)

e *= exp
b *= exp
Expand Down Expand Up @@ -456,7 +458,7 @@ class ShortDipole(Dipole):
@eqx.filter_jit
@jaxtyped(typechecker=typechecker)
def fields(
self, r: Float[Array, "*#batch 3"], t: Float[Array, "*#batch 3"] | None = None
self, r: Float[Array, "*#batch 3"], t: Float[Array, "*#batch"] | None = None
) -> tuple[Inexact[Array, "*batch 3"], Inexact[Array, "*batch 3"]]:
raise NotImplementedError

Expand Down
3 changes: 2 additions & 1 deletion differt/src/differt/em/_fresnel.py
Original file line number Diff line number Diff line change
Expand Up @@ -317,8 +317,9 @@ def reflection_coefficients(
>>> reflected_vectors, s_r = normalize(
... rx_positions - reflection_points, keepdims=True
... )
>>> l = jnp.linalg.norm(rx_positions - tx_position, axis=-1)
>>> l = jnp.linalg.norm(rx_positions - tx_position, axis=-1, keepdims=True)
>>> tau = (s + s_r - l) / c # Delay between two paths
>>> tau = tau.squeeze(axis=-1)
We then compute the EM fields at those points, and use the Fresnel
reflection coefficients to compute the reflected fields.
Expand Down

0 comments on commit 33b0e6d

Please sign in to comment.