Skip to content

Commit

Permalink
Merge pull request #17 from adtzlr/add-diagonal-ravel
Browse files Browse the repository at this point in the history
Add `math.ravel()` and `math.diagonal()`
  • Loading branch information
adtzlr authored Dec 10, 2022
2 parents 271d542 + b5867c2 commit 20a7aa2
Show file tree
Hide file tree
Showing 4 changed files with 60 additions and 6 deletions.
41 changes: 36 additions & 5 deletions tensortrax/_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,19 +149,34 @@ def __pow__(self, p):
def T(self):
return transpose(self)

def ravel(self, order="C"):
return ravel(self, order=order)

def __matmul__(self, B):
return matmul(self, B)

def __rmatmul__(self, B):
return matmul(B, self)

def __getitem__(self, items):
x = f(self)[items]
Δx = Δ(self)[items]
δx = δ(self)[items]
Δδx = Δδ(self)[items]
def __getitem__(self, key):
x = f(self)[key]
Δx = Δ(self)[key]
δx = δ(self)[key]
Δδx = Δδ(self)[key]
return Tensor(x=x, δx=δx, Δx=Δx, Δδx=Δδx, ntrax=self.ntrax)

def __setitem__(self, key, value):
if isinstance(value, Tensor):
self.x[key] = f(value)
self.δx[key] = δ(value)
self.Δx[key] = Δ(value)
self.Δδx[key] = Δδ(value)
else:
self.x[key] = value
self.δx[key].fill(0)
self.Δx[key].fill(0)
self.Δδx[key].fill(0)

def __repr__(self):
header = "<tensortrax tensor object>"
metadata = [
Expand All @@ -176,6 +191,22 @@ def __repr__(self):
__array_ufunc__ = None


def ravel(A, order="C"):
if isinstance(A, Tensor):
δtrax = δ(A).shape[len(A.shape) :]
Δtrax = Δ(A).shape[len(A.shape) :]
Δδtrax = Δδ(A).shape[len(A.shape) :]
return Tensor(
x=f(A).reshape(A.size, *A.trax, order=order),
δx=δ(A).reshape(A.size, *δtrax, order=order),
Δx=Δ(A).reshape(A.size, *Δtrax, order=order),
Δδx=Δδ(A).reshape(A.size, *Δδtrax, order=order),
ntrax=A.ntrax,
)
else:
return np.ravel(A, order=order)


def einsum3(subscripts, *operands):
"Einsum with three operands."
A, B, C = operands
Expand Down
2 changes: 2 additions & 0 deletions tensortrax/math/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,5 +27,7 @@
sqrt,
einsum,
matmul,
diagonal,
ravel,
)
from . import _math_array as array
16 changes: 15 additions & 1 deletion tensortrax/math/_math_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

import numpy as np

from .._tensor import Tensor, einsum, matmul, f, δ, Δ, Δδ
from .._tensor import Tensor, ravel, einsum, matmul, f, δ, Δ, Δδ
from ._linalg import _linalg_array as array


Expand Down Expand Up @@ -170,3 +170,17 @@ def log10(A):
)
else:
return np.log10(A)


def diagonal(A, offset=0, axis1=0, axis2=1):
kwargs = dict(offset=offset, axis1=axis1, axis2=axis2)
if isinstance(A, Tensor):
return Tensor(
x=np.diagonal(f(A), **kwargs).T,
δx=np.diagonal(δ(A), **kwargs).T,
Δx=np.diagonal(Δ(A), **kwargs).T,
Δδx=np.diagonal(Δδ(A), **kwargs).T,
ntrax=A.ntrax,
)
else:
return np.diagonal(A, **kwargs).T
7 changes: 7 additions & 0 deletions tests/test_math.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,8 @@ def test_math():
tm.exp,
tm.log,
tm.log10,
tm.diagonal,
tm.ravel,
]:
assert np.allclose(fun(F), fun(T).x)

Expand Down Expand Up @@ -86,6 +88,11 @@ def test_math():
with pytest.raises(NotImplementedError):
tm.einsum("ij...,kl...,mn...,pq...->ijklmnpq...", T, T, T, T)

T.ravel()
T[0] = F[0]
T[:, 0] = F[:, 0]
T[:, 0] = T[:, 0]


if __name__ == "__main__":
test_math()

0 comments on commit 20a7aa2

Please sign in to comment.