diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index c337c650..1a8d2964 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -4,6 +4,7 @@ repos: hooks: - id: black language_version: python3 + args: [--line-length=80] - repo: https://github.com/pre-commit/mirrors-clang-format rev: v18.1.7 hooks: diff --git a/examples/simple_regression.py b/examples/simple_regression.py index 1501f6de..f27230e6 100644 --- a/examples/simple_regression.py +++ b/examples/simple_regression.py @@ -16,14 +16,20 @@ # program will execute on tt device if not specified otherwise. def initialize(): backend = "tt" - path = os.path.join(os.path.dirname(__file__), "../build/src/tt/pjrt_plugin_tt.so") + path = os.path.join( + os.path.dirname(__file__), "../build/src/tt/pjrt_plugin_tt.so" + ) if not os.path.exists(path): - raise FileNotFoundError(f"Could not find tt_pjrt C API plugin at {path}") + raise FileNotFoundError( + f"Could not find tt_pjrt C API plugin at {path}" + ) print("Loading tt_pjrt C API plugin", file=sys.stderr) xb.discover_pjrt_plugins() - plugin = xb.register_plugin("tt", priority=500, library_path=path, options=None) + plugin = xb.register_plugin( + "tt", priority=500, library_path=path, options=None + ) print("Loaded", file=sys.stderr) jax.config.update("jax_platforms", "tt,cpu") @@ -33,7 +39,9 @@ def random_input_tensor(shape, key=42, on_device=False): def random_input(shape, key): return jax.random.uniform(jax.random.PRNGKey(key), shape=shape) - jitted_tensor_creator = jax.jit(random_input, static_argnums=[0, 1], backend="cpu") + jitted_tensor_creator = jax.jit( + random_input, static_argnums=[0, 1], backend="cpu" + ) tensor = jitted_tensor_creator(shape, key) if on_device: tensor = jax.device_put(tensor, jax.devices()[0]) diff --git a/tests/infra/base_tester.py b/tests/infra/base_tester.py index f5f66d93..fb07596e 100644 --- a/tests/infra/base_tester.py +++ b/tests/infra/base_tester.py @@ -53,9 +53,13 @@ def _compare( if self._comparison_config.equal.enabled: compare_equal(device_output, golden_output) if self._comparison_config.atol.enabled: - compare_atol(device_output, golden_output, self._comparison_config.atol) + compare_atol( + device_output, golden_output, self._comparison_config.atol + ) if self._comparison_config.pcc.enabled: - compare_pcc(device_output, golden_output, self._comparison_config.pcc) + compare_pcc( + device_output, golden_output, self._comparison_config.pcc + ) if self._comparison_config.allclose.enabled: compare_allclose( device_output, golden_output, self._comparison_config.allclose @@ -68,6 +72,8 @@ def _match_data_types(self, *tensors: Tensor) -> Sequence[Tensor]: Tensors need to be in same data format in order to compare them. """ return [ - tensor.astype("float32") if tensor.dtype.str != "float32" else tensor + tensor.astype("float32") + if tensor.dtype.str != "float32" + else tensor for tensor in tensors ] diff --git a/tests/infra/comparison.py b/tests/infra/comparison.py index c8a799b7..d6ed0300 100644 --- a/tests/infra/comparison.py +++ b/tests/infra/comparison.py @@ -123,7 +123,9 @@ def compare_pcc( @run_on_cpu def compare_allclose( - device_output: Tensor, golden_output: Tensor, allclose_config: AllcloseConfig + device_output: Tensor, + golden_output: Tensor, + allclose_config: AllcloseConfig, ) -> None: assert isinstance(device_output, jax.Array) and isinstance( golden_output, jax.Array diff --git a/tests/infra/device_runner.py b/tests/infra/device_runner.py index 08551893..8307f751 100644 --- a/tests/infra/device_runner.py +++ b/tests/infra/device_runner.py @@ -68,7 +68,9 @@ def _run_on_device( workload: Workload, device_type: DeviceType, device_num: int = 0 ) -> Tensor: """Runs `workload` on device identified by `device_type`.""" - device_workload = DeviceRunner._put_on_device(workload, device_type, device_num) + device_workload = DeviceRunner._put_on_device( + workload, device_type, device_num + ) device = device_connector.connect_device(device_type, device_num) with jax.default_device(device): @@ -106,7 +108,9 @@ def _safely_put_workload_on_device( To avoid that, we try to `jax.device_put` arg or kwarg, and if it doesn't succeed, we leave it as is. """ - fn_params = list(inspect.signature(workload.executable).parameters.keys()) + fn_params = list( + inspect.signature(workload.executable).parameters.keys() + ) args_on_device = [] for i, arg in enumerate(workload.args): diff --git a/tests/infra/model_tester.py b/tests/infra/model_tester.py index da01ab0c..230039db 100644 --- a/tests/infra/model_tester.py +++ b/tests/infra/model_tester.py @@ -61,7 +61,9 @@ def _init_model_hooks(self) -> None: kwargs = self._get_forward_method_kwargs() if len(args) == 0 and len(kwargs) == 0: - raise ValueError(f"Forward method args or kwargs or both must be provided") + raise ValueError( + f"Forward method args or kwargs or both must be provided" + ) forward_method_name = self._get_forward_method_name() diff --git a/tests/infra/op_tester.py b/tests/infra/op_tester.py index 8dbe32cb..229b8fbd 100644 --- a/tests/infra/op_tester.py +++ b/tests/infra/op_tester.py @@ -49,7 +49,8 @@ def test_with_random_inputs( TT device and CPU and comparing the results. """ inputs = [ - random_tensor(shape, minval=minval, maxval=maxval) for shape in input_shapes + random_tensor(shape, minval=minval, maxval=maxval) + for shape in input_shapes ] workload = Workload(f, inputs) self.test(workload) diff --git a/tests/jax/graphs/test_MLP_regression.py b/tests/jax/graphs/test_MLP_regression.py index 0a23054c..fcea436e 100644 --- a/tests/jax/graphs/test_MLP_regression.py +++ b/tests/jax/graphs/test_MLP_regression.py @@ -23,7 +23,9 @@ def comparison_config() -> ComparisonConfig: [(784, 64), (32, 64), (64, 10), (32, 10), (32, 784), (32, 10)] ], # 32 samples, 784 features (28x28), 10 output classes ) -def test_nn_with_relu(W1, b1, W2, b2, X, y, comparison_config: ComparisonConfig): +def test_nn_with_relu( + W1, b1, W2, b2, X, y, comparison_config: ComparisonConfig +): def simple_nn(W1, b1, W2, b2, X, y): def forward(W1, b1, W2, b2, X): hidden = jnp.dot(X, W1) + b1 diff --git a/tests/jax/graphs/test_linear_transformation.py b/tests/jax/graphs/test_linear_transformation.py index 8208eaab..5500b82e 100644 --- a/tests/jax/graphs/test_linear_transformation.py +++ b/tests/jax/graphs/test_linear_transformation.py @@ -19,8 +19,12 @@ [(64, 32), (32, 64), (1, 64)], ], ) -def test_linear_transformation(x_shape: tuple, y_shape: tuple, bias_shape: tuple): - def linear_transformation(x: jax.Array, y: jax.Array, bias: jax.Array) -> jax.Array: +def test_linear_transformation( + x_shape: tuple, y_shape: tuple, bias_shape: tuple +): + def linear_transformation( + x: jax.Array, y: jax.Array, bias: jax.Array + ) -> jax.Array: return jnp.matmul(x, y) + bias run_graph_test_with_random_inputs( diff --git a/tests/jax/graphs/test_simple_regression.py b/tests/jax/graphs/test_simple_regression.py index 57512c97..4b9313ee 100644 --- a/tests/jax/graphs/test_simple_regression.py +++ b/tests/jax/graphs/test_simple_regression.py @@ -15,7 +15,9 @@ def test_simple_regression(weights, bias, X, y): def simple_regression(weights, bias, X, y): def loss(weights, bias, X, y): - predict = X.dot(weights) + bias if bias is not None else X.dot(weights) + predict = ( + X.dot(weights) + bias if bias is not None else X.dot(weights) + ) return ((predict - y) ** 2).sum() # Compute gradient and update weights. diff --git a/tests/jax/models/albert/v2/tester.py b/tests/jax/models/albert/v2/tester.py index 6069a920..5cef407b 100644 --- a/tests/jax/models/albert/v2/tester.py +++ b/tests/jax/models/albert/v2/tester.py @@ -6,7 +6,11 @@ import jax from infra import ComparisonConfig, ModelTester, RunMode -from transformers import AutoTokenizer, FlaxAlbertForMaskedLM, FlaxPreTrainedModel +from transformers import ( + AutoTokenizer, + FlaxAlbertForMaskedLM, + FlaxPreTrainedModel, +) class AlbertV2Tester(ModelTester): diff --git a/tests/jax/models/beit/base/test_beit_base.py b/tests/jax/models/beit/base/test_beit_base.py index 8c82289d..9014e0b7 100644 --- a/tests/jax/models/beit/base/test_beit_base.py +++ b/tests/jax/models/beit/base/test_beit_base.py @@ -31,7 +31,9 @@ def training_tester() -> FlaxBeitForImageClassificationTester: # ----- Tests ----- -@pytest.mark.xfail(reason=compile_fail("failed to legalize operation 'ttir.gather'")) +@pytest.mark.xfail( + reason=compile_fail("failed to legalize operation 'ttir.gather'") +) def test_flax_beit_base_inference( inference_tester: FlaxBeitForImageClassificationTester, record_tt_xla_property: Callable, diff --git a/tests/jax/models/beit/large/test_beit_large.py b/tests/jax/models/beit/large/test_beit_large.py index 8efe8271..2b1efca4 100644 --- a/tests/jax/models/beit/large/test_beit_large.py +++ b/tests/jax/models/beit/large/test_beit_large.py @@ -30,7 +30,9 @@ def training_tester() -> FlaxBeitForImageClassificationTester: # ----- Tests ----- -@pytest.mark.xfail(reason=compile_fail("failed to legalize operation 'ttir.gather'")) +@pytest.mark.xfail( + reason=compile_fail("failed to legalize operation 'ttir.gather'") +) def test_flax_beit_large_inference( inference_tester: FlaxBeitForImageClassificationTester, record_tt_xla_property: Callable, diff --git a/tests/jax/models/bloom/tester.py b/tests/jax/models/bloom/tester.py index 75f295bd..bace9160 100644 --- a/tests/jax/models/bloom/tester.py +++ b/tests/jax/models/bloom/tester.py @@ -6,7 +6,11 @@ import jax from infra import ComparisonConfig, ModelTester, RunMode -from transformers import AutoTokenizer, FlaxBloomForCausalLM, FlaxPreTrainedModel +from transformers import ( + AutoTokenizer, + FlaxBloomForCausalLM, + FlaxPreTrainedModel, +) class BloomTester(ModelTester): @@ -23,7 +27,9 @@ def __init__( # @override def _get_model(self) -> FlaxPreTrainedModel: - return FlaxBloomForCausalLM.from_pretrained(self._model_name, from_pt=True) + return FlaxBloomForCausalLM.from_pretrained( + self._model_name, from_pt=True + ) # @override def _get_input_activations(self) -> Sequence[jax.Array]: diff --git a/tests/jax/models/distilbert/test_distilbert.py b/tests/jax/models/distilbert/test_distilbert.py index 7d74317c..95ca004a 100644 --- a/tests/jax/models/distilbert/test_distilbert.py +++ b/tests/jax/models/distilbert/test_distilbert.py @@ -7,7 +7,11 @@ import jax import pytest from infra import ModelTester, RunMode -from transformers import AutoTokenizer, FlaxDistilBertForMaskedLM, FlaxPreTrainedModel +from transformers import ( + AutoTokenizer, + FlaxDistilBertForMaskedLM, + FlaxPreTrainedModel, +) from utils import record_model_test_properties, runtime_fail MODEL_PATH = "distilbert/distilbert-base-uncased" diff --git a/tests/jax/models/example_model/mixed_args_and_kwargs/test_example_model_mixed_args_and_kwargs.py b/tests/jax/models/example_model/mixed_args_and_kwargs/test_example_model_mixed_args_and_kwargs.py index 784bc7d0..e748669d 100644 --- a/tests/jax/models/example_model/mixed_args_and_kwargs/test_example_model_mixed_args_and_kwargs.py +++ b/tests/jax/models/example_model/mixed_args_and_kwargs/test_example_model_mixed_args_and_kwargs.py @@ -89,5 +89,7 @@ def test_example_model_inference( @pytest.mark.push @pytest.mark.nightly @pytest.mark.skip(reason="Support for training not implemented") -def test_example_model_training(training_tester: ExampleModelMixedArgsAndKwargsTester): +def test_example_model_training( + training_tester: ExampleModelMixedArgsAndKwargsTester, +): training_tester.test() diff --git a/tests/jax/models/example_model/model.py b/tests/jax/models/example_model/model.py index f3a4f334..9bb33a30 100644 --- a/tests/jax/models/example_model/model.py +++ b/tests/jax/models/example_model/model.py @@ -23,7 +23,12 @@ def __init__(self) -> None: self.b1 = random_tensor(b1_shape, minval=-0.01, maxval=0.01) def __call__( - self, act: jax.Array, w0: jax.Array, b0: jax.Array, w1: jax.Array, b1: jax.Array + self, + act: jax.Array, + w0: jax.Array, + b0: jax.Array, + w1: jax.Array, + b1: jax.Array, ) -> jax.Array: # Note how activations, weights and biases are directly passed to the forward # method as inputs, `self` is not accessed. Otherwise they would be embedded diff --git a/tests/jax/models/example_model/only_kwargs/test_example_model_only_kwargs.py b/tests/jax/models/example_model/only_kwargs/test_example_model_only_kwargs.py index 08b9af00..6b27f43a 100644 --- a/tests/jax/models/example_model/only_kwargs/test_example_model_only_kwargs.py +++ b/tests/jax/models/example_model/only_kwargs/test_example_model_only_kwargs.py @@ -75,7 +75,9 @@ def training_tester() -> ExampleModelOnlyKwargsTester: @pytest.mark.push @pytest.mark.nightly -def test_example_model_inference(inference_tester: ExampleModelOnlyKwargsTester): +def test_example_model_inference( + inference_tester: ExampleModelOnlyKwargsTester, +): inference_tester.test() diff --git a/tests/jax/models/llama/tester.py b/tests/jax/models/llama/tester.py index e194cf3f..7ebbaada 100644 --- a/tests/jax/models/llama/tester.py +++ b/tests/jax/models/llama/tester.py @@ -6,7 +6,11 @@ import jax from infra import ComparisonConfig, ModelTester, RunMode -from transformers import FlaxLlamaForCausalLM, FlaxPreTrainedModel, LlamaTokenizer +from transformers import ( + FlaxLlamaForCausalLM, + FlaxPreTrainedModel, + LlamaTokenizer, +) class LLamaTester(ModelTester): @@ -25,7 +29,9 @@ def __init__( # @override def _get_model(self) -> FlaxPreTrainedModel: - return FlaxLlamaForCausalLM.from_pretrained(self._model_name, from_pt=True) + return FlaxLlamaForCausalLM.from_pretrained( + self._model_name, from_pt=True + ) # @override def _get_input_activations(self) -> Sequence[jax.Array]: diff --git a/tests/jax/models/mlpmixer/model_implementation.py b/tests/jax/models/mlpmixer/model_implementation.py index 03679e4b..d5201cf7 100644 --- a/tests/jax/models/mlpmixer/model_implementation.py +++ b/tests/jax/models/mlpmixer/model_implementation.py @@ -62,7 +62,10 @@ class MlpMixer(nn.Module): @nn.compact def __call__(self, inputs: jax.Array) -> jax.Array: x = nn.Conv( - self.hidden_dim, self.patches.size, strides=self.patches.size, name="stem" + self.hidden_dim, + self.patches.size, + strides=self.patches.size, + name="stem", )( inputs ) # Patch embedding diff --git a/tests/jax/models/mlpmixer/test_mlpmixer.py b/tests/jax/models/mlpmixer/test_mlpmixer.py index 9b7fb491..4c431022 100644 --- a/tests/jax/models/mlpmixer/test_mlpmixer.py +++ b/tests/jax/models/mlpmixer/test_mlpmixer.py @@ -44,7 +44,9 @@ def _get_model(self) -> nn.Module: def _retrieve_pretrained_weights() -> Dict: # TODO(stefan): Discuss how weights should be handled org wide link = "https://storage.googleapis.com/mixer_models/imagenet21k/Mixer-B_16.npz" - with fsspec.open("filecache::" + link, cache_storage="/tmp/files/") as f: + with fsspec.open( + "filecache::" + link, cache_storage="/tmp/files/" + ) as f: weights = numpy.load(f, encoding="bytes") state_dict = {k: v for k, v in weights.items()} pytree = flax.traverse_util.unflatten_dict(state_dict, sep="/") diff --git a/tests/jax/models/mnist/mlp/test_mnist_mlp.py b/tests/jax/models/mnist/mlp/test_mnist_mlp.py index 11bf4f8a..65bca0fa 100644 --- a/tests/jax/models/mnist/mlp/test_mnist_mlp.py +++ b/tests/jax/models/mnist/mlp/test_mnist_mlp.py @@ -96,6 +96,8 @@ def test_mnist_mlp_training( training_tester: MNISTMLPTester, record_tt_xla_property: Callable, ): - record_model_test_properties(record_tt_xla_property, MNISTMLPModel.__qualname__) + record_model_test_properties( + record_tt_xla_property, MNISTMLPModel.__qualname__ + ) training_tester.test() diff --git a/tests/jax/models/roberta/tester.py b/tests/jax/models/roberta/tester.py index efda8c61..67eddd82 100644 --- a/tests/jax/models/roberta/tester.py +++ b/tests/jax/models/roberta/tester.py @@ -6,7 +6,11 @@ import jax from infra import ComparisonConfig, ModelTester, RunMode -from transformers import AutoTokenizer, FlaxPreTrainedModel, FlaxRobertaForMaskedLM +from transformers import ( + AutoTokenizer, + FlaxPreTrainedModel, + FlaxRobertaForMaskedLM, +) class FlaxRobertaForMaskedLMTester(ModelTester): diff --git a/tests/jax/models/squeezebert/model_implementation.py b/tests/jax/models/squeezebert/model_implementation.py index 9ff83e5f..4062ab80 100644 --- a/tests/jax/models/squeezebert/model_implementation.py +++ b/tests/jax/models/squeezebert/model_implementation.py @@ -50,7 +50,9 @@ def __call__( position_embeddings = self.position_embedding(position_ids) token_type_embeddings = self.token_type_embedding(token_type_ids) - embeddings = word_embeddings + position_embeddings + token_type_embeddings + embeddings = ( + word_embeddings + position_embeddings + token_type_embeddings + ) embeddings = self.layernorm(embeddings) embeddings = self.dropout(embeddings, deterministic=deterministic) return embeddings @@ -83,7 +85,9 @@ def setup(self): feature_group_count=self.config.post_attention_groups, ) - self.attn_dropout = nn.Dropout(rate=self.config.attention_probs_dropout_prob) + self.attn_dropout = nn.Dropout( + rate=self.config.attention_probs_dropout_prob + ) self.resid_dropout = nn.Dropout(rate=self.config.hidden_dropout_prob) self.layernorm = nn.LayerNorm() @@ -127,7 +131,9 @@ def __call__( attention_probs, deterministic=deterministic ) - context = jnp.einsum("B H s S, B S H d -> B s H d", attention_probs, value) + context = jnp.einsum( + "B H s S, B S H d -> B s H d", attention_probs, value + ) context = einops.rearrange(context, "b s H d -> b s (H d)") output = self.output(context) @@ -203,7 +209,8 @@ class SqueezeBertEncoder(nn.Module): def setup(self): self.layers = [ - SqueezeBertLayer(self.config) for _ in range(self.config.num_hidden_layers) + SqueezeBertLayer(self.config) + for _ in range(self.config.num_hidden_layers) ] def __call__( @@ -308,7 +315,9 @@ def __call__( return prediction_scores @staticmethod - def init_from_pytorch_statedict(state_dict: Dict[str, Any]) -> Dict[str, Any]: + def init_from_pytorch_statedict( + state_dict: Dict[str, Any] + ) -> Dict[str, Any]: # Key substitutions for remapping huggingface checkpoints to this implementation PATTERNS = [ ("transformer.", "squeezebert."), @@ -328,7 +337,10 @@ def init_from_pytorch_statedict(state_dict: Dict[str, Any]) -> Dict[str, Any]: ("output.conv1d.bias", "mlp.w2.bias"), ("output.layernorm", "mlp.layernorm"), ("pooler.dense.weight", "pooler.dense.kernel"), - ("cls.predictions.transform.dense.weight", "transform_dense.kernel"), + ( + "cls.predictions.transform.dense.weight", + "transform_dense.kernel", + ), ("cls.predictions.transform.dense.bias", "transform_dense.bias"), ("cls.predictions.transform.layernorm", "transform_layernorm"), ("cls.predictions.decoder.weight", "decoder.kernel"), @@ -356,7 +368,9 @@ def process_value(k: str, v) -> jnp.ndarray: state_dict[k] = jnp.array(v) state_dict = { - rewrite_key(k): v for k, v in state_dict.items() if not is_banned_key(k) + rewrite_key(k): v + for k, v in state_dict.items() + if not is_banned_key(k) } state_dict = {k: process_value(k, v) for k, v in state_dict.items()} state_dict = flax.traverse_util.unflatten_dict(state_dict, sep=".") diff --git a/tests/jax/models/squeezebert/test_squeezebert.py b/tests/jax/models/squeezebert/test_squeezebert.py index 68945d02..023e9a19 100644 --- a/tests/jax/models/squeezebert/test_squeezebert.py +++ b/tests/jax/models/squeezebert/test_squeezebert.py @@ -42,7 +42,8 @@ def _get_input_activations(self) -> Sequence[jax.Array]: # @override def _get_forward_method_kwargs(self) -> Dict[str, jax.Array]: model_file = hf_hub_download( - repo_id="squeezebert/squeezebert-uncased", filename="pytorch_model.bin" + repo_id="squeezebert/squeezebert-uncased", + filename="pytorch_model.bin", ) state_dict = torch.load(model_file, weights_only=True) diff --git a/tests/jax/ops/test_abs.py b/tests/jax/ops/test_abs.py index d5774e14..bfd7225f 100644 --- a/tests/jax/ops/test_abs.py +++ b/tests/jax/ops/test_abs.py @@ -13,7 +13,9 @@ @pytest.mark.push @pytest.mark.nightly -@pytest.mark.parametrize("x_shape", [(32, 32), (64, 64)], ids=lambda val: f"{val}") +@pytest.mark.parametrize( + "x_shape", [(32, 32), (64, 64)], ids=lambda val: f"{val}" +) def test_abs(x_shape: tuple, record_tt_xla_property: Callable): def abs(x: jax.Array) -> jax.Array: return jnp.abs(x) diff --git a/tests/jax/ops/test_broadcast_in_dim.py b/tests/jax/ops/test_broadcast_in_dim.py index 8adafc60..103df234 100644 --- a/tests/jax/ops/test_broadcast_in_dim.py +++ b/tests/jax/ops/test_broadcast_in_dim.py @@ -14,7 +14,9 @@ @pytest.mark.push @pytest.mark.nightly @pytest.mark.parametrize("input_shapes", [[(2, 1)]], ids=lambda val: f"{val}") -def test_broadcast_in_dim(input_shapes: tuple, record_tt_xla_property: Callable): +def test_broadcast_in_dim( + input_shapes: tuple, record_tt_xla_property: Callable +): def broadcast(a: jax.Array): return jnp.broadcast_to(a, (2, 4)) diff --git a/tests/jax/ops/test_cbrt.py b/tests/jax/ops/test_cbrt.py index 35ac43ec..e01b28d4 100644 --- a/tests/jax/ops/test_cbrt.py +++ b/tests/jax/ops/test_cbrt.py @@ -13,7 +13,9 @@ @pytest.mark.push @pytest.mark.nightly -@pytest.mark.parametrize("x_shape", [(32, 32), (64, 64)], ids=lambda val: f"{val}") +@pytest.mark.parametrize( + "x_shape", [(32, 32), (64, 64)], ids=lambda val: f"{val}" +) def test_cbrt(x_shape: tuple, record_tt_xla_property: Callable): def cbrt(x: jax.Array) -> jax.Array: return jnp.cbrt(x) diff --git a/tests/jax/ops/test_compare.py b/tests/jax/ops/test_compare.py index 4cb34615..ca068903 100644 --- a/tests/jax/ops/test_compare.py +++ b/tests/jax/ops/test_compare.py @@ -148,7 +148,9 @@ def greater_equal(x: jax.Array, y: jax.Array) -> jax.Array: ], ids=lambda val: f"{val}", ) -def test_compare_less(x_shape: tuple, y_shape: tuple, record_tt_xla_property: Callable): +def test_compare_less( + x_shape: tuple, y_shape: tuple, record_tt_xla_property: Callable +): @convert_output_to_bfloat16 def less(x: jax.Array, y: jax.Array) -> jax.Array: return x < y diff --git a/tests/jax/ops/test_constant.py b/tests/jax/ops/test_constant.py index 1525691f..b79d8ef2 100644 --- a/tests/jax/ops/test_constant.py +++ b/tests/jax/ops/test_constant.py @@ -46,7 +46,9 @@ def module_constant_ones(): @pytest.mark.push @pytest.mark.nightly -@pytest.mark.xfail(reason=compile_fail("failed to legalize operation 'ttir.constant'")) +@pytest.mark.xfail( + reason=compile_fail("failed to legalize operation 'ttir.constant'") +) def test_constant_multi_value(record_tt_xla_property: Callable): def module_constant_multi(): return jnp.array([[1, 2], [3, 4]], dtype=jnp.float32) diff --git a/tests/jax/ops/test_convert.py b/tests/jax/ops/test_convert.py index 0521cc99..2873ebb3 100644 --- a/tests/jax/ops/test_convert.py +++ b/tests/jax/ops/test_convert.py @@ -31,7 +31,9 @@ def conditionally_skip(from_dtype: DTypeLike, to_dtype: DTypeLike): # If the input tensor is deallocated, the output tensor will lose access # to valid data and may contain garbage. # See issue #248 for more details. - if from_dtype == to_dtype or (from_dtype == jnp.uint32 and to_dtype == jnp.uint64): + if from_dtype == to_dtype or ( + from_dtype == jnp.uint32 and to_dtype == jnp.uint64 + ): pytest.xfail( runtime_fail( "Atol comparison failed. Calculated: atol=65535.0. Required: atol=0.16." diff --git a/tests/jax/ops/test_convolution.py b/tests/jax/ops/test_convolution.py index 1de7b3b8..f147c9da 100644 --- a/tests/jax/ops/test_convolution.py +++ b/tests/jax/ops/test_convolution.py @@ -141,7 +141,12 @@ def conv2d(img: jax.Array, kernel: jax.Array): ) img_shape = (batch_size, input_height, input_width, input_channels) - kernel_shape = (output_channels, input_channels, filter_height, filter_width) + kernel_shape = ( + output_channels, + input_channels, + filter_height, + filter_width, + ) # NOTE Some resnet convolutions seem to require bfloat16, ttnn throws in runtime # otherwise. On another note, MaxPool2d is also only supported for bfloat16 in ttnn, diff --git a/tests/jax/ops/test_divide.py b/tests/jax/ops/test_divide.py index ec0a1a7f..51f3d7f0 100644 --- a/tests/jax/ops/test_divide.py +++ b/tests/jax/ops/test_divide.py @@ -21,7 +21,9 @@ ], ids=lambda val: f"{val}", ) -def test_divide(x_shape: tuple, y_shape: tuple, record_tt_xla_property: Callable): +def test_divide( + x_shape: tuple, y_shape: tuple, record_tt_xla_property: Callable +): def divide(x: jax.Array, y: jax.Array) -> jax.Array: return jnp.divide(x, y) diff --git a/tests/jax/ops/test_dot_general.py b/tests/jax/ops/test_dot_general.py index e1880a1d..9ff93005 100644 --- a/tests/jax/ops/test_dot_general.py +++ b/tests/jax/ops/test_dot_general.py @@ -76,7 +76,9 @@ def test_dot_general_multiple_contract( x_shape: tuple, y_shape: tuple, record_tt_xla_property: Callable ): def dot_general(x: jax.Array, y: jax.Array) -> jax.Array: - return jax.lax.dot_general(x, y, dimension_numbers=(((1, 3), (1, 2)), (0, 0))) + return jax.lax.dot_general( + x, y, dimension_numbers=(((1, 3), (1, 2)), (0, 0)) + ) record_binary_op_test_properties( record_tt_xla_property, "jax.lax.dot_general", "stablehlo.dot_general" diff --git a/tests/jax/ops/test_exponential.py b/tests/jax/ops/test_exponential.py index 38b2a27d..f78da2e0 100644 --- a/tests/jax/ops/test_exponential.py +++ b/tests/jax/ops/test_exponential.py @@ -13,7 +13,9 @@ @pytest.mark.push @pytest.mark.nightly -@pytest.mark.parametrize("x_shape", [(32, 32), (64, 64)], ids=lambda val: f"{val}") +@pytest.mark.parametrize( + "x_shape", [(32, 32), (64, 64)], ids=lambda val: f"{val}" +) def test_exponential(x_shape: tuple, record_tt_xla_property: Callable): def exponential(x: jax.Array) -> jax.Array: return jnp.exp(x) diff --git a/tests/jax/ops/test_exponential_minus_one.py b/tests/jax/ops/test_exponential_minus_one.py index fb6f9fcd..a8824859 100644 --- a/tests/jax/ops/test_exponential_minus_one.py +++ b/tests/jax/ops/test_exponential_minus_one.py @@ -25,7 +25,9 @@ def comparison_config() -> ComparisonConfig: @pytest.mark.push @pytest.mark.nightly -@pytest.mark.parametrize("x_shape", [(32, 32), (64, 64)], ids=lambda val: f"{val}") +@pytest.mark.parametrize( + "x_shape", [(32, 32), (64, 64)], ids=lambda val: f"{val}" +) def test_exponential_minus_one( x_shape: tuple, comparison_config: ComparisonConfig, diff --git a/tests/jax/ops/test_log_plus_one.py b/tests/jax/ops/test_log_plus_one.py index d0ed20cc..6cdafa53 100644 --- a/tests/jax/ops/test_log_plus_one.py +++ b/tests/jax/ops/test_log_plus_one.py @@ -13,7 +13,9 @@ @pytest.mark.push @pytest.mark.nightly -@pytest.mark.parametrize("x_shape", [(32, 32), (64, 64)], ids=lambda val: f"{val}") +@pytest.mark.parametrize( + "x_shape", [(32, 32), (64, 64)], ids=lambda val: f"{val}" +) def test_log1p(x_shape: tuple, record_tt_xla_property: Callable): def log1p(x: jax.Array) -> jax.Array: return jnp.log1p(x) diff --git a/tests/jax/ops/test_maximum.py b/tests/jax/ops/test_maximum.py index 868fc820..82ec3182 100644 --- a/tests/jax/ops/test_maximum.py +++ b/tests/jax/ops/test_maximum.py @@ -21,7 +21,9 @@ ], ids=lambda val: f"{val}", ) -def test_maximum(x_shape: tuple, y_shape: tuple, record_tt_xla_property: Callable): +def test_maximum( + x_shape: tuple, y_shape: tuple, record_tt_xla_property: Callable +): def maximum(x: jax.Array, y: jax.Array) -> jax.Array: return jnp.maximum(x, y) diff --git a/tests/jax/ops/test_minimum.py b/tests/jax/ops/test_minimum.py index 1b652e7f..6542e7c3 100644 --- a/tests/jax/ops/test_minimum.py +++ b/tests/jax/ops/test_minimum.py @@ -21,7 +21,9 @@ ], ids=lambda val: f"{val}", ) -def test_minimum(x_shape: tuple, y_shape: tuple, record_tt_xla_property: Callable): +def test_minimum( + x_shape: tuple, y_shape: tuple, record_tt_xla_property: Callable +): def minimum(x: jax.Array, y: jax.Array) -> jax.Array: return jnp.minimum(x, y) diff --git a/tests/jax/ops/test_multiply.py b/tests/jax/ops/test_multiply.py index 52992134..48fcd551 100644 --- a/tests/jax/ops/test_multiply.py +++ b/tests/jax/ops/test_multiply.py @@ -21,7 +21,9 @@ ], ids=lambda val: f"{val}", ) -def test_multiply(x_shape: tuple, y_shape: tuple, record_tt_xla_property: Callable): +def test_multiply( + x_shape: tuple, y_shape: tuple, record_tt_xla_property: Callable +): def multiply(x: jax.Array, y: jax.Array) -> jax.Array: return jnp.multiply(x, y) diff --git a/tests/jax/ops/test_negate.py b/tests/jax/ops/test_negate.py index 00ea3499..ca55eafd 100644 --- a/tests/jax/ops/test_negate.py +++ b/tests/jax/ops/test_negate.py @@ -13,7 +13,9 @@ @pytest.mark.push @pytest.mark.nightly -@pytest.mark.parametrize("x_shape", [(32, 32), (64, 64)], ids=lambda val: f"{val}") +@pytest.mark.parametrize( + "x_shape", [(32, 32), (64, 64)], ids=lambda val: f"{val}" +) def test_negate(x_shape: tuple, record_tt_xla_property: Callable): def negate(x: jax.Array) -> jax.Array: return jnp.negative(x) diff --git a/tests/jax/ops/test_reduce.py b/tests/jax/ops/test_reduce.py index db1fa636..0d0fecfa 100644 --- a/tests/jax/ops/test_reduce.py +++ b/tests/jax/ops/test_reduce.py @@ -24,7 +24,9 @@ def comparison_config() -> ComparisonConfig: # TODO axis should be parametrized as well. @pytest.mark.push @pytest.mark.nightly -@pytest.mark.parametrize("x_shape", [(32, 32), (64, 64)], ids=lambda val: f"{val}") +@pytest.mark.parametrize( + "x_shape", [(32, 32), (64, 64)], ids=lambda val: f"{val}" +) def test_reduce_sum( x_shape: tuple, comparison_config: ComparisonConfig, @@ -48,7 +50,9 @@ def reduce_sum(x: jax.Array) -> jax.Array: # TODO axis should be parametrized as well. @pytest.mark.push @pytest.mark.nightly -@pytest.mark.parametrize("x_shape", [(32, 32), (64, 64)], ids=lambda val: f"{val}") +@pytest.mark.parametrize( + "x_shape", [(32, 32), (64, 64)], ids=lambda val: f"{val}" +) def test_reduce_max( x_shape: tuple, comparison_config: ComparisonConfig, diff --git a/tests/jax/ops/test_remainder.py b/tests/jax/ops/test_remainder.py index 1f1acef2..b28eabbf 100644 --- a/tests/jax/ops/test_remainder.py +++ b/tests/jax/ops/test_remainder.py @@ -21,7 +21,9 @@ ], ids=lambda val: f"{val}", ) -def test_remainder(x_shape: tuple, y_shape: tuple, record_tt_xla_property: Callable): +def test_remainder( + x_shape: tuple, y_shape: tuple, record_tt_xla_property: Callable +): def remainder(x: jax.Array, y: jax.Array) -> jax.Array: return jlx.rem(x, y) diff --git a/tests/jax/ops/test_reshape.py b/tests/jax/ops/test_reshape.py index 6a857749..834a52eb 100644 --- a/tests/jax/ops/test_reshape.py +++ b/tests/jax/ops/test_reshape.py @@ -22,7 +22,9 @@ ], ids=lambda val: f"{val}", ) -def test_reshape(in_shape: tuple, out_shape: tuple, record_tt_xla_property: Callable): +def test_reshape( + in_shape: tuple, out_shape: tuple, record_tt_xla_property: Callable +): def reshape(x: jax.Array): return jnp.reshape(x, out_shape) diff --git a/tests/jax/ops/test_rsqrt.py b/tests/jax/ops/test_rsqrt.py index fe37e460..056e14f7 100644 --- a/tests/jax/ops/test_rsqrt.py +++ b/tests/jax/ops/test_rsqrt.py @@ -13,7 +13,9 @@ @pytest.mark.push @pytest.mark.nightly -@pytest.mark.parametrize("x_shape", [(32, 32), (64, 64)], ids=lambda val: f"{val}") +@pytest.mark.parametrize( + "x_shape", [(32, 32), (64, 64)], ids=lambda val: f"{val}" +) def test_rsqrt(x_shape: tuple, record_tt_xla_property: Callable): def rsqrt(x: jax.Array) -> jax.Array: return jlx.rsqrt(x) diff --git a/tests/jax/ops/test_sign.py b/tests/jax/ops/test_sign.py index 83fd8e26..cd5b8b4e 100644 --- a/tests/jax/ops/test_sign.py +++ b/tests/jax/ops/test_sign.py @@ -13,7 +13,9 @@ @pytest.mark.push @pytest.mark.nightly -@pytest.mark.parametrize("x_shape", [(32, 32), (64, 64)], ids=lambda val: f"{val}") +@pytest.mark.parametrize( + "x_shape", [(32, 32), (64, 64)], ids=lambda val: f"{val}" +) def test_sign(x_shape: tuple, record_tt_xla_property: Callable): def sign(x: jax.Array) -> jax.Array: return jnp.sign(x) diff --git a/tests/jax/ops/test_slice.py b/tests/jax/ops/test_slice.py index 7aa81d8e..9d544788 100644 --- a/tests/jax/ops/test_slice.py +++ b/tests/jax/ops/test_slice.py @@ -38,7 +38,9 @@ [*dim2_cases, *dim3_cases, *dim0_cases, *dim1_cases], ids=lambda val: f"{val}", ) -def test_slice(begin: int, end: int, dim: int, record_tt_xla_property: Callable): +def test_slice( + begin: int, end: int, dim: int, record_tt_xla_property: Callable +): def module_slice(a): if dim == 0: return a[begin:end, :, :, :] diff --git a/tests/jax/ops/test_sqrt.py b/tests/jax/ops/test_sqrt.py index 8f6a195f..7d9eb334 100644 --- a/tests/jax/ops/test_sqrt.py +++ b/tests/jax/ops/test_sqrt.py @@ -13,7 +13,9 @@ @pytest.mark.push @pytest.mark.nightly -@pytest.mark.parametrize("x_shape", [(32, 32), (64, 64)], ids=lambda val: f"{val}") +@pytest.mark.parametrize( + "x_shape", [(32, 32), (64, 64)], ids=lambda val: f"{val}" +) def test_sqrt(x_shape: tuple, record_tt_xla_property: Callable): def sqrt(x: jax.Array) -> jax.Array: return jnp.sqrt(x) diff --git a/tests/jax/ops/test_subtract.py b/tests/jax/ops/test_subtract.py index 20f79075..fed00076 100644 --- a/tests/jax/ops/test_subtract.py +++ b/tests/jax/ops/test_subtract.py @@ -21,7 +21,9 @@ ], ids=lambda val: f"{val}", ) -def test_subtract(x_shape: tuple, y_shape: tuple, record_tt_xla_property: Callable): +def test_subtract( + x_shape: tuple, y_shape: tuple, record_tt_xla_property: Callable +): def subtract(x: jax.Array, y: jax.Array) -> jax.Array: return jnp.subtract(x, y) diff --git a/tests/jax/ops/test_transpose.py b/tests/jax/ops/test_transpose.py index be99870a..af0d010a 100644 --- a/tests/jax/ops/test_transpose.py +++ b/tests/jax/ops/test_transpose.py @@ -13,7 +13,9 @@ @pytest.mark.push @pytest.mark.nightly -@pytest.mark.parametrize("x_shape", [(32, 32), (64, 64)], ids=lambda val: f"{val}") +@pytest.mark.parametrize( + "x_shape", [(32, 32), (64, 64)], ids=lambda val: f"{val}" +) def test_transpose(x_shape: tuple, record_tt_xla_property: Callable): def transpose(x: jax.Array) -> jax.Array: return jnp.transpose(x) diff --git a/tests/jax/test_device_initialization.py b/tests/jax/test_device_initialization.py index 5385ffff..e2ef66fd 100644 --- a/tests/jax/test_device_initialization.py +++ b/tests/jax/test_device_initialization.py @@ -58,7 +58,9 @@ def test_devices_are_connected(): tt_devices = jax.devices("tt") - assert len(tt_devices) > 0, f"Expected at least one TT device to be connected" + assert ( + len(tt_devices) > 0 + ), f"Expected at least one TT device to be connected" assert is_tt_device(tt_devices[0]) diff --git a/tests/jax/test_scalar_types.py b/tests/jax/test_scalar_types.py index f7b2e7f5..ff07c227 100644 --- a/tests/jax/test_scalar_types.py +++ b/tests/jax/test_scalar_types.py @@ -25,7 +25,9 @@ def add() -> jax.Array: @pytest.mark.push @pytest.mark.nightly -@pytest.mark.skip("Fails due to https://github.com/tenstorrent/tt-metal/issues/16701") +@pytest.mark.skip( + "Fails due to https://github.com/tenstorrent/tt-metal/issues/16701" +) def test_scalar_array_add(): """ Tests adding scalar and an array. diff --git a/tests/utils.py b/tests/utils.py index 73ea98fb..26ff54f1 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -34,7 +34,10 @@ def record_binary_op_test_properties( def record_op_test_properties( - record_property: Callable, op_kind: str, framework_op_name: str, op_name: str + record_property: Callable, + op_kind: str, + framework_op_name: str, + op_name: str, ): record_property(RecordProperties.OP_KIND.value, op_kind) record_property(RecordProperties.FRAMEWORK_OP_NAME.value, framework_op_name)