Skip to content

Commit

Permalink
Merge pull request #75 from adtzlr/change-ntrax-to-minimum-of-args
Browse files Browse the repository at this point in the history
Change `ntrax` to Minimum of all input arguments
  • Loading branch information
adtzlr authored Feb 15, 2023
2 parents 5a8c77c + cf8d98c commit cdb8746
Show file tree
Hide file tree
Showing 3 changed files with 19 additions and 13 deletions.
2 changes: 1 addition & 1 deletion src/tensortrax/__about__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,4 @@
tensorTRAX: Math on (Hyper-Dual) Tensors with Trailing Axes.
"""

__version__ = "0.8.5"
__version__ = "0.9.0"
22 changes: 14 additions & 8 deletions src/tensortrax/_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,12 +176,14 @@ def __add__(self, B):
δx = δ(A) + δ(B)
Δx = Δ(A) + Δ(B)
Δδx = Δδ(A) + Δδ(B)
ntrax = min(A.ntrax, B.ntrax)
else:
x = f(A) + B
δx = δ(A)
Δx = Δ(A)
Δδx = Δδ(A)
return Tensor(x=x, δx=δx, Δx=Δx, Δδx=Δδx, ntrax=A.ntrax)
ntrax = A.ntrax
return Tensor(x=x, δx=δx, Δx=Δx, Δδx=Δδx, ntrax=ntrax)

def __sub__(self, B):
A = self
Expand All @@ -190,12 +192,14 @@ def __sub__(self, B):
δx = δ(A) - δ(B)
Δx = Δ(A) - Δ(B)
Δδx = Δδ(A) - Δδ(B)
ntrax = min(A.ntrax, B.ntrax)
else:
x = f(A) - B
δx = δ(A)
Δx = Δ(A)
Δδx = Δδ(A)
return Tensor(x=x, δx=δx, Δx=Δx, Δδx=Δδx, ntrax=A.ntrax)
ntrax = A.ntrax
return Tensor(x=x, δx=δx, Δx=Δx, Δδx=Δδx, ntrax=ntrax)

def __rsub__(self, B):
return -self.__sub__(B)
Expand All @@ -207,12 +211,14 @@ def __mul__(self, B):
δx = δ(A) * f(B) + f(A) * δ(B)
Δx = Δ(A) * f(B) + f(A) * Δ(B)
Δδx = Δ(A) * δ(B) + δ(A) * Δ(B) + Δδ(A) * f(B) + f(A) * Δδ(B)
ntrax = min(A.ntrax, B.ntrax)
else:
x = f(A) * B
δx = δ(A) * B
Δx = Δ(A) * B
Δδx = Δδ(A) * B
return Tensor(x=x, δx=δx, Δx=Δx, Δδx=Δδx, ntrax=A.ntrax)
ntrax = A.ntrax
return Tensor(x=x, δx=δx, Δx=Δx, Δδx=Δδx, ntrax=ntrax)

def __truediv__(self, B):
A = self
Expand Down Expand Up @@ -403,7 +409,7 @@ def einsum3(subscripts, *operands):
+ _einsum(f(A), δ(B), Δ(C))
+ _einsum(f(A), Δ(B), δ(C))
)
ntrax = A.ntrax
ntrax = min(A.ntrax, B.ntrax, C.ntrax)
elif (
isinstance(A, Tensor)
and not isinstance(B, Tensor)
Expand Down Expand Up @@ -444,7 +450,7 @@ def einsum3(subscripts, *operands):
+ _einsum(δ(A), Δ(B), C)
+ _einsum(Δ(A), δ(B), C)
)
ntrax = A.ntrax
ntrax = min(A.ntrax, B.ntrax)
elif isinstance(A, Tensor) and not isinstance(B, Tensor) and isinstance(C, Tensor):
x = _einsum(f(A), B, f(C))
δx = _einsum(δ(A), B, f(C)) + _einsum(f(A), B, δ(C))
Expand All @@ -455,7 +461,7 @@ def einsum3(subscripts, *operands):
+ _einsum(δ(A), B, Δ(C))
+ _einsum(Δ(A), B, δ(C))
)
ntrax = A.ntrax
ntrax = min(A.ntrax, C.ntrax)
elif not isinstance(A, Tensor) and isinstance(B, Tensor) and isinstance(C, Tensor):
x = _einsum(A, f(B), f(C))
δx = _einsum(A, δ(B), f(C)) + _einsum(A, f(B), δ(C))
Expand All @@ -466,7 +472,7 @@ def einsum3(subscripts, *operands):
+ _einsum(A, δ(B), Δ(C))
+ _einsum(A, Δ(B), δ(C))
)
ntrax = B.ntrax
ntrax = min(B.ntrax, C.ntrax)
else:
return _einsum(*operands)

Expand All @@ -488,7 +494,7 @@ def einsum2(subscripts, *operands):
+ _einsum(δ(A), Δ(B))
+ _einsum(Δ(A), δ(B))
)
ntrax = A.ntrax
ntrax = min(A.ntrax, B.ntrax)
elif isinstance(A, Tensor) and not isinstance(B, Tensor):
x = _einsum(f(A), B)
δx = _einsum(δ(A), B)
Expand Down
8 changes: 4 additions & 4 deletions src/tensortrax/math/_math_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ def array(object, dtype=None):
δx=np.array([δ(o) for o in object], dtype=dtype),
Δx=np.array([Δ(o) for o in object], dtype=dtype),
Δδx=np.array([Δδ(o) for o in object], dtype=dtype),
ntrax=object[0].ntrax,
ntrax=min([o.ntrax for o in object]),
)
else:
return np.array(object, dtype=dtype)
Expand Down Expand Up @@ -286,7 +286,7 @@ def hstack(tup):
δx=np.hstack([δ(A) for A in tup]),
Δx=np.hstack([Δ(A) for A in tup]),
Δδx=np.hstack([Δδ(A) for A in tup]),
ntrax=tup[0].ntrax,
ntrax=min([A.ntrax for A in tup]),
)
else:
return np.hstack(tup)
Expand All @@ -301,7 +301,7 @@ def vstack(tup):
δx=np.vstack([δ(A) for A in tup]),
Δx=np.vstack([Δ(A) for A in tup]),
Δδx=np.vstack([Δδ(A) for A in tup]),
ntrax=tup[0].ntrax,
ntrax=min([A.ntrax for A in tup]),
)
else:
return np.vstack(tup)
Expand All @@ -316,7 +316,7 @@ def stack(arrays, axis=0):
δx=np.stack([δ(A) for A in arrays], axis=axis),
Δx=np.stack([Δ(A) for A in arrays], axis=axis),
Δδx=np.stack([Δδ(A) for A in arrays], axis=axis),
ntrax=arrays[0].ntrax,
ntrax=min([A.ntrax for A in arrays]),
)
else:
return np.stack(arrays, axis=0)
Expand Down

0 comments on commit cdb8746

Please sign in to comment.