Skip to content

Commit

Permalink
Merge pull request #91 from adtzlr/add-if-else
Browse files Browse the repository at this point in the history
Add `math.if_else(cond, true, false)`
  • Loading branch information
adtzlr authored May 24, 2023
2 parents ac6fde2 + 7c922b5 commit 482ae85
Show file tree
Hide file tree
Showing 4 changed files with 73 additions and 2 deletions.
2 changes: 1 addition & 1 deletion src/tensortrax/__about__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,4 @@
tensorTRAX: Math on (Hyper-Dual) Tensors with Trailing Axes.
"""

__version__ = "0.14.0"
__version__ = "0.15.0"
2 changes: 2 additions & 0 deletions src/tensortrax/math/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
exp,
external,
hstack,
if_else,
log,
log10,
matmul,
Expand Down Expand Up @@ -78,4 +79,5 @@
"trace",
"transpose",
"vstack",
"if_else",
]
51 changes: 50 additions & 1 deletion src/tensortrax/math/_math_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

import numpy as np

from .._tensor import Tensor, Δ, Δδ, einsum, f, matmul, δ
from .._tensor import Tensor, Δ, Δδ, broadcast_to, einsum, f, matmul, δ

dot = matmul

Expand Down Expand Up @@ -370,3 +370,52 @@ def hvp(h, v, u, ntrax):
)
else:
return function(x, *args, **kwargs)


def if_else(cond, true, false):
"Mask-based Condition for arrays and tensors."

mask = np.asarray(cond)
out = true.copy()

if isinstance(true, np.ndarray) and isinstance(false, np.ndarray):
out = true.copy()
out[..., mask] = true[..., mask]
out[..., ~mask] = false[..., ~mask]

elif isinstance(true, Tensor) and isinstance(false, Tensor):
shape = np.maximum.reduce(
[
true.x.shape,
true.δx.shape,
true.Δx.shape,
true.Δδx.shape,
false.x.shape,
false.δx.shape,
false.Δx.shape,
false.Δδx.shape,
]
)

out = broadcast_to(true, shape=shape).copy()

mask = np.broadcast_to(mask, shape)
true = broadcast_to(true, shape=shape)
false = broadcast_to(false, shape=shape)

out.x[..., mask] = true.x[..., mask]
out.δx[..., mask] = true.δx[..., mask]
out.Δx[..., mask] = true.Δx[..., mask]
out.Δδx[..., mask] = true.Δδx[..., mask]

out.x[..., ~mask] = false.x[..., ~mask]
out.δx[..., ~mask] = false.δx[..., ~mask]
out.Δx[..., ~mask] = false.Δx[..., ~mask]
out.Δδx[..., ~mask] = false.Δδx[..., ~mask]

else:
raise NotImplementedError(
"`true` and `false` must be both arrays or both tensors."
)

return out
20 changes: 20 additions & 0 deletions tests/test_math.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,6 +214,25 @@ def test_logical():
assert np.all(B < A)


def test_condition():
F = np.tile((np.eye(3) + np.arange(-2, 7).reshape(3, 3) / 10).reshape(3, 3, 1), 10)
T = tr.Tensor(F, ntrax=1)

G = np.tile((np.eye(3) - np.arange(-7, 2).reshape(3, 3) / 10).reshape(3, 3, 1), 10)
V = tr.Tensor(G, ntrax=1)

Y = tm.if_else(F >= G, 2 * F, G / 2)
Z = tm.if_else(T >= V, 2 * T, V / 2)

np.allclose(Y, Z.x)

with pytest.raises(NotImplementedError):
tm.if_else(F >= T, 2 * F, V / 2)

with pytest.raises(NotImplementedError):
tm.if_else(T >= G, 2 * T, G / 2)


if __name__ == "__main__":
test_math()
test_einsum()
Expand All @@ -222,3 +241,4 @@ def test_logical():
test_eigh()
test_triu()
test_logical()
test_condition()

0 comments on commit 482ae85

Please sign in to comment.