Skip to content

Commit

Permalink
Merge pull request #22 from adtzlr/add-reshape
Browse files Browse the repository at this point in the history
Add reshape
  • Loading branch information
adtzlr authored Dec 11, 2022
2 parents 7caffe9 + e071350 commit 9df24dc
Show file tree
Hide file tree
Showing 6 changed files with 34 additions and 6 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[project]
name = "tensortrax"
version = "0.2.1"
version = "0.2.2"
description = "Math on (Hyper-Dual) Tensors with Trailing Axes"
readme = "README.md"
requires-python = ">=3.7"
Expand Down
2 changes: 1 addition & 1 deletion setup.cfg
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[metadata]
name = tensortrax
version = 0.2.1
version = 0.2.2
author = Andreas Dutzler
author_email = a.dutzler@gmail.com
description = Math on (Hyper-Dual) Tensors with Trailing Axes
Expand Down
25 changes: 22 additions & 3 deletions tensortrax/_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,9 @@ def T(self):
def ravel(self, order="C"):
return ravel(self, order=order)

def reshape(self, *shape, order="C"):
return reshape(self, newshape=shape, order=order)

def __matmul__(self, B):
return matmul(self, B)

Expand Down Expand Up @@ -222,6 +225,22 @@ def ravel(A, order="C"):
return np.ravel(A, order=order)


def reshape(A, newshape, order="C"):
if isinstance(A, Tensor):
δtrax = δ(A).shape[len(A.shape) :]
Δtrax = Δ(A).shape[len(A.shape) :]
Δδtrax = Δδ(A).shape[len(A.shape) :]
return Tensor(
x=f(A).reshape(*newshape, *A.trax, order=order),
δx=δ(A).reshape(*newshape, *δtrax, order=order),
Δx=Δ(A).reshape(*newshape, *Δtrax, order=order),
Δδx=Δδ(A).reshape(*newshape, *Δδtrax, order=order),
ntrax=A.ntrax,
)
else:
return np.reshape(A, newshape=newshape, order=order)


def einsum3(subscripts, *operands):
"Einsum with three operands."
A, B, C = operands
Expand Down Expand Up @@ -387,7 +406,7 @@ def transpose(A):


def matmul(A, B):
ik = "abcdefghijklm"[13-len(A.shape):]
kj = "mnopqrstuvwxy"[: len(B.shape)]
ij = (ik + kj).replace("m", "")
ik = "ik"[2 - len(A.shape) :]
kj = "kj"[: len(B.shape)]
ij = (ik + kj).replace("k", "")
return einsum(f"{ik}...,{kj}...->{ij}...", A, B)
1 change: 1 addition & 0 deletions tensortrax/math/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,5 +29,6 @@
matmul,
diagonal,
ravel,
reshape,
)
from . import _math_array as array
2 changes: 1 addition & 1 deletion tensortrax/math/_math_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

import numpy as np

from .._tensor import Tensor, ravel, einsum, matmul, f, δ, Δ, Δδ
from .._tensor import Tensor, ravel, reshape, einsum, matmul, f, δ, Δ, Δδ
from ._linalg import _linalg_array as array


Expand Down
8 changes: 8 additions & 0 deletions tests/test_math.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,14 @@ def test_math():

u[0] = t[0]

t.reshape(9)
t.reshape(3, 3)

tm.reshape(t, (9,))
tm.reshape(t, (3, 3))

tm.reshape(x, (3, 3, 100))


if __name__ == "__main__":
test_math()

0 comments on commit 9df24dc

Please sign in to comment.