Skip to content

Commit

Permalink
Merge pull request #15 from adtzlr/implement-wrt
Browse files Browse the repository at this point in the history
Implement with-respect-to argument (`wrt=0` or `wrt="x"`)
  • Loading branch information
adtzlr authored Dec 9, 2022
2 parents 21f5b7f + 5a0ba16 commit d6f4263
Show file tree
Hide file tree
Showing 9 changed files with 126 additions and 44 deletions.
24 changes: 16 additions & 8 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -34,14 +34,14 @@ Let's define a scalar-valued function which operates on a tensor.
import tensortrax as tr
import tensortrax.math as tm

def fun(F):
def fun(F, mu=1):
C = F.T() @ F
I1 = tm.trace(C)
J = tm.linalg.det(F)
return J ** (-2 / 3) * I1 - 3
return mu / 2 * (J ** (-2 / 3) * I1 - 3)
```

The hessian of the scalar-valued function w.r.t. the function argument is evaluated by variational calculus (Forward Mode AD implemented as Hyper-Dual Tensors). The function is called once for each component of the hessian (symmetry is taken care of). The function and the gradient are evaluated with no additional computational cost.
The hessian of the scalar-valued function w.r.t. the chosen function argument (here, `wrt=0` or `wrt="F"`) is evaluated by variational calculus (Forward Mode AD implemented as Hyper-Dual Tensors). The function is called once for each component of the hessian (symmetry is taken care of). The function and the gradient are evaluated with no additional computational cost.

```python
import numpy as np
Expand All @@ -52,9 +52,9 @@ F = np.random.rand(3, 3, 8, 50) / 10
for a in range(3):
F[a, a] += 1

# W = tr.function(fun, ntrax=2)(F)
# dWdF, W = tr.gradient(fun, ntrax=2)(F)
d2WdF2, dWdF, W = tr.hessian(fun, ntrax=2)(F)
# W = tr.function(fun, wrt=0, ntrax=2)(F)
# dWdF, W = tr.gradient(fun, wrt=0, ntrax=2)(F)
d2WdF2, dWdF, W = tr.hessian(fun, wrt="F", ntrax=2)(F=F)
```

# Theory
Expand Down Expand Up @@ -120,6 +120,7 @@ Once again, each component $A_{ijkl}$ of the fourth-order hessian is numerically
Each Tensor has four attributes: the (real) tensor array and the (hyper-dual) variational arrays. To obtain the above mentioned $12$ - component of the gradient and the $1223$ - component of the hessian, a tensor has to be created with the appropriate small-changes of the tensor components (dual arrays).

```python
import tensortrax as tr
from tensortrax import Tensor, f, δ, Δ, Δδ
from tensortrax.math import trace

Expand Down Expand Up @@ -151,8 +152,15 @@ A_1223 = Δδ(I1_C)
To obtain full gradients and hessians in one function call, `tensortrax` provides helpers (decorators) which handle the multiple function calls. Optionally, the function calls are executed in parallel (threaded).

```python
gradient(lambda F: trace(F.T() @ F), parallel=False)(x)
hessian(lambda F: trace(F.T() @ F), parallel=False)(x)
grad, func = tr.gradient(lambda F: trace(F.T() @ F), wrt=0, parallel=False)(x)
hess, grad, func = tr.hessian(lambda F: trace(F.T() @ F), wrt=0, parallel=False)(x)
```

Evaluate the gradient- as well as hessian-vector-product:

```python
gvp = tr.gradient_vector_product(lambda F: trace(F.T() @ F), parallel=False)(x, δx=x)
hvp = tr.hessian_vector_product(lambda F: trace(F.T() @ F), parallel=False)(x, δx=x, Δx=x)
```

# Extensions
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[project]
name = "tensortrax"
version = "0.1.4"
version = "0.1.5"
description = "Math on (Hyper-Dual) Tensors with Trailing Axes"
readme = "README.md"
requires-python = ">=3.7"
Expand Down
2 changes: 1 addition & 1 deletion setup.cfg
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[metadata]
name = tensortrax
version = 0.1.4
version = 0.1.5
author = Andreas Dutzler
author_email = a.dutzler@gmail.com
description = Math on (Hyper-Dual) Tensors with Trailing Axes
Expand Down
80 changes: 57 additions & 23 deletions tensortrax/_evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,45 +9,76 @@
"""

from threading import Thread
from copy import copy
import numpy as np

from ._tensor import Tensor, f, δ, Δδ


def function(fun, ntrax=0, parallel=False):
def add_tensor(args, kwargs, wrt, δx, Δx, ntrax):
"Modify the arguments and replace the w.r.t.-argument by a tensor."

kwargs_out = copy(kwargs)
args_out = list(args)

if isinstance(wrt, str):
kwargs_out[wrt] = Tensor(x=kwargs[wrt], δx=δx, Δx=Δx, ntrax=ntrax)

elif isinstance(wrt, int):
args_out[wrt] = Tensor(x=args[wrt], δx=δx, Δx=Δx, ntrax=ntrax)

return args_out, kwargs_out


def arg_to_tensor(args, kwargs, wrt):
"Return the argument which will be replaced by a tensor."

if isinstance(wrt, str):
x = kwargs[wrt]
elif isinstance(wrt, int):
x = args[wrt]
else:
raise TypeError(f"w.r.t. {wrt} not supported.")

return x


def function(fun, wrt=0, ntrax=0, parallel=False):
"Evaluate a scalar-valued function."

def evaluate_function(x, *args, **kwargs):
return fun(Tensor(x, ntrax=ntrax), *args, **kwargs).x
def evaluate_function(*args, **kwargs):
args, kwargs = add_tensor(args, kwargs, wrt, None, None, ntrax)
return fun(*args, **kwargs).x

return evaluate_function


def gradient(fun, ntrax=0, parallel=False):
def gradient(fun, wrt=0, ntrax=0, parallel=False):
"Evaluate the gradient of a scalar-valued function."

def evaluate_gradient(x, *args, **kwargs):
def evaluate_gradient(*args, **kwargs):

x = arg_to_tensor(args, kwargs, wrt)
t = Tensor(x, ntrax=ntrax)
indices = range(t.size)

fx = np.zeros((1, *t.trax))
dfdx = np.zeros((t.size, *t.trax))
δx = Δx = np.eye(t.size)

def kernel(a, x, δx, Δx, args, kwargs):
t = Tensor(x, δx=δx[a], Δx=Δx[a], ntrax=ntrax)
func = fun(t, *args, **kwargs)
def kernel(a, wrt, δx, Δx, ntrax, args, kwargs):
args, kwargs = add_tensor(args, kwargs, wrt, δx[a], Δx[a], ntrax)
func = fun(*args, **kwargs)
fx[:] = f(func)
dfdx[a] = δ(func)

if not parallel:
for a in indices:
kernel(a, x, δx, Δx, args, kwargs)
kernel(a, wrt, δx, Δx, ntrax, args, kwargs)

else:
threads = [
Thread(target=kernel, args=(a, x, δx, Δx, args, kwargs))
Thread(target=kernel, args=(a, wrt, δx, Δx, ntrax, args, kwargs))
for a in indices
]

Expand All @@ -62,11 +93,12 @@ def kernel(a, x, δx, Δx, args, kwargs):
return evaluate_gradient


def hessian(fun, ntrax=0, parallel=False):
def hessian(fun, wrt=0, ntrax=0, parallel=False):
"Evaluate the hessian of a scalar-valued function."

def evaluate_hessian(x, *args, **kwargs):
def evaluate_hessian(*args, **kwargs):

x = arg_to_tensor(args, kwargs, wrt)
t = Tensor(x, ntrax=ntrax)
indices = np.array(np.triu_indices(t.size)).T

Expand All @@ -75,20 +107,20 @@ def evaluate_hessian(x, *args, **kwargs):
d2fdx2 = np.zeros((t.size, t.size, *t.trax))
δx = Δx = np.eye(t.size)

def kernel(a, b, x, δx, Δx, args, kwargs):
t = Tensor(x, δx=δx[a], Δx=Δx[b], ntrax=ntrax)
func = fun(t, *args, **kwargs)
def kernel(a, b, wrt, δx, Δx, ntrax, args, kwargs):
args, kwargs = add_tensor(args, kwargs, wrt, δx[a], Δx[b], ntrax)
func = fun(*args, **kwargs)
fx[:] = f(func)
dfdx[a] = δ(func)
d2fdx2[a, b] = d2fdx2[b, a] = Δδ(func)

if not parallel:
for a, b in indices:
kernel(a, b, x, δx, Δx, args, kwargs)
kernel(a, b, wrt, δx, Δx, ntrax, args, kwargs)

else:
threads = [
Thread(target=kernel, args=(a, b, x, δx, Δx, args, kwargs))
Thread(target=kernel, args=(a, b, wrt, δx, Δx, ntrax, args, kwargs))
for a, b in indices
]

Expand All @@ -107,19 +139,21 @@ def kernel(a, b, x, δx, Δx, args, kwargs):
return evaluate_hessian


def gradient_vector_product(fun, ntrax=0, parallel=False):
def gradient_vector_product(fun, wrt=0, ntrax=0, parallel=False):
"Evaluate the gradient-vector-product of a function."

def evaluate_gradient_vector_product(x, δx, *args, **kwargs):
return fun(Tensor(x, δx, ntrax=ntrax), *args, **kwargs).δx
def evaluate_gradient_vector_product(*args, δx, **kwargs):
args, kwargs = add_tensor(args, kwargs, wrt, δx, None, ntrax)
return fun(*args, **kwargs).δx

return evaluate_gradient_vector_product


def hessian_vector_product(fun, ntrax=0, parallel=False):
def hessian_vector_product(fun, wrt=0, ntrax=0, parallel=False):
"Evaluate the gradient-vector-product of a function."

def evaluate_hessian_vector_product(x, δx, Δx, *args, **kwargs):
return fun(Tensor(x, δx, Δx, ntrax=ntrax), *args, **kwargs).Δδx
def evaluate_hessian_vector_product(*args, δx, Δx, **kwargs):
args, kwargs = add_tensor(args, kwargs, wrt, δx, Δx, ntrax)
return fun(*args, **kwargs).Δδx

return evaluate_hessian_vector_product
4 changes: 4 additions & 0 deletions tensortrax/_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,10 @@ def __truediv__(self, B):
x=f(A) / B, δx=δ(A) / B, Δx=Δ(A) / B, Δδx=Δδ(A) / B, ntrax=A.ntrax
)

def __rtruediv__(self, B):
A = self
return B * A**-1

def __pow__(self, p):
A = self
x = f(A) ** p
Expand Down
13 changes: 7 additions & 6 deletions tensortrax/math/_math_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,9 +148,9 @@ def log(A):
x = np.log(f(A))
return Tensor(
x=x,
δx=1 / x * δ(A),
Δx=1 / x * Δ(A),
Δδx=-1 / x**2 * δ(A) * Δ(A) + 1 / x * Δδ(A),
δx=1 / f(A) * δ(A),
Δx=1 / f(A) * Δ(A),
Δδx=-1 / f(A) ** 2 * δ(A) * Δ(A) + 1 / f(A) * Δδ(A),
ntrax=A.ntrax,
)
else:
Expand All @@ -162,9 +162,10 @@ def log10(A):
x = np.log10(f(A))
return Tensor(
x=x,
δx=1 / (np.log(10) * x) * δ(A),
Δx=1 / (np.log(10) * x) * Δ(A),
Δδx=-1 / (np.log(10) * x**2) * δ(A) * Δ(A) + 1 / (np.log(10) * x) * Δδ(A),
δx=1 / (np.log(10) * f(A)) * δ(A),
Δx=1 / (np.log(10) * f(A)) * Δ(A),
Δδx=-1 / (np.log(10) * f(A) ** 2) * δ(A) * Δ(A)
+ 1 / (np.log(10) * f(A)) * Δδ(A),
ntrax=A.ntrax,
)
else:
Expand Down
8 changes: 5 additions & 3 deletions tests/test_hessian_vector_product.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,11 @@ def test_hvp():
for parallel in [False, True]:

for fun in [neo_hooke, ogden]:
δfun = tr.gradient_vector_product(fun, ntrax=2, parallel=parallel)(F, δF)
Δδfun = tr.hessian_vector_product(fun, ntrax=2, parallel=parallel)(
F, δF, ΔF
δfun = tr.gradient_vector_product(fun, wrt="F", ntrax=2, parallel=parallel)(
F=F, δx=δF
)
Δδfun = tr.hessian_vector_product(fun, wrt="F", ntrax=2, parallel=parallel)(
F=F, δx=δF, Δx=ΔF
)

assert δfun.shape == (1, 1)
Expand Down
3 changes: 1 addition & 2 deletions tests/test_math.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,7 @@ def test_math():
assert isinstance(T * T, tr.Tensor)

assert isinstance(T / F, tr.Tensor)
with pytest.raises(TypeError):
F / T
assert isinstance(F / T, tr.Tensor)
with pytest.raises(NotImplementedError):
T / T

Expand Down
34 changes: 34 additions & 0 deletions tests/test_scalar.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
import tensortrax as tr
import tensortrax.math as tm
import numpy as np
import pytest


def fun(x, y):
return x**2 / y + x * tm.log(y)


def test_scalar():

np.random.seed(6574)
x = np.random.rand(100)

np.random.seed(54234)
y = np.random.rand(100)

with pytest.raises(TypeError):
tr.hessian(fun, wrt=[1, 2])(x, y)

h, g, f = tr.hessian(fun, wrt=0, ntrax=1)(x, y)

assert np.allclose(g, 2 * x / y + np.log(y))
assert np.allclose(h, 2 / y)

h, g, f = tr.hessian(fun, wrt="y", ntrax=1)(x=x, y=y)

assert np.allclose(g, -(x**2) / y**2 + x / y)
assert np.allclose(h, 2 * x**2 / y**3 - x / y**2)


if __name__ == "__main__":
test_scalar()

0 comments on commit d6f4263

Please sign in to comment.