-
Notifications
You must be signed in to change notification settings - Fork 3
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
a8769d7
commit 8fa564f
Showing
10 changed files
with
273 additions
and
0 deletions.
There are no files selected for viewing
Empty file.
Empty file.
58 changes: 58 additions & 0 deletions
58
tests/jax/models/clip/base_patch16/test_clip_base_patch16.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,58 @@ | ||
# SPDX-FileCopyrightText: (c) 2025 Tenstorrent AI ULC | ||
# | ||
# SPDX-License-Identifier: Apache-2.0 | ||
|
||
|
||
from typing import Callable | ||
|
||
import pytest | ||
from infra import RunMode | ||
from utils import compile_fail, record_model_test_properties | ||
|
||
from ..tester import FlaxCLIPTester | ||
|
||
MODEL_PATH = "openai/clip-vit-base-patch16" | ||
MODEL_NAME = "clip-base-patch16" | ||
|
||
|
||
# ----- Fixtures ----- | ||
|
||
|
||
@pytest.fixture | ||
def inference_tester() -> FlaxCLIPTester: | ||
return FlaxCLIPTester(MODEL_PATH) | ||
|
||
|
||
@pytest.fixture | ||
def training_tester() -> FlaxCLIPTester: | ||
return FlaxCLIPTester(MODEL_PATH, RunMode.TRAINING) | ||
|
||
|
||
# ----- Tests ----- | ||
|
||
|
||
@pytest.mark.push | ||
@pytest.mark.nightly | ||
@pytest.mark.skip( | ||
reason=compile_fail( | ||
'Assertion `llvm::isUIntN(BitWidth, val) && "Value is not an N-bit unsigned value"\' failed.' | ||
) | ||
) | ||
def test_clip_base_patch16_inference( | ||
inference_tester: FlaxCLIPTester, | ||
record_tt_xla_property: Callable, | ||
): | ||
record_model_test_properties(record_tt_xla_property, MODEL_NAME) | ||
|
||
inference_tester.test() | ||
|
||
|
||
@pytest.mark.push | ||
@pytest.mark.skip(reason="Support for training not implemented") | ||
def test_clip_base_patch16_training( | ||
training_tester: FlaxCLIPTester, | ||
record_tt_xla_property: Callable, | ||
): | ||
record_model_test_properties(record_tt_xla_property, MODEL_NAME) | ||
|
||
training_tester.test() |
Empty file.
56 changes: 56 additions & 0 deletions
56
tests/jax/models/clip/base_patch32/test_clip_base_patch32.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,56 @@ | ||
# SPDX-FileCopyrightText: (c) 2025 Tenstorrent AI ULC | ||
# | ||
# SPDX-License-Identifier: Apache-2.0 | ||
|
||
|
||
from typing import Callable | ||
|
||
import pytest | ||
from infra import RunMode | ||
from utils import compile_fail, record_model_test_properties | ||
|
||
from ..tester import FlaxCLIPTester | ||
|
||
MODEL_PATH = "openai/clip-vit-base-patch32" | ||
MODEL_NAME = "clip-base-patch32" | ||
|
||
|
||
# ----- Fixtures ----- | ||
|
||
|
||
@pytest.fixture | ||
def inference_tester() -> FlaxCLIPTester: | ||
return FlaxCLIPTester(MODEL_PATH) | ||
|
||
|
||
@pytest.fixture | ||
def training_tester() -> FlaxCLIPTester: | ||
return FlaxCLIPTester(MODEL_PATH, RunMode.TRAINING) | ||
|
||
|
||
# ----- Tests ----- | ||
|
||
|
||
@pytest.mark.nightly | ||
@pytest.mark.skip( | ||
reason=compile_fail( | ||
'Assertion `llvm::isUIntN(BitWidth, val) && "Value is not an N-bit unsigned value"\' failed.' | ||
) | ||
) | ||
def test_clip_base_patch32_inference( | ||
inference_tester: FlaxCLIPTester, | ||
record_tt_xla_property: Callable, | ||
): | ||
record_model_test_properties(record_tt_xla_property, MODEL_NAME) | ||
|
||
inference_tester.test() | ||
|
||
|
||
@pytest.mark.skip(reason="Support for training not implemented") | ||
def test_clip_base_patch32_training( | ||
training_tester: FlaxCLIPTester, | ||
record_tt_xla_property: Callable, | ||
): | ||
record_model_test_properties(record_tt_xla_property, MODEL_NAME) | ||
|
||
training_tester.test() |
Empty file.
56 changes: 56 additions & 0 deletions
56
tests/jax/models/clip/large_patch14/test_clip_large_patch14.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,56 @@ | ||
# SPDX-FileCopyrightText: (c) 2025 Tenstorrent AI ULC | ||
# | ||
# SPDX-License-Identifier: Apache-2.0 | ||
|
||
|
||
from typing import Callable | ||
|
||
import pytest | ||
from infra import RunMode | ||
from utils import compile_fail, record_model_test_properties | ||
|
||
from ..tester import FlaxCLIPTester | ||
|
||
MODEL_PATH = "openai/clip-vit-large-patch14" | ||
MODEL_NAME = "clip-large-patch14" | ||
|
||
|
||
# ----- Fixtures ----- | ||
|
||
|
||
@pytest.fixture | ||
def inference_tester() -> FlaxCLIPTester: | ||
return FlaxCLIPTester(MODEL_PATH) | ||
|
||
|
||
@pytest.fixture | ||
def training_tester() -> FlaxCLIPTester: | ||
return FlaxCLIPTester(MODEL_PATH, RunMode.TRAINING) | ||
|
||
|
||
# ----- Tests ----- | ||
|
||
|
||
@pytest.mark.nightly | ||
@pytest.mark.skip( | ||
reason=compile_fail( | ||
'Assertion `llvm::isUIntN(BitWidth, val) && "Value is not an N-bit unsigned value"\' failed.' | ||
) | ||
) | ||
def test_clip_large_patch14_inference( | ||
inference_tester: FlaxCLIPTester, | ||
record_tt_xla_property: Callable, | ||
): | ||
record_model_test_properties(record_tt_xla_property, MODEL_NAME) | ||
|
||
inference_tester.test() | ||
|
||
|
||
@pytest.mark.skip(reason="Support for training not implemented") | ||
def test_clip_large_patch14_training( | ||
training_tester: FlaxCLIPTester, | ||
record_tt_xla_property: Callable, | ||
): | ||
record_model_test_properties(record_tt_xla_property, MODEL_NAME) | ||
|
||
training_tester.test() |
Empty file.
56 changes: 56 additions & 0 deletions
56
tests/jax/models/clip/large_patch14_336/test_clip_large_patch14_336.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,56 @@ | ||
# SPDX-FileCopyrightText: (c) 2025 Tenstorrent AI ULC | ||
# | ||
# SPDX-License-Identifier: Apache-2.0 | ||
|
||
|
||
from typing import Callable | ||
|
||
import pytest | ||
from infra import RunMode | ||
from utils import compile_fail, record_model_test_properties | ||
|
||
from ..tester import FlaxCLIPTester | ||
|
||
MODEL_PATH = "openai/clip-vit-large-patch14-336" | ||
MODEL_NAME = "clip-large-patch14-336" | ||
|
||
|
||
# ----- Fixtures ----- | ||
|
||
|
||
@pytest.fixture | ||
def inference_tester() -> FlaxCLIPTester: | ||
return FlaxCLIPTester(MODEL_PATH) | ||
|
||
|
||
@pytest.fixture | ||
def training_tester() -> FlaxCLIPTester: | ||
return FlaxCLIPTester(MODEL_PATH, RunMode.TRAINING) | ||
|
||
|
||
# ----- Tests ----- | ||
|
||
|
||
@pytest.mark.nightly | ||
@pytest.mark.skip( | ||
reason=compile_fail( | ||
'Assertion `llvm::isUIntN(BitWidth, val) && "Value is not an N-bit unsigned value"\' failed.' | ||
) | ||
) | ||
def test_clip_large_patch14_336_inference( | ||
inference_tester: FlaxCLIPTester, | ||
record_tt_xla_property: Callable, | ||
): | ||
record_model_test_properties(record_tt_xla_property, MODEL_NAME) | ||
|
||
inference_tester.test() | ||
|
||
|
||
@pytest.mark.skip(reason="Support for training not implemented") | ||
def test_clip_large_patch14_336_training( | ||
training_tester: FlaxCLIPTester, | ||
record_tt_xla_property: Callable, | ||
): | ||
record_model_test_properties(record_tt_xla_property, MODEL_NAME) | ||
|
||
training_tester.test() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,47 @@ | ||
# SPDX-FileCopyrightText: (c) 2025 Tenstorrent AI ULC | ||
# | ||
# SPDX-License-Identifier: Apache-2.0 | ||
|
||
from typing import Dict | ||
|
||
import jax | ||
from infra import ComparisonConfig, ModelTester, RunMode | ||
from transformers import CLIPProcessor, FlaxCLIPModel, FlaxPreTrainedModel | ||
|
||
|
||
class FlaxCLIPTester(ModelTester): | ||
"""Tester for CLIP family of models on image classification tasks.""" | ||
|
||
def __init__( | ||
self, | ||
model_name: str, | ||
comparison_config: ComparisonConfig = ComparisonConfig(), | ||
run_mode: RunMode = RunMode.INFERENCE, | ||
) -> None: | ||
self._model_name = model_name | ||
super().__init__(comparison_config, run_mode) | ||
|
||
# @override | ||
def _get_model(self) -> FlaxPreTrainedModel: | ||
return FlaxCLIPModel.from_pretrained(self._model_name) | ||
|
||
# @override | ||
def _get_input_activations(self) -> Dict: | ||
image = jax.random.uniform(jax.random.PRNGKey(42), (1, 3, 224, 224)) | ||
preprocessor = CLIPProcessor.from_pretrained(self._model_name, do_rescale=False) | ||
inputs = preprocessor( | ||
text=["a photo of a cat", "a photo of a dog"], | ||
images=image, | ||
return_tensors="np", | ||
) | ||
print(inputs) | ||
input() | ||
return inputs | ||
|
||
# @override | ||
def _get_forward_method_kwargs(self) -> Dict[str, jax.Array]: | ||
assert hasattr(self._model, "params") | ||
return { | ||
"params": self._model.params, | ||
**self._get_input_activations(), | ||
} |