Skip to content

Commit

Permalink
Merge pull request #83 from adtzlr/external-for-arrays
Browse files Browse the repository at this point in the history
`math.external()` for non-tensor `x` arguments
  • Loading branch information
adtzlr authored Apr 19, 2023
2 parents a4e332a + f342490 commit bc91f19
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 16 deletions.
24 changes: 14 additions & 10 deletions src/tensortrax/math/_math_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -334,9 +334,10 @@ def external(x, function, gradient, hessian, indices="ij", *args, **kwargs):
"""

# pre-evaluate the scalar-valued function along with its gradient and hessian
func = function(f(x), *args, **kwargs)
grad = gradient(f(x), *args, **kwargs)
hess = hessian(f(x), *args, **kwargs)
if isinstance(x, Tensor):
func = function(f(x), *args, **kwargs)
grad = gradient(f(x), *args, **kwargs)
hess = hessian(f(x), *args, **kwargs)

def gvp(g, v, ntrax):
"Evaluate the gradient-vector product."
Expand All @@ -353,10 +354,13 @@ def hvp(h, v, u, ntrax):

return einsum(f"{ij}{kl}...,{ij}...,{kl}...->...", h, v, u)

return Tensor(
x=func,
δx=gvp(grad, δ(x), x.ntrax),
Δx=gvp(grad, Δ(x), x.ntrax),
Δδx=hvp(hess, δ(x), Δ(x), x.ntrax) + gvp(grad, Δδ(x), x.ntrax),
ntrax=x.ntrax,
)
if isinstance(x, Tensor):
return Tensor(
x=func,
δx=gvp(grad, δ(x), x.ntrax),
Δx=gvp(grad, Δ(x), x.ntrax),
Δδx=hvp(hess, δ(x), Δ(x), x.ntrax) + gvp(grad, Δδ(x), x.ntrax),
ntrax=x.ntrax,
)
else:
return function(x, *args, **kwargs)
22 changes: 16 additions & 6 deletions tests/test_external.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,8 @@
import tensortrax.math as tm


def W(F, mu=1):
C = F.T() @ F
def psi(F, mu=1):
C = tm.dot(tm.transpose(F), F)
I1 = tm.trace(C)
return mu * (I1 - 3) / 2

Expand All @@ -21,15 +21,15 @@ def neo_hooke_ext(F):
J = tm.linalg.det(F)
return tm.external(
x=J ** (-1 / 3) * F,
function=tr.function(W, ntrax=F.ntrax),
gradient=tr.gradient(W, ntrax=F.ntrax),
hessian=tr.hessian(W, ntrax=F.ntrax),
function=tr.function(psi, ntrax=F.ntrax),
gradient=tr.gradient(psi, ntrax=F.ntrax),
hessian=tr.hessian(psi, ntrax=F.ntrax),
indices="ij",
)


def neo_hooke(F, mu=1):
C = F.T() @ F
C = tm.dot(tm.transpose(F), F)
I1 = tm.trace(C)
J = tm.linalg.det(F)
return mu * (J ** (-2 / 3) * I1 - 3) / 2
Expand All @@ -55,6 +55,16 @@ def test_external():
assert np.allclose(*dWdF)
assert np.allclose(*d2WdF2)

W = tm.external(
x=F,
function=neo_hooke,
gradient=None,
hessian=None,
indices=None,
)

assert W.shape == (1, 2100)


if __name__ == "__main__":
test_external()

0 comments on commit bc91f19

Please sign in to comment.