From 164f7c3b800636a7c022aa124422bf3fbef76c4a Mon Sep 17 00:00:00 2001 From: umalesTT Date: Fri, 24 Jan 2025 13:44:12 +0000 Subject: [PATCH 1/2] Make allclose for pcc tunable for different tests. Add test for MLP training with MSE loss. --- tests/infra/comparison.py | 15 ++++--- tests/jax/graphs/test_MLP_regression.py | 53 +++++++++++++++++++++++++ 2 files changed, 60 insertions(+), 8 deletions(-) create mode 100644 tests/jax/graphs/test_MLP_regression.py diff --git a/tests/infra/comparison.py b/tests/infra/comparison.py index 40800989..67b63f26 100644 --- a/tests/infra/comparison.py +++ b/tests/infra/comparison.py @@ -34,17 +34,18 @@ class AtolConfig(ConfigBase): required_atol: float = 1.6e-1 -@dataclass -class PccConfig(ConfigBase): - required_pcc: float = 0.99 - - @dataclass class AllcloseConfig(ConfigBase): rtol: float = 1e-2 atol: float = 1e-2 +@dataclass +class PccConfig(ConfigBase): + required_pcc: float = 0.99 + allclose: AllcloseConfig = AllcloseConfig() + + @dataclass class ComparisonConfig: equal: EqualConfig = EqualConfig(False) @@ -106,9 +107,7 @@ def compare_pcc( # If tensors are really close, pcc will be nan. Handle that before calculating pcc. try: - compare_allclose( - device_output, golden_output, AllcloseConfig(rtol=1e-2, atol=1e-2) - ) + compare_allclose(device_output, golden_output, pcc_config.allclose) except AssertionError: pcc = jnp.corrcoef(device_output.flatten(), golden_output.flatten()) pcc = jnp.min(pcc) diff --git a/tests/jax/graphs/test_MLP_regression.py b/tests/jax/graphs/test_MLP_regression.py new file mode 100644 index 00000000..68f4f1a5 --- /dev/null +++ b/tests/jax/graphs/test_MLP_regression.py @@ -0,0 +1,53 @@ +# SPDX-FileCopyrightText: (c) 2025 Tenstorrent AI ULC +# +# SPDX-License-Identifier: Apache-2.0 +import jax +import jax.numpy as jnp +import pytest +from infra import ComparisonConfig, run_graph_test_with_random_inputs + + +@pytest.fixture +def comparison_config() -> ComparisonConfig: + config = ComparisonConfig() + config.pcc.allclose.atol = 0.03 + config.pcc.allclose.rtol = 0.03 + return config + + +@pytest.mark.parametrize( + ["W1", "b1", "W2", "b2", "X", "y"], + [ + [(784, 64), (32, 64), (64, 10), (32, 10), (32, 784), (32, 10)] + ], # 32 samples, 784 features (28x28), 10 output classes +) +def test_nn_with_relu(W1, b1, W2, b2, X, y, comparison_config: ComparisonConfig): + def simple_nn(W1, b1, W2, b2, X, y): + def forward(W1, b1, W2, b2, X): + hidden = jnp.dot(X, W1) + b1 + hidden = jnp.maximum(0, hidden) + output = jnp.dot(hidden, W2) + b2 + return output + + def loss(W1, b1, W2, b2, X, y): + output = forward(W1, b1, W2, b2, X) + return jnp.mean((output - y) ** 2) + + @jax.jit + def update_params(W1, b1, W2, b2, X, y, lr=0.01): + grads = jax.grad(loss, argnums=(0, 1, 2, 3))(W1, b1, W2, b2, X, y) + W1 -= lr * grads[0] + b1 -= lr * grads[1] + W2 -= lr * grads[2] + b2 -= lr * grads[3] + return W1, b1, W2, b2, grads + + for i in range(50): + W1, b1, W2, b2, grads = update_params(W1, b1, W2, b2, X, y, lr=0.01) + + final_loss = loss(W1, b1, W2, b2, X, y) + return final_loss + + run_graph_test_with_random_inputs( + simple_nn, [W1, b1, W2, b2, X, y], comparison_config=comparison_config + ) From 05f58e5e5589f1f4da86925b071f83bb6f2f130f Mon Sep 17 00:00:00 2001 From: umalesTT Date: Fri, 31 Jan 2025 12:56:57 +0000 Subject: [PATCH 2/2] Added comments for why we need tunable pcc config. --- tests/infra/comparison.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/tests/infra/comparison.py b/tests/infra/comparison.py index 67b63f26..c8a799b7 100644 --- a/tests/infra/comparison.py +++ b/tests/infra/comparison.py @@ -40,6 +40,9 @@ class AllcloseConfig(ConfigBase): atol: float = 1e-2 +# When tensors are too close, pcc will output NaN values. +# Therefore, for each test it should be possible to separately tune the threshold of allclose.rtol and allclose.atol +# below which pcc won't be calculated and therefore test will be able to pass without pcc comparison. @dataclass class PccConfig(ConfigBase): required_pcc: float = 0.99