Skip to content

Commit

Permalink
Merge branch 'main' into sdjukic/device-debug-strings
Browse files Browse the repository at this point in the history
  • Loading branch information
sdjukicTT authored Feb 4, 2025
2 parents 6edbe10 + abe6da2 commit abc0c17
Show file tree
Hide file tree
Showing 112 changed files with 2,709 additions and 377 deletions.
26 changes: 12 additions & 14 deletions .github/get-docker-tag.sh
Original file line number Diff line number Diff line change
Expand Up @@ -8,26 +8,24 @@

# Exit immediately if a command exits with a non-zero status
set -e

# Execute this in a separate bash process
(
# Read tt-mlir version from third_party/CMakeLists.txt and clone third_party/tt-mlir
MLIR_DOCKER_TAG=$(
# Read tt-mlir version from third_party/CMakeLists.txt
# clone tt-mlir version to tmp/third_party/tt-mlir
# Get the MLIR docker tag
TT_MLIR_PATH=tmp/third_party/tt-mlir
TT_MLIR_VERSION=$(grep -oP 'set\(TT_MLIR_VERSION "\K[^"]+' third_party/CMakeLists.txt)
if [ ! -d "third_party/tt-mlir" ]; then
git clone https://github.com/tenstorrent/tt-mlir.git third_party/tt-mlir --quiet
if [ ! -d $TT_MLIR_PATH ]; then
git clone https://github.com/tenstorrent/tt-mlir.git $TT_MLIR_PATH --quiet
fi
cd third_party/tt-mlir
cd $TT_MLIR_PATH
git fetch --quiet
git checkout $TT_MLIR_VERSION --quiet
if [ -f ".github/get-docker-tag.sh" ]; then
MLIR_DOCKER_TAG=$(.github/get-docker-tag.sh)
.github/get-docker-tag.sh
else
MLIR_DOCKER_TAG="default-tag"
echo "default-tag"
fi
cd ../..
)

DOCKERFILE_HASH_FILES=".github/Dockerfile.base .github/Dockerfile.ci"
DOCKERFILE_HASH=$( (echo $MLIR_DOCKER_TAG; sha256sum $DOCKERFILE_HASH_FILES) | sha256sum | cut -d ' ' -f 1)
echo dt-$DOCKERFILE_HASH
DOCKERFILE_HASH=$( (cat .github/Dockerfile.base .github/Dockerfile.ci | sha256sum) | cut -d ' ' -f 1)
COMBINED_HASH=$( (echo $DOCKERFILE_HASH $MLIR_DOCKER_TAG | sha256sum) | cut -d ' ' -f 1)
echo dt-$COMBINED_HASH
12 changes: 12 additions & 0 deletions .github/pull_request_template.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
### Ticket
Link to Github Issue

### Problem description
Provide context for the problem.

### What's changed
Describe the approach used to solve the problem.
Summarize the changes made and its impact.

### Checklist
- [ ] New/Existing tests provide coverage for changes
76 changes: 74 additions & 2 deletions .github/workflows/build-and-test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,35 @@ name: Build and Test

on:
workflow_dispatch:
inputs:
mlir_override:
description: 'Git SHA of commit in tenstorrent/tt-mlir'
required: false
type: string
test_mark:
description: 'Test mark to run'
required: true
default: 'push'
type: choice
options:
- push
- nightly
workflow_call:
inputs:
mlir_override:
description: 'Git SHA of commit in tenstorrent/tt-mlir'
required: false
type: string
test_mark:
description: 'Test mark to run'
required: false
default: 'push'
type: string

permissions:
packages: write
checks: write
pull-requests: write

jobs:
# build-ttxla:
Expand Down Expand Up @@ -144,6 +172,13 @@ jobs:
submodules: recursive
lfs: true

- name: Override tt-mlir SHA mlir_override is set
if: ${{ inputs.mlir_override }}
shell: bash
run: |
# Update the CMakeLists.txt file with the new SHA
sed -i "s/set(TT_MLIR_VERSION \".*\")/set(TT_MLIR_VERSION \"${{ inputs.mlir_override }}\")/" third_party/CMakeLists.txt
- name: Set reusable strings
id: strings
shell: bash
Expand Down Expand Up @@ -187,6 +222,21 @@ jobs:
cmake --build ${{ steps.strings.outputs.build-output-dir }}
cmake --install ${{ steps.strings.outputs.build-output-dir }}
- name: Verify tt-mlir SHA override
if: ${{ inputs.mlir_override }}
continue-on-error: true
shell: bash
run: |
cd third_party/tt-mlir
branch_name=$(git rev-parse --abbrev-ref HEAD)
commit_sha=$(git rev-parse HEAD)
commit_title=$(git log -1 --pretty=%s)
echo "Branch name: $branch_name"
echo "Commit SHA: $commit_sha"
echo "Commit title: $commit_title"
echo "::notice::Using tt-mlir: $branch_name, commit: $commit_sha, title: $commit_title"
cd ../..
- name: Run tests
shell: bash
run: |
Expand All @@ -203,14 +253,36 @@ jobs:
path: ${{ steps.strings.outputs.test_report_path }}

- name: Show Test Report
uses: mikepenz/action-junit-report@v4
uses: mikepenz/action-junit-report@v5
if: success() || failure()
with:
report_paths: ${{ steps.strings.outputs.test_report_path }}
check_name: TT-XLA Tests
comment: true
updateComment: true
detailed_summary: true
group_suite: true

- name: Prepare code coverage report
if: matrix.build.runs-on == 'n300' && (success() || failure())
run: |
lcov --directory build --capture --output-file coverage.info
lcov --extract coverage.info '**/src/*' --output-file coverage.info
lcov --extract coverage.info '**/tt-xla/src/*' --output-file coverage.info
sed -i 's|SF:/__w/tt-xla/tt-xla/src/|SF:src/|' coverage.info
lcov --list coverage.info
- name: Upload coverage reports to Codecov
if: matrix.build.runs-on == 'n300' && (success() || failure())
uses: codecov/codecov-action@v5
with:
files: coverage.info
disable_search: true
token: ${{ secrets.CODECOV_TOKEN }}

- name: Upload test results to Codecov
if: success() || failure()
uses: codecov/test-results-action@v1
with:
files: ${{ steps.strings.outputs.test_report_path }}
disable_search: true
token: ${{ secrets.CODECOV_TOKEN }}
13 changes: 13 additions & 0 deletions .github/workflows/on-nightly.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
name: On nightly

on:
workflow_dispatch:
schedule:
- cron: '0 0 * * *'

jobs:
build-and-test:
uses: ./.github/workflows/build-and-test.yml
secrets: inherit
with:
test_mark: 'nightly'
7 changes: 7 additions & 0 deletions .github/workflows/on-pr.yml
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,11 @@ name: On PR

on:
workflow_dispatch:
inputs:
mlir_override:
description: 'Git SHA of commit in tenstorrent/tt-mlir'
required: false
type: string
pull_request:
branches: [ "main" ]

Expand All @@ -20,3 +25,5 @@ jobs:
needs: [pre-commit, spdx]
uses: ./.github/workflows/build-and-test.yml
secrets: inherit
with:
mlir_override: ${{ inputs.mlir_override }}
1 change: 1 addition & 0 deletions .github/workflows/produce_data.yml
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ on:
- "On PR"
- "On push"
- "Build and Test"
- "On nightly"
types:
- completed

Expand Down
5 changes: 0 additions & 5 deletions inc/common/pjrt_implementation/client_instance.h
Original file line number Diff line number Diff line change
Expand Up @@ -55,11 +55,6 @@ class ClientInstance {
return cached_platform_version_;
}

// Checks if the output on the i-th index is a scalar.
bool isOutputScalar(const size_t index) const {
return module_builder_->isOutputScalar(index);
}

// Compiles.
// See TODOs in PJRT_Client_Compile.
PJRT_Error *
Expand Down
2 changes: 1 addition & 1 deletion inc/common/pjrt_implementation/device_description.h
Original file line number Diff line number Diff line change
Expand Up @@ -63,4 +63,4 @@ class DeviceDescription {

} // namespace tt::pjrt

#endif
#endif
44 changes: 29 additions & 15 deletions inc/common/pjrt_implementation/executable_image.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,9 @@

#include "xla/pjrt/c/pjrt_c_api.h"

// tt-mlir includes
#include "tt/runtime/types.h"

#ifndef TT_XLA_INC_COMMON_PJRT_IMPLEMENTATION_EXECUTABLE_IMAGE_H_
#define TT_XLA_INC_COMMON_PJRT_IMPLEMENTATION_EXECUTABLE_IMAGE_H_

Expand All @@ -22,10 +25,10 @@ namespace tt::pjrt {
class ExecutableImage {

public:
ExecutableImage(std::shared_ptr<void> binary, std::string code,
size_t arg_count, size_t result_count)
: ref_count(1), binary(std::move(binary)), code(code),
arg_count(arg_count), result_count(result_count) {}
ExecutableImage(const tt::runtime::Binary &binary, std::string code,
const std::vector<bool> &is_output_scalar,
size_t num_addressable_devices);

operator PJRT_Executable *() {
return reinterpret_cast<PJRT_Executable *>(this);
}
Expand All @@ -34,33 +37,44 @@ class ExecutableImage {
}
static void BindApi(PJRT_Api *api);

void AddRef() { ref_count.fetch_add(1); }
void AddRef() { m_ref_count.fetch_add(1); }
void DecRef() {
if (ref_count.fetch_sub(1) == 0) {
if (m_ref_count.fetch_sub(1) == 0) {
delete this;
}
}

const size_t get_arg_count() const { return arg_count; }
const size_t get_arg_count() const { return m_arg_count; }

const size_t get_result_count() const { return m_result_count; }

const size_t get_result_count() const { return result_count; }
const tt::runtime::Binary &get_binary() const { return m_binary; }

std::shared_ptr<void> get_binary() { return binary; }
const std::string &get_code() const { return m_code; }

const std::string &get_code() const { return code; }
// Checks if the output on the i-th index is a scalar.
bool isOutputScalar(size_t index) const;

const size_t get_num_addressable_devices() const {
return num_addressable_devices;
}

private:
// The reference count. Must be disposed when reaching zero.
std::atomic<int> ref_count;
std::atomic<int> m_ref_count;

// Raw compiler output.
std::shared_ptr<void> binary;
tt::runtime::Binary m_binary;

// Original code fed to the compiler. Stored for debugging.
const std::string code;
const std::string m_code;

size_t m_arg_count;
size_t m_result_count;
size_t num_addressable_devices;

size_t arg_count;
size_t result_count;
// For every output, holds if the type is a scalar or not.
std::vector<bool> m_is_output_scalar;
};

} // namespace tt::pjrt
Expand Down
4 changes: 4 additions & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -13,3 +13,7 @@ lit
pybind11
pytest
transformers
fsspec
einops
torch
ml_collections
5 changes: 5 additions & 0 deletions src/common/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,10 @@ target_include_directories(TTPJRTCommon PUBLIC
${PROJECT_SOURCE_DIR}/third_party/tt-mlir/src/tt-mlir-build/include/ttmlir/Target/Common
${PROJECT_SOURCE_DIR}/third_party/tt-mlir/src/tt-mlir/include
${PROJECT_SOURCE_DIR}/third_party/tt-mlir/src/tt-mlir/runtime/include
${PROJECT_SOURCE_DIR}/third_party/tt-mlir/src/tt-mlir-build/shardy/
${PROJECT_SOURCE_DIR}/third_party/tt-mlir/src/tt-mlir-build/stablehlo/
${TTMLIR_TOOLCHAIN_DIR}/include
${TTMLIR_TOOLCHAIN_DIR}/src/shardy
${TTMLIR_TOOLCHAIN_DIR}/src/stablehlo
)

Expand All @@ -63,6 +65,8 @@ ChloOps
Version
VhloOps
VhloTypes
SdyDialect
SdyRegister
StablehloOps
StablehloRegister
StablehloReferenceToken
Expand Down Expand Up @@ -108,6 +112,7 @@ target_link_libraries(TTPJRTCommon PUBLIC
TTPJRTCommonDylibPlatform
TTMLIRStatic
TTMLIRTosaToTTIR
TTMLIRTTIRToLinalg
MLIRTTIRPipelines
TTMLIRStableHLOToTTIR
${STABLEHLO_LIBS}
Expand Down
22 changes: 2 additions & 20 deletions src/common/module_builder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@
namespace tt::pjrt {

ModuleBuilder::ModuleBuilder()
: m_status(tt_pjrt_status::kSuccess), m_num_inputs(0), m_num_outputs(0) {
: m_status(tt_pjrt_status::kSuccess), m_flatbuffer_binary(nullptr) {
m_context = std::make_unique<mlir::MLIRContext>();

// Register all the required dialects and passes.
Expand All @@ -66,11 +66,6 @@ ModuleBuilder::ModuleBuilder()
m_context->appendDialectRegistry(registry);
}

bool ModuleBuilder::isOutputScalar(const size_t index) const {
assert(index < m_is_output_scalar.size() && "Output index out of range");
return m_is_output_scalar[index];
}

tt_pjrt_status ModuleBuilder::buildModule(const std::string_view &code,
const std::string_view &format) {
DLOG_F(LOG_DEBUG, "ModuleBuilder::buildModule");
Expand Down Expand Up @@ -219,22 +214,9 @@ void ModuleBuilder::createFlatbufferBinary(
const mlir::OwningOpRef<mlir::ModuleOp> &mlir_module) {
m_flatbuffer_binary = mlir::tt::ttnn::ttnnToFlatbuffer(mlir_module.get());

if (m_flatbuffer_binary == nullptr) {
if (m_flatbuffer_binary.handle == nullptr) {
DLOG_F(ERROR, "Failed to generate flatbuffer binary");
m_status = tt_pjrt_status::kInternal;
return;
}

tt::runtime::Binary runtime_binary_handle(m_flatbuffer_binary);
m_num_inputs = runtime_binary_handle.getProgramInputs(0).size();
m_num_outputs = runtime_binary_handle.getProgramOutputs(0).size();

if (m_num_outputs != m_is_output_scalar.size()) {
DLOG_F(ERROR,
"Created flatbuffer binary contains different number of outputs %ld "
"than expected %ld",
m_num_outputs, m_is_output_scalar.size());
m_status = tt_pjrt_status::kInternal;
}
}

Expand Down
Loading

0 comments on commit abc0c17

Please sign in to comment.