Skip to content

Commit

Permalink
Only do reshapes in tensordot when needed
Browse files Browse the repository at this point in the history
  • Loading branch information
ricardoV94 committed Feb 17, 2025
1 parent 65b96c1 commit 2a539bc
Show file tree
Hide file tree
Showing 2 changed files with 86 additions and 34 deletions.
81 changes: 49 additions & 32 deletions pytensor/tensor/math.py
Original file line number Diff line number Diff line change
Expand Up @@ -2158,62 +2158,79 @@ def tensordot(
a = as_tensor_variable(a)
b = as_tensor_variable(b)
runtime_shape_a = a.shape
bcast_a = a.broadcastable
static_shape_a = a.type.shape
ndim_a = a.ndim
ndim_a = a.type.ndim
runtime_shape_b = b.shape
bcast_b = b.broadcastable
static_shape_b = b.type.shape
ndim_b = b.ndim
ndim_b = b.type.ndim
if na != nb:
raise ValueError(
"The number of axes supplied for tensordot must be equal for each tensor. "
f"Got {na} and {nb} respectively."
)
axes_a = list(normalize_axis_tuple(axes_a, ndim_a))
axes_b = list(normalize_axis_tuple(axes_b, ndim_b))

# The operation is only valid if the original dimensions match in length
# The ravelling of the dimensions to coerce the operation into a single dot
# could mask such errors, so we add an Assert if needed.
must_assert_runtime = False
for k in range(na):
ax_a = axes_a[k]
ax_b = axes_b[k]
if (bcast_a[ax_a] != bcast_b[ax_b]) or (
for ax_a, ax_b in zip(axes_a, axes_b, strict=True):
if (
static_shape_a[ax_a] is not None
and static_shape_b[ax_b] is not None
and static_shape_a[ax_a] != static_shape_b[ax_b]
):
raise ValueError(
"Input arrays have inconsistent broadcastable pattern or type shape along the axes "
"Input arrays have inconsistent type shape along the axes "
"that are to be reduced with tensordot."
)
elif static_shape_a[ax_a] is None or static_shape_b[ax_b] is None:
if must_assert_runtime:
a = Assert(
"Input array shape along reduced axes of tensordot are not equal"
)(a, eq(a.shape[ax_a], b.shape[ax_b]))
)(a, eq(runtime_shape_a[ax_a], runtime_shape_b[ax_b]))
must_assert_runtime = True

# Move the axes to sum over to the end of "a"
# and to the front of "b"
notin = [k for k in range(ndim_a) if k not in axes_a]
newaxes_a = notin + axes_a
N2 = 1
for axis in axes_a:
N2 *= runtime_shape_a[axis]
newshape_a = (-1, N2)
olda = [runtime_shape_a[axis] for axis in notin]

notin = [k for k in range(ndim_b) if k not in axes_b]
newaxes_b = axes_b + notin
N2 = 1
for axis in axes_b:
N2 *= runtime_shape_b[axis]
newshape_b = (N2, -1)
oldb = [runtime_shape_b[axis] for axis in notin]

at = a.transpose(newaxes_a).reshape(newshape_a)
bt = b.transpose(newaxes_b).reshape(newshape_b)
res = _dot(at, bt)
return res.reshape(olda + oldb)
# Convert tensordot into a stacked dot product.
# We stack the summed axes and the non-summed axes of each tensor separately,
# and place the summed axes at the end of a and the beginning of b
non_summed_axes_a = [k for k in range(ndim_a) if k not in axes_a]
non_summed_dims_a = [runtime_shape_a[axis] for axis in non_summed_axes_a]
transpose_axes_a = non_summed_axes_a + axes_a
# We only need a reshape when we need to combine summed or non-summed dims
# or introduce a new dimension (expand_dims), when doing a non-scalar outer product (axes = 0)
a_needs_reshape = (ndim_a != 0) and (
(len(non_summed_axes_a) > 1) or (len(axes_a) != 1)
)

non_summed_axes_b = [k for k in range(ndim_b) if k not in axes_b]
non_summed_dims_b = [runtime_shape_b[axis] for axis in non_summed_axes_b]
transpose_axes_b = axes_b + non_summed_axes_b
b_needs_reshape = (ndim_b != 0) and (
(len(non_summed_axes_b) > 1) or (len(axes_b) != 1)
)

# summed_size_a and summed_size_b must be the same,
# but to facilitate reasoning about useless reshapes we compute both from their shapes
at = a.transpose(transpose_axes_a)
if a_needs_reshape:
non_summed_size_a = variadic_mul(*non_summed_dims_a)
summed_size_a = variadic_mul(*[runtime_shape_a[axis] for axis in axes_a])
at = at.reshape((non_summed_size_a, summed_size_a))

bt = b.transpose(transpose_axes_b)
if b_needs_reshape:
non_summed_size_b = variadic_mul(*non_summed_dims_b)
summed_size_b = variadic_mul(*[runtime_shape_b[axis] for axis in axes_b])
bt = bt.reshape((summed_size_b, non_summed_size_b))

res = dot(at, bt)

if a_needs_reshape or b_needs_reshape:
res = res.reshape(non_summed_dims_a + non_summed_dims_b)

return res


def outer(x, y):
Expand Down
39 changes: 37 additions & 2 deletions tests/tensor/test_math.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from pytensor.compile.sharedvalue import shared
from pytensor.configdefaults import config
from pytensor.gradient import NullTypeGradError, grad, numeric_grad
from pytensor.graph.basic import Variable, ancestors, applys_between
from pytensor.graph.basic import Variable, ancestors, applys_between, equal_computations
from pytensor.graph.fg import FunctionGraph
from pytensor.graph.replace import vectorize_node
from pytensor.link.c.basic import DualLinker
Expand Down Expand Up @@ -2278,7 +2278,7 @@ def test_type_shape(self):

with pytest.raises(
ValueError,
match="Input arrays have inconsistent broadcastable pattern or type shape",
match="Input arrays have inconsistent type shape",
):
tensordot(ones(shape=(7, 4)), ones(shape=(7, 4)), axes=1)

Expand Down Expand Up @@ -2323,6 +2323,41 @@ def test_shape_assert(self, axes, has_assert, values, expected_fail):
else:
assert np.allclose(np.tensordot(xv, yv, axes=axes), z.eval({x: xv, y: yv}))

def test_eager_simplification(self):
# Test that cases where tensordot isn't needed, it returns a simple graph
scl = tensor(shape=())
vec = tensor(shape=(None,))
mat = tensor(shape=(None, None))

# scalar product
out = tensordot(scl, scl, axes=[[], []])
assert equal_computations([out], [scl * scl])

# vector-vector product
out = tensordot(vec, vec, axes=[[-1], [-1]])
assert equal_computations([out], [dot(vec, vec)])

# matrix-vector product
out = tensordot(mat, vec, axes=[[-1], [-1]])
assert equal_computations([out], [dot(mat, vec)])

out = tensordot(mat, vec, axes=[[-2], [-1]])
assert equal_computations([out], [dot(mat.T, vec)])

# vector-matrix product
out = tensordot(vec, mat, axes=[[-1], [-2]])
assert equal_computations([out], [dot(vec, mat)])

out = tensordot(vec, mat, axes=[[-1], [-1]])
assert equal_computations([out], [dot(vec, mat.T)])

# matrix-matrix product
out = tensordot(mat, mat, axes=[[-1], [-2]])
assert equal_computations([out], [dot(mat, mat)])

out = tensordot(mat, mat, axes=[[-1], [-1]])
assert equal_computations([out], [dot(mat, mat.T)])


def test_smallest():
x = dvector()
Expand Down

0 comments on commit 2a539bc

Please sign in to comment.