Skip to content

Commit

Permalink
Stop using FunctionGraph and tag.test_value in linker tests
Browse files Browse the repository at this point in the history
Co-authored-by:  Adv <adhvaithhundi.221ds003@nitk.edu.in>
  • Loading branch information
AdvH039 authored and ricardoV94 committed Feb 17, 2025
1 parent 8e5e8a4 commit 99333ba
Show file tree
Hide file tree
Showing 40 changed files with 1,098 additions and 1,602 deletions.
51 changes: 24 additions & 27 deletions tests/link/jax/test_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,12 @@
from pytensor.compile.builders import OpFromGraph
from pytensor.compile.function import function
from pytensor.compile.mode import JAX, Mode
from pytensor.compile.sharedvalue import SharedVariable, shared
from pytensor.compile.sharedvalue import shared
from pytensor.configdefaults import config
from pytensor.graph import RewriteDatabaseQuery
from pytensor.graph.basic import Apply
from pytensor.graph.basic import Apply, Variable
from pytensor.graph.fg import FunctionGraph
from pytensor.graph.op import Op, get_test_value
from pytensor.graph.op import Op
from pytensor.ifelse import ifelse
from pytensor.link.jax import JAXLinker
from pytensor.raise_op import assert_op
Expand All @@ -34,25 +34,28 @@ def set_pytensor_flags():


def compare_jax_and_py(
fgraph: FunctionGraph,
graph_inputs: Iterable[Variable],
graph_outputs: Variable | Iterable[Variable],
test_inputs: Iterable,
*,
assert_fn: Callable | None = None,
must_be_device_array: bool = True,
jax_mode=jax_mode,
py_mode=py_mode,
):
"""Function to compare python graph output and jax compiled output for testing equality
"""Function to compare python function output and jax compiled output for testing equality
In the tests below computational graphs are defined in PyTensor. These graphs are then passed to
this function which then compiles the graphs in both jax and python, runs the calculation
in both and checks if the results are the same
The inputs and outputs are then passed to this function which then compiles the given function in both
jax and python, runs the calculation in both and checks if the results are the same
Parameters
----------
fgraph: FunctionGraph
PyTensor function Graph object
graph_inputs:
Symbolic inputs to the graph
outputs:
Symbolic outputs of the graph
test_inputs: iter
Numerical inputs for testing the function graph
Numerical inputs for testing the function.
assert_fn: func, opt
Assert function used to check for equality between python and jax. If not
provided uses np.testing.assert_allclose
Expand All @@ -68,8 +71,10 @@ def compare_jax_and_py(
if assert_fn is None:
assert_fn = partial(np.testing.assert_allclose, rtol=1e-4)

fn_inputs = [i for i in fgraph.inputs if not isinstance(i, SharedVariable)]
pytensor_jax_fn = function(fn_inputs, fgraph.outputs, mode=jax_mode)
if any(inp.owner is not None for inp in graph_inputs):
raise ValueError("Inputs must be root variables")

pytensor_jax_fn = function(graph_inputs, graph_outputs, mode=jax_mode)
jax_res = pytensor_jax_fn(*test_inputs)

if must_be_device_array:
Expand All @@ -78,10 +83,10 @@ def compare_jax_and_py(
else:
assert isinstance(jax_res, jax.Array)

pytensor_py_fn = function(fn_inputs, fgraph.outputs, mode=py_mode)
pytensor_py_fn = function(graph_inputs, graph_outputs, mode=py_mode)
py_res = pytensor_py_fn(*test_inputs)

if len(fgraph.outputs) > 1:
if isinstance(graph_outputs, list | tuple):
for j, p in zip(jax_res, py_res, strict=True):
assert_fn(j, p)
else:
Expand Down Expand Up @@ -187,16 +192,14 @@ def test_jax_ifelse():
false_vals = np.r_[-1, -2, -3]

x = ifelse(np.array(True), true_vals, false_vals)
x_fg = FunctionGraph([], [x])

compare_jax_and_py(x_fg, [])
compare_jax_and_py([], [x], [])

a = dscalar("a")
a.tag.test_value = np.array(0.2, dtype=config.floatX)
a_test = np.array(0.2, dtype=config.floatX)
x = ifelse(a < 0.5, true_vals, false_vals)
x_fg = FunctionGraph([a], [x]) # I.e. False

compare_jax_and_py(x_fg, [get_test_value(i) for i in x_fg.inputs])
compare_jax_and_py([a], [x], [a_test])


def test_jax_checkandraise():
Expand All @@ -209,22 +212,16 @@ def test_jax_checkandraise():
function((p,), res, mode=jax_mode)


def set_test_value(x, v):
x.tag.test_value = v
return x


def test_OpFromGraph():
x, y, z = matrices("xyz")
ofg_1 = OpFromGraph([x, y], [x + y], inline=False)
ofg_2 = OpFromGraph([x, y], [x * y, x - y], inline=False)

o1, o2 = ofg_2(y, z)
out = ofg_1(x, o1) + o2
out_fg = FunctionGraph([x, y, z], [out])

xv = np.ones((2, 2), dtype=config.floatX)
yv = np.ones((2, 2), dtype=config.floatX) * 3
zv = np.ones((2, 2), dtype=config.floatX) * 5

compare_jax_and_py(out_fg, [xv, yv, zv])
compare_jax_and_py([x, y, z], [out], [xv, yv, zv])
13 changes: 5 additions & 8 deletions tests/link/jax/test_blas.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,6 @@
from pytensor.compile.function import function
from pytensor.compile.mode import Mode
from pytensor.configdefaults import config
from pytensor.graph.fg import FunctionGraph
from pytensor.graph.op import get_test_value
from pytensor.graph.rewriting.db import RewriteDatabaseQuery
from pytensor.link.jax import JAXLinker
from pytensor.tensor import blas as pt_blas
Expand All @@ -16,21 +14,20 @@
def test_jax_BatchedDot():
# tensor3 . tensor3
a = tensor3("a")
a.tag.test_value = (
a_test_value = (
np.linspace(-1, 1, 10 * 5 * 3).astype(config.floatX).reshape((10, 5, 3))
)
b = tensor3("b")
b.tag.test_value = (
b_test_value = (
np.linspace(1, -1, 10 * 3 * 2).astype(config.floatX).reshape((10, 3, 2))
)
out = pt_blas.BatchedDot()(a, b)
fgraph = FunctionGraph([a, b], [out])
compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs])
compare_jax_and_py([a, b], [out], [a_test_value, b_test_value])

# A dimension mismatch should raise a TypeError for compatibility
inputs = [get_test_value(a)[:-1], get_test_value(b)]
inputs = [a_test_value[:-1], b_test_value]
opts = RewriteDatabaseQuery(include=[None], exclude=["cxx_only", "BlasOpt"])
jax_mode = Mode(JAXLinker(), opts)
pytensor_jax_fn = function(fgraph.inputs, fgraph.outputs, mode=jax_mode)
pytensor_jax_fn = function([a, b], [out], mode=jax_mode)
with pytest.raises(TypeError):
pytensor_jax_fn(*inputs)
4 changes: 1 addition & 3 deletions tests/link/jax/test_blockwise.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
import pytest

from pytensor import config
from pytensor.graph import FunctionGraph
from pytensor.tensor import tensor
from pytensor.tensor.blockwise import Blockwise
from pytensor.tensor.math import Dot, matmul
Expand Down Expand Up @@ -32,8 +31,7 @@ def test_matmul(matmul_op):

out = matmul_op(a, b)
assert isinstance(out.owner.op, Blockwise)
fg = FunctionGraph([a, b], [out])
fn, _ = compare_jax_and_py(fg, test_values)
fn, _ = compare_jax_and_py([a, b], [out], test_values)

# Check we are not adding any unnecessary stuff
jaxpr = str(jax.make_jaxpr(fn.vm.jit_fn)(*test_values))
Expand Down
7 changes: 2 additions & 5 deletions tests/link/jax/test_einsum.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
import pytest

import pytensor.tensor as pt
from pytensor.graph import FunctionGraph
from tests.link.jax.test_basic import compare_jax_and_py


Expand All @@ -22,8 +21,7 @@ def test_jax_einsum():
}
x_pt, y_pt, z_pt = (pt.tensor(name, shape=shape) for name, shape in shapes.items())
out = pt.einsum(subscripts, x_pt, y_pt, z_pt)
fg = FunctionGraph([x_pt, y_pt, z_pt], [out])
compare_jax_and_py(fg, [x, y, z])
compare_jax_and_py([x_pt, y_pt, z_pt], [out], [x, y, z])


def test_ellipsis_einsum():
Expand All @@ -34,5 +32,4 @@ def test_ellipsis_einsum():
x_pt = pt.tensor("x", shape=x.shape)
y_pt = pt.tensor("y", shape=y.shape)
out = pt.einsum(subscripts, x_pt, y_pt)
fg = FunctionGraph([x_pt, y_pt], [out])
compare_jax_and_py(fg, [x, y])
compare_jax_and_py([x_pt, y_pt], [out], [x, y])
56 changes: 23 additions & 33 deletions tests/link/jax/test_elemwise.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,6 @@
import pytensor.tensor as pt
from pytensor.compile import get_mode
from pytensor.configdefaults import config
from pytensor.graph.fg import FunctionGraph
from pytensor.graph.op import get_test_value
from pytensor.tensor import elemwise as pt_elemwise
from pytensor.tensor.math import all as pt_all
from pytensor.tensor.math import prod
Expand All @@ -26,87 +24,81 @@ def test_jax_Dimshuffle():
a_pt = matrix("a")

x = a_pt.T
x_fg = FunctionGraph([a_pt], [x])
compare_jax_and_py(x_fg, [np.c_[[1.0, 2.0], [3.0, 4.0]].astype(config.floatX)])
compare_jax_and_py(
[a_pt], [x], [np.c_[[1.0, 2.0], [3.0, 4.0]].astype(config.floatX)]
)

x = a_pt.dimshuffle([0, 1, "x"])
x_fg = FunctionGraph([a_pt], [x])
compare_jax_and_py(x_fg, [np.c_[[1.0, 2.0], [3.0, 4.0]].astype(config.floatX)])
compare_jax_and_py(
[a_pt], [x], [np.c_[[1.0, 2.0], [3.0, 4.0]].astype(config.floatX)]
)

a_pt = tensor(dtype=config.floatX, shape=(None, 1))
x = a_pt.dimshuffle((0,))
x_fg = FunctionGraph([a_pt], [x])
compare_jax_and_py(x_fg, [np.c_[[1.0, 2.0, 3.0, 4.0]].astype(config.floatX)])
compare_jax_and_py([a_pt], [x], [np.c_[[1.0, 2.0, 3.0, 4.0]].astype(config.floatX)])

a_pt = tensor(dtype=config.floatX, shape=(None, 1))
x = pt_elemwise.DimShuffle(input_ndim=2, new_order=(0,))(a_pt)
x_fg = FunctionGraph([a_pt], [x])
compare_jax_and_py(x_fg, [np.c_[[1.0, 2.0, 3.0, 4.0]].astype(config.floatX)])
compare_jax_and_py([a_pt], [x], [np.c_[[1.0, 2.0, 3.0, 4.0]].astype(config.floatX)])


def test_jax_CAReduce():
a_pt = vector("a")
a_pt.tag.test_value = np.r_[1, 2, 3].astype(config.floatX)

x = pt_sum(a_pt, axis=None)
x_fg = FunctionGraph([a_pt], [x])

compare_jax_and_py(x_fg, [np.r_[1, 2, 3].astype(config.floatX)])
compare_jax_and_py([a_pt], [x], [np.r_[1, 2, 3].astype(config.floatX)])

a_pt = matrix("a")
a_pt.tag.test_value = np.c_[[1, 2, 3], [1, 2, 3]].astype(config.floatX)

x = pt_sum(a_pt, axis=0)
x_fg = FunctionGraph([a_pt], [x])

compare_jax_and_py(x_fg, [np.c_[[1, 2, 3], [1, 2, 3]].astype(config.floatX)])
compare_jax_and_py([a_pt], [x], [np.c_[[1, 2, 3], [1, 2, 3]].astype(config.floatX)])

x = pt_sum(a_pt, axis=1)
x_fg = FunctionGraph([a_pt], [x])

compare_jax_and_py(x_fg, [np.c_[[1, 2, 3], [1, 2, 3]].astype(config.floatX)])
compare_jax_and_py([a_pt], [x], [np.c_[[1, 2, 3], [1, 2, 3]].astype(config.floatX)])

a_pt = matrix("a")
a_pt.tag.test_value = np.c_[[1, 2, 3], [1, 2, 3]].astype(config.floatX)

x = prod(a_pt, axis=0)
x_fg = FunctionGraph([a_pt], [x])

compare_jax_and_py(x_fg, [np.c_[[1, 2, 3], [1, 2, 3]].astype(config.floatX)])
compare_jax_and_py([a_pt], [x], [np.c_[[1, 2, 3], [1, 2, 3]].astype(config.floatX)])

x = pt_all(a_pt)
x_fg = FunctionGraph([a_pt], [x])

compare_jax_and_py(x_fg, [np.c_[[1, 2, 3], [1, 2, 3]].astype(config.floatX)])
compare_jax_and_py([a_pt], [x], [np.c_[[1, 2, 3], [1, 2, 3]].astype(config.floatX)])


@pytest.mark.parametrize("axis", [None, 0, 1])
def test_softmax(axis):
x = matrix("x")
x.tag.test_value = np.arange(6, dtype=config.floatX).reshape(2, 3)
x_test_value = np.arange(6, dtype=config.floatX).reshape(2, 3)
out = softmax(x, axis=axis)
fgraph = FunctionGraph([x], [out])
compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs])
compare_jax_and_py([x], [out], [x_test_value])


@pytest.mark.parametrize("axis", [None, 0, 1])
def test_logsoftmax(axis):
x = matrix("x")
x.tag.test_value = np.arange(6, dtype=config.floatX).reshape(2, 3)
x_test_value = np.arange(6, dtype=config.floatX).reshape(2, 3)
out = log_softmax(x, axis=axis)
fgraph = FunctionGraph([x], [out])
compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs])

compare_jax_and_py([x], [out], [x_test_value])


@pytest.mark.parametrize("axis", [None, 0, 1])
def test_softmax_grad(axis):
dy = matrix("dy")
dy.tag.test_value = np.array([[1, 1, 1], [0, 0, 0]], dtype=config.floatX)
dy_test_value = np.array([[1, 1, 1], [0, 0, 0]], dtype=config.floatX)
sm = matrix("sm")
sm.tag.test_value = np.arange(6, dtype=config.floatX).reshape(2, 3)
sm_test_value = np.arange(6, dtype=config.floatX).reshape(2, 3)
out = SoftmaxGrad(axis=axis)(dy, sm)
fgraph = FunctionGraph([dy, sm], [out])
compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs])

compare_jax_and_py([dy, sm], [out], [dy_test_value, sm_test_value])


@pytest.mark.parametrize("size", [(10, 10), (1000, 1000)])
Expand Down Expand Up @@ -134,6 +126,4 @@ def test_logsumexp_benchmark(size, axis, benchmark):
def test_multiple_input_multiply():
x, y, z = vectors("xyz")
out = pt.mul(x, y, z)

fg = FunctionGraph(outputs=[out], clone=False)
compare_jax_and_py(fg, [[1.5], [2.5], [3.5]])
compare_jax_and_py([x, y, z], [out], test_inputs=[[1.5], [2.5], [3.5]])
Loading

0 comments on commit 99333ba

Please sign in to comment.