Skip to content

Commit

Permalink
Add tests for CLIP
Browse files Browse the repository at this point in the history
  • Loading branch information
sgligorijevicTT committed Feb 26, 2025
1 parent a8769d7 commit 8fa564f
Show file tree
Hide file tree
Showing 10 changed files with 273 additions and 0 deletions.
Empty file.
Empty file.
58 changes: 58 additions & 0 deletions tests/jax/models/clip/base_patch16/test_clip_base_patch16.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
# SPDX-FileCopyrightText: (c) 2025 Tenstorrent AI ULC
#
# SPDX-License-Identifier: Apache-2.0


from typing import Callable

import pytest
from infra import RunMode
from utils import compile_fail, record_model_test_properties

from ..tester import FlaxCLIPTester

MODEL_PATH = "openai/clip-vit-base-patch16"
MODEL_NAME = "clip-base-patch16"


# ----- Fixtures -----


@pytest.fixture
def inference_tester() -> FlaxCLIPTester:
return FlaxCLIPTester(MODEL_PATH)


@pytest.fixture
def training_tester() -> FlaxCLIPTester:
return FlaxCLIPTester(MODEL_PATH, RunMode.TRAINING)


# ----- Tests -----


@pytest.mark.push
@pytest.mark.nightly
@pytest.mark.skip(
reason=compile_fail(
'Assertion `llvm::isUIntN(BitWidth, val) && "Value is not an N-bit unsigned value"\' failed.'
)
)
def test_clip_base_patch16_inference(
inference_tester: FlaxCLIPTester,
record_tt_xla_property: Callable,
):
record_model_test_properties(record_tt_xla_property, MODEL_NAME)

inference_tester.test()


@pytest.mark.push
@pytest.mark.skip(reason="Support for training not implemented")
def test_clip_base_patch16_training(
training_tester: FlaxCLIPTester,
record_tt_xla_property: Callable,
):
record_model_test_properties(record_tt_xla_property, MODEL_NAME)

training_tester.test()
Empty file.
56 changes: 56 additions & 0 deletions tests/jax/models/clip/base_patch32/test_clip_base_patch32.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
# SPDX-FileCopyrightText: (c) 2025 Tenstorrent AI ULC
#
# SPDX-License-Identifier: Apache-2.0


from typing import Callable

import pytest
from infra import RunMode
from utils import compile_fail, record_model_test_properties

from ..tester import FlaxCLIPTester

MODEL_PATH = "openai/clip-vit-base-patch32"
MODEL_NAME = "clip-base-patch32"


# ----- Fixtures -----


@pytest.fixture
def inference_tester() -> FlaxCLIPTester:
return FlaxCLIPTester(MODEL_PATH)


@pytest.fixture
def training_tester() -> FlaxCLIPTester:
return FlaxCLIPTester(MODEL_PATH, RunMode.TRAINING)


# ----- Tests -----


@pytest.mark.nightly
@pytest.mark.skip(
reason=compile_fail(
'Assertion `llvm::isUIntN(BitWidth, val) && "Value is not an N-bit unsigned value"\' failed.'
)
)
def test_clip_base_patch32_inference(
inference_tester: FlaxCLIPTester,
record_tt_xla_property: Callable,
):
record_model_test_properties(record_tt_xla_property, MODEL_NAME)

inference_tester.test()


@pytest.mark.skip(reason="Support for training not implemented")
def test_clip_base_patch32_training(
training_tester: FlaxCLIPTester,
record_tt_xla_property: Callable,
):
record_model_test_properties(record_tt_xla_property, MODEL_NAME)

training_tester.test()
Empty file.
56 changes: 56 additions & 0 deletions tests/jax/models/clip/large_patch14/test_clip_large_patch14.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
# SPDX-FileCopyrightText: (c) 2025 Tenstorrent AI ULC
#
# SPDX-License-Identifier: Apache-2.0


from typing import Callable

import pytest
from infra import RunMode
from utils import compile_fail, record_model_test_properties

from ..tester import FlaxCLIPTester

MODEL_PATH = "openai/clip-vit-large-patch14"
MODEL_NAME = "clip-large-patch14"


# ----- Fixtures -----


@pytest.fixture
def inference_tester() -> FlaxCLIPTester:
return FlaxCLIPTester(MODEL_PATH)


@pytest.fixture
def training_tester() -> FlaxCLIPTester:
return FlaxCLIPTester(MODEL_PATH, RunMode.TRAINING)


# ----- Tests -----


@pytest.mark.nightly
@pytest.mark.skip(
reason=compile_fail(
'Assertion `llvm::isUIntN(BitWidth, val) && "Value is not an N-bit unsigned value"\' failed.'
)
)
def test_clip_large_patch14_inference(
inference_tester: FlaxCLIPTester,
record_tt_xla_property: Callable,
):
record_model_test_properties(record_tt_xla_property, MODEL_NAME)

inference_tester.test()


@pytest.mark.skip(reason="Support for training not implemented")
def test_clip_large_patch14_training(
training_tester: FlaxCLIPTester,
record_tt_xla_property: Callable,
):
record_model_test_properties(record_tt_xla_property, MODEL_NAME)

training_tester.test()
Empty file.
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
# SPDX-FileCopyrightText: (c) 2025 Tenstorrent AI ULC
#
# SPDX-License-Identifier: Apache-2.0


from typing import Callable

import pytest
from infra import RunMode
from utils import compile_fail, record_model_test_properties

from ..tester import FlaxCLIPTester

MODEL_PATH = "openai/clip-vit-large-patch14-336"
MODEL_NAME = "clip-large-patch14-336"


# ----- Fixtures -----


@pytest.fixture
def inference_tester() -> FlaxCLIPTester:
return FlaxCLIPTester(MODEL_PATH)


@pytest.fixture
def training_tester() -> FlaxCLIPTester:
return FlaxCLIPTester(MODEL_PATH, RunMode.TRAINING)


# ----- Tests -----


@pytest.mark.nightly
@pytest.mark.skip(
reason=compile_fail(
'Assertion `llvm::isUIntN(BitWidth, val) && "Value is not an N-bit unsigned value"\' failed.'
)
)
def test_clip_large_patch14_336_inference(
inference_tester: FlaxCLIPTester,
record_tt_xla_property: Callable,
):
record_model_test_properties(record_tt_xla_property, MODEL_NAME)

inference_tester.test()


@pytest.mark.skip(reason="Support for training not implemented")
def test_clip_large_patch14_336_training(
training_tester: FlaxCLIPTester,
record_tt_xla_property: Callable,
):
record_model_test_properties(record_tt_xla_property, MODEL_NAME)

training_tester.test()
47 changes: 47 additions & 0 deletions tests/jax/models/clip/tester.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
# SPDX-FileCopyrightText: (c) 2025 Tenstorrent AI ULC
#
# SPDX-License-Identifier: Apache-2.0

from typing import Dict

import jax
from infra import ComparisonConfig, ModelTester, RunMode
from transformers import CLIPProcessor, FlaxCLIPModel, FlaxPreTrainedModel


class FlaxCLIPTester(ModelTester):
"""Tester for CLIP family of models on image classification tasks."""

def __init__(
self,
model_name: str,
comparison_config: ComparisonConfig = ComparisonConfig(),
run_mode: RunMode = RunMode.INFERENCE,
) -> None:
self._model_name = model_name
super().__init__(comparison_config, run_mode)

# @override
def _get_model(self) -> FlaxPreTrainedModel:
return FlaxCLIPModel.from_pretrained(self._model_name)

# @override
def _get_input_activations(self) -> Dict:
image = jax.random.uniform(jax.random.PRNGKey(42), (1, 3, 224, 224))
preprocessor = CLIPProcessor.from_pretrained(self._model_name, do_rescale=False)
inputs = preprocessor(
text=["a photo of a cat", "a photo of a dog"],
images=image,
return_tensors="np",
)
print(inputs)
input()
return inputs

# @override
def _get_forward_method_kwargs(self) -> Dict[str, jax.Array]:
assert hasattr(self._model, "params")
return {
"params": self._model.params,
**self._get_input_activations(),
}

0 comments on commit 8fa564f

Please sign in to comment.