From 1302f1a9c0b9615496c1fe9a33bc8730a71b76fe Mon Sep 17 00:00:00 2001 From: Stefan Gligorijevic <189116645+sgligorijevicTT@users.noreply.github.com> Date: Fri, 24 Jan 2025 12:59:57 +0100 Subject: [PATCH] Fix MLPMixer test (#183) --- requirements.txt | 1 + tests/jax/models/mlpmixer/test_mlpmixer.py | 20 ++++++++++---------- 2 files changed, 11 insertions(+), 10 deletions(-) diff --git a/requirements.txt b/requirements.txt index 78abac39..781ab540 100644 --- a/requirements.txt +++ b/requirements.txt @@ -16,3 +16,4 @@ transformers fsspec einops torch +ml_collections diff --git a/tests/jax/models/mlpmixer/test_mlpmixer.py b/tests/jax/models/mlpmixer/test_mlpmixer.py index fd8c1fa9..469e0688 100644 --- a/tests/jax/models/mlpmixer/test_mlpmixer.py +++ b/tests/jax/models/mlpmixer/test_mlpmixer.py @@ -7,7 +7,7 @@ import flax.traverse_util import fsspec import jax -import jax.numpy as jnp +import ml_collections import numpy import pytest from flax import linen as nn @@ -29,7 +29,7 @@ class MlpMixerTester(ModelTester): # @override def _get_model(self) -> nn.Module: - patch = jnp.ones((patch_size, patch_size)) + patch = ml_collections.ConfigDict({"size": (patch_size, patch_size)}) return MlpMixer( patches=patch, num_classes=num_classes, @@ -56,7 +56,7 @@ def _get_forward_method_name(self) -> str: # @override def _get_input_activations(self) -> jax.Array: key = jax.random.PRNGKey(42) - random_image = jax.random.normal(key, (1, 196, 196, 3)) + random_image = jax.random.normal(key, (1, 224, 224, 3)) return random_image # @override @@ -64,11 +64,6 @@ def _get_forward_method_args(self) -> Sequence[Any]: ins = self._get_input_activations() weights = self._retrieve_pretrained_weights() - # Required to bypass "Initializer expected to generate shape (16, 16, 3, 768) but got shape (256, 3, 768)" - kernel = weights["params"]["stem"]["kernel"] - kernel = kernel.reshape(-1, 3, hidden_dim) - weights["params"]["stem"]["kernel"] = kernel - # Alternatively, weights could be randomly initialized like this: # weights = self._model.init(jax.random.PRNGKey(42), ins) @@ -93,8 +88,13 @@ def training_tester() -> MlpMixerTester: @pytest.mark.skip( - reason="error: failed to legalize operation 'ttir.convolution' that was explicitly marked illegal" -) + reason=( + "Statically allocated circular buffers in program 16 clash with L1 buffers " + "on core range [(x=0,y=0) - (x=6,y=0)]. L1 buffer allocated at 475136 and " + "static circular buffer region ends at 951136 " + "(https://github.com/tenstorrent/tt-xla/issues/187)" + ) +) # segfault def test_mlpmixer(inference_tester: MlpMixerTester): inference_tester.test()