From dbb06d2d2354e351c8b2e2855f8bcb9f872c6923 Mon Sep 17 00:00:00 2001 From: Andreas Dutzler Date: Wed, 15 Feb 2023 23:24:55 +0100 Subject: [PATCH] Add new method `Tensor.dual2real()` --- src/tensortrax/_tensor.py | 13 +++++++++++ src/tensortrax/math/__init__.py | 1 + src/tensortrax/math/_math_tensor.py | 1 + tests/test_dual2real.py | 36 +++++++++++++++++++++++++++++ 4 files changed, 51 insertions(+) create mode 100644 tests/test_dual2real.py diff --git a/src/tensortrax/_tensor.py b/src/tensortrax/_tensor.py index 05c1dc8..179de4c 100644 --- a/src/tensortrax/_tensor.py +++ b/src/tensortrax/_tensor.py @@ -309,6 +309,9 @@ def reshape(self, *shape, order="C"): def squeeze(self, axis=None): return squeeze(self, axis=axis) + def dual2real(self, like): + return dual2real(self, like=like) + __radd__ = __add__ __rmul__ = __mul__ __array_ufunc__ = None @@ -331,6 +334,16 @@ def broadcast_to(A, shape): return _broadcast_to(A) +def dual2real(A, like=None): + """Return a new Tensor with old-dual data as new-real values, + with `ntrax` derived by `like`.""" + + ndual = like.ndual - len(like.shape) + ntrax = A.ntrax - ndual + + return Tensor(x=A.δx, δx=A.Δδx, Δx=A.Δδx, ndual=min(ndual, ntrax), ntrax=ntrax) + + def ravel(A, order="C"): if isinstance(A, Tensor): δtrax = δ(A).shape[len(A.shape) :] diff --git a/src/tensortrax/math/__init__.py b/src/tensortrax/math/__init__.py index 4f382f2..3e506d7 100644 --- a/src/tensortrax/math/__init__.py +++ b/src/tensortrax/math/__init__.py @@ -19,6 +19,7 @@ cosh, diagonal, dot, + dual2real, einsum, exp, hstack, diff --git a/src/tensortrax/math/_math_tensor.py b/src/tensortrax/math/_math_tensor.py index 3f81fdd..8b95c75 100644 --- a/src/tensortrax/math/_math_tensor.py +++ b/src/tensortrax/math/_math_tensor.py @@ -15,6 +15,7 @@ Δ, Δδ, broadcast_to, + dual2real, einsum, f, matmul, diff --git a/tests/test_dual2real.py b/tests/test_dual2real.py new file mode 100644 index 0000000..8b0ccdd --- /dev/null +++ b/tests/test_dual2real.py @@ -0,0 +1,36 @@ +import numpy as np + +import tensortrax as tr +import tensortrax.math as tm + + +def test_dual2real(): + np.random.seed(34563) + x = (np.random.rand(3, 3) - 0.5) / 10 + np.eye(3) + + # init a Tensor with `hessian=True` + F = tr.Tensor(x) + F.init(hessian=True) + + # perform some math operations + C = F.T() @ F + J = tm.linalg.det(F) + W = tm.trace(J**(-2 / 3) * C) - 3 + eta = 1 - 1 / 3 * tm.tanh(W / 8) + + # set old dual data as new real values (i.e. obtain the gradient) + P = W.dual2real(like=F) + tm.dual2real(W, like=F) + + # perform some more math with a derived Tensor involved + Q = eta * P + + # take the gradient + A = tr.δ(Q) + + assert P.shape == (3, 3) + assert Q.shape == (3, 3) + assert A.shape == (3, 3, 3, 3) + +if __name__ == "__main__": + test_dual2real() \ No newline at end of file