Skip to content

Commit

Permalink
Merge branch 'main' into umales/mnist_mse
Browse files Browse the repository at this point in the history
  • Loading branch information
umalesTT authored Jan 27, 2025
2 parents 164f7c3 + a590aa0 commit 9d4f831
Show file tree
Hide file tree
Showing 3 changed files with 15 additions and 11 deletions.
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -16,3 +16,4 @@ transformers
fsspec
einops
torch
ml_collections
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,10 @@ def training_tester() -> LLamaTester:
# ----- Tests -----


@pytest.mark.xfail(reason="failed to legalize operation 'stablehlo.reduce'")
# @pytest.mark.xfail(reason="failed to legalize operation 'stablehlo.reduce'")
@pytest.mark.skip(
reason="OOMs in CI (https://github.com/tenstorrent/tt-xla/issues/186)"
)
def test_openllama3b_inference(
inference_tester: LLamaTester,
):
Expand Down
20 changes: 10 additions & 10 deletions tests/jax/models/mlpmixer/test_mlpmixer.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import flax.traverse_util
import fsspec
import jax
import jax.numpy as jnp
import ml_collections
import numpy
import pytest
from flax import linen as nn
Expand All @@ -29,7 +29,7 @@ class MlpMixerTester(ModelTester):

# @override
def _get_model(self) -> nn.Module:
patch = jnp.ones((patch_size, patch_size))
patch = ml_collections.ConfigDict({"size": (patch_size, patch_size)})
return MlpMixer(
patches=patch,
num_classes=num_classes,
Expand All @@ -56,19 +56,14 @@ def _get_forward_method_name(self) -> str:
# @override
def _get_input_activations(self) -> jax.Array:
key = jax.random.PRNGKey(42)
random_image = jax.random.normal(key, (1, 196, 196, 3))
random_image = jax.random.normal(key, (1, 224, 224, 3))
return random_image

# @override
def _get_forward_method_args(self) -> Sequence[Any]:
ins = self._get_input_activations()
weights = self._retrieve_pretrained_weights()

# Required to bypass "Initializer expected to generate shape (16, 16, 3, 768) but got shape (256, 3, 768)"
kernel = weights["params"]["stem"]["kernel"]
kernel = kernel.reshape(-1, 3, hidden_dim)
weights["params"]["stem"]["kernel"] = kernel

# Alternatively, weights could be randomly initialized like this:
# weights = self._model.init(jax.random.PRNGKey(42), ins)

Expand All @@ -93,8 +88,13 @@ def training_tester() -> MlpMixerTester:


@pytest.mark.skip(
reason="error: failed to legalize operation 'ttir.convolution' that was explicitly marked illegal"
)
reason=(
"Statically allocated circular buffers in program 16 clash with L1 buffers "
"on core range [(x=0,y=0) - (x=6,y=0)]. L1 buffer allocated at 475136 and "
"static circular buffer region ends at 951136 "
"(https://github.com/tenstorrent/tt-xla/issues/187)"
)
) # segfault
def test_mlpmixer(inference_tester: MlpMixerTester):
inference_tester.test()

Expand Down

0 comments on commit 9d4f831

Please sign in to comment.