Skip to content

Commit

Permalink
Merge pull request #76 from adtzlr/add-dual2real
Browse files Browse the repository at this point in the history
Add `Tensor.dual2real(like=None)`
  • Loading branch information
adtzlr authored Feb 15, 2023
2 parents cdb8746 + dbb06d2 commit 9f4a001
Show file tree
Hide file tree
Showing 4 changed files with 51 additions and 0 deletions.
13 changes: 13 additions & 0 deletions src/tensortrax/_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -309,6 +309,9 @@ def reshape(self, *shape, order="C"):
def squeeze(self, axis=None):
return squeeze(self, axis=axis)

def dual2real(self, like):
return dual2real(self, like=like)

__radd__ = __add__
__rmul__ = __mul__
__array_ufunc__ = None
Expand All @@ -331,6 +334,16 @@ def broadcast_to(A, shape):
return _broadcast_to(A)


def dual2real(A, like=None):
"""Return a new Tensor with old-dual data as new-real values,
with `ntrax` derived by `like`."""

ndual = like.ndual - len(like.shape)
ntrax = A.ntrax - ndual

return Tensor(x=A.δx, δx=A.Δδx, Δx=A.Δδx, ndual=min(ndual, ntrax), ntrax=ntrax)


def ravel(A, order="C"):
if isinstance(A, Tensor):
δtrax = δ(A).shape[len(A.shape) :]
Expand Down
1 change: 1 addition & 0 deletions src/tensortrax/math/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
cosh,
diagonal,
dot,
dual2real,
einsum,
exp,
hstack,
Expand Down
1 change: 1 addition & 0 deletions src/tensortrax/math/_math_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
Δ,
Δδ,
broadcast_to,
dual2real,
einsum,
f,
matmul,
Expand Down
36 changes: 36 additions & 0 deletions tests/test_dual2real.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
import numpy as np

import tensortrax as tr
import tensortrax.math as tm


def test_dual2real():
np.random.seed(34563)
x = (np.random.rand(3, 3) - 0.5) / 10 + np.eye(3)

# init a Tensor with `hessian=True`
F = tr.Tensor(x)
F.init(hessian=True)

# perform some math operations
C = F.T() @ F
J = tm.linalg.det(F)
W = tm.trace(J**(-2 / 3) * C) - 3
eta = 1 - 1 / 3 * tm.tanh(W / 8)

# set old dual data as new real values (i.e. obtain the gradient)
P = W.dual2real(like=F)
tm.dual2real(W, like=F)

# perform some more math with a derived Tensor involved
Q = eta * P

# take the gradient
A = tr.δ(Q)

assert P.shape == (3, 3)
assert Q.shape == (3, 3)
assert A.shape == (3, 3, 3, 3)

if __name__ == "__main__":
test_dual2real()

0 comments on commit 9f4a001

Please sign in to comment.