Skip to content

Commit

Permalink
Merge pull request #18 from adtzlr/fix-set-items
Browse files Browse the repository at this point in the history
Fix `Tensor.__setitems__()`
  • Loading branch information
adtzlr authored Dec 10, 2022
2 parents 9ee75bc + b330a7d commit 110309b
Show file tree
Hide file tree
Showing 4 changed files with 20 additions and 2 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[project]
name = "tensortrax"
version = "0.1.7"
version = "0.1.8"
description = "Math on (Hyper-Dual) Tensors with Trailing Axes"
readme = "README.md"
requires-python = ">=3.7"
Expand Down
2 changes: 1 addition & 1 deletion setup.cfg
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[metadata]
name = tensortrax
version = 0.1.7
version = 0.1.8
author = Andreas Dutzler
author_email = a.dutzler@gmail.com
description = Math on (Hyper-Dual) Tensors with Trailing Axes
Expand Down
12 changes: 12 additions & 0 deletions tensortrax/_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,18 @@ def __getitem__(self, key):
def __setitem__(self, key, value):
if isinstance(value, Tensor):
self.x[key] = f(value)
if self.δx[key].shape != δ(value).shape:
self.δx = transpose(
np.resize(self.δx, (*self.trax[::-1], *self.shape[::-1])).T
)
if self.Δx[key].shape != Δ(value).shape:
self.Δx = transpose(
np.resize(self.Δx, (*self.trax[::-1], *self.shape[::-1])).T
)
if self.Δδx[key].shape != Δδ(value).shape:
self.Δδx = transpose(
np.resize(self.Δδx, (*self.trax[::-1], *self.shape[::-1])).T
)
self.δx[key] = δ(value)
self.Δx[key] = Δ(value)
self.Δδx[key] = Δδ(value)
Expand Down
6 changes: 6 additions & 0 deletions tests/test_math.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,12 @@ def test_math():
T[:, 0] = F[:, 0]
T[:, 0] = T[:, 0]

x = np.ones((3, 3, 100))
t = tr.Tensor(x, x, x, x, ntrax=1)
u = tr.Tensor(x, ntrax=1)

u[0] = t[0]


if __name__ == "__main__":
test_math()

0 comments on commit 110309b

Please sign in to comment.