Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Setting the max line length of python files to 80 #278

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ repos:
hooks:
- id: black
language_version: python3
args: [--line-length=80]
- repo: https://github.com/pre-commit/mirrors-clang-format
rev: v18.1.7
hooks:
Expand Down
16 changes: 12 additions & 4 deletions examples/simple_regression.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,20 @@
# program will execute on tt device if not specified otherwise.
def initialize():
backend = "tt"
path = os.path.join(os.path.dirname(__file__), "../build/src/tt/pjrt_plugin_tt.so")
path = os.path.join(
os.path.dirname(__file__), "../build/src/tt/pjrt_plugin_tt.so"
)
if not os.path.exists(path):
raise FileNotFoundError(f"Could not find tt_pjrt C API plugin at {path}")
raise FileNotFoundError(
f"Could not find tt_pjrt C API plugin at {path}"
)

print("Loading tt_pjrt C API plugin", file=sys.stderr)
xb.discover_pjrt_plugins()

plugin = xb.register_plugin("tt", priority=500, library_path=path, options=None)
plugin = xb.register_plugin(
"tt", priority=500, library_path=path, options=None
)
print("Loaded", file=sys.stderr)
jax.config.update("jax_platforms", "tt,cpu")

Expand All @@ -33,7 +39,9 @@ def random_input_tensor(shape, key=42, on_device=False):
def random_input(shape, key):
return jax.random.uniform(jax.random.PRNGKey(key), shape=shape)

jitted_tensor_creator = jax.jit(random_input, static_argnums=[0, 1], backend="cpu")
jitted_tensor_creator = jax.jit(
random_input, static_argnums=[0, 1], backend="cpu"
)
tensor = jitted_tensor_creator(shape, key)
if on_device:
tensor = jax.device_put(tensor, jax.devices()[0])
Expand Down
12 changes: 9 additions & 3 deletions tests/infra/base_tester.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,9 +53,13 @@ def _compare(
if self._comparison_config.equal.enabled:
compare_equal(device_output, golden_output)
if self._comparison_config.atol.enabled:
compare_atol(device_output, golden_output, self._comparison_config.atol)
compare_atol(
device_output, golden_output, self._comparison_config.atol
)
if self._comparison_config.pcc.enabled:
compare_pcc(device_output, golden_output, self._comparison_config.pcc)
compare_pcc(
device_output, golden_output, self._comparison_config.pcc
)
if self._comparison_config.allclose.enabled:
compare_allclose(
device_output, golden_output, self._comparison_config.allclose
Expand All @@ -68,6 +72,8 @@ def _match_data_types(self, *tensors: Tensor) -> Sequence[Tensor]:
Tensors need to be in same data format in order to compare them.
"""
return [
tensor.astype("float32") if tensor.dtype.str != "float32" else tensor
tensor.astype("float32")
if tensor.dtype.str != "float32"
else tensor
for tensor in tensors
]
4 changes: 3 additions & 1 deletion tests/infra/comparison.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,9 @@ def compare_pcc(

@run_on_cpu
def compare_allclose(
device_output: Tensor, golden_output: Tensor, allclose_config: AllcloseConfig
device_output: Tensor,
golden_output: Tensor,
allclose_config: AllcloseConfig,
) -> None:
assert isinstance(device_output, jax.Array) and isinstance(
golden_output, jax.Array
Expand Down
8 changes: 6 additions & 2 deletions tests/infra/device_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,9 @@ def _run_on_device(
workload: Workload, device_type: DeviceType, device_num: int = 0
) -> Tensor:
"""Runs `workload` on device identified by `device_type`."""
device_workload = DeviceRunner._put_on_device(workload, device_type, device_num)
device_workload = DeviceRunner._put_on_device(
workload, device_type, device_num
)
device = device_connector.connect_device(device_type, device_num)

with jax.default_device(device):
Expand Down Expand Up @@ -106,7 +108,9 @@ def _safely_put_workload_on_device(
To avoid that, we try to `jax.device_put` arg or kwarg, and if it doesn't
succeed, we leave it as is.
"""
fn_params = list(inspect.signature(workload.executable).parameters.keys())
fn_params = list(
inspect.signature(workload.executable).parameters.keys()
)

args_on_device = []
for i, arg in enumerate(workload.args):
Expand Down
4 changes: 3 additions & 1 deletion tests/infra/model_tester.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,9 @@ def _init_model_hooks(self) -> None:
kwargs = self._get_forward_method_kwargs()

if len(args) == 0 and len(kwargs) == 0:
raise ValueError(f"Forward method args or kwargs or both must be provided")
raise ValueError(
f"Forward method args or kwargs or both must be provided"
)

forward_method_name = self._get_forward_method_name()

Expand Down
3 changes: 2 additions & 1 deletion tests/infra/op_tester.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,8 @@ def test_with_random_inputs(
TT device and CPU and comparing the results.
"""
inputs = [
random_tensor(shape, minval=minval, maxval=maxval) for shape in input_shapes
random_tensor(shape, minval=minval, maxval=maxval)
for shape in input_shapes
]
workload = Workload(f, inputs)
self.test(workload)
Expand Down
4 changes: 3 additions & 1 deletion tests/jax/graphs/test_MLP_regression.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,9 @@ def comparison_config() -> ComparisonConfig:
[(784, 64), (32, 64), (64, 10), (32, 10), (32, 784), (32, 10)]
], # 32 samples, 784 features (28x28), 10 output classes
)
def test_nn_with_relu(W1, b1, W2, b2, X, y, comparison_config: ComparisonConfig):
def test_nn_with_relu(
W1, b1, W2, b2, X, y, comparison_config: ComparisonConfig
):
def simple_nn(W1, b1, W2, b2, X, y):
def forward(W1, b1, W2, b2, X):
hidden = jnp.dot(X, W1) + b1
Expand Down
8 changes: 6 additions & 2 deletions tests/jax/graphs/test_linear_transformation.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,12 @@
[(64, 32), (32, 64), (1, 64)],
],
)
def test_linear_transformation(x_shape: tuple, y_shape: tuple, bias_shape: tuple):
def linear_transformation(x: jax.Array, y: jax.Array, bias: jax.Array) -> jax.Array:
def test_linear_transformation(
x_shape: tuple, y_shape: tuple, bias_shape: tuple
):
def linear_transformation(
x: jax.Array, y: jax.Array, bias: jax.Array
) -> jax.Array:
return jnp.matmul(x, y) + bias

run_graph_test_with_random_inputs(
Expand Down
4 changes: 3 additions & 1 deletion tests/jax/graphs/test_simple_regression.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,9 @@
def test_simple_regression(weights, bias, X, y):
def simple_regression(weights, bias, X, y):
def loss(weights, bias, X, y):
predict = X.dot(weights) + bias if bias is not None else X.dot(weights)
predict = (
X.dot(weights) + bias if bias is not None else X.dot(weights)
)
return ((predict - y) ** 2).sum()

# Compute gradient and update weights.
Expand Down
6 changes: 5 additions & 1 deletion tests/jax/models/albert/v2/tester.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,11 @@

import jax
from infra import ComparisonConfig, ModelTester, RunMode
from transformers import AutoTokenizer, FlaxAlbertForMaskedLM, FlaxPreTrainedModel
from transformers import (
AutoTokenizer,
FlaxAlbertForMaskedLM,
FlaxPreTrainedModel,
)


class AlbertV2Tester(ModelTester):
Expand Down
4 changes: 3 additions & 1 deletion tests/jax/models/beit/base/test_beit_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,9 @@ def training_tester() -> FlaxBeitForImageClassificationTester:
# ----- Tests -----


@pytest.mark.xfail(reason=compile_fail("failed to legalize operation 'ttir.gather'"))
@pytest.mark.xfail(
reason=compile_fail("failed to legalize operation 'ttir.gather'")
)
def test_flax_beit_base_inference(
inference_tester: FlaxBeitForImageClassificationTester,
record_tt_xla_property: Callable,
Expand Down
4 changes: 3 additions & 1 deletion tests/jax/models/beit/large/test_beit_large.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,9 @@ def training_tester() -> FlaxBeitForImageClassificationTester:
# ----- Tests -----


@pytest.mark.xfail(reason=compile_fail("failed to legalize operation 'ttir.gather'"))
@pytest.mark.xfail(
reason=compile_fail("failed to legalize operation 'ttir.gather'")
)
def test_flax_beit_large_inference(
inference_tester: FlaxBeitForImageClassificationTester,
record_tt_xla_property: Callable,
Expand Down
10 changes: 8 additions & 2 deletions tests/jax/models/bloom/tester.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,11 @@

import jax
from infra import ComparisonConfig, ModelTester, RunMode
from transformers import AutoTokenizer, FlaxBloomForCausalLM, FlaxPreTrainedModel
from transformers import (
AutoTokenizer,
FlaxBloomForCausalLM,
FlaxPreTrainedModel,
)


class BloomTester(ModelTester):
Expand All @@ -23,7 +27,9 @@ def __init__(

# @override
def _get_model(self) -> FlaxPreTrainedModel:
return FlaxBloomForCausalLM.from_pretrained(self._model_name, from_pt=True)
return FlaxBloomForCausalLM.from_pretrained(
self._model_name, from_pt=True
)

# @override
def _get_input_activations(self) -> Sequence[jax.Array]:
Expand Down
6 changes: 5 additions & 1 deletion tests/jax/models/distilbert/test_distilbert.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,11 @@
import jax
import pytest
from infra import ModelTester, RunMode
from transformers import AutoTokenizer, FlaxDistilBertForMaskedLM, FlaxPreTrainedModel
from transformers import (
AutoTokenizer,
FlaxDistilBertForMaskedLM,
FlaxPreTrainedModel,
)
from utils import record_model_test_properties, runtime_fail

MODEL_PATH = "distilbert/distilbert-base-uncased"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -89,5 +89,7 @@ def test_example_model_inference(
@pytest.mark.push
@pytest.mark.nightly
@pytest.mark.skip(reason="Support for training not implemented")
def test_example_model_training(training_tester: ExampleModelMixedArgsAndKwargsTester):
def test_example_model_training(
training_tester: ExampleModelMixedArgsAndKwargsTester,
):
training_tester.test()
7 changes: 6 additions & 1 deletion tests/jax/models/example_model/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,12 @@ def __init__(self) -> None:
self.b1 = random_tensor(b1_shape, minval=-0.01, maxval=0.01)

def __call__(
self, act: jax.Array, w0: jax.Array, b0: jax.Array, w1: jax.Array, b1: jax.Array
self,
act: jax.Array,
w0: jax.Array,
b0: jax.Array,
w1: jax.Array,
b1: jax.Array,
) -> jax.Array:
# Note how activations, weights and biases are directly passed to the forward
# method as inputs, `self` is not accessed. Otherwise they would be embedded
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,9 @@ def training_tester() -> ExampleModelOnlyKwargsTester:

@pytest.mark.push
@pytest.mark.nightly
def test_example_model_inference(inference_tester: ExampleModelOnlyKwargsTester):
def test_example_model_inference(
inference_tester: ExampleModelOnlyKwargsTester,
):
inference_tester.test()


Expand Down
10 changes: 8 additions & 2 deletions tests/jax/models/llama/tester.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,11 @@

import jax
from infra import ComparisonConfig, ModelTester, RunMode
from transformers import FlaxLlamaForCausalLM, FlaxPreTrainedModel, LlamaTokenizer
from transformers import (
FlaxLlamaForCausalLM,
FlaxPreTrainedModel,
LlamaTokenizer,
)


class LLamaTester(ModelTester):
Expand All @@ -25,7 +29,9 @@ def __init__(

# @override
def _get_model(self) -> FlaxPreTrainedModel:
return FlaxLlamaForCausalLM.from_pretrained(self._model_name, from_pt=True)
return FlaxLlamaForCausalLM.from_pretrained(
self._model_name, from_pt=True
)

# @override
def _get_input_activations(self) -> Sequence[jax.Array]:
Expand Down
5 changes: 4 additions & 1 deletion tests/jax/models/mlpmixer/model_implementation.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,10 @@ class MlpMixer(nn.Module):
@nn.compact
def __call__(self, inputs: jax.Array) -> jax.Array:
x = nn.Conv(
self.hidden_dim, self.patches.size, strides=self.patches.size, name="stem"
self.hidden_dim,
self.patches.size,
strides=self.patches.size,
name="stem",
)(
inputs
) # Patch embedding
Expand Down
4 changes: 3 additions & 1 deletion tests/jax/models/mlpmixer/test_mlpmixer.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,9 @@ def _get_model(self) -> nn.Module:
def _retrieve_pretrained_weights() -> Dict:
# TODO(stefan): Discuss how weights should be handled org wide
link = "https://storage.googleapis.com/mixer_models/imagenet21k/Mixer-B_16.npz"
with fsspec.open("filecache::" + link, cache_storage="/tmp/files/") as f:
with fsspec.open(
"filecache::" + link, cache_storage="/tmp/files/"
) as f:
weights = numpy.load(f, encoding="bytes")
state_dict = {k: v for k, v in weights.items()}
pytree = flax.traverse_util.unflatten_dict(state_dict, sep="/")
Expand Down
4 changes: 3 additions & 1 deletion tests/jax/models/mnist/mlp/test_mnist_mlp.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,8 @@ def test_mnist_mlp_training(
training_tester: MNISTMLPTester,
record_tt_xla_property: Callable,
):
record_model_test_properties(record_tt_xla_property, MNISTMLPModel.__qualname__)
record_model_test_properties(
record_tt_xla_property, MNISTMLPModel.__qualname__
)

training_tester.test()
6 changes: 5 additions & 1 deletion tests/jax/models/roberta/tester.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,11 @@

import jax
from infra import ComparisonConfig, ModelTester, RunMode
from transformers import AutoTokenizer, FlaxPreTrainedModel, FlaxRobertaForMaskedLM
from transformers import (
AutoTokenizer,
FlaxPreTrainedModel,
FlaxRobertaForMaskedLM,
)


class FlaxRobertaForMaskedLMTester(ModelTester):
Expand Down
Loading
Loading