From 095cebceabe0329a1424ac451f70cf8c60b7223b Mon Sep 17 00:00:00 2001 From: Andreas Dutzler Date: Sat, 10 Dec 2022 19:19:59 +0100 Subject: [PATCH 1/2] add diagonal and ravel --- tensortrax/math/__init__.py | 2 ++ tensortrax/math/_math_tensor.py | 30 ++++++++++++++++++++++++++++++ tests/test_math.py | 2 ++ 3 files changed, 34 insertions(+) diff --git a/tensortrax/math/__init__.py b/tensortrax/math/__init__.py index 44dd563..5e8e01d 100644 --- a/tensortrax/math/__init__.py +++ b/tensortrax/math/__init__.py @@ -27,5 +27,7 @@ sqrt, einsum, matmul, + diagonal, + ravel, ) from . import _math_array as array diff --git a/tensortrax/math/_math_tensor.py b/tensortrax/math/_math_tensor.py index 22752ce..63a5f16 100644 --- a/tensortrax/math/_math_tensor.py +++ b/tensortrax/math/_math_tensor.py @@ -170,3 +170,33 @@ def log10(A): ) else: return np.log10(A) + + +def diagonal(A, offset=0, axis1=0, axis2=1): + kwargs = dict(offset=offset, axis1=axis1, axis2=axis2) + if isinstance(A, Tensor): + return Tensor( + x=np.diagonal(f(A), **kwargs).T, + δx=np.diagonal(δ(A), **kwargs).T, + Δx=np.diagonal(Δ(A), **kwargs).T, + Δδx=np.diagonal(Δδ(A), **kwargs).T, + ntrax=A.ntrax, + ) + else: + return np.diagonal(A, **kwargs).T + + +def ravel(A, order="C"): + if isinstance(A, Tensor): + δtrax = δ(A).shape[len(A.shape) :] + Δtrax = Δ(A).shape[len(A.shape) :] + Δδtrax = Δδ(A).shape[len(A.shape) :] + return Tensor( + x=f(A).reshape(A.size, *A.trax, order=order), + δx=δ(A).reshape(A.size, *δtrax, order=order), + Δx=Δ(A).reshape(A.size, *Δtrax, order=order), + Δδx=Δδ(A).reshape(A.size, *Δδtrax, order=order), + ntrax=A.ntrax, + ) + else: + return np.ravel(A, order=order) diff --git a/tests/test_math.py b/tests/test_math.py index 90265f4..cf7d617 100644 --- a/tests/test_math.py +++ b/tests/test_math.py @@ -58,6 +58,8 @@ def test_math(): tm.exp, tm.log, tm.log10, + tm.diagonal, + tm.ravel, ]: assert np.allclose(fun(F), fun(T).x) From b5867c2f4d91e7e491224a33cde6cefd06cd5085 Mon Sep 17 00:00:00 2001 From: Andreas Dutzler Date: Sat, 10 Dec 2022 22:05:28 +0100 Subject: [PATCH 2/2] add `Tensor.ravel()` as method also allow setting of items --- tensortrax/_tensor.py | 41 +++++++++++++++++++++++++++++---- tensortrax/math/_math_tensor.py | 18 +-------------- tests/test_math.py | 5 ++++ 3 files changed, 42 insertions(+), 22 deletions(-) diff --git a/tensortrax/_tensor.py b/tensortrax/_tensor.py index bec95c7..47d0011 100644 --- a/tensortrax/_tensor.py +++ b/tensortrax/_tensor.py @@ -149,19 +149,34 @@ def __pow__(self, p): def T(self): return transpose(self) + def ravel(self, order="C"): + return ravel(self, order=order) + def __matmul__(self, B): return matmul(self, B) def __rmatmul__(self, B): return matmul(B, self) - def __getitem__(self, items): - x = f(self)[items] - Δx = Δ(self)[items] - δx = δ(self)[items] - Δδx = Δδ(self)[items] + def __getitem__(self, key): + x = f(self)[key] + Δx = Δ(self)[key] + δx = δ(self)[key] + Δδx = Δδ(self)[key] return Tensor(x=x, δx=δx, Δx=Δx, Δδx=Δδx, ntrax=self.ntrax) + def __setitem__(self, key, value): + if isinstance(value, Tensor): + self.x[key] = f(value) + self.δx[key] = δ(value) + self.Δx[key] = Δ(value) + self.Δδx[key] = Δδ(value) + else: + self.x[key] = value + self.δx[key].fill(0) + self.Δx[key].fill(0) + self.Δδx[key].fill(0) + def __repr__(self): header = "" metadata = [ @@ -176,6 +191,22 @@ def __repr__(self): __array_ufunc__ = None +def ravel(A, order="C"): + if isinstance(A, Tensor): + δtrax = δ(A).shape[len(A.shape) :] + Δtrax = Δ(A).shape[len(A.shape) :] + Δδtrax = Δδ(A).shape[len(A.shape) :] + return Tensor( + x=f(A).reshape(A.size, *A.trax, order=order), + δx=δ(A).reshape(A.size, *δtrax, order=order), + Δx=Δ(A).reshape(A.size, *Δtrax, order=order), + Δδx=Δδ(A).reshape(A.size, *Δδtrax, order=order), + ntrax=A.ntrax, + ) + else: + return np.ravel(A, order=order) + + def einsum3(subscripts, *operands): "Einsum with three operands." A, B, C = operands diff --git a/tensortrax/math/_math_tensor.py b/tensortrax/math/_math_tensor.py index 63a5f16..b28c682 100644 --- a/tensortrax/math/_math_tensor.py +++ b/tensortrax/math/_math_tensor.py @@ -10,7 +10,7 @@ import numpy as np -from .._tensor import Tensor, einsum, matmul, f, δ, Δ, Δδ +from .._tensor import Tensor, ravel, einsum, matmul, f, δ, Δ, Δδ from ._linalg import _linalg_array as array @@ -184,19 +184,3 @@ def diagonal(A, offset=0, axis1=0, axis2=1): ) else: return np.diagonal(A, **kwargs).T - - -def ravel(A, order="C"): - if isinstance(A, Tensor): - δtrax = δ(A).shape[len(A.shape) :] - Δtrax = Δ(A).shape[len(A.shape) :] - Δδtrax = Δδ(A).shape[len(A.shape) :] - return Tensor( - x=f(A).reshape(A.size, *A.trax, order=order), - δx=δ(A).reshape(A.size, *δtrax, order=order), - Δx=Δ(A).reshape(A.size, *Δtrax, order=order), - Δδx=Δδ(A).reshape(A.size, *Δδtrax, order=order), - ntrax=A.ntrax, - ) - else: - return np.ravel(A, order=order) diff --git a/tests/test_math.py b/tests/test_math.py index cf7d617..c6ee31e 100644 --- a/tests/test_math.py +++ b/tests/test_math.py @@ -88,6 +88,11 @@ def test_math(): with pytest.raises(NotImplementedError): tm.einsum("ij...,kl...,mn...,pq...->ijklmnpq...", T, T, T, T) + T.ravel() + T[0] = F[0] + T[:, 0] = F[:, 0] + T[:, 0] = T[:, 0] + if __name__ == "__main__": test_math()