Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix MLPMixer test #183

Merged
merged 2 commits into from
Jan 24, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -16,3 +16,4 @@ transformers
fsspec
einops
torch
ml_collections
20 changes: 10 additions & 10 deletions tests/jax/models/mlpmixer/test_mlpmixer.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import flax.traverse_util
import fsspec
import jax
import jax.numpy as jnp
import ml_collections
import numpy
import pytest
from flax import linen as nn
Expand All @@ -29,7 +29,7 @@ class MlpMixerTester(ModelTester):

# @override
def _get_model(self) -> nn.Module:
patch = jnp.ones((patch_size, patch_size))
patch = ml_collections.ConfigDict({"size": (patch_size, patch_size)})
return MlpMixer(
patches=patch,
num_classes=num_classes,
Expand All @@ -56,19 +56,14 @@ def _get_forward_method_name(self) -> str:
# @override
def _get_input_activations(self) -> jax.Array:
key = jax.random.PRNGKey(42)
random_image = jax.random.normal(key, (1, 196, 196, 3))
random_image = jax.random.normal(key, (1, 224, 224, 3))
return random_image

# @override
def _get_forward_method_args(self) -> Sequence[Any]:
ins = self._get_input_activations()
weights = self._retrieve_pretrained_weights()

# Required to bypass "Initializer expected to generate shape (16, 16, 3, 768) but got shape (256, 3, 768)"
kernel = weights["params"]["stem"]["kernel"]
kernel = kernel.reshape(-1, 3, hidden_dim)
weights["params"]["stem"]["kernel"] = kernel

# Alternatively, weights could be randomly initialized like this:
# weights = self._model.init(jax.random.PRNGKey(42), ins)

Expand All @@ -93,8 +88,13 @@ def training_tester() -> MlpMixerTester:


@pytest.mark.skip(
reason="error: failed to legalize operation 'ttir.convolution' that was explicitly marked illegal"
)
reason=(
"Statically allocated circular buffers in program 16 clash with L1 buffers "
"on core range [(x=0,y=0) - (x=6,y=0)]. L1 buffer allocated at 475136 and "
"static circular buffer region ends at 951136 "
"(https://github.com/tenstorrent/tt-xla/issues/187)"
)
) # segfault
def test_mlpmixer(inference_tester: MlpMixerTester):
inference_tester.test()

Expand Down
Loading