Skip to content

Commit

Permalink
Merge branch 'main' into umales/mnist_mse
Browse files Browse the repository at this point in the history
  • Loading branch information
umalesTT authored Jan 31, 2025
2 parents 9d4f831 + c1cb43a commit 423318c
Show file tree
Hide file tree
Showing 32 changed files with 369 additions and 79 deletions.
26 changes: 12 additions & 14 deletions .github/get-docker-tag.sh
Original file line number Diff line number Diff line change
Expand Up @@ -8,26 +8,24 @@

# Exit immediately if a command exits with a non-zero status
set -e

# Execute this in a separate bash process
(
# Read tt-mlir version from third_party/CMakeLists.txt and clone third_party/tt-mlir
MLIR_DOCKER_TAG=$(
# Read tt-mlir version from third_party/CMakeLists.txt
# clone tt-mlir version to tmp/third_party/tt-mlir
# Get the MLIR docker tag
TT_MLIR_PATH=tmp/third_party/tt-mlir
TT_MLIR_VERSION=$(grep -oP 'set\(TT_MLIR_VERSION "\K[^"]+' third_party/CMakeLists.txt)
if [ ! -d "third_party/tt-mlir" ]; then
git clone https://github.com/tenstorrent/tt-mlir.git third_party/tt-mlir --quiet
if [ ! -d $TT_MLIR_PATH ]; then
git clone https://github.com/tenstorrent/tt-mlir.git $TT_MLIR_PATH --quiet
fi
cd third_party/tt-mlir
cd $TT_MLIR_PATH
git fetch --quiet
git checkout $TT_MLIR_VERSION --quiet
if [ -f ".github/get-docker-tag.sh" ]; then
MLIR_DOCKER_TAG=$(.github/get-docker-tag.sh)
.github/get-docker-tag.sh
else
MLIR_DOCKER_TAG="default-tag"
echo "default-tag"
fi
cd ../..
)

DOCKERFILE_HASH_FILES=".github/Dockerfile.base .github/Dockerfile.ci"
DOCKERFILE_HASH=$( (echo $MLIR_DOCKER_TAG; sha256sum $DOCKERFILE_HASH_FILES) | sha256sum | cut -d ' ' -f 1)
echo dt-$DOCKERFILE_HASH
DOCKERFILE_HASH=$( (cat .github/Dockerfile.base .github/Dockerfile.ci | sha256sum) | cut -d ' ' -f 1)
COMBINED_HASH=$( (echo $DOCKERFILE_HASH $MLIR_DOCKER_TAG | sha256sum) | cut -d ' ' -f 1)
echo dt-$COMBINED_HASH
12 changes: 12 additions & 0 deletions .github/pull_request_template.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
### Ticket
Link to Github Issue

### Problem description
Provide context for the problem.

### What's changed
Describe the approach used to solve the problem.
Summarize the changes made and its impact.

### Checklist
- [ ] New/Existing tests provide coverage for changes
33 changes: 32 additions & 1 deletion .github/workflows/build-and-test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,25 @@ on:
description: 'Git SHA of commit in tenstorrent/tt-mlir'
required: false
type: string
test_mark:
description: 'Test mark to run'
required: true
default: 'push'
type: choice
options:
- push
- nightly
workflow_call:
inputs:
mlir_override:
description: 'Git SHA of commit in tenstorrent/tt-mlir'
required: false
type: string
test_mark:
description: 'Test mark to run'
required: false
default: 'push'
type: string

permissions:
packages: write
Expand Down Expand Up @@ -251,7 +264,25 @@ jobs:
group_suite: true

- name: Prepare code coverage report
if: matrix.build.runs-on == 'n300' && (success() || failure())
run: |
lcov --directory build --capture --output-file coverage.info
lcov --extract coverage.info '**/src/*' --output-file coverage.info
lcov --extract coverage.info '**/tt-xla/src/*' --output-file coverage.info
sed -i 's|SF:/__w/tt-xla/tt-xla/src/|SF:src/|' coverage.info
lcov --list coverage.info
- name: Upload coverage reports to Codecov
if: matrix.build.runs-on == 'n300' && (success() || failure())
uses: codecov/codecov-action@v5
with:
files: coverage.info
disable_search: true
token: ${{ secrets.CODECOV_TOKEN }}

- name: Upload test results to Codecov
if: success() || failure()
uses: codecov/test-results-action@v1
with:
files: ${{ steps.strings.outputs.test_report_path }}
disable_search: true
token: ${{ secrets.CODECOV_TOKEN }}
13 changes: 13 additions & 0 deletions .github/workflows/on-nightly.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
name: On nightly

on:
workflow_dispatch:
schedule:
- cron: '0 0 * * *'

jobs:
build-and-test:
uses: ./.github/workflows/build-and-test.yml
secrets: inherit
with:
test_mark: 'nightly'
1 change: 1 addition & 0 deletions .github/workflows/produce_data.yml
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ on:
- "On PR"
- "On push"
- "Build and Test"
- "On nightly"
types:
- completed

Expand Down
5 changes: 5 additions & 0 deletions src/common/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,10 @@ target_include_directories(TTPJRTCommon PUBLIC
${PROJECT_SOURCE_DIR}/third_party/tt-mlir/src/tt-mlir-build/include/ttmlir/Target/Common
${PROJECT_SOURCE_DIR}/third_party/tt-mlir/src/tt-mlir/include
${PROJECT_SOURCE_DIR}/third_party/tt-mlir/src/tt-mlir/runtime/include
${PROJECT_SOURCE_DIR}/third_party/tt-mlir/src/tt-mlir-build/shardy/
${PROJECT_SOURCE_DIR}/third_party/tt-mlir/src/tt-mlir-build/stablehlo/
${TTMLIR_TOOLCHAIN_DIR}/include
${TTMLIR_TOOLCHAIN_DIR}/src/shardy
${TTMLIR_TOOLCHAIN_DIR}/src/stablehlo
)

Expand All @@ -63,6 +65,8 @@ ChloOps
Version
VhloOps
VhloTypes
SdyDialect
SdyRegister
StablehloOps
StablehloRegister
StablehloReferenceToken
Expand Down Expand Up @@ -108,6 +112,7 @@ target_link_libraries(TTPJRTCommon PUBLIC
TTPJRTCommonDylibPlatform
TTMLIRStatic
TTMLIRTosaToTTIR
TTMLIRTTIRToLinalg
MLIRTTIRPipelines
TTMLIRStableHLOToTTIR
${STABLEHLO_LIBS}
Expand Down
3 changes: 0 additions & 3 deletions tests/jax/graphs/test_softmax.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,6 @@
[(64, 64), 1],
],
)
@pytest.mark.xfail(
reason="Atol comparison failed. Calculated: atol=inf. Required: atol=0.16"
)
def test_softmax(x_shape: tuple, axis: int):
def softmax(x: jax.Array) -> jax.Array:
return jax.nn.softmax(x, axis=axis)
Expand Down
5 changes: 3 additions & 2 deletions tests/jax/models/albert/v2/base/test_albert_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@

from ..tester import AlbertV2Tester


MODEL_PATH = "albert/albert-base-v2"


Expand All @@ -27,7 +26,9 @@ def training_tester() -> AlbertV2Tester:
# ----- Tests -----


@pytest.mark.xfail(reason="failed to legalize operation 'stablehlo.reduce'")
@pytest.mark.xfail(
reason="Cannot get the device from a tensor with host storage (https://github.com/tenstorrent/tt-xla/issues/171)"
)
def test_flax_albert_v2_base_inference(
inference_tester: AlbertV2Tester,
):
Expand Down
4 changes: 3 additions & 1 deletion tests/jax/models/albert/v2/large/test_albert_large.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,9 @@ def training_tester() -> AlbertV2Tester:
# ----- Tests -----


@pytest.mark.xfail(reason="failed to legalize operation 'stablehlo.reduce'")
@pytest.mark.xfail(
reason="Cannot get the device from a tensor with host storage (https://github.com/tenstorrent/tt-xla/issues/171)"
)
def test_flax_albert_v2_large_inference(
inference_tester: AlbertV2Tester,
):
Expand Down
4 changes: 3 additions & 1 deletion tests/jax/models/albert/v2/xlarge/test_albert_xlarge.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,9 @@ def training_tester() -> AlbertV2Tester:
# ----- Tests -----


@pytest.mark.xfail(reason="failed to legalize operation 'stablehlo.reduce'")
@pytest.mark.xfail(
reason="Cannot get the device from a tensor with host storage (https://github.com/tenstorrent/tt-xla/issues/171)"
)
def test_flax_albert_v2_xlarge_inference(
inference_tester: AlbertV2Tester,
):
Expand Down
4 changes: 3 additions & 1 deletion tests/jax/models/albert/v2/xxlarge/test_albert_xxlarge.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,9 @@ def training_tester() -> AlbertV2Tester:
# ----- Tests -----


@pytest.mark.xfail(reason="failed to legalize operation 'stablehlo.reduce'")
@pytest.mark.xfail(
reason="Cannot get the device from a tensor with host storage (https://github.com/tenstorrent/tt-xla/issues/171)"
)
def test_flax_albert_v2_xxlarge_inference(
inference_tester: AlbertV2Tester,
):
Expand Down
Empty file.
Empty file.
42 changes: 42 additions & 0 deletions tests/jax/models/bart/base/test_bart_base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
# SPDX-FileCopyrightText: (c) 2024 Tenstorrent AI ULC
#
# SPDX-License-Identifier: Apache-2.0

import pytest
from infra import RunMode

from ..tester import FlaxBartForCausalLMTester

MODEL_PATH = "facebook/bart-base"


# ----- Fixtures -----


@pytest.fixture
def inference_tester() -> FlaxBartForCausalLMTester:
return FlaxBartForCausalLMTester(MODEL_PATH)


@pytest.fixture
def training_tester() -> FlaxBartForCausalLMTester:
return FlaxBartForCausalLMTester(MODEL_PATH, RunMode.TRAINING)


# ----- Tests -----


@pytest.mark.xfail(
reason="Cannot get the device from a tensor with host storage (https://github.com/tenstorrent/tt-xla/issues/171)"
)
def test_flax_bart_base_inference(
inference_tester: FlaxBartForCausalLMTester,
):
inference_tester.test()


@pytest.mark.skip(reason="Support for training not implemented")
def test_flax_bart_base_training(
training_tester: FlaxBartForCausalLMTester,
):
training_tester.test()
Empty file.
42 changes: 42 additions & 0 deletions tests/jax/models/bart/large/test_bart_large.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
# SPDX-FileCopyrightText: (c) 2024 Tenstorrent AI ULC
#
# SPDX-License-Identifier: Apache-2.0

import pytest
from infra import RunMode

from ..tester import FlaxBartForCausalLMTester

MODEL_PATH = "facebook/bart-large"


# ----- Fixtures -----


@pytest.fixture
def inference_tester() -> FlaxBartForCausalLMTester:
return FlaxBartForCausalLMTester(MODEL_PATH)


@pytest.fixture
def training_tester() -> FlaxBartForCausalLMTester:
return FlaxBartForCausalLMTester(MODEL_PATH, RunMode.TRAINING)


# ----- Tests -----


@pytest.mark.xfail(
reason="Unsupported data type (https://github.com/tenstorrent/tt-xla/issues/214)"
)
def test_flax_bart_large_inference(
inference_tester: FlaxBartForCausalLMTester,
):
inference_tester.test()


@pytest.mark.skip(reason="Support for training not implemented")
def test_flax_bart_large_training(
training_tester: FlaxBartForCausalLMTester,
):
training_tester.test()
43 changes: 43 additions & 0 deletions tests/jax/models/bart/tester.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
# SPDX-FileCopyrightText: (c) 2024 Tenstorrent AI ULC
#
# SPDX-License-Identifier: Apache-2.0

from typing import Dict, Sequence

import jax
from flax import linen as nn
from infra import ComparisonConfig, ModelTester, RunMode
from transformers import AutoTokenizer, FlaxBartForCausalLM


class FlaxBartForCausalLMTester(ModelTester):
"""Tester for BART model variants with a language modeling head on top."""

# TODO(mrakita): Add tests for other variants.

def __init__(
self,
model_name: str,
comparison_config: ComparisonConfig = ComparisonConfig(),
run_mode: RunMode = RunMode.INFERENCE,
) -> None:
self._model_name = model_name
super().__init__(comparison_config, run_mode)

# @override
def _get_model(self) -> nn.Module:
return FlaxBartForCausalLM.from_pretrained(self._model_name)

# @override
def _get_input_activations(self) -> Sequence[jax.Array]:
tokenizer = AutoTokenizer.from_pretrained(self._model_name)
inputs = tokenizer("Hello", return_tensors="np")
return inputs["input_ids"]

# @override
def _get_forward_method_kwargs(self) -> Dict[str, jax.Array]:
assert hasattr(self._model, "params")
return {
"params": self._model.params,
"input_ids": self._get_input_activations(),
}
4 changes: 3 additions & 1 deletion tests/jax/models/distilbert/test_distilbert.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,9 @@ def training_tester() -> FlaxDistilBertForMaskedLMTester:
# ----- Tests -----


@pytest.mark.xfail(reason="failed to legalize operation 'stablehlo.reduce'")
@pytest.mark.xfail(
reason="Cannot get the device from a tensor with host storage (https://github.com/tenstorrent/tt-xla/issues/171)"
)
def test_flax_distilbert_inference(
inference_tester: FlaxDistilBertForMaskedLMTester,
):
Expand Down
8 changes: 5 additions & 3 deletions tests/jax/models/gpt2/test_gpt2.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
#
# SPDX-License-Identifier: Apache-2.0

from typing import Sequence, Dict
from typing import Dict, Sequence

import jax
import pytest
Expand Down Expand Up @@ -51,8 +51,10 @@ def training_tester() -> GPT2Tester:
# ----- Tests -----


@pytest.mark.xfail(reason="failed to legalize operation 'stablehlo.reduce'")
def test_gp2_inference(
@pytest.mark.xfail(
reason="Cannot get the device from a tensor with host storage (https://github.com/tenstorrent/tt-xla/issues/171)"
)
def test_gpt2_inference(
inference_tester: GPT2Tester,
):
inference_tester.test()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ def training_tester() -> LLamaTester:
# ----- Tests -----


# @pytest.mark.xfail(reason="failed to legalize operation 'stablehlo.reduce'")
# @pytest.mark.xfail(reason="failed to legalize operation 'ttir.gather'")
@pytest.mark.skip(
reason="OOMs in CI (https://github.com/tenstorrent/tt-xla/issues/186)"
)
Expand Down
2 changes: 1 addition & 1 deletion tests/jax/models/mlpmixer/test_mlpmixer.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ def training_tester() -> MlpMixerTester:
"(https://github.com/tenstorrent/tt-xla/issues/187)"
)
) # segfault
def test_mlpmixer(inference_tester: MlpMixerTester):
def test_mlpmixer_inference(inference_tester: MlpMixerTester):
inference_tester.test()


Expand Down
Loading

0 comments on commit 423318c

Please sign in to comment.