diff --git a/.github/get-docker-tag.sh b/.github/get-docker-tag.sh index 9ed9f0f7..8dea07ba 100755 --- a/.github/get-docker-tag.sh +++ b/.github/get-docker-tag.sh @@ -8,26 +8,24 @@ # Exit immediately if a command exits with a non-zero status set -e - -# Execute this in a separate bash process -( - # Read tt-mlir version from third_party/CMakeLists.txt and clone third_party/tt-mlir +MLIR_DOCKER_TAG=$( + # Read tt-mlir version from third_party/CMakeLists.txt + # clone tt-mlir version to tmp/third_party/tt-mlir # Get the MLIR docker tag + TT_MLIR_PATH=tmp/third_party/tt-mlir TT_MLIR_VERSION=$(grep -oP 'set\(TT_MLIR_VERSION "\K[^"]+' third_party/CMakeLists.txt) - if [ ! -d "third_party/tt-mlir" ]; then - git clone https://github.com/tenstorrent/tt-mlir.git third_party/tt-mlir --quiet + if [ ! -d $TT_MLIR_PATH ]; then + git clone https://github.com/tenstorrent/tt-mlir.git $TT_MLIR_PATH --quiet fi - cd third_party/tt-mlir + cd $TT_MLIR_PATH git fetch --quiet git checkout $TT_MLIR_VERSION --quiet if [ -f ".github/get-docker-tag.sh" ]; then - MLIR_DOCKER_TAG=$(.github/get-docker-tag.sh) + .github/get-docker-tag.sh else - MLIR_DOCKER_TAG="default-tag" + echo "default-tag" fi - cd ../.. ) - -DOCKERFILE_HASH_FILES=".github/Dockerfile.base .github/Dockerfile.ci" -DOCKERFILE_HASH=$( (echo $MLIR_DOCKER_TAG; sha256sum $DOCKERFILE_HASH_FILES) | sha256sum | cut -d ' ' -f 1) -echo dt-$DOCKERFILE_HASH +DOCKERFILE_HASH=$( (cat .github/Dockerfile.base .github/Dockerfile.ci | sha256sum) | cut -d ' ' -f 1) +COMBINED_HASH=$( (echo $DOCKERFILE_HASH $MLIR_DOCKER_TAG | sha256sum) | cut -d ' ' -f 1) +echo dt-$COMBINED_HASH diff --git a/.github/pull_request_template.md b/.github/pull_request_template.md new file mode 100644 index 00000000..4a91d818 --- /dev/null +++ b/.github/pull_request_template.md @@ -0,0 +1,12 @@ +### Ticket +Link to Github Issue + +### Problem description +Provide context for the problem. + +### What's changed +Describe the approach used to solve the problem. +Summarize the changes made and its impact. + +### Checklist +- [ ] New/Existing tests provide coverage for changes diff --git a/.github/workflows/build-and-test.yml b/.github/workflows/build-and-test.yml index 5f3d5a98..3466d8de 100644 --- a/.github/workflows/build-and-test.yml +++ b/.github/workflows/build-and-test.yml @@ -2,7 +2,35 @@ name: Build and Test on: workflow_dispatch: + inputs: + mlir_override: + description: 'Git SHA of commit in tenstorrent/tt-mlir' + required: false + type: string + test_mark: + description: 'Test mark to run' + required: true + default: 'push' + type: choice + options: + - push + - nightly workflow_call: + inputs: + mlir_override: + description: 'Git SHA of commit in tenstorrent/tt-mlir' + required: false + type: string + test_mark: + description: 'Test mark to run' + required: false + default: 'push' + type: string + +permissions: + packages: write + checks: write + pull-requests: write jobs: # build-ttxla: @@ -144,6 +172,13 @@ jobs: submodules: recursive lfs: true + - name: Override tt-mlir SHA mlir_override is set + if: ${{ inputs.mlir_override }} + shell: bash + run: | + # Update the CMakeLists.txt file with the new SHA + sed -i "s/set(TT_MLIR_VERSION \".*\")/set(TT_MLIR_VERSION \"${{ inputs.mlir_override }}\")/" third_party/CMakeLists.txt + - name: Set reusable strings id: strings shell: bash @@ -187,6 +222,21 @@ jobs: cmake --build ${{ steps.strings.outputs.build-output-dir }} cmake --install ${{ steps.strings.outputs.build-output-dir }} + - name: Verify tt-mlir SHA override + if: ${{ inputs.mlir_override }} + continue-on-error: true + shell: bash + run: | + cd third_party/tt-mlir + branch_name=$(git rev-parse --abbrev-ref HEAD) + commit_sha=$(git rev-parse HEAD) + commit_title=$(git log -1 --pretty=%s) + echo "Branch name: $branch_name" + echo "Commit SHA: $commit_sha" + echo "Commit title: $commit_title" + echo "::notice::Using tt-mlir: $branch_name, commit: $commit_sha, title: $commit_title" + cd ../.. + - name: Run tests shell: bash run: | @@ -203,14 +253,36 @@ jobs: path: ${{ steps.strings.outputs.test_report_path }} - name: Show Test Report - uses: mikepenz/action-junit-report@v4 + uses: mikepenz/action-junit-report@v5 if: success() || failure() with: report_paths: ${{ steps.strings.outputs.test_report_path }} check_name: TT-XLA Tests + comment: true + updateComment: true + detailed_summary: true + group_suite: true - name: Prepare code coverage report + if: matrix.build.runs-on == 'n300' && (success() || failure()) run: | lcov --directory build --capture --output-file coverage.info - lcov --extract coverage.info '**/src/*' --output-file coverage.info + lcov --extract coverage.info '**/tt-xla/src/*' --output-file coverage.info + sed -i 's|SF:/__w/tt-xla/tt-xla/src/|SF:src/|' coverage.info lcov --list coverage.info + + - name: Upload coverage reports to Codecov + if: matrix.build.runs-on == 'n300' && (success() || failure()) + uses: codecov/codecov-action@v5 + with: + files: coverage.info + disable_search: true + token: ${{ secrets.CODECOV_TOKEN }} + + - name: Upload test results to Codecov + if: success() || failure() + uses: codecov/test-results-action@v1 + with: + files: ${{ steps.strings.outputs.test_report_path }} + disable_search: true + token: ${{ secrets.CODECOV_TOKEN }} diff --git a/.github/workflows/on-nightly.yml b/.github/workflows/on-nightly.yml new file mode 100644 index 00000000..f1b90093 --- /dev/null +++ b/.github/workflows/on-nightly.yml @@ -0,0 +1,13 @@ +name: On nightly + +on: + workflow_dispatch: + schedule: + - cron: '0 0 * * *' + +jobs: + build-and-test: + uses: ./.github/workflows/build-and-test.yml + secrets: inherit + with: + test_mark: 'nightly' diff --git a/.github/workflows/on-pr.yml b/.github/workflows/on-pr.yml index 34ea5ef6..5c407bbf 100644 --- a/.github/workflows/on-pr.yml +++ b/.github/workflows/on-pr.yml @@ -2,6 +2,11 @@ name: On PR on: workflow_dispatch: + inputs: + mlir_override: + description: 'Git SHA of commit in tenstorrent/tt-mlir' + required: false + type: string pull_request: branches: [ "main" ] @@ -20,3 +25,5 @@ jobs: needs: [pre-commit, spdx] uses: ./.github/workflows/build-and-test.yml secrets: inherit + with: + mlir_override: ${{ inputs.mlir_override }} diff --git a/.github/workflows/produce_data.yml b/.github/workflows/produce_data.yml index 1b26c3d0..ca089055 100644 --- a/.github/workflows/produce_data.yml +++ b/.github/workflows/produce_data.yml @@ -6,6 +6,7 @@ on: - "On PR" - "On push" - "Build and Test" + - "On nightly" types: - completed diff --git a/inc/common/pjrt_implementation/client_instance.h b/inc/common/pjrt_implementation/client_instance.h index 2e74a271..9334cd82 100644 --- a/inc/common/pjrt_implementation/client_instance.h +++ b/inc/common/pjrt_implementation/client_instance.h @@ -55,11 +55,6 @@ class ClientInstance { return cached_platform_version_; } - // Checks if the output on the i-th index is a scalar. - bool isOutputScalar(const size_t index) const { - return module_builder_->isOutputScalar(index); - } - // Compiles. // See TODOs in PJRT_Client_Compile. PJRT_Error * diff --git a/inc/common/pjrt_implementation/device_description.h b/inc/common/pjrt_implementation/device_description.h index 212cddb3..b244419d 100644 --- a/inc/common/pjrt_implementation/device_description.h +++ b/inc/common/pjrt_implementation/device_description.h @@ -63,4 +63,4 @@ class DeviceDescription { } // namespace tt::pjrt -#endif +#endif \ No newline at end of file diff --git a/inc/common/pjrt_implementation/executable_image.h b/inc/common/pjrt_implementation/executable_image.h index 319ca6e9..32fb3a9e 100644 --- a/inc/common/pjrt_implementation/executable_image.h +++ b/inc/common/pjrt_implementation/executable_image.h @@ -14,6 +14,9 @@ #include "xla/pjrt/c/pjrt_c_api.h" +// tt-mlir includes +#include "tt/runtime/types.h" + #ifndef TT_XLA_INC_COMMON_PJRT_IMPLEMENTATION_EXECUTABLE_IMAGE_H_ #define TT_XLA_INC_COMMON_PJRT_IMPLEMENTATION_EXECUTABLE_IMAGE_H_ @@ -22,10 +25,10 @@ namespace tt::pjrt { class ExecutableImage { public: - ExecutableImage(std::shared_ptr binary, std::string code, - size_t arg_count, size_t result_count) - : ref_count(1), binary(std::move(binary)), code(code), - arg_count(arg_count), result_count(result_count) {} + ExecutableImage(const tt::runtime::Binary &binary, std::string code, + const std::vector &is_output_scalar, + size_t num_addressable_devices); + operator PJRT_Executable *() { return reinterpret_cast(this); } @@ -34,33 +37,44 @@ class ExecutableImage { } static void BindApi(PJRT_Api *api); - void AddRef() { ref_count.fetch_add(1); } + void AddRef() { m_ref_count.fetch_add(1); } void DecRef() { - if (ref_count.fetch_sub(1) == 0) { + if (m_ref_count.fetch_sub(1) == 0) { delete this; } } - const size_t get_arg_count() const { return arg_count; } + const size_t get_arg_count() const { return m_arg_count; } + + const size_t get_result_count() const { return m_result_count; } - const size_t get_result_count() const { return result_count; } + const tt::runtime::Binary &get_binary() const { return m_binary; } - std::shared_ptr get_binary() { return binary; } + const std::string &get_code() const { return m_code; } - const std::string &get_code() const { return code; } + // Checks if the output on the i-th index is a scalar. + bool isOutputScalar(size_t index) const; + + const size_t get_num_addressable_devices() const { + return num_addressable_devices; + } private: // The reference count. Must be disposed when reaching zero. - std::atomic ref_count; + std::atomic m_ref_count; // Raw compiler output. - std::shared_ptr binary; + tt::runtime::Binary m_binary; // Original code fed to the compiler. Stored for debugging. - const std::string code; + const std::string m_code; + + size_t m_arg_count; + size_t m_result_count; + size_t num_addressable_devices; - size_t arg_count; - size_t result_count; + // For every output, holds if the type is a scalar or not. + std::vector m_is_output_scalar; }; } // namespace tt::pjrt diff --git a/requirements.txt b/requirements.txt index 38d4414f..781ab540 100644 --- a/requirements.txt +++ b/requirements.txt @@ -13,3 +13,7 @@ lit pybind11 pytest transformers +fsspec +einops +torch +ml_collections diff --git a/src/common/CMakeLists.txt b/src/common/CMakeLists.txt index bcc3b88a..1ca6f741 100644 --- a/src/common/CMakeLists.txt +++ b/src/common/CMakeLists.txt @@ -39,8 +39,10 @@ target_include_directories(TTPJRTCommon PUBLIC ${PROJECT_SOURCE_DIR}/third_party/tt-mlir/src/tt-mlir-build/include/ttmlir/Target/Common ${PROJECT_SOURCE_DIR}/third_party/tt-mlir/src/tt-mlir/include ${PROJECT_SOURCE_DIR}/third_party/tt-mlir/src/tt-mlir/runtime/include + ${PROJECT_SOURCE_DIR}/third_party/tt-mlir/src/tt-mlir-build/shardy/ ${PROJECT_SOURCE_DIR}/third_party/tt-mlir/src/tt-mlir-build/stablehlo/ ${TTMLIR_TOOLCHAIN_DIR}/include + ${TTMLIR_TOOLCHAIN_DIR}/src/shardy ${TTMLIR_TOOLCHAIN_DIR}/src/stablehlo ) @@ -63,6 +65,8 @@ ChloOps Version VhloOps VhloTypes +SdyDialect +SdyRegister StablehloOps StablehloRegister StablehloReferenceToken @@ -108,6 +112,7 @@ target_link_libraries(TTPJRTCommon PUBLIC TTPJRTCommonDylibPlatform TTMLIRStatic TTMLIRTosaToTTIR + TTMLIRTTIRToLinalg MLIRTTIRPipelines TTMLIRStableHLOToTTIR ${STABLEHLO_LIBS} diff --git a/src/common/module_builder.cc b/src/common/module_builder.cc index f39276b1..55cb112f 100644 --- a/src/common/module_builder.cc +++ b/src/common/module_builder.cc @@ -43,7 +43,7 @@ namespace tt::pjrt { ModuleBuilder::ModuleBuilder() - : m_status(tt_pjrt_status::kSuccess), m_num_inputs(0), m_num_outputs(0) { + : m_status(tt_pjrt_status::kSuccess), m_flatbuffer_binary(nullptr) { m_context = std::make_unique(); // Register all the required dialects and passes. @@ -66,11 +66,6 @@ ModuleBuilder::ModuleBuilder() m_context->appendDialectRegistry(registry); } -bool ModuleBuilder::isOutputScalar(const size_t index) const { - assert(index < m_is_output_scalar.size() && "Output index out of range"); - return m_is_output_scalar[index]; -} - tt_pjrt_status ModuleBuilder::buildModule(const std::string_view &code, const std::string_view &format) { DLOG_F(LOG_DEBUG, "ModuleBuilder::buildModule"); @@ -219,22 +214,9 @@ void ModuleBuilder::createFlatbufferBinary( const mlir::OwningOpRef &mlir_module) { m_flatbuffer_binary = mlir::tt::ttnn::ttnnToFlatbuffer(mlir_module.get()); - if (m_flatbuffer_binary == nullptr) { + if (m_flatbuffer_binary.handle == nullptr) { DLOG_F(ERROR, "Failed to generate flatbuffer binary"); m_status = tt_pjrt_status::kInternal; - return; - } - - tt::runtime::Binary runtime_binary_handle(m_flatbuffer_binary); - m_num_inputs = runtime_binary_handle.getProgramInputs(0).size(); - m_num_outputs = runtime_binary_handle.getProgramOutputs(0).size(); - - if (m_num_outputs != m_is_output_scalar.size()) { - DLOG_F(ERROR, - "Created flatbuffer binary contains different number of outputs %ld " - "than expected %ld", - m_num_outputs, m_is_output_scalar.size()); - m_status = tt_pjrt_status::kInternal; } } diff --git a/src/common/module_builder.h b/src/common/module_builder.h index 9ea1f570..548cc321 100644 --- a/src/common/module_builder.h +++ b/src/common/module_builder.h @@ -15,6 +15,9 @@ #include "mlir/IR/MLIRContext.h" #include "mlir/IR/OwningOpRef.h" +// tt-mlir includes +#include "tt/runtime/types.h" + // tt-xla includes #include "status.h" @@ -27,13 +30,15 @@ class ModuleBuilder { tt_pjrt_status buildModule(const std::string_view &code, const std::string_view &format); - std::shared_ptr getBinary() const { return m_flatbuffer_binary; } - - size_t getNumInputs() const { return m_num_inputs; }; + const tt::runtime::Binary &getBinary() const { return m_flatbuffer_binary; } - size_t getNumOutputs() const { return m_num_outputs; }; + const std::vector &getIsOutputScalar() const { + return m_is_output_scalar; + }; - bool isOutputScalar(size_t index) const; + // This needs to return the number of addressable devices from the StableHLO + // code. Currently hardcoded to one, as we only support one-chip execution. + size_t getNumAddressableDevices() const { return 1; } private: // Creates VHLO module from the input program code. @@ -66,14 +71,8 @@ class ModuleBuilder { // MLIR context handle. std::unique_ptr m_context; - // Flatbuffer binary handle. - std::shared_ptr m_flatbuffer_binary; - - // Number of binary program inputs. - size_t m_num_inputs; - - // Number of binary program outputs. - size_t m_num_outputs; + // Flatbuffer binary. + tt::runtime::Binary m_flatbuffer_binary; // Holds status of the last builder action. tt_pjrt_status m_status; diff --git a/src/common/pjrt_implementation/client_instance.cc b/src/common/pjrt_implementation/client_instance.cc index f48ad181..cfe87830 100644 --- a/src/common/pjrt_implementation/client_instance.cc +++ b/src/common/pjrt_implementation/client_instance.cc @@ -164,14 +164,13 @@ void ClientInstance::BindApi(PJRT_Api *api) { tt_pjrt_status ClientInstance::PopulateDevices() { DLOG_F(LOG_DEBUG, "ClientInstance::PopulateDevices"); auto [system_desc, chip_ids] = tt::runtime::getCurrentSystemDesc(); - int device_info_count_ = - 1; // TODO: revert to chip_ids.size(); once - // https://github.com/tenstorrent/tt-xla/issues/9 is fixed + int devices_count = chip_ids.size(); - devices_.resize(device_info_count_); - for (size_t i = 0; i < device_info_count_; ++i) { + devices_.resize(devices_count); + for (size_t i = 0; i < devices_count; ++i) { devices_[i] = new DeviceInstance(i, *this, system_desc->chip_descs()->Get(i)->arch()); + } // For now, just make all devices addressable. @@ -198,8 +197,8 @@ PJRT_Error *ClientInstance::Compile(const PJRT_Program *program, *this, new ExecutableImage(module_builder_->getBinary(), std::string(program->code, program->code_size), - module_builder_->getNumInputs(), - module_builder_->getNumOutputs()), + module_builder_->getIsOutputScalar(), + module_builder_->getNumAddressableDevices()), addressable_devices_); *out_executable = executable.release(); diff --git a/src/common/pjrt_implementation/device_description.cc b/src/common/pjrt_implementation/device_description.cc index 60b1e977..3feaf962 100644 --- a/src/common/pjrt_implementation/device_description.cc +++ b/src/common/pjrt_implementation/device_description.cc @@ -28,7 +28,7 @@ void DeviceDescription::BindApi(PJRT_Api *api) { api->PJRT_DeviceDescription_Id = +[](PJRT_DeviceDescription_Id_Args *args) -> PJRT_Error * { args->id = - DeviceDescription::Unwrap(args->device_description)->getClientId(); + DeviceDescription::Unwrap(args->device_description)->getDeviceId(); return nullptr; }; api->PJRT_DeviceDescription_ProcessIndex = diff --git a/src/common/pjrt_implementation/executable_image.cc b/src/common/pjrt_implementation/executable_image.cc index 852cc9ad..322619a2 100644 --- a/src/common/pjrt_implementation/executable_image.cc +++ b/src/common/pjrt_implementation/executable_image.cc @@ -19,6 +19,25 @@ namespace tt::pjrt { const std::string_view kMlirFormat = "mlir"; +ExecutableImage::ExecutableImage(const tt::runtime::Binary &binary, + std::string code, + const std::vector &is_output_scalar, + size_t num_addressable_devices) + : m_ref_count(1), m_binary(binary), m_code(code), + m_arg_count(binary.getProgramInputs(0).size()), + m_result_count(binary.getProgramOutputs(0).size()), + m_is_output_scalar(is_output_scalar), + num_addressable_devices(num_addressable_devices) { + if (m_result_count != m_is_output_scalar.size()) { + // TODO: We should throw error instead, otherwise execution will continue + // and crash later. + DLOG_F(ERROR, + "Created flatbuffer binary contains different number of outputs %ld " + "than expected %ld", + m_result_count, m_is_output_scalar.size()); + } +} + void ExecutableImage::BindApi(PJRT_Api *api) { api->PJRT_Executable_Destroy = +[](PJRT_Executable_Destroy_Args *args) -> PJRT_Error * { @@ -48,7 +67,7 @@ void ExecutableImage::BindApi(PJRT_Api *api) { +[](PJRT_Executable_NumOutputs_Args *args) -> PJRT_Error * { DLOG_F(LOG_DEBUG, "ExecutableImage::PJRT_Executable_NumOutputs"); ExecutableImage *exec = ExecutableImage::Unwrap(args->executable); - args->num_outputs = exec->result_count; + args->num_outputs = exec->get_result_count(); return nullptr; }; api->PJRT_Executable_NumPartitions = @@ -86,15 +105,14 @@ void ExecutableImage::BindApi(PJRT_Api *api) { PJRT_Program *program = args->program; program->format = kMlirFormat.data(); program->format_size = kMlirFormat.size(); - size_t code_size = executable->code.size(); + size_t code_size = executable->get_code().size(); if (program->code == nullptr) { program->code_size = code_size; } else { if (program->code_size < code_size) { return ErrorInstance::MakeError(tt_pjrt_status::kInvalidArgument); } - std::memcpy(program->code, executable->code.c_str(), - executable->code.size()); + std::memcpy(program->code, executable->get_code().c_str(), code_size); } return nullptr; }; @@ -121,4 +139,9 @@ void ExecutableImage::BindApi(PJRT_Api *api) { }; } +bool ExecutableImage::isOutputScalar(const size_t index) const { + assert(index < m_is_output_scalar.size() && "Output index out of range"); + return m_is_output_scalar[index]; +} + } // namespace tt::pjrt diff --git a/src/common/pjrt_implementation/loaded_executable_instance.cc b/src/common/pjrt_implementation/loaded_executable_instance.cc index e8a6e72b..9b025172 100644 --- a/src/common/pjrt_implementation/loaded_executable_instance.cc +++ b/src/common/pjrt_implementation/loaded_executable_instance.cc @@ -10,6 +10,8 @@ #include "common/pjrt_implementation/loaded_executable_instance.h" +#include + #include "common/pjrt_implementation/buffer_instance.h" #include "common/pjrt_implementation/client_instance.h" #include "common/pjrt_implementation/error_instance.h" @@ -32,12 +34,15 @@ void LoadedExecutableInstance::BindApi(PJRT_Api *api) { DLOG_F( LOG_DEBUG, "LoadedExecutableInstance::PJRT_LoadedExecutable_AddressableDevices"); - const std::vector &devices = - LoadedExecutableInstance::Unwrap(args->executable) - ->addressable_devices(); + LoadedExecutableInstance *loaded_executable = + LoadedExecutableInstance::Unwrap(args->executable); + const std::vector &addressable_devices = + loaded_executable->addressable_devices(); + int num_addressable_devices = + loaded_executable->image_->get_num_addressable_devices(); args->addressable_devices = const_cast( - reinterpret_cast(devices.data())); - args->num_addressable_devices = devices.size(); + reinterpret_cast(addressable_devices.data())); + args->num_addressable_devices = num_addressable_devices; return nullptr; }; api->PJRT_LoadedExecutable_Delete = @@ -77,23 +82,41 @@ LoadedExecutableInstance::Execute(PJRT_LoadedExecutable_Execute_Args *args) { DLOG_F(LOG_DEBUG, "LoadedExecutableInstance::Execute"); auto [system_desc, chip_ids] = tt::runtime::getCurrentSystemDesc(); - int dev_0 = chip_ids[0]; - tt::runtime::Device device = tt::runtime::openDevice({dev_0}); + // Sanity check, as we only support execution on one chip currently. assert(args->num_devices == 1); + int dev_index = 0; - tt::runtime::Binary binary(image_->get_binary()); + const tt::runtime::Binary &binary = image_->get_binary(); std::vector rt_inputs; rt_inputs.reserve(args->num_args); + std::unordered_set device_ids; + for (size_t i = 0; i < args->num_args; ++i) { BufferInstance *buffer = BufferInstance::Unwrap(args->argument_lists[dev_index][i]); rt_inputs.emplace_back(buffer->tensor()); + int64_t buffer_device_id = + buffer->device().device_description()->getDeviceId(); + device_ids.insert(chip_ids[buffer_device_id]); DLOG_F(INFO, "Runtime input id: %d", buffer->unique_id()); } + std::vector device_ids_vector(device_ids.begin(), device_ids.end()); + + // If there are no input buffers, we still want to run on a device. + // TODO: Now we will run only on the first one, but this should be somehow + // explicit. + if (device_ids.size() == 0) { + device_ids_vector.push_back(chip_ids[0]); + } + + assert(device_ids_vector.size() == 1); + + tt::runtime::Device device = tt::runtime::openDevice(device_ids_vector); + std::vector rt_outputs = tt::runtime::submit(device, binary, 0, rt_inputs); std::vector output_specs = @@ -102,7 +125,7 @@ LoadedExecutableInstance::Execute(PJRT_LoadedExecutable_Execute_Args *args) { assert(rt_outputs.size() == output_specs.size()); for (size_t i = 0; i < output_specs.size(); ++i) { - bool is_scalar = client_.isOutputScalar(i); + bool is_scalar = image_->isOutputScalar(i); // PJRT expects an empty shape for scalars. std::vector output_shape = is_scalar ? std::vector() : output_specs[i].shape; diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/conftest.py b/tests/conftest.py new file mode 100644 index 00000000..affb3e03 --- /dev/null +++ b/tests/conftest.py @@ -0,0 +1,102 @@ +# SPDX-FileCopyrightText: (c) 2024 Tenstorrent AI ULC +# +# SPDX-License-Identifier: Apache-2.0 + +from datetime import datetime +from enum import Enum +from typing import Callable + +import pytest + + +class RecordProperties(Enum): + """Properties we can record.""" + + # Timestamp of test start. + START_TIMESTAMP = "start_timestamp" + # Timestamp of test end. + END_TIMESTAMP = "end_timestamp" + # Frontend or framework used to run the test. + FRONTEND = "frontend" + # Kind of operation. e.g. eltwise. + OP_KIND = "op_kind" + # Name of the operation in the framework. e.g. torch.conv2d. + FRAMEWORK_OP_NAME = "framework_op_name" + # Name of the operation. e.g. ttir.conv2d. + OP_NAME = "op_name" + # Name of the model in which this op appears. + MODEL_NAME = "model_name" + + +@pytest.fixture(scope="function", autouse=True) +def record_test_timestamp(record_property: Callable): + """ + Autouse fixture used to capture execution time of a test. + + Parameters: + ---------- + record_property: Callable + A pytest built-in function used to record test metadata, such as custom + properties or additional information about the test execution. + + Yields: + ------- + Callable + The `record_property` callable, allowing tests to add additional properties if + needed. + + + Example: + -------- + ``` + def test_model(fixture1, fixture2, ..., record_tt_xla_property): + record_tt_xla_property("key", value) + + # Test logic... + ``` + """ + start_timestamp = datetime.strftime(datetime.now(), "%Y-%m-%dT%H:%M:%S%z") + record_property(RecordProperties.START_TIMESTAMP.value, start_timestamp) + + # Run the test. + yield + + end_timestamp = datetime.strftime(datetime.now(), "%Y-%m-%dT%H:%M:%S%z") + record_property(RecordProperties.END_TIMESTAMP.value, end_timestamp) + + +@pytest.fixture(scope="function", autouse=True) +def record_tt_xla_property(record_property: Callable): + """ + Autouse fixture that automatically records some test properties for each test + function. + + It also yields back callable which can be explicitly used in tests to record + additional properties. + + Example: + + ``` + def test_model(fixture1, fixture2, ..., record_tt_xla_property): + record_tt_xla_property("key", value) + + # Test logic... + ``` + + Parameters: + ---------- + record_property: Callable + A pytest built-in function used to record test metadata, such as custom + properties or additional information about the test execution. + + Yields: + ------- + Callable + The `record_property` callable, allowing tests to add additional properties if + needed. + """ + # Record default properties for tt-xla. + record_property(RecordProperties.FRONTEND.value, "tt-xla") + + # Run the test. + yield record_property diff --git a/tests/infra/comparison.py b/tests/infra/comparison.py index 40800989..c8a799b7 100644 --- a/tests/infra/comparison.py +++ b/tests/infra/comparison.py @@ -34,17 +34,21 @@ class AtolConfig(ConfigBase): required_atol: float = 1.6e-1 -@dataclass -class PccConfig(ConfigBase): - required_pcc: float = 0.99 - - @dataclass class AllcloseConfig(ConfigBase): rtol: float = 1e-2 atol: float = 1e-2 +# When tensors are too close, pcc will output NaN values. +# Therefore, for each test it should be possible to separately tune the threshold of allclose.rtol and allclose.atol +# below which pcc won't be calculated and therefore test will be able to pass without pcc comparison. +@dataclass +class PccConfig(ConfigBase): + required_pcc: float = 0.99 + allclose: AllcloseConfig = AllcloseConfig() + + @dataclass class ComparisonConfig: equal: EqualConfig = EqualConfig(False) @@ -106,9 +110,7 @@ def compare_pcc( # If tensors are really close, pcc will be nan. Handle that before calculating pcc. try: - compare_allclose( - device_output, golden_output, AllcloseConfig(rtol=1e-2, atol=1e-2) - ) + compare_allclose(device_output, golden_output, pcc_config.allclose) except AssertionError: pcc = jnp.corrcoef(device_output.flatten(), golden_output.flatten()) pcc = jnp.min(pcc) diff --git a/tests/infra/device_connector.py b/tests/infra/device_connector.py index 35750949..eb2c13d1 100644 --- a/tests/infra/device_connector.py +++ b/tests/infra/device_connector.py @@ -69,9 +69,9 @@ def is_initialized(self) -> bool: return False - def connect_tt_device(self) -> jax.Device: + def connect_tt_device(self, device_num: int = 0) -> jax.Device: """Returns TTDevice handle.""" - return self.connect_device(DeviceType.TT) + return self.connect_device(DeviceType.TT, device_num) def connect_cpu(self) -> jax.Device: """Returns CPUDevice handle.""" @@ -81,9 +81,23 @@ def connect_gpu(self) -> jax.Device: """Returns GPUDevice handle.""" return self.connect_device(DeviceType.GPU) - def connect_device(self, device_type: DeviceType) -> jax.Device: - """Returns handle for device identified by `device_type`.""" - return jax.devices(device_type.value)[0] + def connect_device( + self, device_type: DeviceType, device_num: int = 0 + ) -> jax.Device: + """ + Returns handle for device identified by `device_type`. + + If there are multiple available devices of `device_type`, `device_num` makes it + possible to choose between them. By default, returns first available device. + """ + assert device_num < self._number_of_devices(device_type) + assert device_num >= 0 + + return jax.devices(device_type.value)[device_num] + + def _number_of_devices(self, device_type: DeviceType) -> int: + """Returns the number of available devices of specified type.""" + return len(jax.devices(device_type.value)) def _supported_devices(self) -> Sequence[DeviceType]: """Returns list of supported device types.""" diff --git a/tests/infra/device_runner.py b/tests/infra/device_runner.py index 76349ec8..08551893 100644 --- a/tests/infra/device_runner.py +++ b/tests/infra/device_runner.py @@ -19,14 +19,14 @@ class DeviceRunner: """ @staticmethod - def run_on_tt_device(workload: Workload) -> Tensor: + def run_on_tt_device(workload: Workload, device_num: int = 0) -> Tensor: """Runs `workload` on TT device.""" - return DeviceRunner._run_on_device(DeviceType.TT, workload) + return DeviceRunner._run_on_device(workload, DeviceType.TT, device_num) @staticmethod def run_on_cpu(workload: Workload) -> Tensor: """Runs `workload` on CPU.""" - return DeviceRunner._run_on_device(DeviceType.CPU, workload) + return DeviceRunner._run_on_device(workload, DeviceType.CPU) @staticmethod def run_on_gpu(workload: Workload) -> Tensor: @@ -34,14 +34,14 @@ def run_on_gpu(workload: Workload) -> Tensor: raise NotImplementedError("Support for GPUs not implemented") @staticmethod - def put_on_tt_device(workload: Workload) -> Workload: + def put_on_tt_device(workload: Workload, device_num: int = 0) -> Workload: """Puts `workload` on TT device.""" - return DeviceRunner._put_on_device(DeviceType.TT, workload) + return DeviceRunner._put_on_device(workload, DeviceType.TT, device_num) @staticmethod def put_on_cpu(workload: Workload) -> Workload: """Puts `workload` on CPU.""" - return DeviceRunner._put_on_device(DeviceType.CPU, workload) + return DeviceRunner._put_on_device(workload, DeviceType.CPU) @staticmethod def put_on_gpu(workload: Workload) -> Workload: @@ -64,18 +64,22 @@ def put_tensors_on_gpu(*tensors: Tensor) -> Sequence[Tensor]: raise NotImplementedError("Support for GPUs not implemented") @staticmethod - def _run_on_device(device_type: DeviceType, workload: Workload) -> Tensor: + def _run_on_device( + workload: Workload, device_type: DeviceType, device_num: int = 0 + ) -> Tensor: """Runs `workload` on device identified by `device_type`.""" - device_workload = DeviceRunner._put_on_device(device_type, workload) - device = device_connector.connect_device(device_type) + device_workload = DeviceRunner._put_on_device(workload, device_type, device_num) + device = device_connector.connect_device(device_type, device_num) with jax.default_device(device): return device_workload.execute() @staticmethod - def _put_on_device(device_type: DeviceType, workload: Workload) -> Workload: + def _put_on_device( + workload: Workload, device_type: DeviceType, device_num: int = 0 + ) -> Workload: """Puts `workload` on device and returns it.""" - device = device_connector.connect_device(device_type) + device = device_connector.connect_device(device_type, device_num) return DeviceRunner._safely_put_workload_on_device(workload, device) @staticmethod diff --git a/tests/jax/graphs/test_MLP_regression.py b/tests/jax/graphs/test_MLP_regression.py new file mode 100644 index 00000000..68f4f1a5 --- /dev/null +++ b/tests/jax/graphs/test_MLP_regression.py @@ -0,0 +1,53 @@ +# SPDX-FileCopyrightText: (c) 2025 Tenstorrent AI ULC +# +# SPDX-License-Identifier: Apache-2.0 +import jax +import jax.numpy as jnp +import pytest +from infra import ComparisonConfig, run_graph_test_with_random_inputs + + +@pytest.fixture +def comparison_config() -> ComparisonConfig: + config = ComparisonConfig() + config.pcc.allclose.atol = 0.03 + config.pcc.allclose.rtol = 0.03 + return config + + +@pytest.mark.parametrize( + ["W1", "b1", "W2", "b2", "X", "y"], + [ + [(784, 64), (32, 64), (64, 10), (32, 10), (32, 784), (32, 10)] + ], # 32 samples, 784 features (28x28), 10 output classes +) +def test_nn_with_relu(W1, b1, W2, b2, X, y, comparison_config: ComparisonConfig): + def simple_nn(W1, b1, W2, b2, X, y): + def forward(W1, b1, W2, b2, X): + hidden = jnp.dot(X, W1) + b1 + hidden = jnp.maximum(0, hidden) + output = jnp.dot(hidden, W2) + b2 + return output + + def loss(W1, b1, W2, b2, X, y): + output = forward(W1, b1, W2, b2, X) + return jnp.mean((output - y) ** 2) + + @jax.jit + def update_params(W1, b1, W2, b2, X, y, lr=0.01): + grads = jax.grad(loss, argnums=(0, 1, 2, 3))(W1, b1, W2, b2, X, y) + W1 -= lr * grads[0] + b1 -= lr * grads[1] + W2 -= lr * grads[2] + b2 -= lr * grads[3] + return W1, b1, W2, b2, grads + + for i in range(50): + W1, b1, W2, b2, grads = update_params(W1, b1, W2, b2, X, y, lr=0.01) + + final_loss = loss(W1, b1, W2, b2, X, y) + return final_loss + + run_graph_test_with_random_inputs( + simple_nn, [W1, b1, W2, b2, X, y], comparison_config=comparison_config + ) diff --git a/tests/jax/graphs/test_activation_functions.py b/tests/jax/graphs/test_activation_functions.py index 7dacb282..f715f5bc 100644 --- a/tests/jax/graphs/test_activation_functions.py +++ b/tests/jax/graphs/test_activation_functions.py @@ -9,9 +9,6 @@ @pytest.mark.parametrize("x_shape", [(32, 32), (64, 64)]) -@pytest.mark.skip( - "ttnn::operations::binary::BinaryDeviceOperation: unsupported broadcast" -) def test_relu(x_shape: tuple): """Test ReLU activation function.""" diff --git a/tests/jax/graphs/test_simple_regression.py b/tests/jax/graphs/test_simple_regression.py index 4449bdc4..439379d6 100644 --- a/tests/jax/graphs/test_simple_regression.py +++ b/tests/jax/graphs/test_simple_regression.py @@ -10,7 +10,6 @@ @pytest.mark.parametrize( ["weights", "bias", "X", "y"], [[(1, 2), (1, 1), (2, 1), (1, 1)]] ) -@pytest.mark.skip("failed to legalize operation 'stablehlo.dot_general'") def test_simple_regression(weights, bias, X, y): def simple_regression(weights, bias, X, y): def loss(weights, bias, X, y): diff --git a/tests/jax/graphs/test_softmax.py b/tests/jax/graphs/test_softmax.py index 6853d363..5ec93aa8 100644 --- a/tests/jax/graphs/test_softmax.py +++ b/tests/jax/graphs/test_softmax.py @@ -16,10 +16,6 @@ [(64, 64), 1], ], ) -@pytest.mark.skip( - "tt-metal assert: Index is out of bounds for the rank. " - "Similar to https://github.com/tenstorrent/tt-xla/issues/12" -) def test_softmax(x_shape: tuple, axis: int): def softmax(x: jax.Array) -> jax.Array: return jax.nn.softmax(x, axis=axis) diff --git a/tests/jax/models/albert/v2/base/test_albert_base.py b/tests/jax/models/albert/v2/base/test_albert_base.py index 5c54e308..8d933cb3 100644 --- a/tests/jax/models/albert/v2/base/test_albert_base.py +++ b/tests/jax/models/albert/v2/base/test_albert_base.py @@ -2,13 +2,16 @@ # # SPDX-License-Identifier: Apache-2.0 +from typing import Callable + import pytest from infra import RunMode +from utils import record_model_test_properties, runtime_fail from ..tester import AlbertV2Tester - MODEL_PATH = "albert/albert-base-v2" +MODEL_NAME = "albert-v2-base" # ----- Fixtures ----- @@ -27,15 +30,28 @@ def training_tester() -> AlbertV2Tester: # ----- Tests ----- -@pytest.mark.skip(reason="failed to legalize operation 'stablehlo.dot_general'") +@pytest.mark.xfail( + reason=( + runtime_fail( + "Cannot get the device from a tensor with host storage " + "(https://github.com/tenstorrent/tt-xla/issues/171)" + ) + ) +) def test_flax_albert_v2_base_inference( inference_tester: AlbertV2Tester, + record_tt_xla_property: Callable, ): + record_model_test_properties(record_tt_xla_property, MODEL_NAME) + inference_tester.test() @pytest.mark.skip(reason="Support for training not implemented") def test_flax_albert_v2_base_training( training_tester: AlbertV2Tester, + record_tt_xla_property: Callable, ): + record_model_test_properties(record_tt_xla_property, MODEL_NAME) + training_tester.test() diff --git a/tests/jax/models/albert/v2/large/test_albert_large.py b/tests/jax/models/albert/v2/large/test_albert_large.py index ff59a8be..fb2f29b7 100644 --- a/tests/jax/models/albert/v2/large/test_albert_large.py +++ b/tests/jax/models/albert/v2/large/test_albert_large.py @@ -2,12 +2,16 @@ # # SPDX-License-Identifier: Apache-2.0 +from typing import Callable + import pytest from infra import RunMode +from utils import record_model_test_properties, runtime_fail from ..tester import AlbertV2Tester MODEL_PATH = "albert/albert-large-v2" +MODEL_NAME = "albert-v2-large" # ----- Fixtures ----- @@ -26,15 +30,28 @@ def training_tester() -> AlbertV2Tester: # ----- Tests ----- -@pytest.mark.skip(reason="failed to legalize operation 'stablehlo.dot_general'") +@pytest.mark.xfail( + reason=( + runtime_fail( + "Cannot get the device from a tensor with host storage " + "(https://github.com/tenstorrent/tt-xla/issues/171)" + ) + ) +) def test_flax_albert_v2_large_inference( inference_tester: AlbertV2Tester, + record_tt_xla_property: Callable, ): + record_model_test_properties(record_tt_xla_property, MODEL_NAME) + inference_tester.test() @pytest.mark.skip(reason="Support for training not implemented") def test_flax_albert_v2_large_training( training_tester: AlbertV2Tester, + record_tt_xla_property: Callable, ): + record_model_test_properties(record_tt_xla_property, MODEL_NAME) + training_tester.test() diff --git a/tests/jax/models/albert/v2/xlarge/test_albert_xlarge.py b/tests/jax/models/albert/v2/xlarge/test_albert_xlarge.py index 93346267..b9676b91 100644 --- a/tests/jax/models/albert/v2/xlarge/test_albert_xlarge.py +++ b/tests/jax/models/albert/v2/xlarge/test_albert_xlarge.py @@ -2,12 +2,16 @@ # # SPDX-License-Identifier: Apache-2.0 +from typing import Callable + import pytest from infra import RunMode +from utils import record_model_test_properties, runtime_fail from ..tester import AlbertV2Tester MODEL_PATH = "albert/albert-xlarge-v2" +MODEL_NAME = "albert-v2-xlarge" # ----- Fixtures ----- @@ -26,15 +30,28 @@ def training_tester() -> AlbertV2Tester: # ----- Tests ----- -@pytest.mark.skip(reason="failed to legalize operation 'stablehlo.dot_general'") +@pytest.mark.xfail( + reason=( + runtime_fail( + "Cannot get the device from a tensor with host storage " + "(https://github.com/tenstorrent/tt-xla/issues/171)" + ) + ) +) def test_flax_albert_v2_xlarge_inference( inference_tester: AlbertV2Tester, + record_tt_xla_property: Callable, ): + record_model_test_properties(record_tt_xla_property, MODEL_NAME) + inference_tester.test() @pytest.mark.skip(reason="Support for training not implemented") def test_flax_albert_v2_xlarge_training( training_tester: AlbertV2Tester, + record_tt_xla_property: Callable, ): + record_model_test_properties(record_tt_xla_property, MODEL_NAME) + training_tester.test() diff --git a/tests/jax/models/albert/v2/xxlarge/test_albert_xxlarge.py b/tests/jax/models/albert/v2/xxlarge/test_albert_xxlarge.py index cce0ef8f..0b642d86 100644 --- a/tests/jax/models/albert/v2/xxlarge/test_albert_xxlarge.py +++ b/tests/jax/models/albert/v2/xxlarge/test_albert_xxlarge.py @@ -2,12 +2,16 @@ # # SPDX-License-Identifier: Apache-2.0 +from typing import Callable + import pytest from infra import RunMode +from utils import record_model_test_properties, runtime_fail from ..tester import AlbertV2Tester MODEL_PATH = "albert/albert-xxlarge-v2" +MODEL_NAME = "albert-v2-xxlarge" # ----- Fixtures ----- @@ -26,15 +30,28 @@ def training_tester() -> AlbertV2Tester: # ----- Tests ----- -@pytest.mark.skip(reason="failed to legalize operation 'stablehlo.dot_general'") +@pytest.mark.xfail( + reason=( + runtime_fail( + "Cannot get the device from a tensor with host storage " + "(https://github.com/tenstorrent/tt-xla/issues/171)" + ) + ) +) def test_flax_albert_v2_xxlarge_inference( inference_tester: AlbertV2Tester, + record_tt_xla_property: Callable, ): + record_model_test_properties(record_tt_xla_property, MODEL_NAME) + inference_tester.test() @pytest.mark.skip(reason="Support for training not implemented") def test_flax_albert_v2_xxlarge_training( training_tester: AlbertV2Tester, + record_tt_xla_property: Callable, ): + record_model_test_properties(record_tt_xla_property, MODEL_NAME) + training_tester.test() diff --git a/tests/jax/models/bart/__init__.py b/tests/jax/models/bart/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/jax/models/bart/base/__init__.py b/tests/jax/models/bart/base/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/jax/models/bart/base/test_bart_base.py b/tests/jax/models/bart/base/test_bart_base.py new file mode 100644 index 00000000..b530195d --- /dev/null +++ b/tests/jax/models/bart/base/test_bart_base.py @@ -0,0 +1,57 @@ +# SPDX-FileCopyrightText: (c) 2024 Tenstorrent AI ULC +# +# SPDX-License-Identifier: Apache-2.0 + +from typing import Callable + +import pytest +from infra import RunMode +from utils import record_model_test_properties, runtime_fail + +from ..tester import FlaxBartForCausalLMTester + +MODEL_PATH = "facebook/bart-base" +MODEL_NAME = "bart-base" + + +# ----- Fixtures ----- + + +@pytest.fixture +def inference_tester() -> FlaxBartForCausalLMTester: + return FlaxBartForCausalLMTester(MODEL_PATH) + + +@pytest.fixture +def training_tester() -> FlaxBartForCausalLMTester: + return FlaxBartForCausalLMTester(MODEL_PATH, RunMode.TRAINING) + + +# ----- Tests ----- + + +@pytest.mark.xfail( + reason=( + runtime_fail( + "Cannot get the device from a tensor with host storage " + "(https://github.com/tenstorrent/tt-xla/issues/171)" + ) + ) +) +def test_flax_bart_base_inference( + inference_tester: FlaxBartForCausalLMTester, + record_tt_xla_property: Callable, +): + record_model_test_properties(record_tt_xla_property, MODEL_NAME) + + inference_tester.test() + + +@pytest.mark.skip(reason="Support for training not implemented") +def test_flax_bart_base_training( + training_tester: FlaxBartForCausalLMTester, + record_tt_xla_property: Callable, +): + record_model_test_properties(record_tt_xla_property, MODEL_NAME) + + training_tester.test() diff --git a/tests/jax/models/bart/large/__init__.py b/tests/jax/models/bart/large/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/jax/models/bart/large/test_bart_large.py b/tests/jax/models/bart/large/test_bart_large.py new file mode 100644 index 00000000..88189d6b --- /dev/null +++ b/tests/jax/models/bart/large/test_bart_large.py @@ -0,0 +1,54 @@ +# SPDX-FileCopyrightText: (c) 2024 Tenstorrent AI ULC +# +# SPDX-License-Identifier: Apache-2.0 + +from typing import Callable + +import pytest +from infra import RunMode +from utils import compile_fail, record_model_test_properties + +from ..tester import FlaxBartForCausalLMTester + +MODEL_PATH = "facebook/bart-large" +MODEL_NAME = "bart-large" + + +# ----- Fixtures ----- + + +@pytest.fixture +def inference_tester() -> FlaxBartForCausalLMTester: + return FlaxBartForCausalLMTester(MODEL_PATH) + + +@pytest.fixture +def training_tester() -> FlaxBartForCausalLMTester: + return FlaxBartForCausalLMTester(MODEL_PATH, RunMode.TRAINING) + + +# ----- Tests ----- + + +@pytest.mark.xfail( + reason=compile_fail( + "Unsupported data type (https://github.com/tenstorrent/tt-xla/issues/214)" + ) +) +def test_flax_bart_large_inference( + inference_tester: FlaxBartForCausalLMTester, + record_tt_xla_property: Callable, +): + record_model_test_properties(record_tt_xla_property, MODEL_NAME) + + inference_tester.test() + + +@pytest.mark.skip(reason="Support for training not implemented") +def test_flax_bart_large_training( + training_tester: FlaxBartForCausalLMTester, + record_tt_xla_property: Callable, +): + record_model_test_properties(record_tt_xla_property, MODEL_NAME) + + training_tester.test() diff --git a/tests/jax/models/bart/tester.py b/tests/jax/models/bart/tester.py new file mode 100644 index 00000000..d6681adf --- /dev/null +++ b/tests/jax/models/bart/tester.py @@ -0,0 +1,43 @@ +# SPDX-FileCopyrightText: (c) 2024 Tenstorrent AI ULC +# +# SPDX-License-Identifier: Apache-2.0 + +from typing import Dict, Sequence + +import jax +from flax import linen as nn +from infra import ComparisonConfig, ModelTester, RunMode +from transformers import AutoTokenizer, FlaxBartForCausalLM + + +class FlaxBartForCausalLMTester(ModelTester): + """Tester for BART model variants with a language modeling head on top.""" + + # TODO(mrakita): Add tests for other variants. + + def __init__( + self, + model_name: str, + comparison_config: ComparisonConfig = ComparisonConfig(), + run_mode: RunMode = RunMode.INFERENCE, + ) -> None: + self._model_name = model_name + super().__init__(comparison_config, run_mode) + + # @override + def _get_model(self) -> nn.Module: + return FlaxBartForCausalLM.from_pretrained(self._model_name) + + # @override + def _get_input_activations(self) -> Sequence[jax.Array]: + tokenizer = AutoTokenizer.from_pretrained(self._model_name) + inputs = tokenizer("Hello", return_tensors="np") + return inputs["input_ids"] + + # @override + def _get_forward_method_kwargs(self) -> Dict[str, jax.Array]: + assert hasattr(self._model, "params") + return { + "params": self._model.params, + "input_ids": self._get_input_activations(), + } diff --git a/tests/jax/models/bert/__init__.py b/tests/jax/models/bert/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/jax/models/bert/base/__init__.py b/tests/jax/models/bert/base/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/jax/models/bert/base/test_bert_base.py b/tests/jax/models/bert/base/test_bert_base.py new file mode 100644 index 00000000..bca12b24 --- /dev/null +++ b/tests/jax/models/bert/base/test_bert_base.py @@ -0,0 +1,56 @@ +# SPDX-FileCopyrightText: (c) 2025 Tenstorrent AI ULC +# +# SPDX-License-Identifier: Apache-2.0 + +from typing import Callable + +import pytest +from infra import RunMode +from utils import record_model_test_properties, runtime_fail + +from ..tester import FlaxBertForMaskedLMTester + +MODEL_PATH = "google-bert/bert-base-uncased" +MODEL_NAME = "bert-base" + +# ----- Fixtures ----- + + +@pytest.fixture +def inference_tester() -> FlaxBertForMaskedLMTester: + return FlaxBertForMaskedLMTester(MODEL_PATH) + + +@pytest.fixture +def training_tester() -> FlaxBertForMaskedLMTester: + return FlaxBertForMaskedLMTester(MODEL_PATH, RunMode.TRAINING) + + +# ----- Tests ----- + + +@pytest.mark.xfail( + reason=( + runtime_fail( + "Cannot get the device from a tensor with host storage " + "(https://github.com/tenstorrent/tt-xla/issues/171)" + ) + ) +) +def test_flax_bert_base_inference( + inference_tester: FlaxBertForMaskedLMTester, + record_tt_xla_property: Callable, +): + record_model_test_properties(record_tt_xla_property, MODEL_NAME) + + inference_tester.test() + + +@pytest.mark.skip(reason="Support for training not implemented") +def test_flax_bert_base_training( + training_tester: FlaxBertForMaskedLMTester, + record_tt_xla_property: Callable, +): + record_model_test_properties(record_tt_xla_property, MODEL_NAME) + + training_tester.test() diff --git a/tests/jax/models/bert/large/__init__.py b/tests/jax/models/bert/large/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/jax/models/bert/large/test_bert_large.py b/tests/jax/models/bert/large/test_bert_large.py new file mode 100644 index 00000000..5520777c --- /dev/null +++ b/tests/jax/models/bert/large/test_bert_large.py @@ -0,0 +1,54 @@ +# SPDX-FileCopyrightText: (c) 2024 Tenstorrent AI ULC +# +# SPDX-License-Identifier: Apache-2.0 + +from typing import Callable + +import pytest +from infra import RunMode +from utils import record_model_test_properties, runtime_fail + +from ..tester import FlaxBertForMaskedLMTester + +MODEL_PATH = "google-bert/bert-large-uncased" +MODEL_NAME = "bert-large" + +# ----- Fixtures ----- + + +@pytest.fixture +def inference_tester() -> FlaxBertForMaskedLMTester: + return FlaxBertForMaskedLMTester(MODEL_PATH) + + +@pytest.fixture +def training_tester() -> FlaxBertForMaskedLMTester: + return FlaxBertForMaskedLMTester(MODEL_PATH, RunMode.TRAINING) + + +# ----- Tests ----- + + +@pytest.mark.xfail( + reason=runtime_fail( + "Cannot get the device from a tensor with host storage " + "(https://github.com/tenstorrent/tt-xla/issues/171)" + ) +) +def test_flax_bert_large_inference( + inference_tester: FlaxBertForMaskedLMTester, + record_tt_xla_property: Callable, +): + record_model_test_properties(record_tt_xla_property, MODEL_NAME) + + inference_tester.test() + + +@pytest.mark.skip(reason="Support for training not implemented") +def test_flax_bert_large_training( + training_tester: FlaxBertForMaskedLMTester, + record_tt_xla_property: Callable, +): + record_model_test_properties(record_tt_xla_property, MODEL_NAME) + + training_tester.test() diff --git a/tests/jax/models/bert/tester.py b/tests/jax/models/bert/tester.py new file mode 100644 index 00000000..e96f83c5 --- /dev/null +++ b/tests/jax/models/bert/tester.py @@ -0,0 +1,41 @@ +# SPDX-FileCopyrightText: (c) 2025 Tenstorrent AI ULC +# +# SPDX-License-Identifier: Apache-2.0 + +from typing import Dict, Sequence + +import jax +from flax import linen as nn +from infra import ComparisonConfig, ModelTester, RunMode +from transformers import AutoTokenizer, FlaxBertForMaskedLM + + +class FlaxBertForMaskedLMTester(ModelTester): + """Tester for BERT model variants on masked language modeling task.""" + + def __init__( + self, + model_name: str, + comparison_config: ComparisonConfig = ComparisonConfig(), + run_mode: RunMode = RunMode.INFERENCE, + ) -> None: + self._model_name = model_name + super().__init__(comparison_config, run_mode) + + # @override + def _get_model(self) -> nn.Module: + return FlaxBertForMaskedLM.from_pretrained(self._model_name) + + # @override + def _get_input_activations(self) -> Sequence[jax.Array]: + tokenizer = AutoTokenizer.from_pretrained(self._model_name) + inputs = tokenizer("Hello [MASK]", return_tensors="np") + return inputs["input_ids"] + + # @override + def _get_forward_method_kwargs(self) -> Dict[str, jax.Array]: + assert hasattr(self._model, "params") + return { + "params": self._model.params, + "input_ids": self._get_input_activations(), + } diff --git a/tests/jax/models/distilbert/test_distilbert.py b/tests/jax/models/distilbert/test_distilbert.py index 06d3b785..ff86a6dc 100644 --- a/tests/jax/models/distilbert/test_distilbert.py +++ b/tests/jax/models/distilbert/test_distilbert.py @@ -2,15 +2,17 @@ # # SPDX-License-Identifier: Apache-2.0 -from typing import Dict, Sequence +from typing import Callable, Dict, Sequence import jax import pytest from flax import linen as nn from infra import ModelTester, RunMode from transformers import AutoTokenizer, FlaxDistilBertForMaskedLM +from utils import record_model_test_properties, runtime_fail MODEL_PATH = "distilbert/distilbert-base-uncased" +MODEL_NAME = "distilbert" # ----- Tester ----- @@ -53,15 +55,26 @@ def training_tester() -> FlaxDistilBertForMaskedLMTester: # ----- Tests ----- -@pytest.mark.skip(reason="failed to legalize operation 'stablehlo.dot_general'") +@pytest.mark.xfail( + reason=runtime_fail( + "Cannot get the device from a tensor with host storage " + "(https://github.com/tenstorrent/tt-xla/issues/171)" + ) +) def test_flax_distilbert_inference( inference_tester: FlaxDistilBertForMaskedLMTester, + record_tt_xla_property: Callable, ): + record_model_test_properties(record_tt_xla_property, MODEL_NAME) + inference_tester.test() @pytest.mark.skip(reason="Support for training not implemented") def test_flax_distilbert_training( training_tester: FlaxDistilBertForMaskedLMTester, + record_tt_xla_property: Callable, ): + record_model_test_properties(record_tt_xla_property, MODEL_NAME) + training_tester.test() diff --git a/tests/jax/models/example_model/mixed_args_and_kwargs/test_example_model_mixed_args_and_kwargs.py b/tests/jax/models/example_model/mixed_args_and_kwargs/test_example_model_mixed_args_and_kwargs.py index d08aedc1..adb7adb2 100644 --- a/tests/jax/models/example_model/mixed_args_and_kwargs/test_example_model_mixed_args_and_kwargs.py +++ b/tests/jax/models/example_model/mixed_args_and_kwargs/test_example_model_mixed_args_and_kwargs.py @@ -7,7 +7,7 @@ import jax import pytest from flax import nnx -from infra import ComparisonConfig, ModelTester, RunMode +from infra import ModelTester, RunMode from ..model import ExampleModel diff --git a/tests/jax/models/example_model/only_args/test_example_model_only_args.py b/tests/jax/models/example_model/only_args/test_example_model_only_args.py index 15ff0358..47ef50e5 100644 --- a/tests/jax/models/example_model/only_args/test_example_model_only_args.py +++ b/tests/jax/models/example_model/only_args/test_example_model_only_args.py @@ -7,7 +7,7 @@ import jax import pytest from flax import nnx -from infra import ComparisonConfig, ModelTester, RunMode +from infra import ModelTester, RunMode from ..model import ExampleModel diff --git a/tests/jax/models/example_model/only_kwargs/test_example_model_only_kwargs.py b/tests/jax/models/example_model/only_kwargs/test_example_model_only_kwargs.py index c4922e99..e74ad2c5 100644 --- a/tests/jax/models/example_model/only_kwargs/test_example_model_only_kwargs.py +++ b/tests/jax/models/example_model/only_kwargs/test_example_model_only_kwargs.py @@ -7,7 +7,7 @@ import jax import pytest from flax import nnx -from infra import ComparisonConfig, ModelTester, RunMode +from infra import ModelTester, RunMode from ..model import ExampleModel diff --git a/tests/jax/models/gpt2/__init__.py b/tests/jax/models/gpt2/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/jax/models/gpt2/base/__init__.py b/tests/jax/models/gpt2/base/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/jax/models/gpt2/base/test_gpt2_base.py b/tests/jax/models/gpt2/base/test_gpt2_base.py new file mode 100644 index 00000000..beba9d21 --- /dev/null +++ b/tests/jax/models/gpt2/base/test_gpt2_base.py @@ -0,0 +1,54 @@ +# SPDX-FileCopyrightText: (c) 2025 Tenstorrent AI ULC +# +# SPDX-License-Identifier: Apache-2.0 + +from typing import Callable + +import pytest +from infra import ModelTester, RunMode +from utils import record_model_test_properties, runtime_fail + +from ..tester import GPT2Tester + +MODEL_PATH = "openai-community/gpt2" +MODEL_NAME = "gpt2-base" + +# ----- Fixtures ----- + + +@pytest.fixture +def inference_tester() -> GPT2Tester: + return GPT2Tester(MODEL_PATH) + + +@pytest.fixture +def training_tester() -> GPT2Tester: + return GPT2Tester(ModelTester, run_mode=RunMode.TRAINING) + + +# ----- Tests ----- + + +@pytest.mark.xfail( + reason=runtime_fail( + "Cannot get the device from a tensor with host storage " + "(https://github.com/tenstorrent/tt-xla/issues/171)" + ) +) +def test_gpt2_base_inference( + inference_tester: GPT2Tester, + record_tt_xla_property: Callable, +): + record_model_test_properties(record_tt_xla_property, MODEL_NAME) + + inference_tester.test() + + +@pytest.mark.skip(reason="Support for training not implemented") +def test_gpt2_base_training( + training_tester: GPT2Tester, + record_tt_xla_property: Callable, +): + record_model_test_properties(record_tt_xla_property, MODEL_NAME) + + training_tester.test() diff --git a/tests/jax/models/gpt2/large/__init__.py b/tests/jax/models/gpt2/large/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/jax/models/gpt2/large/test_gpt2_large.py b/tests/jax/models/gpt2/large/test_gpt2_large.py new file mode 100644 index 00000000..d9b92b90 --- /dev/null +++ b/tests/jax/models/gpt2/large/test_gpt2_large.py @@ -0,0 +1,55 @@ +# SPDX-FileCopyrightText: (c) 2025 Tenstorrent AI ULC +# +# SPDX-License-Identifier: Apache-2.0 + +from typing import Callable + +import pytest +from infra import ModelTester, RunMode +from utils import record_model_test_properties, runtime_fail + +from ..tester import GPT2Tester + +MODEL_PATH = "openai-community/gpt2-large" +MODEL_NAME = "gpt2-large" + + +# ----- Fixtures ----- + + +@pytest.fixture +def inference_tester() -> GPT2Tester: + return GPT2Tester(MODEL_PATH) + + +@pytest.fixture +def training_tester() -> GPT2Tester: + return GPT2Tester(ModelTester, run_mode=RunMode.TRAINING) + + +# ----- Tests ----- + + +@pytest.mark.xfail( + reason=runtime_fail( + "Cannot get the device from a tensor with host storage " + "(https://github.com/tenstorrent/tt-xla/issues/171)" + ) +) +def test_gpt2_large_inference( + inference_tester: GPT2Tester, + record_tt_xla_property: Callable, +): + record_model_test_properties(record_tt_xla_property, MODEL_NAME) + + inference_tester.test() + + +@pytest.mark.skip(reason="Support for training not implemented") +def test_gpt2_large_training( + training_tester: GPT2Tester, + record_tt_xla_property: Callable, +): + record_model_test_properties(record_tt_xla_property, MODEL_NAME) + + training_tester.test() diff --git a/tests/jax/models/gpt2/medium/__init__.py b/tests/jax/models/gpt2/medium/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/jax/models/gpt2/medium/test_gpt2_medium.py b/tests/jax/models/gpt2/medium/test_gpt2_medium.py new file mode 100644 index 00000000..ba709611 --- /dev/null +++ b/tests/jax/models/gpt2/medium/test_gpt2_medium.py @@ -0,0 +1,55 @@ +# SPDX-FileCopyrightText: (c) 2025 Tenstorrent AI ULC +# +# SPDX-License-Identifier: Apache-2.0 + +from typing import Callable + +import pytest +from infra import ModelTester, RunMode +from utils import record_model_test_properties, runtime_fail + +from ..tester import GPT2Tester + +MODEL_PATH = "openai-community/gpt2-medium" +MODEL_NAME = "gpt2-medium" + + +# ----- Fixtures ----- + + +@pytest.fixture +def inference_tester() -> GPT2Tester: + return GPT2Tester(MODEL_PATH) + + +@pytest.fixture +def training_tester() -> GPT2Tester: + return GPT2Tester(ModelTester, run_mode=RunMode.TRAINING) + + +# ----- Tests ----- + + +@pytest.mark.xfail( + reason=runtime_fail( + "Cannot get the device from a tensor with host storage " + "(https://github.com/tenstorrent/tt-xla/issues/171)" + ) +) +def test_gpt2_medium_inference( + inference_tester: GPT2Tester, + record_tt_xla_property: Callable, +): + record_model_test_properties(record_tt_xla_property, MODEL_NAME) + + inference_tester.test() + + +@pytest.mark.skip(reason="Support for training not implemented") +def test_gpt2_medium_training( + training_tester: GPT2Tester, + record_tt_xla_property: Callable, +): + record_model_test_properties(record_tt_xla_property, MODEL_NAME) + + training_tester.test() diff --git a/tests/jax/models/gpt2/test_gpt2.py b/tests/jax/models/gpt2/test_gpt2.py deleted file mode 100644 index efb9a344..00000000 --- a/tests/jax/models/gpt2/test_gpt2.py +++ /dev/null @@ -1,65 +0,0 @@ -# SPDX-FileCopyrightText: (c) 2024 Tenstorrent AI ULC -# -# SPDX-License-Identifier: Apache-2.0 - -from typing import Sequence, Dict - -import jax -import pytest -from flax import linen as nn -from infra import ModelTester, RunMode -from transformers import AutoTokenizer, FlaxGPT2LMHeadModel - -MODEL_PATH = "openai-community/gpt2" - - -class GPT2Tester(ModelTester): - """Tester for GPT2 for autoregressive text generation.""" - - # @override - def _get_model(self) -> nn.Module: - return FlaxGPT2LMHeadModel.from_pretrained(MODEL_PATH) - - # @override - def _get_input_activations(self) -> Sequence[jax.Array]: - tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH) - inputs = tokenizer("Hello", return_tensors="np") - return inputs["input_ids"] - - # @override - def _get_forward_method_kwargs(self) -> Dict[str, jax.Array]: - assert hasattr(self._model, "params") - return { - "params": self._model.params, - "input_ids": self._get_input_activations(), - } - - -# ----- Fixtures ----- - - -@pytest.fixture -def inference_tester() -> GPT2Tester: - return GPT2Tester() - - -@pytest.fixture -def training_tester() -> GPT2Tester: - return GPT2Tester(RunMode.TRAINING) - - -# ----- Tests ----- - - -@pytest.mark.skip(reason="failed to legalize operation 'stablehlo.dot_general'") -def test_gp2_inference( - inference_tester: GPT2Tester, -): - inference_tester.test() - - -@pytest.mark.skip(reason="Support for training not implemented") -def test_gpt2_training( - training_tester: GPT2Tester, -): - training_tester.test() diff --git a/tests/jax/models/gpt2/tester.py b/tests/jax/models/gpt2/tester.py new file mode 100644 index 00000000..ae28caf1 --- /dev/null +++ b/tests/jax/models/gpt2/tester.py @@ -0,0 +1,41 @@ +# SPDX-FileCopyrightText: (c) 2025 Tenstorrent AI ULC +# +# SPDX-License-Identifier: Apache-2.0 + +from typing import Dict, Sequence + +import jax +from flax import linen as nn +from infra import ComparisonConfig, ModelTester, RunMode +from transformers import AutoTokenizer, FlaxGPT2LMHeadModel + + +class GPT2Tester(ModelTester): + """Tester for GPT2 for autoregressive text generation.""" + + def __init__( + self, + model_name: str, + comparison_config: ComparisonConfig = ComparisonConfig(), + run_mode: RunMode = RunMode.INFERENCE, + ) -> None: + self._model_name = model_name + super().__init__(comparison_config, run_mode) + + # @override + def _get_model(self) -> nn.Module: + return FlaxGPT2LMHeadModel.from_pretrained(self._model_name) + + # @override + def _get_input_activations(self) -> Sequence[jax.Array]: + tokenizer = AutoTokenizer.from_pretrained(self._model_name) + inputs = tokenizer("Hello", return_tensors="np") + return inputs["input_ids"] + + # @override + def _get_forward_method_kwargs(self) -> Dict[str, jax.Array]: + assert hasattr(self._model, "params") + return { + "params": self._model.params, + "input_ids": self._get_input_activations(), + } diff --git a/tests/jax/models/gpt2/xl/__init__.py b/tests/jax/models/gpt2/xl/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/jax/models/gpt2/xl/test_gpt2_xl.py b/tests/jax/models/gpt2/xl/test_gpt2_xl.py new file mode 100644 index 00000000..758998ac --- /dev/null +++ b/tests/jax/models/gpt2/xl/test_gpt2_xl.py @@ -0,0 +1,52 @@ +# SPDX-FileCopyrightText: (c) 2025 Tenstorrent AI ULC +# +# SPDX-License-Identifier: Apache-2.0 + +from typing import Callable + +import pytest +from infra import ModelTester, RunMode +from utils import record_model_test_properties + +from ..tester import GPT2Tester + +MODEL_PATH = "openai-community/gpt2-xl" +MODEL_NAME = "gpt2-xl" + + +# ----- Fixtures ----- + + +@pytest.fixture +def inference_tester() -> GPT2Tester: + return GPT2Tester(MODEL_PATH) + + +@pytest.fixture +def training_tester() -> GPT2Tester: + return GPT2Tester(ModelTester, run_mode=RunMode.TRAINING) + + +# ----- Tests ----- + + +@pytest.mark.skip( + reason="OOMs in CI (https://github.com/tenstorrent/tt-xla/issues/186)" +) +def test_gpt2_xl_inference( + inference_tester: GPT2Tester, + record_tt_xla_property: Callable, +): + record_model_test_properties(record_tt_xla_property, MODEL_NAME) + + inference_tester.test() + + +@pytest.mark.skip(reason="Support for training not implemented") +def test_gpt2_xl_training( + training_tester: GPT2Tester, + record_tt_xla_property: Callable, +): + record_model_test_properties(record_tt_xla_property, MODEL_NAME) + + training_tester.test() diff --git a/tests/jax/models/llama/openllama_3b_v2/test_openllama_3b_v2.py b/tests/jax/models/llama/openllama_3b_v2/test_openllama_3b_v2.py index 18688ed0..605f8342 100644 --- a/tests/jax/models/llama/openllama_3b_v2/test_openllama_3b_v2.py +++ b/tests/jax/models/llama/openllama_3b_v2/test_openllama_3b_v2.py @@ -2,13 +2,16 @@ # # SPDX-License-Identifier: Apache-2.0 +from typing import Callable + import pytest -from flax import linen as nn from infra import RunMode +from utils import record_model_test_properties from ..tester import LLamaTester MODEL_PATH = "openlm-research/open_llama_3b_v2" +MODEL_NAME = "open-llama-3b-v2" # ----- Fixtures ----- @@ -27,15 +30,23 @@ def training_tester() -> LLamaTester: # ----- Tests ----- -@pytest.mark.skip(reason="failed to legalize operation 'stablehlo.dot_general'") +@pytest.mark.skip( + reason="OOMs in CI (https://github.com/tenstorrent/tt-xla/issues/186)" +) def test_openllama3b_inference( inference_tester: LLamaTester, + record_tt_xla_property: Callable, ): + record_model_test_properties(record_tt_xla_property, MODEL_NAME) + inference_tester.test() @pytest.mark.skip(reason="Support for training not implemented") def test_openllama3b_training( training_tester: LLamaTester, + record_tt_xla_property: Callable, ): + record_model_test_properties(record_tt_xla_property, MODEL_NAME) + training_tester.test() diff --git a/tests/jax/models/llama/tester.py b/tests/jax/models/llama/tester.py index e915e0db..82b43330 100644 --- a/tests/jax/models/llama/tester.py +++ b/tests/jax/models/llama/tester.py @@ -21,8 +21,8 @@ def __init__( comparison_config: ComparisonConfig = ComparisonConfig(), run_mode: RunMode = RunMode.INFERENCE, ) -> None: - super().__init__(comparison_config, run_mode) self._model_name = model_name + super().__init__(comparison_config, run_mode) # @override def _get_model(self) -> nn.Module: diff --git a/tests/jax/models/mlpmixer/__init__.py b/tests/jax/models/mlpmixer/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/jax/models/mlpmixer/model_implementation.py b/tests/jax/models/mlpmixer/model_implementation.py new file mode 100644 index 00000000..03679e4b --- /dev/null +++ b/tests/jax/models/mlpmixer/model_implementation.py @@ -0,0 +1,82 @@ +# SPDX-FileCopyrightText: (c) 2025 Tenstorrent AI ULC +# +# SPDX-License-Identifier: Apache-2.0 + +# This file incorporates work covered by the following copyright and permission +# notice: +# SPDX-FileCopyrightText: Copyright 2024 Google LLC. +# SPDX-License-Identifier: Apache-2.0 + +# This code is based on google-research/vision_transformer + +from typing import Any, Optional + +import einops +import flax.linen as nn +import jax.numpy as jnp +import jax + + +class MlpBlock(nn.Module): + mlp_dim: int + + @nn.compact + def __call__(self, x: jax.Array) -> jax.Array: + y = nn.Dense(self.mlp_dim)(x) + y = nn.gelu(y) + return nn.Dense(x.shape[-1])(y) + + +class MixerBlock(nn.Module): + """Mixer block layer.""" + + tokens_mlp_dim: int + channels_mlp_dim: int + + @nn.compact + def __call__(self, x: jax.Array) -> jax.Array: + y = nn.LayerNorm()(x) + y = jnp.swapaxes(y, 1, 2) + y = MlpBlock(self.tokens_mlp_dim, name="token_mixing")(y) + y = jnp.swapaxes(y, 1, 2) + x = x + y + + y = nn.LayerNorm()(x) + y = MlpBlock(self.channels_mlp_dim, name="channel_mixing")(y) + y = x + y + + return y + + +class MlpMixer(nn.Module): + """Mixer architecture.""" + + patches: Any + num_classes: int + num_blocks: int + hidden_dim: int + tokens_mlp_dim: int + channels_mlp_dim: int + model_name: Optional[str] = None + + @nn.compact + def __call__(self, inputs: jax.Array) -> jax.Array: + x = nn.Conv( + self.hidden_dim, self.patches.size, strides=self.patches.size, name="stem" + )( + inputs + ) # Patch embedding + x = einops.rearrange(x, "n h w c -> n (h w) c") + + for _ in range(self.num_blocks): + x = MixerBlock(self.tokens_mlp_dim, self.channels_mlp_dim)(x) + + x = nn.LayerNorm(name="pre_head_layer_norm")(x) + x = jnp.mean(x, axis=1) + + if self.num_classes: + x = nn.Dense( + self.num_classes, kernel_init=nn.initializers.zeros, name="head" + )(x) + + return x diff --git a/tests/jax/models/mlpmixer/test_mlpmixer.py b/tests/jax/models/mlpmixer/test_mlpmixer.py new file mode 100644 index 00000000..94ac0872 --- /dev/null +++ b/tests/jax/models/mlpmixer/test_mlpmixer.py @@ -0,0 +1,115 @@ +# SPDX-FileCopyrightText: (c) 2025 Tenstorrent AI ULC +# +# SPDX-License-Identifier: Apache-2.0 + +from typing import Any, Callable, Dict, Sequence + +import flax.traverse_util +import fsspec +import jax +import ml_collections +import numpy +import pytest +from flax import linen as nn +from infra import ModelTester, RunMode +from utils import record_model_test_properties, runtime_fail + +from .model_implementation import MlpMixer + +# Hyperparameters for Mixer-B/16 +patch_size = 16 +num_classes = 21843 +num_blocks = 12 +hidden_dim = 768 +token_mlp_dim = 384 +channel_mlp_dim = 3072 + + +class MlpMixerTester(ModelTester): + """Tester for MlpMixer model.""" + + # @override + def _get_model(self) -> nn.Module: + patch = ml_collections.ConfigDict({"size": (patch_size, patch_size)}) + return MlpMixer( + patches=patch, + num_classes=num_classes, + num_blocks=num_blocks, + hidden_dim=hidden_dim, + tokens_mlp_dim=token_mlp_dim, + channels_mlp_dim=channel_mlp_dim, + ) + + @staticmethod + def _retrieve_pretrained_weights() -> Dict: + # TODO(stefan): Discuss how weights should be handled org wide + link = "https://storage.googleapis.com/mixer_models/imagenet21k/Mixer-B_16.npz" + with fsspec.open("filecache::" + link, cache_storage="/tmp/files/") as f: + weights = numpy.load(f, encoding="bytes") + state_dict = {k: v for k, v in weights.items()} + pytree = flax.traverse_util.unflatten_dict(state_dict, sep="/") + return {"params": pytree} + + # @override + def _get_forward_method_name(self) -> str: + return "apply" + + # @override + def _get_input_activations(self) -> jax.Array: + key = jax.random.PRNGKey(42) + random_image = jax.random.normal(key, (1, 224, 224, 3)) + return random_image + + # @override + def _get_forward_method_args(self) -> Sequence[Any]: + ins = self._get_input_activations() + weights = self._retrieve_pretrained_weights() + + # Alternatively, weights could be randomly initialized like this: + # weights = self._model.init(jax.random.PRNGKey(42), ins) + + # JAX frameworks have a convention of passing weights as the first argument + return [weights, ins] + + +# ----- Fixtures ----- + + +@pytest.fixture +def inference_tester() -> MlpMixerTester: + return MlpMixerTester() + + +@pytest.fixture +def training_tester() -> MlpMixerTester: + return MlpMixerTester(RunMode.TRAINING) + + +# ----- Tests ----- + + +@pytest.mark.skip( + reason=runtime_fail( + "Statically allocated circular buffers in program 16 clash with L1 buffers " + "on core range [(x=0,y=0) - (x=6,y=0)]. L1 buffer allocated at 475136 and " + "static circular buffer region ends at 951136 " + "(https://github.com/tenstorrent/tt-xla/issues/187)" + ) +) # segfault +def test_mlpmixer_inference( + inference_tester: MlpMixerTester, + record_tt_xla_property: Callable, +): + record_model_test_properties(record_tt_xla_property, "mlpmixer") + + inference_tester.test() + + +@pytest.mark.skip(reason="Support for training not implemented") +def test_mlpmixer_training( + training_tester: MlpMixerTester, + record_tt_xla_property: Callable, +): + record_model_test_properties(record_tt_xla_property, "mlpmixer") + + training_tester.test() diff --git a/tests/jax/models/mnist/cnn/__init__.py b/tests/jax/models/mnist/cnn/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/jax/models/mnist/cnn/dropout/__init__.py b/tests/jax/models/mnist/cnn/dropout/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/jax/models/mnist/cnn/dropout/model_implementation.py b/tests/jax/models/mnist/cnn/dropout/model_implementation.py new file mode 100644 index 00000000..9528c170 --- /dev/null +++ b/tests/jax/models/mnist/cnn/dropout/model_implementation.py @@ -0,0 +1,29 @@ +# SPDX-FileCopyrightText: (c) 2024 Tenstorrent AI ULC +# +# SPDX-License-Identifier: Apache-2.0 + +from flax import linen as nn + + +class MNISTCNNDropoutModel(nn.Module): + @nn.compact + def __call__(self, x, *, train: bool): + x = nn.Conv(features=32, kernel_size=(3, 3))(x) + x = nn.relu(x) + + x = nn.Conv(features=64, kernel_size=(3, 3))(x) + x = nn.relu(x) + x = nn.max_pool(x, window_shape=(2, 2), strides=(2, 2)) + + x = nn.Dropout(rate=0.25)(x, deterministic=not train) + + x = x.reshape((x.shape[0], -1)) + + x = nn.Dense(features=128)(x) + x = nn.relu(x) + x = nn.Dropout(rate=0.5)(x, deterministic=not train) + + x = nn.Dense(features=10)(x) + x = nn.softmax(x) + + return x diff --git a/tests/jax/models/mnist/cnn/dropout/test_mnist_cnn_dropout.py b/tests/jax/models/mnist/cnn/dropout/test_mnist_cnn_dropout.py new file mode 100644 index 00000000..daae4d18 --- /dev/null +++ b/tests/jax/models/mnist/cnn/dropout/test_mnist_cnn_dropout.py @@ -0,0 +1,46 @@ +# SPDX-FileCopyrightText: (c) 2024 Tenstorrent AI ULC +# +# SPDX-License-Identifier: Apache-2.0 + +from typing import Callable + +import pytest +from infra import RunMode +from utils import record_model_test_properties + +from ..tester import MNISTCNNTester +from .model_implementation import MNISTCNNDropoutModel + +# ----- Fixtures ----- + + +@pytest.fixture +def inference_tester() -> MNISTCNNTester: + return MNISTCNNTester(MNISTCNNDropoutModel) + + +@pytest.fixture +def training_tester() -> MNISTCNNTester: + return MNISTCNNTester(MNISTCNNDropoutModel, RunMode.TRAINING) + + +# ----- Tests ----- + + +def test_mnist_cnn_dropout_inference( + inference_tester: MNISTCNNTester, + record_tt_xla_property: Callable, +): + record_model_test_properties(record_tt_xla_property, "mnist-cnn-dropout") + + inference_tester.test() + + +@pytest.mark.skip(reason="Support for training not implemented") +def test_mnist_cnn_nodropout_training( + training_tester: MNISTCNNTester, + record_tt_xla_property: Callable, +): + record_model_test_properties(record_tt_xla_property, "mnist-cnn-dropout") + + training_tester.test() diff --git a/tests/jax/models/mnist/cnn/nodropout/__init__.py b/tests/jax/models/mnist/cnn/nodropout/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/jax/models/mnist/cnn/model_implementation.py b/tests/jax/models/mnist/cnn/nodropout/model_implementation.py similarity index 96% rename from tests/jax/models/mnist/cnn/model_implementation.py rename to tests/jax/models/mnist/cnn/nodropout/model_implementation.py index 2e957af2..720427da 100644 --- a/tests/jax/models/mnist/cnn/model_implementation.py +++ b/tests/jax/models/mnist/cnn/nodropout/model_implementation.py @@ -2,11 +2,10 @@ # # SPDX-License-Identifier: Apache-2.0 -import jax from flax import linen as nn -class MNISTCNNModel(nn.Module): +class MNISTCNNNoDropoutModel(nn.Module): @nn.compact def __call__(self, x, *, train: bool): x = nn.Conv(features=32, kernel_size=(3, 3), padding="SAME")(x) diff --git a/tests/jax/models/mnist/cnn/nodropout/test_mnist_cnn_nodropout.py b/tests/jax/models/mnist/cnn/nodropout/test_mnist_cnn_nodropout.py new file mode 100644 index 00000000..f22b227e --- /dev/null +++ b/tests/jax/models/mnist/cnn/nodropout/test_mnist_cnn_nodropout.py @@ -0,0 +1,46 @@ +# SPDX-FileCopyrightText: (c) 2024 Tenstorrent AI ULC +# +# SPDX-License-Identifier: Apache-2.0 + +from typing import Callable + +import pytest +from infra import RunMode +from utils import record_model_test_properties + +from ..tester import MNISTCNNTester +from .model_implementation import MNISTCNNNoDropoutModel + +# ----- Fixtures ----- + + +@pytest.fixture +def inference_tester() -> MNISTCNNTester: + return MNISTCNNTester(MNISTCNNNoDropoutModel) + + +@pytest.fixture +def training_tester() -> MNISTCNNTester: + return MNISTCNNTester(MNISTCNNNoDropoutModel, RunMode.TRAINING) + + +# ----- Tests ----- + + +def test_mnist_cnn_nodropout_inference( + inference_tester: MNISTCNNTester, + record_tt_xla_property: Callable, +): + record_model_test_properties(record_tt_xla_property, "mnist-cnn-nodropout") + + inference_tester.test() + + +@pytest.mark.skip(reason="Support for training not implemented") +def test_mnist_cnn_nodropout_training( + training_tester: MNISTCNNTester, + record_tt_xla_property: Callable, +): + record_model_test_properties(record_tt_xla_property, "mnist-cnn-nodropout") + + training_tester.test() diff --git a/tests/jax/models/mnist/cnn/test_mnist_cnn.py b/tests/jax/models/mnist/cnn/tester.py similarity index 57% rename from tests/jax/models/mnist/cnn/test_mnist_cnn.py rename to tests/jax/models/mnist/cnn/tester.py index 1a6f8fcf..8bf2f8be 100644 --- a/tests/jax/models/mnist/cnn/test_mnist_cnn.py +++ b/tests/jax/models/mnist/cnn/tester.py @@ -6,19 +6,20 @@ import jax import jax.numpy as jnp -import pytest from flax import linen as nn -from infra import ModelTester, RunMode - -from tests.jax.models.mnist.cnn.model_implementation import MNISTCNNModel +from infra import ModelTester class MNISTCNNTester(ModelTester): """Tester for MNIST CNN model.""" + def __init__(self, cls): + self._model_class = cls + super().__init__() + # @override def _get_model(self) -> nn.Module: - return MNISTCNNModel() + return self._model_class() # @override def _get_forward_method_name(self) -> str: @@ -47,39 +48,3 @@ def _get_forward_method_kwargs(self) -> Dict[str, jax.Array]: # @override def _get_static_argnames(self): return ["train"] - - -# ----- Fixtures ----- - - -@pytest.fixture -def inference_tester() -> MNISTCNNTester: - return MNISTCNNTester() - - -@pytest.fixture -def training_tester() -> MNISTCNNTester: - return MNISTCNNTester(RunMode.TRAINING) - - -# ----- Tests ----- - - -@pytest.mark.skip( - reason='void mlir::OperationConverter::finalize(mlir::ConversionPatternRewriter &): Assertion `newValue && "replacement value not found"\' failed.' -) -def test_mnist_inference( - inference_tester: MNISTCNNTester, -): - inference_tester.test() - - -@pytest.mark.skip(reason="Support for training not implemented") -def test_mnist_training( - training_tester: MNISTCNNTester, -): - training_tester.test() - - -if __name__ == "__main__": - MNISTCNNTester().test() diff --git a/tests/jax/models/mnist/mlp/__init__.py b/tests/jax/models/mnist/mlp/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/jax/models/mnist/mlp/model_implementation.py b/tests/jax/models/mnist/mlp/model_implementation.py new file mode 100644 index 00000000..25b7eac9 --- /dev/null +++ b/tests/jax/models/mnist/mlp/model_implementation.py @@ -0,0 +1,22 @@ +# SPDX-FileCopyrightText: (c) 2025 Tenstorrent AI ULC +# +# SPDX-License-Identifier: Apache-2.0 + +from flax import linen as nn + + +class MNISTMLPModel(nn.Module): + hidden_sizes: tuple[int] + + @nn.compact + def __call__(self, x): + x = x.reshape((x.shape[0], -1)) + + for h in self.hidden_sizes: + x = nn.Dense(features=h)(x) + x = nn.relu(x) + + x = nn.Dense(features=10)(x) + x = nn.softmax(x) + + return x diff --git a/tests/jax/models/mnist/mlp/test_mnist_mlp.py b/tests/jax/models/mnist/mlp/test_mnist_mlp.py new file mode 100644 index 00000000..c458fbdd --- /dev/null +++ b/tests/jax/models/mnist/mlp/test_mnist_mlp.py @@ -0,0 +1,97 @@ +# SPDX-FileCopyrightText: (c) 2025 Tenstorrent AI ULC +# +# SPDX-License-Identifier: Apache-2.0 + +from typing import Callable, Sequence + +import jax +import pytest +from flax import linen as nn +from infra import ComparisonConfig, ModelTester, RunMode +from utils import record_model_test_properties + +from .model_implementation import MNISTMLPModel + + +class MNISTMLPTester(ModelTester): + """Tester for MNIST MLP model.""" + + def __init__( + self, + hidden_sizes: Sequence[int], + comparison_config: ComparisonConfig = ComparisonConfig(), + run_mode: RunMode = RunMode.INFERENCE, + ) -> None: + self._hidden_sizes = hidden_sizes + super().__init__(comparison_config, run_mode) + + # @override + def _get_model(self) -> nn.Module: + return MNISTMLPModel(self._hidden_sizes) + + # @override + def _get_forward_method_name(self) -> str: + return "apply" + + # @override + def _get_input_activations(self) -> Sequence[jax.Array]: + key = jax.random.PRNGKey(37) + img = jax.random.normal(key, (4, 28, 28, 1)) # B, H, W, C + # Channels is 1 as MNIST is in grayscale. + return img + + # @override + def _get_forward_method_args(self): + inp = self._get_input_activations() + + parameters = self._model.init(jax.random.PRNGKey(42), inp) + + return [parameters, inp] + + +# ----- Fixtures ----- + + +@pytest.fixture +def inference_tester(request) -> MNISTMLPTester: + return MNISTMLPTester(request.param) + + +@pytest.fixture +def training_tester(request) -> MNISTMLPTester: + return MNISTMLPTester(request.param, run_mode=RunMode.TRAINING) + + +# ----- Tests ----- + + +@pytest.mark.parametrize( + "inference_tester", + [ + (128,), + (128, 128), + (192, 128), + (512, 512), + (128, 128, 128), + (256, 128, 64), + ], + indirect=True, + ids=lambda val: f"{val}", +) +def test_mnist_mlp_inference( + inference_tester: MNISTMLPTester, + record_tt_xla_property: Callable, +): + record_model_test_properties(record_tt_xla_property, "mnist-mlp") + + inference_tester.test() + + +@pytest.mark.skip(reason="Support for training not implemented") +def test_mnist_mlp_training( + training_tester: MNISTMLPTester, + record_tt_xla_property: Callable, +): + record_model_test_properties(record_tt_xla_property, MNISTMLPModel.__qualname__) + + training_tester.test() diff --git a/tests/jax/models/roberta/__init__.py b/tests/jax/models/roberta/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/jax/models/roberta/base/__init__.py b/tests/jax/models/roberta/base/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/jax/models/roberta/base/test_roberta_base.py b/tests/jax/models/roberta/base/test_roberta_base.py new file mode 100644 index 00000000..5f3d104d --- /dev/null +++ b/tests/jax/models/roberta/base/test_roberta_base.py @@ -0,0 +1,55 @@ +# SPDX-FileCopyrightText: (c) 2024 Tenstorrent AI ULC +# +# SPDX-License-Identifier: Apache-2.0 + +from typing import Callable + +import pytest +from infra import RunMode +from utils import record_model_test_properties, runtime_fail + +from ..tester import FlaxRobertaForMaskedLMTester + +MODEL_PATH = "FacebookAI/roberta-base" +MODEL_NAME = "roberta-base" + + +# ----- Fixtures ----- + + +@pytest.fixture +def inference_tester() -> FlaxRobertaForMaskedLMTester: + return FlaxRobertaForMaskedLMTester(MODEL_PATH) + + +@pytest.fixture +def training_tester() -> FlaxRobertaForMaskedLMTester: + return FlaxRobertaForMaskedLMTester(MODEL_PATH, RunMode.TRAINING) + + +# ----- Tests ----- + + +@pytest.mark.xfail( + reason=runtime_fail( + "Cannot get the device from a tensor with host storage " + "(https://github.com/tenstorrent/tt-xla/issues/171)" + ) +) +def test_flax_roberta_base_inference( + inference_tester: FlaxRobertaForMaskedLMTester, + record_tt_xla_property: Callable, +): + record_model_test_properties(record_tt_xla_property, MODEL_NAME) + + inference_tester.test() + + +@pytest.mark.skip(reason="Support for training not implemented") +def test_flax_roberta_base_training( + training_tester: FlaxRobertaForMaskedLMTester, + record_tt_xla_property: Callable, +): + record_model_test_properties(record_tt_xla_property, MODEL_NAME) + + training_tester.test() diff --git a/tests/jax/models/roberta/large/__init__.py b/tests/jax/models/roberta/large/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/jax/models/roberta/large/test_roberta_large.py b/tests/jax/models/roberta/large/test_roberta_large.py new file mode 100644 index 00000000..0bad0a41 --- /dev/null +++ b/tests/jax/models/roberta/large/test_roberta_large.py @@ -0,0 +1,54 @@ +# SPDX-FileCopyrightText: (c) 2024 Tenstorrent AI ULC +# +# SPDX-License-Identifier: Apache-2.0 + +from typing import Callable + +import pytest +from infra import RunMode +from utils import record_model_test_properties, runtime_fail + +from ..tester import FlaxRobertaForMaskedLMTester + +MODEL_PATH = "FacebookAI/roberta-large" +MODEL_NAME = "roberta-large" + +# ----- Fixtures ----- + + +@pytest.fixture +def inference_tester() -> FlaxRobertaForMaskedLMTester: + return FlaxRobertaForMaskedLMTester(MODEL_PATH) + + +@pytest.fixture +def training_tester() -> FlaxRobertaForMaskedLMTester: + return FlaxRobertaForMaskedLMTester(MODEL_PATH, RunMode.TRAINING) + + +# ----- Tests ----- + + +@pytest.mark.xfail( + reason=runtime_fail( + "Cannot get the device from a tensor with host storage " + "(https://github.com/tenstorrent/tt-xla/issues/171)" + ) +) +def test_flax_roberta_large_inference( + inference_tester: FlaxRobertaForMaskedLMTester, + record_tt_xla_property: Callable, +): + record_model_test_properties(record_tt_xla_property, MODEL_NAME) + + inference_tester.test() + + +@pytest.mark.skip(reason="Support for training not implemented") +def test_flax_roberta_large_training( + training_tester: FlaxRobertaForMaskedLMTester, + record_tt_xla_property: Callable, +): + record_model_test_properties(record_tt_xla_property, MODEL_NAME) + + training_tester.test() diff --git a/tests/jax/models/roberta/test_roberta.py b/tests/jax/models/roberta/tester.py similarity index 50% rename from tests/jax/models/roberta/test_roberta.py rename to tests/jax/models/roberta/tester.py index f883bf86..71fc3ad4 100644 --- a/tests/jax/models/roberta/test_roberta.py +++ b/tests/jax/models/roberta/tester.py @@ -5,26 +5,32 @@ from typing import Dict, Sequence import jax -import pytest from flax import linen as nn -from infra import ModelTester, RunMode +from infra import ComparisonConfig, ModelTester, RunMode from transformers import AutoTokenizer, FlaxRobertaForMaskedLM -MODEL_PATH = "FacebookAI/roberta-base" - -# ----- Tester ----- - class FlaxRobertaForMaskedLMTester(ModelTester): """Tester for Roberta model on a masked language modeling task.""" + # TODO(mrakita): Add tests for other variants. + + def __init__( + self, + model_name: str, + comparison_config: ComparisonConfig = ComparisonConfig(), + run_mode: RunMode = RunMode.INFERENCE, + ) -> None: + self._model_name = model_name + super().__init__(comparison_config, run_mode) + # @override def _get_model(self) -> nn.Module: - return FlaxRobertaForMaskedLM.from_pretrained(MODEL_PATH) + return FlaxRobertaForMaskedLM.from_pretrained(self._model_name) # @override def _get_input_activations(self) -> Sequence[jax.Array]: - tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH) + tokenizer = AutoTokenizer.from_pretrained(self._model_name) inputs = tokenizer("Hello .", return_tensors="np") return inputs["input_ids"] @@ -39,33 +45,3 @@ def _get_forward_method_kwargs(self) -> Dict[str, jax.Array]: # @ override def _get_static_argnames(self): return ["train"] - - -# ----- Fixtures ----- - - -@pytest.fixture -def inference_tester() -> FlaxRobertaForMaskedLMTester: - return FlaxRobertaForMaskedLMTester() - - -@pytest.fixture -def training_tester() -> FlaxRobertaForMaskedLMTester: - return FlaxRobertaForMaskedLMTester(RunMode.TRAINING) - - -# ----- Tests ----- - - -@pytest.mark.skip(reason="failed to legalize operation 'stablehlo.dot_general'") -def test_roberta_inference( - inference_tester: FlaxRobertaForMaskedLMTester, -): - inference_tester.test() - - -@pytest.mark.skip(reason="Support for training not implemented") -def test_flax_roberta_training( - training_tester: FlaxRobertaForMaskedLMTester, -): - training_tester.test() diff --git a/tests/jax/models/squeezebert/__init__.py b/tests/jax/models/squeezebert/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/jax/models/squeezebert/model_implementation.py b/tests/jax/models/squeezebert/model_implementation.py new file mode 100644 index 00000000..9ff83e5f --- /dev/null +++ b/tests/jax/models/squeezebert/model_implementation.py @@ -0,0 +1,363 @@ +# SPDX-FileCopyrightText: (c) 2025 Tenstorrent AI ULC +# +# SPDX-License-Identifier: Apache-2.0 + +import re +from typing import Any, Dict, Tuple + +import einops +import flax.traverse_util +import jax +import jax.numpy as jnp +from flax import linen as nn +from transformers import SqueezeBertConfig + + +class SqueezeBertEmbedding(nn.Module): + """Embedding layer for SqueezeBERT model.""" + + config: SqueezeBertConfig + + def setup(self): + self.word_embedding = nn.Embed( + self.config.vocab_size, self.config.embedding_size + ) + self.position_embedding = nn.Embed( + self.config.max_position_embeddings, + self.config.embedding_size, + ) + self.token_type_embedding = nn.Embed( + self.config.type_vocab_size, + self.config.embedding_size, + ) + + self.layernorm = nn.LayerNorm() + self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob) + + def __call__( + self, + input_ids: jax.Array, + token_type_ids: jax.Array = None, + position_ids: jax.Array = None, + deterministic: bool = False, + ) -> jax.Array: + if position_ids is None: + position_ids = jax.numpy.arange(input_ids.shape[1]) + if token_type_ids is None: + token_type_ids = jax.numpy.zeros_like(input_ids) + + word_embeddings = self.word_embedding(input_ids) + position_embeddings = self.position_embedding(position_ids) + token_type_embeddings = self.token_type_embedding(token_type_ids) + + embeddings = word_embeddings + position_embeddings + token_type_embeddings + embeddings = self.layernorm(embeddings) + embeddings = self.dropout(embeddings, deterministic=deterministic) + return embeddings + + +class SqueezeBertSelfAttention(nn.Module): + """Self-attention layer for SqueezeBERT model.""" + + config: SqueezeBertConfig + + def setup(self): + self.query = nn.Conv( + features=self.config.hidden_size, + kernel_size=(1,), + feature_group_count=self.config.q_groups, + ) + self.key = nn.Conv( + features=self.config.hidden_size, + kernel_size=(1,), + feature_group_count=self.config.k_groups, + ) + self.value = nn.Conv( + features=self.config.hidden_size, + kernel_size=(1,), + feature_group_count=self.config.v_groups, + ) + self.output = nn.Conv( + features=self.config.hidden_size, + kernel_size=(1,), + feature_group_count=self.config.post_attention_groups, + ) + + self.attn_dropout = nn.Dropout(rate=self.config.attention_probs_dropout_prob) + self.resid_dropout = nn.Dropout(rate=self.config.hidden_dropout_prob) + self.layernorm = nn.LayerNorm() + + def __call__( + self, + hidden_states: jax.Array, + attention_mask: jax.Array, + deterministic: bool = False, + ) -> jax.Array: + head_dim = self.config.hidden_size // self.config.num_attention_heads + query = self.query(hidden_states) + key = self.key(hidden_states) + value = self.value(hidden_states) + + query = einops.rearrange( + query, + "b s (H d) -> b s H d", # batch sequence Heads dim_head + H=self.config.num_attention_heads, + d=head_dim, + ) + key = einops.rearrange( + key, + "b s (H d) -> b s H d", + H=self.config.num_attention_heads, + d=head_dim, + ) + value = einops.rearrange( + value, + "b s (H d) -> b s H d", + H=self.config.num_attention_heads, + d=head_dim, + ) + + attention_scores = jnp.einsum("B s H d ,B S H d -> B H s S", query, key) + attention_scores = attention_scores / jnp.sqrt(head_dim) + if attention_mask is not None: + attention_scores = attention_scores + attention_mask + + attention_probs = nn.activation.softmax(attention_scores, axis=-1) + attention_probs = self.attn_dropout( + attention_probs, deterministic=deterministic + ) + + context = jnp.einsum("B H s S, B S H d -> B s H d", attention_probs, value) + context = einops.rearrange(context, "b s H d -> b s (H d)") + + output = self.output(context) + output = self.resid_dropout(output, deterministic=deterministic) + output = hidden_states + output + output = self.layernorm(output) + return output + + +class SqueezeBertMLP(nn.Module): + """MLP layer for SqueezeBERT model.""" + + config: SqueezeBertConfig + + def setup(self): + self.w1 = nn.Conv( + features=self.config.intermediate_size, + kernel_size=(1,), + feature_group_count=self.config.intermediate_groups, + ) + if self.config.hidden_act == "gelu": + self.act = nn.gelu + else: + raise ValueError( + f"Activation function {self.config.hidden_act} not supported." + ) + self.w2 = nn.Conv( + features=self.config.hidden_size, + kernel_size=(1,), + feature_group_count=self.config.output_groups, + ) + self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob) + self.layernorm = nn.LayerNorm() + + def __call__( + self, hidden_states: jax.Array, deterministic: bool = False + ) -> jax.Array: + x = self.w1(hidden_states) + x = self.act(x) + x = self.w2(x) + x = self.dropout(x, deterministic=deterministic) + output = hidden_states + x + output = self.layernorm(output) + return output + + +class SqueezeBertLayer(nn.Module): + """Layer for SqueezeBERT model.""" + + config: SqueezeBertConfig + + def setup(self): + self.attention = SqueezeBertSelfAttention(self.config) + self.mlp = SqueezeBertMLP(self.config) + + def __call__( + self, + hidden_states: jax.Array, + attention_mask: jax.Array, + deterministic: bool = False, + ) -> jax.Array: + attention_output = self.attention( + hidden_states, attention_mask, deterministic=deterministic + ) + output = self.mlp(attention_output, deterministic=deterministic) + return output + + +class SqueezeBertEncoder(nn.Module): + """Encoder for SqueezeBERT model.""" + + config: SqueezeBertConfig + + def setup(self): + self.layers = [ + SqueezeBertLayer(self.config) for _ in range(self.config.num_hidden_layers) + ] + + def __call__( + self, + hidden_states: jax.Array, + attention_mask: jax.Array, + deterministic: bool = False, + ) -> jax.Array: + for layer in self.layers: + hidden_states = layer( + hidden_states, attention_mask, deterministic=deterministic + ) + return hidden_states + + +class SqueezeBertPooler(nn.Module): + """Pooler layer for SqueezeBERT model.""" + + config: SqueezeBertConfig + + def setup(self): + self.dense = nn.Dense(self.config.hidden_size) + self.activation = nn.tanh + + def __call__(self, hidden_states: jax.Array) -> jax.Array: + first_token_tensor = hidden_states[:, 0] + pooled_output = self.dense(first_token_tensor) + pooled_output = self.activation(pooled_output) + return pooled_output + + +class SqueezeBertModel(nn.Module): + """SqueezeBERT model.""" + + config: SqueezeBertConfig + + def setup(self): + self.embeddings = SqueezeBertEmbedding(self.config) + self.encoder = SqueezeBertEncoder(self.config) + self.pooler = SqueezeBertPooler(self.config) + + def __call__( + self, + input_ids: jax.Array, + attention_mask: jax.Array, + token_type_ids: jax.Array = None, + position_ids: jax.Array = None, + *, + train: bool, + ) -> Tuple[jax.Array, jax.Array]: + if attention_mask is None: + attention_mask = jnp.ones_like(input_ids) + embeddings = self.embeddings( + input_ids, token_type_ids, position_ids, deterministic=not train + ) + encoder_output = self.encoder( + embeddings, attention_mask, deterministic=not train + ) + pooled_output = self.pooler(encoder_output) + return encoder_output, pooled_output + + +class SqueezeBertForMaskedLM(nn.Module): + """SqueezeBERT model with masked language modeling head.""" + + config: SqueezeBertConfig + + def setup(self): + self.squeezebert = SqueezeBertModel(self.config) + self.transform_dense = nn.Dense(self.config.hidden_size) + + if self.config.hidden_act == "gelu": + self.transform_act = nn.gelu + else: + raise ValueError( + f"Activation function {self.config.hidden_act} not supported." + ) + + self.transform_layernorm = nn.LayerNorm() + + self.decoder = nn.Dense(self.config.vocab_size) + # TODO(stefan): Figure out if SqueezeBERT uses tied weights for embeddings and output layer + # that is only relevant for training + + def __call__( + self, + input_ids: jax.Array, + attention_mask: jax.Array = None, + token_type_ids: jax.Array = None, + position_ids: jax.Array = None, + *, + train: bool, + ) -> jax.Array: + hidden_states, _ = self.squeezebert( + input_ids, attention_mask, token_type_ids, position_ids, train=train + ) + hidden_states = self.transform_dense(hidden_states) + hidden_states = self.transform_act(hidden_states) + hidden_states = self.transform_layernorm(hidden_states) + + prediction_scores = self.decoder(hidden_states) + return prediction_scores + + @staticmethod + def init_from_pytorch_statedict(state_dict: Dict[str, Any]) -> Dict[str, Any]: + # Key substitutions for remapping huggingface checkpoints to this implementation + PATTERNS = [ + ("transformer.", "squeezebert."), + ("LayerNorm", "layernorm"), + ("layernorm.weight", "layernorm.scale"), + ("_embeddings.weight", "_embedding.embedding"), + ("encoder.layers.", "encoder.layers_"), + ("attention.query.weight", "attention.query.kernel"), + ("attention.key.weight", "attention.key.kernel"), + ("attention.value.weight", "attention.value.kernel"), + ("post_attention.conv1d.weight", "attention.output.kernel"), + ("post_attention.conv1d.bias", "attention.output.bias"), + ("post_attention.layernorm", "attention.layernorm"), + ("intermediate.conv1d.weight", "mlp.w1.kernel"), + ("intermediate.conv1d.bias", "mlp.w1.bias"), + ("output.conv1d.weight", "mlp.w2.kernel"), + ("output.conv1d.bias", "mlp.w2.bias"), + ("output.layernorm", "mlp.layernorm"), + ("pooler.dense.weight", "pooler.dense.kernel"), + ("cls.predictions.transform.dense.weight", "transform_dense.kernel"), + ("cls.predictions.transform.dense.bias", "transform_dense.bias"), + ("cls.predictions.transform.layernorm", "transform_layernorm"), + ("cls.predictions.decoder.weight", "decoder.kernel"), + ("cls.predictions.bias", "decoder.bias"), + ] + + def is_banned_key(key: str) -> bool: + return "seq_relationship" in key + + def rewrite_key(key: str) -> str: + for pattern in PATTERNS: + key = re.sub(pattern[0], pattern[1], key) + return key + + def process_value(k: str, v) -> jnp.ndarray: + if "kernel" in k: + if len(v.shape) == 2: + return jnp.transpose(v) + if len(v.shape) == 3: + return jnp.transpose(v, (2, 1, 0)) + return v + + for k, v in state_dict.items(): + # Inplace conversion might lower peak memory usage + state_dict[k] = jnp.array(v) + + state_dict = { + rewrite_key(k): v for k, v in state_dict.items() if not is_banned_key(k) + } + state_dict = {k: process_value(k, v) for k, v in state_dict.items()} + state_dict = flax.traverse_util.unflatten_dict(state_dict, sep=".") + return {"params": state_dict} diff --git a/tests/jax/models/squeezebert/test_squeezebert.py b/tests/jax/models/squeezebert/test_squeezebert.py new file mode 100644 index 00000000..9a91ae74 --- /dev/null +++ b/tests/jax/models/squeezebert/test_squeezebert.py @@ -0,0 +1,97 @@ +# SPDX-FileCopyrightText: (c) 2024 Tenstorrent AI ULC +# +# SPDX-License-Identifier: Apache-2.0 + +from typing import Callable, Dict, Sequence + +import jax +import pytest +import torch +from flax import linen as nn +from huggingface_hub import hf_hub_download +from infra import ModelTester, RunMode +from transformers import AutoTokenizer +from utils import compile_fail, record_model_test_properties + +from .model_implementation import SqueezeBertConfig, SqueezeBertForMaskedLM + +MODEL_PATH = "squeezebert/squeezebert-uncased" +MODEL_NAME = "squeezebert" + +# ----- Tester ----- + + +class SqueezeBertTester(ModelTester): + """Tester for SqueezeBERT model on a masked language modeling task""" + + # @override + def _get_model(self) -> nn.Module: + config = SqueezeBertConfig.from_pretrained(MODEL_PATH) + return SqueezeBertForMaskedLM(config) + + # @override + def _get_forward_method_name(self): + return "apply" + + # @override + def _get_input_activations(self) -> Sequence[jax.Array]: + tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH) + inputs = tokenizer("The [MASK] barked at me", return_tensors="np") + return inputs["input_ids"] + + # @override + def _get_forward_method_kwargs(self) -> Dict[str, jax.Array]: + model_file = hf_hub_download( + repo_id="squeezebert/squeezebert-uncased", filename="pytorch_model.bin" + ) + state_dict = torch.load(model_file, weights_only=True) + + params = self._model.init_from_pytorch_statedict(state_dict) + + return { + "variables": params, # JAX frameworks have a convention of passing weights as the first argument + "input_ids": self._get_input_activations(), + "train": False, + } + + # @override + def _get_static_argnames(self): + return ["train"] + + +# ----- Fixtures ----- + + +@pytest.fixture +def inference_tester() -> SqueezeBertTester: + return SqueezeBertTester() + + +@pytest.fixture +def training_tester() -> SqueezeBertTester: + return SqueezeBertTester(RunMode.TRAINING) + + +# ----- Tests ----- + + +@pytest.mark.xfail( + reason=compile_fail("Failed to legalize operation 'ttir.convolution'") +) +def test_squeezebert_inference( + inference_tester: SqueezeBertTester, + record_tt_xla_property: Callable, +): + record_model_test_properties(record_tt_xla_property, MODEL_NAME) + + inference_tester.test() + + +@pytest.mark.skip(reason="Support for training not implemented") +def test_squeezebert_training( + training_tester: SqueezeBertTester, + record_tt_xla_property: Callable, +): + record_model_test_properties(record_tt_xla_property, MODEL_NAME) + + training_tester.test() diff --git a/tests/jax/ops/test_abs.py b/tests/jax/ops/test_abs.py index d5c0827f..c7bd9c6c 100644 --- a/tests/jax/ops/test_abs.py +++ b/tests/jax/ops/test_abs.py @@ -2,16 +2,25 @@ # # SPDX-License-Identifier: Apache-2.0 +from typing import Callable + import jax import jax.numpy as jnp import pytest from infra import run_op_test_with_random_inputs +from utils import record_unary_op_test_properties -@pytest.mark.parametrize("x_shape", [(32, 32), (64, 64)]) -def test_abs(x_shape: tuple): +@pytest.mark.parametrize("x_shape", [(32, 32), (64, 64)], ids=lambda val: f"{val}") +def test_abs(x_shape: tuple, record_tt_xla_property: Callable): def abs(x: jax.Array) -> jax.Array: return jnp.abs(x) + record_unary_op_test_properties( + record_tt_xla_property, + "jax.numpy.abs", + "stablehlo.abs", + ) + # Test both negative and positive values. run_op_test_with_random_inputs(abs, [x_shape], minval=-5.0, maxval=5.0) diff --git a/tests/jax/ops/test_add.py b/tests/jax/ops/test_add.py index ce880a7c..ccef3d0a 100644 --- a/tests/jax/ops/test_add.py +++ b/tests/jax/ops/test_add.py @@ -2,10 +2,13 @@ # # SPDX-License-Identifier: Apache-2.0 +from typing import Callable + import jax import jax.numpy as jnp import pytest from infra import run_op_test_with_random_inputs +from utils import record_binary_op_test_properties @pytest.mark.parametrize( @@ -14,9 +17,16 @@ [(32, 32), (32, 32)], [(64, 64), (64, 64)], ], + ids=lambda val: f"{val}", ) -def test_add(x_shape: tuple, y_shape: tuple): +def test_add(x_shape: tuple, y_shape: tuple, record_tt_xla_property: Callable): def add(x: jax.Array, y: jax.Array) -> jax.Array: return jnp.add(x, y) + record_binary_op_test_properties( + record_tt_xla_property, + "jax.numpy.add", + "stablehlo.add", + ) + run_op_test_with_random_inputs(add, [x_shape, y_shape]) diff --git a/tests/jax/ops/test_broadcast_in_dim.py b/tests/jax/ops/test_broadcast_in_dim.py index f431d061..61f2e24f 100644 --- a/tests/jax/ops/test_broadcast_in_dim.py +++ b/tests/jax/ops/test_broadcast_in_dim.py @@ -2,18 +2,27 @@ # # SPDX-License-Identifier: Apache-2.0 +from typing import Callable + +import jax import jax.numpy as jnp import pytest from infra import run_op_test_with_random_inputs +from utils import record_unary_op_test_properties -@pytest.mark.parametrize("input_shapes", [[(2, 1)]]) -@pytest.mark.skip( - "error: type of return operand 0 doesn't match function result type in " - "function @main" +@pytest.mark.parametrize("input_shapes", [[(2, 1)]], ids=lambda val: f"{val}") +@pytest.mark.xfail( + reason="AssertionError: Atol comparison failed. Calculated: atol=0.804124116897583. Required: atol=0.16" ) -def test_broadcast_in_dim(input_shapes): - def broadcast(a): +def test_broadcast_in_dim(input_shapes: tuple, record_tt_xla_property: Callable): + def broadcast(a: jax.Array): return jnp.broadcast_to(a, (2, 4)) + record_unary_op_test_properties( + record_tt_xla_property, + "jax.numpy.broadcast_to", + "stablehlo.broadcast_in_dim", + ) + run_op_test_with_random_inputs(broadcast, input_shapes) diff --git a/tests/jax/ops/test_cbrt.py b/tests/jax/ops/test_cbrt.py index b690034c..c90573d9 100644 --- a/tests/jax/ops/test_cbrt.py +++ b/tests/jax/ops/test_cbrt.py @@ -2,15 +2,24 @@ # # SPDX-License-Identifier: Apache-2.0 +from typing import Callable + import jax import jax.numpy as jnp import pytest from infra import run_op_test_with_random_inputs +from utils import record_unary_op_test_properties -@pytest.mark.parametrize("x_shape", [(32, 32), (64, 64)]) -def test_cbrt(x_shape: tuple): +@pytest.mark.parametrize("x_shape", [(32, 32), (64, 64)], ids=lambda val: f"{val}") +def test_cbrt(x_shape: tuple, record_tt_xla_property: Callable): def cbrt(x: jax.Array) -> jax.Array: return jnp.cbrt(x) + record_unary_op_test_properties( + record_tt_xla_property, + "jax.numpy.cbrt", + "stablehlo.cbrt", + ) + run_op_test_with_random_inputs(cbrt, [x_shape]) diff --git a/tests/jax/ops/test_compare.py b/tests/jax/ops/test_compare.py index 9d0d63e4..79e5cc1f 100644 --- a/tests/jax/ops/test_compare.py +++ b/tests/jax/ops/test_compare.py @@ -9,6 +9,7 @@ import jax.numpy as jnp import pytest from infra import run_op_test_with_random_inputs +from utils import record_binary_op_test_properties # NOTE TTNN does not support boolean data type, so bfloat16 is used instead. # Hence the output of comparison operation is bfloat16. JAX can not perform any @@ -22,8 +23,6 @@ # TODO investigate why this decorator cannot be removed. See issue # https://github.com/tenstorrent/tt-xla/issues/156 -# TODO split this file into multiple files, one per op. - def convert_output_to_bfloat16(f: Callable): """Decorator to work around the mentioned issue.""" @@ -35,34 +34,100 @@ def wrapper(*args, **kwargs): return wrapper -@convert_output_to_bfloat16 -def equal(x: jax.Array, y: jax.Array) -> jax.Array: - return x == y +@pytest.mark.parametrize( + ["x_shape", "y_shape"], + [ + [(32, 32), (32, 32)], + [(64, 64), (64, 64)], + ], + ids=lambda val: f"{val}", +) +def test_compare_equal( + x_shape: tuple, y_shape: tuple, record_tt_xla_property: Callable +): + @convert_output_to_bfloat16 + def equal(x: jax.Array, y: jax.Array) -> jax.Array: + return x == y + + record_binary_op_test_properties( + record_tt_xla_property, + "jax.numpy.equal", + "stablehlo.compare{EQ}", + ) + run_op_test_with_random_inputs(equal, [x_shape, y_shape]) -@convert_output_to_bfloat16 -def not_equal(x: jax.Array, y: jax.Array) -> jax.Array: - return x != y +@pytest.mark.parametrize( + ["x_shape", "y_shape"], + [ + [(32, 32), (32, 32)], + [(64, 64), (64, 64)], + ], + ids=lambda val: f"{val}", +) +def test_compare_not_equal( + x_shape: tuple, y_shape: tuple, record_tt_xla_property: Callable +): + @convert_output_to_bfloat16 + def not_equal(x: jax.Array, y: jax.Array) -> jax.Array: + return x != y + + record_binary_op_test_properties( + record_tt_xla_property, + "jax.numpy.not_equal", + "stablehlo.compare{NE}", + ) -@convert_output_to_bfloat16 -def greater(x: jax.Array, y: jax.Array) -> jax.Array: - return x > y + run_op_test_with_random_inputs(not_equal, [x_shape, y_shape]) -@convert_output_to_bfloat16 -def greater_or_equal(x: jax.Array, y: jax.Array) -> jax.Array: - return x >= y +@pytest.mark.parametrize( + ["x_shape", "y_shape"], + [ + [(32, 32), (32, 32)], + [(64, 64), (64, 64)], + ], + ids=lambda val: f"{val}", +) +def test_compare_greater( + x_shape: tuple, y_shape: tuple, record_tt_xla_property: Callable +): + @convert_output_to_bfloat16 + def greater(x: jax.Array, y: jax.Array) -> jax.Array: + return x > y + + record_binary_op_test_properties( + record_tt_xla_property, + "jax.numpy.greater", + "stablehlo.compare{GT}", + ) + + run_op_test_with_random_inputs(greater, [x_shape, y_shape]) -@convert_output_to_bfloat16 -def less(x: jax.Array, y: jax.Array) -> jax.Array: - return x < y +@pytest.mark.parametrize( + ["x_shape", "y_shape"], + [ + [(32, 32), (32, 32)], + [(64, 64), (64, 64)], + ], + ids=lambda val: f"{val}", +) +def test_compare_greater_equal( + x_shape: tuple, y_shape: tuple, record_tt_xla_property: Callable +): + @convert_output_to_bfloat16 + def greater_equal(x: jax.Array, y: jax.Array) -> jax.Array: + return x >= y + record_binary_op_test_properties( + record_tt_xla_property, + "jax.numpy.greater_equal", + "stablehlo.compare{GE}", + ) -@convert_output_to_bfloat16 -def less_or_equal(x: jax.Array, y: jax.Array) -> jax.Array: - return x <= y + run_op_test_with_random_inputs(greater_equal, [x_shape, y_shape]) @pytest.mark.parametrize( @@ -71,11 +136,41 @@ def less_or_equal(x: jax.Array, y: jax.Array) -> jax.Array: [(32, 32), (32, 32)], [(64, 64), (64, 64)], ], + ids=lambda val: f"{val}", ) -def test_compare(x_shape: tuple, y_shape: tuple): - run_op_test_with_random_inputs(equal, [x_shape, y_shape]) - run_op_test_with_random_inputs(not_equal, [x_shape, y_shape]) - run_op_test_with_random_inputs(greater, [x_shape, y_shape]) - run_op_test_with_random_inputs(greater_or_equal, [x_shape, y_shape]) +def test_compare_less(x_shape: tuple, y_shape: tuple, record_tt_xla_property: Callable): + @convert_output_to_bfloat16 + def less(x: jax.Array, y: jax.Array) -> jax.Array: + return x < y + + record_binary_op_test_properties( + record_tt_xla_property, + "jax.numpy.less", + "stablehlo.compare{LT}", + ) + run_op_test_with_random_inputs(less, [x_shape, y_shape]) - run_op_test_with_random_inputs(less_or_equal, [x_shape, y_shape]) + + +@pytest.mark.parametrize( + ["x_shape", "y_shape"], + [ + [(32, 32), (32, 32)], + [(64, 64), (64, 64)], + ], + ids=lambda val: f"{val}", +) +def test_compare_less_equal( + x_shape: tuple, y_shape: tuple, record_tt_xla_property: Callable +): + @convert_output_to_bfloat16 + def less_equal(x: jax.Array, y: jax.Array) -> jax.Array: + return x <= y + + record_binary_op_test_properties( + record_tt_xla_property, + "jax.numpy.less_equal", + "stablehlo.compare{LE}", + ) + + run_op_test_with_random_inputs(less_equal, [x_shape, y_shape]) diff --git a/tests/jax/ops/test_concatenate.py b/tests/jax/ops/test_concatenate.py index bf377677..a3deb8e1 100644 --- a/tests/jax/ops/test_concatenate.py +++ b/tests/jax/ops/test_concatenate.py @@ -2,10 +2,13 @@ # # SPDX-License-Identifier: Apache-2.0 +from typing import Callable + import jax import jax.numpy as jnp import pytest from infra import run_op_test_with_random_inputs +from utils import record_binary_op_test_properties @pytest.mark.parametrize( @@ -16,9 +19,18 @@ [(32, 32, 32), (32, 32, 32), 2], [(64, 64, 64, 64), (64, 64, 64, 64), 3], ], + ids=lambda val: f"{val}", ) -def test_concatenate(x_shape: tuple, y_shape: tuple, axis: int): +def test_concatenate( + x_shape: tuple, y_shape: tuple, axis: int, record_tt_xla_property: Callable +): def concat(x: jax.Array, y: jax.Array) -> jax.Array: return jnp.concatenate([x, y], axis=axis) + record_binary_op_test_properties( + record_tt_xla_property, + "jax.numpy.concatenate", + "stablehlo.concatenate", + ) + run_op_test_with_random_inputs(concat, [x_shape, y_shape]) diff --git a/tests/jax/ops/test_constant.py b/tests/jax/ops/test_constant.py index 02657e35..e7bb5f39 100644 --- a/tests/jax/ops/test_constant.py +++ b/tests/jax/ops/test_constant.py @@ -2,38 +2,54 @@ # # SPDX-License-Identifier: Apache-2.0 +from typing import Callable + import jax.numpy as jnp import pytest from infra import run_op_test +from utils import record_op_test_properties -@pytest.mark.parametrize("shape", [(32, 32), (1, 1)]) -@pytest.mark.skip( - "error: type of return operand 0 doesn't match function result type in " - "function @main" -) -def test_constant_zeros(shape: tuple): +@pytest.mark.parametrize("shape", [(32, 32), (1, 1)], ids=lambda val: f"{val}") +def test_constant_zeros(shape: tuple, record_tt_xla_property: Callable): def module_constant_zeros(): return jnp.zeros(shape) + record_op_test_properties( + record_tt_xla_property, + "Constant op", + "jax.numpy.zeros", + "stablehlo.constant", + ) + run_op_test(module_constant_zeros, []) -@pytest.mark.parametrize("shape", [(32, 32), (1, 1)]) -@pytest.mark.skip( - "error: type of return operand 0 doesn't match function result type in " - "function @main" -) -def test_constant_ones(shape: tuple): +@pytest.mark.parametrize("shape", [(32, 32), (1, 1)], ids=lambda val: f"{val}") +def test_constant_ones(shape: tuple, record_tt_xla_property: Callable): def module_constant_ones(): return jnp.ones(shape) + record_op_test_properties( + record_tt_xla_property, + "Constant op", + "jax.numpy.ones", + "stablehlo.constant", + ) + run_op_test(module_constant_ones, []) -@pytest.mark.skip("Fails due to: error: failed to legalize operation 'ttir.constant'") -def test_constant_multi_value(): +@pytest.mark.xfail(reason="failed to legalize operation 'ttir.constant'") +def test_constant_multi_value(record_tt_xla_property: Callable): def module_constant_multi(): return jnp.array([[1, 2], [3, 4]], dtype=jnp.float32) + record_op_test_properties( + record_tt_xla_property, + "Constant op", + "jax.numpy.array", + "stablehlo.constant", + ) + run_op_test(module_constant_multi, []) diff --git a/tests/jax/ops/test_convert.py b/tests/jax/ops/test_convert.py index 252e62de..99ba4385 100644 --- a/tests/jax/ops/test_convert.py +++ b/tests/jax/ops/test_convert.py @@ -2,12 +2,15 @@ # # SPDX-License-Identifier: Apache-2.0 +from typing import Callable + import jax import jax.lax as jlx import jax.numpy as jnp import pytest from infra import random_tensor, run_op_test from jax._src.typing import DTypeLike +from utils import record_unary_op_test_properties # TODO we need to parametrize with all supported dtypes. @@ -30,10 +33,21 @@ "float64", ], ) -def test_convert(from_dtype: DTypeLike, to_dtype: DTypeLike): +@pytest.mark.skip( + f"Skipped unconditionally due to many fails. There is ongoing work on rewriting these tests." +) +def test_convert( + from_dtype: DTypeLike, to_dtype: DTypeLike, record_tt_xla_property: Callable +): def convert(x: jax.Array) -> jax.Array: return jlx.convert_element_type(x, new_dtype=jnp.dtype(to_dtype)) + record_unary_op_test_properties( + record_tt_xla_property, + "jax.lax.convert_element_type", + "stablehlo.convert", + ) + x_shape = (32, 32) # Shape does not make any impact here, thus not parametrized. input = random_tensor(x_shape, dtype=from_dtype) diff --git a/tests/jax/ops/test_convolution.py b/tests/jax/ops/test_convolution.py index 6d348954..49a13aff 100644 --- a/tests/jax/ops/test_convolution.py +++ b/tests/jax/ops/test_convolution.py @@ -2,9 +2,12 @@ # # SPDX-License-Identifier: Apache-2.0 +from typing import Callable + import jax import pytest from infra import ComparisonConfig, random_tensor, run_op_test +from utils import record_op_test_properties # TODO investigate why conv has such poor precision. @@ -25,9 +28,13 @@ def comparison_config() -> ComparisonConfig: ((1, 512, 256), (512, 512, 1)), ((1, 512, 512), (1024, 512, 1)), ], + ids=lambda val: f"{val}", ) def test_conv1d( - img_shape: tuple, kernel_shape: tuple, comparison_config: ComparisonConfig + img_shape: tuple, + kernel_shape: tuple, + comparison_config: ComparisonConfig, + record_tt_xla_property: Callable, ): def conv1d(img, weights): return jax.lax.conv_general_dilated( @@ -42,6 +49,13 @@ def conv1d(img, weights): batch_group_count=1, ) + record_op_test_properties( + record_tt_xla_property, + "Convolution op", + "jax.lax.conv_general_dilated", + "stablehlo.convolution", + ) + img = random_tensor(img_shape, dtype="bfloat16") kernel = random_tensor(kernel_shape, dtype="bfloat16") @@ -78,10 +92,7 @@ def conv1d(img, weights): (1, 256, 256, 14, 14, 3, 3, 1, 1, 1), (1, 1024, 256, 14, 14, 1, 1, 1, 1, 0), (1, 256, 1024, 14, 14, 1, 1, 1, 1, 0), - pytest.param( # TODO This passed in old infra. Investigate. - *(1, 2048, 1024, 14, 14, 1, 1, 2, 2, 0), - marks=pytest.mark.skip(reason="Segmentation fault"), - ), + (1, 2048, 1024, 14, 14, 1, 1, 2, 2, 0), (1, 512, 1024, 14, 14, 1, 1, 2, 2, 0), (1, 512, 512, 7, 7, 3, 3, 1, 1, 1), (1, 2048, 512, 7, 7, 1, 1, 1, 1, 0), @@ -93,6 +104,7 @@ def conv1d(img, weights): (1, 256, 256, 7, 7, 3, 3, 1, 1, 1), (1, 256, 64, 56, 56, 1, 1, 2, 2, 0), ], + ids=lambda val: f"{val}", ) def test_conv2d( batch_size: int, @@ -106,6 +118,7 @@ def test_conv2d( stride_w: int, padding: int, comparison_config: ComparisonConfig, + record_tt_xla_property: Callable, ): def conv2d(img: jax.Array, kernel: jax.Array): return jax.lax.conv_general_dilated( @@ -116,6 +129,13 @@ def conv2d(img: jax.Array, kernel: jax.Array): dimension_numbers=("NHWC", "OIHW", "NHWC"), ) + record_op_test_properties( + record_tt_xla_property, + "Convolution op", + "jax.lax.conv_general_dilated", + "stablehlo.convolution", + ) + img_shape = (batch_size, input_height, input_width, input_channels) kernel_shape = (output_channels, input_channels, filter_height, filter_width) diff --git a/tests/jax/ops/test_divide.py b/tests/jax/ops/test_divide.py index a47eb56d..fe3173af 100644 --- a/tests/jax/ops/test_divide.py +++ b/tests/jax/ops/test_divide.py @@ -2,10 +2,13 @@ # # SPDX-License-Identifier: Apache-2.0 +from typing import Callable + import jax import jax.numpy as jnp import pytest from infra import run_op_test_with_random_inputs +from utils import record_binary_op_test_properties @pytest.mark.parametrize( @@ -14,9 +17,14 @@ [(32, 32), (32, 32)], [(64, 64), (64, 64)], ], + ids=lambda val: f"{val}", ) -def test_divide(x_shape: tuple, y_shape: tuple): +def test_divide(x_shape: tuple, y_shape: tuple, record_tt_xla_property: Callable): def divide(x: jax.Array, y: jax.Array) -> jax.Array: return jnp.divide(x, y) + record_binary_op_test_properties( + record_tt_xla_property, "jax.numpy.divide", "stablehlo.divide" + ) + run_op_test_with_random_inputs(divide, [x_shape, y_shape]) diff --git a/tests/jax/ops/test_dot_general.py b/tests/jax/ops/test_dot_general.py index b8f62bb1..681faac6 100644 --- a/tests/jax/ops/test_dot_general.py +++ b/tests/jax/ops/test_dot_general.py @@ -2,23 +2,78 @@ # # SPDX-License-Identifier: Apache-2.0 +from typing import Callable + import jax -import jax.numpy as jnp import pytest from infra import run_op_test_with_random_inputs +from utils import record_binary_op_test_properties + + +# Tests for dot_general op where vectors containing indices of contracting dimensions +# are of size 1 and are equal. In training models, besides cases that correspond to matmul, +# this is the most common one we have. +@pytest.mark.parametrize( + ["x_shape", "y_shape"], + [ + [(1, 32), (1, 32)], + [(1, 32, 64), (1, 32, 32)], + [(2, 32, 64), (2, 32, 64)], + [(2, 16, 32, 64), (2, 16, 64, 32)], + ], + ids=lambda val: f"{val}", +) +def test_dot_general_common( + x_shape: tuple, y_shape: tuple, record_tt_xla_property: Callable +): + def dot_general(x: jax.Array, y: jax.Array) -> jax.Array: + return jax.lax.dot_general(x, y, dimension_numbers=((1, 1), (0, 0))) + + record_binary_op_test_properties( + record_tt_xla_property, "jax.lax.dot_general", "stablehlo.dot_general" + ) + + run_op_test_with_random_inputs(dot_general, [x_shape, y_shape]) +# Tests for dot_general op where this operation corresponds to regular matmul. @pytest.mark.parametrize( ["x_shape", "y_shape"], [ - [(32, 32), (32, 32)], - [(64, 64), (64, 64)], - [(32, 64), (64, 32)], - [(64, 32), (32, 64)], + [(1, 32, 64), (1, 64, 32)], + [(2, 32, 64), (2, 64, 64)], ], ) -def test_dot_general(x_shape: tuple, y_shape: tuple): +def test_dot_general_matmul( + x_shape: tuple, y_shape: tuple, record_tt_xla_property: Callable +): def dot_general(x: jax.Array, y: jax.Array) -> jax.Array: - return jnp.dot(x, y) + return jax.lax.dot_general(x, y, dimension_numbers=((2, 1), (0, 0))) + + record_binary_op_test_properties( + record_tt_xla_property, "jax.lax.dot_general", "stablehlo.dot_general" + ) + + run_op_test_with_random_inputs(dot_general, [x_shape, y_shape]) + + +# Tests for dot_general op where vectors containing indices of +# contracting dimensions are of size greater than 1. +@pytest.mark.parametrize( + ["x_shape", "y_shape"], + [ + [(1, 16, 16, 8), (1, 16, 8, 16)], + [(2, 8, 8, 16), (2, 8, 16, 8)], + ], +) +def test_dot_general_multiple_contract( + x_shape: tuple, y_shape: tuple, record_tt_xla_property: Callable +): + def dot_general(x: jax.Array, y: jax.Array) -> jax.Array: + return jax.lax.dot_general(x, y, dimension_numbers=(((1, 3), (1, 2)), (0, 0))) + + record_binary_op_test_properties( + record_tt_xla_property, "jax.lax.dot_general", "stablehlo.dot_general" + ) run_op_test_with_random_inputs(dot_general, [x_shape, y_shape]) diff --git a/tests/jax/ops/test_exponential.py b/tests/jax/ops/test_exponential.py index 12b9c308..b23cddec 100644 --- a/tests/jax/ops/test_exponential.py +++ b/tests/jax/ops/test_exponential.py @@ -2,15 +2,24 @@ # # SPDX-License-Identifier: Apache-2.0 +from typing import Callable + import jax import jax.numpy as jnp import pytest from infra import run_op_test_with_random_inputs +from utils import record_unary_op_test_properties -@pytest.mark.parametrize("x_shape", [(32, 32), (64, 64)]) -def test_exponential(x_shape: tuple): +@pytest.mark.parametrize("x_shape", [(32, 32), (64, 64)], ids=lambda val: f"{val}") +def test_exponential(x_shape: tuple, record_tt_xla_property: Callable): def exponential(x: jax.Array) -> jax.Array: return jnp.exp(x) + record_unary_op_test_properties( + record_tt_xla_property, + "jax.numpy.exp", + "stablehlo.exponential", + ) + run_op_test_with_random_inputs(exponential, [x_shape]) diff --git a/tests/jax/ops/test_exponential_minus_one.py b/tests/jax/ops/test_exponential_minus_one.py index 50da4abd..059975b0 100644 --- a/tests/jax/ops/test_exponential_minus_one.py +++ b/tests/jax/ops/test_exponential_minus_one.py @@ -2,10 +2,13 @@ # # SPDX-License-Identifier: Apache-2.0 +from typing import Callable + import jax import jax.numpy as jnp import pytest from infra import ComparisonConfig, run_op_test_with_random_inputs +from utils import record_unary_op_test_properties @pytest.fixture @@ -20,11 +23,21 @@ def comparison_config() -> ComparisonConfig: return config -@pytest.mark.parametrize("x_shape", [(32, 32), (64, 64)]) -def test_exponential_minus_one(x_shape: tuple, comparison_config: ComparisonConfig): +@pytest.mark.parametrize("x_shape", [(32, 32), (64, 64)], ids=lambda val: f"{val}") +def test_exponential_minus_one( + x_shape: tuple, + comparison_config: ComparisonConfig, + record_tt_xla_property: Callable, +): def expm1(x: jax.Array) -> jax.Array: return jnp.expm1(x) + record_unary_op_test_properties( + record_tt_xla_property, + "jax.numpy.expm1", + "stablehlo.exponential_minus_one", + ) + run_op_test_with_random_inputs( expm1, [x_shape], comparison_config=comparison_config ) diff --git a/tests/jax/ops/test_log_plus_one.py b/tests/jax/ops/test_log_plus_one.py index 7f620bf0..c192d3c5 100644 --- a/tests/jax/ops/test_log_plus_one.py +++ b/tests/jax/ops/test_log_plus_one.py @@ -2,15 +2,24 @@ # # SPDX-License-Identifier: Apache-2.0 +from typing import Callable + import jax import jax.numpy as jnp import pytest from infra import run_op_test_with_random_inputs +from utils import record_unary_op_test_properties -@pytest.mark.parametrize("x_shape", [(32, 32), (64, 64)]) -def test_log1p(x_shape: tuple): +@pytest.mark.parametrize("x_shape", [(32, 32), (64, 64)], ids=lambda val: f"{val}") +def test_log1p(x_shape: tuple, record_tt_xla_property: Callable): def log1p(x: jax.Array) -> jax.Array: return jnp.log1p(x) + record_unary_op_test_properties( + record_tt_xla_property, + "jax.numpy.log1p", + "stablehlo.log_plus_one", + ) + run_op_test_with_random_inputs(log1p, [x_shape]) diff --git a/tests/jax/ops/test_maximum.py b/tests/jax/ops/test_maximum.py index 744da193..b96e9794 100644 --- a/tests/jax/ops/test_maximum.py +++ b/tests/jax/ops/test_maximum.py @@ -2,10 +2,13 @@ # # SPDX-License-Identifier: Apache-2.0 +from typing import Callable + import jax import jax.numpy as jnp import pytest from infra import run_op_test_with_random_inputs +from utils import record_binary_op_test_properties @pytest.mark.parametrize( @@ -14,9 +17,16 @@ [(32, 32), (32, 32)], [(64, 64), (64, 64)], ], + ids=lambda val: f"{val}", ) -def test_maximum(x_shape: tuple, y_shape: tuple): +def test_maximum(x_shape: tuple, y_shape: tuple, record_tt_xla_property: Callable): def maximum(x: jax.Array, y: jax.Array) -> jax.Array: return jnp.maximum(x, y) + record_binary_op_test_properties( + record_tt_xla_property, + "jax.numpy.maximum", + "stablehlo.maximum", + ) + run_op_test_with_random_inputs(maximum, [x_shape, y_shape]) diff --git a/tests/jax/ops/test_minimum.py b/tests/jax/ops/test_minimum.py index e457c2b2..5fe27610 100644 --- a/tests/jax/ops/test_minimum.py +++ b/tests/jax/ops/test_minimum.py @@ -2,10 +2,13 @@ # # SPDX-License-Identifier: Apache-2.0 +from typing import Callable + import jax import jax.numpy as jnp import pytest from infra import run_op_test_with_random_inputs +from utils import record_binary_op_test_properties @pytest.mark.parametrize( @@ -14,9 +17,16 @@ [(32, 32), (32, 32)], [(64, 64), (64, 64)], ], + ids=lambda val: f"{val}", ) -def test_minimum(x_shape: tuple, y_shape: tuple): +def test_minimum(x_shape: tuple, y_shape: tuple, record_tt_xla_property: Callable): def minimum(x: jax.Array, y: jax.Array) -> jax.Array: return jnp.minimum(x, y) + record_binary_op_test_properties( + record_tt_xla_property, + "jax.numpy.minimum", + "stablehlo.minimum", + ) + run_op_test_with_random_inputs(minimum, [x_shape, y_shape]) diff --git a/tests/jax/ops/test_multiply.py b/tests/jax/ops/test_multiply.py index 3ec13a52..33d91406 100644 --- a/tests/jax/ops/test_multiply.py +++ b/tests/jax/ops/test_multiply.py @@ -2,10 +2,13 @@ # # SPDX-License-Identifier: Apache-2.0 +from typing import Callable + import jax import jax.numpy as jnp import pytest from infra import run_op_test_with_random_inputs +from utils import record_binary_op_test_properties @pytest.mark.parametrize( @@ -14,9 +17,16 @@ [(32, 32), (32, 32)], [(64, 64), (64, 64)], ], + ids=lambda val: f"{val}", ) -def test_multiply(x_shape: tuple, y_shape: tuple): +def test_multiply(x_shape: tuple, y_shape: tuple, record_tt_xla_property: Callable): def multiply(x: jax.Array, y: jax.Array) -> jax.Array: return jnp.multiply(x, y) + record_binary_op_test_properties( + record_tt_xla_property, + "jax.numpy.multiply", + "stablehlo.multiply", + ) + run_op_test_with_random_inputs(multiply, [x_shape, y_shape]) diff --git a/tests/jax/ops/test_negate.py b/tests/jax/ops/test_negate.py index 49a05687..00c8489f 100644 --- a/tests/jax/ops/test_negate.py +++ b/tests/jax/ops/test_negate.py @@ -2,16 +2,25 @@ # # SPDX-License-Identifier: Apache-2.0 +from typing import Callable + import jax import jax.numpy as jnp import pytest from infra import run_op_test_with_random_inputs +from utils import record_unary_op_test_properties -@pytest.mark.parametrize("x_shape", [(32, 32), (64, 64)]) -def test_negate(x_shape: tuple): +@pytest.mark.parametrize("x_shape", [(32, 32), (64, 64)], ids=lambda val: f"{val}") +def test_negate(x_shape: tuple, record_tt_xla_property: Callable): def negate(x: jax.Array) -> jax.Array: return jnp.negative(x) + record_unary_op_test_properties( + record_tt_xla_property, + "jax.numpy.negative", + "stablehlo.negate", + ) + # Trying both negative and positive values. run_op_test_with_random_inputs(negate, [x_shape], minval=-5.0, maxval=5.0) diff --git a/tests/jax/ops/test_reduce.py b/tests/jax/ops/test_reduce.py index 1361557f..b66fc021 100644 --- a/tests/jax/ops/test_reduce.py +++ b/tests/jax/ops/test_reduce.py @@ -2,10 +2,13 @@ # # SPDX-License-Identifier: Apache-2.0 +from typing import Callable + import jax import jax.numpy as jnp import pytest from infra import ComparisonConfig, run_op_test_with_random_inputs +from utils import record_op_test_properties # TODO investigate why this doesn't pass with default comparison config. @@ -19,22 +22,44 @@ def comparison_config() -> ComparisonConfig: # TODO axis should be parametrized as well. -@pytest.mark.parametrize("x_shape", [(32, 32), (64, 64)]) -def test_reduce_sum(x_shape: tuple, comparison_config: ComparisonConfig): +@pytest.mark.parametrize("x_shape", [(32, 32), (64, 64)], ids=lambda val: f"{val}") +def test_reduce_sum( + x_shape: tuple, + comparison_config: ComparisonConfig, + record_tt_xla_property: Callable, +): def reduce_sum(x: jax.Array) -> jax.Array: return jnp.sum(x) + record_op_test_properties( + record_tt_xla_property, + "Reduce op", + "jax.numpy.sum", + "stablehlo.reduce{SUM}", + ) + run_op_test_with_random_inputs( reduce_sum, [x_shape], comparison_config=comparison_config ) # TODO axis should be parametrized as well. -@pytest.mark.parametrize("x_shape", [(32, 32), (64, 64)]) -def test_reduce_max(x_shape: tuple, comparison_config: ComparisonConfig): +@pytest.mark.parametrize("x_shape", [(32, 32), (64, 64)], ids=lambda val: f"{val}") +def test_reduce_max( + x_shape: tuple, + comparison_config: ComparisonConfig, + record_tt_xla_property: Callable, +): def reduce_max(x: jax.Array) -> jax.Array: return jnp.max(x) + record_op_test_properties( + record_tt_xla_property, + "Reduce op", + "jax.numpy.max", + "stablehlo.reduce{MAX}", + ) + run_op_test_with_random_inputs( reduce_max, [x_shape], comparison_config=comparison_config ) diff --git a/tests/jax/ops/test_reduce_window.py b/tests/jax/ops/test_reduce_window.py index 77259f88..b8e40b8b 100644 --- a/tests/jax/ops/test_reduce_window.py +++ b/tests/jax/ops/test_reduce_window.py @@ -2,10 +2,13 @@ # # SPDX-License-Identifier: Apache-2.0 +from typing import Callable + import flax import jax import pytest from infra import ComparisonConfig, random_tensor, run_op_test +from utils import record_op_test_properties @pytest.fixture @@ -48,6 +51,7 @@ def comparison_config() -> ComparisonConfig: (1, 128, 128, 64), (1, 128, 128, 128), ], + ids=lambda val: f"{val}", ) @pytest.mark.parametrize( ["window_shape", "strides", "padding"], @@ -63,12 +67,20 @@ def test_reduce_window_max( strides: tuple, padding: tuple, comparison_config: ComparisonConfig, + record_tt_xla_property: Callable, ): def maxpool2d(img: jax.Array): return flax.linen.max_pool( img, window_shape=window_shape, strides=strides, padding=padding ) + record_op_test_properties( + record_tt_xla_property, + "Maxpool op", + "flax.linen.max_pool", + "stablehlo.reduce_window{MAX}", + ) + # NOTE Some resnet convolutions seem to require bfloat16, ttnn throws in runtime # otherwise. On another note, MaxPool2d is also only supported for bfloat16 in ttnn, # so we have to run conv in bfloat16 for the time being. diff --git a/tests/jax/ops/test_remainder.py b/tests/jax/ops/test_remainder.py index 91c7402f..65ddfa13 100644 --- a/tests/jax/ops/test_remainder.py +++ b/tests/jax/ops/test_remainder.py @@ -2,10 +2,13 @@ # # SPDX-License-Identifier: Apache-2.0 +from typing import Callable + import jax import jax.lax as jlx import pytest from infra import run_op_test_with_random_inputs +from utils import record_binary_op_test_properties @pytest.mark.parametrize( @@ -14,9 +17,16 @@ [(32, 32), (32, 32)], [(64, 64), (64, 64)], ], + ids=lambda val: f"{val}", ) -def test_remainder(x_shape: tuple, y_shape: tuple): +def test_remainder(x_shape: tuple, y_shape: tuple, record_tt_xla_property: Callable): def remainder(x: jax.Array, y: jax.Array) -> jax.Array: return jlx.rem(x, y) + record_binary_op_test_properties( + record_tt_xla_property, + "jax.lax.rem", + "stablehlo.remainder", + ) + run_op_test_with_random_inputs(remainder, [x_shape, y_shape]) diff --git a/tests/jax/ops/test_reshape.py b/tests/jax/ops/test_reshape.py index 2f923e19..bd926b44 100644 --- a/tests/jax/ops/test_reshape.py +++ b/tests/jax/ops/test_reshape.py @@ -2,10 +2,13 @@ # # SPDX-License-Identifier: Apache-2.0 +from typing import Callable + import jax import jax.numpy as jnp import pytest from infra import run_op_test_with_random_inputs +from utils import record_unary_op_test_properties @pytest.mark.parametrize( @@ -15,9 +18,16 @@ ((8, 32, 32), (1, 2, 4, 32, 32)), ((8192, 128), (1, 256, 32, 128)), ], + ids=lambda val: f"{val}", ) -def test_reshape(in_shape: tuple, out_shape: tuple): +def test_reshape(in_shape: tuple, out_shape: tuple, record_tt_xla_property: Callable): def reshape(x: jax.Array): return jnp.reshape(x, out_shape) + record_unary_op_test_properties( + record_tt_xla_property, + "jax.numpy.reshape", + "stablehlo.reshape", + ) + run_op_test_with_random_inputs(reshape, [in_shape]) diff --git a/tests/jax/ops/test_rsqrt.py b/tests/jax/ops/test_rsqrt.py index 6f2a14ed..49502e23 100644 --- a/tests/jax/ops/test_rsqrt.py +++ b/tests/jax/ops/test_rsqrt.py @@ -2,16 +2,25 @@ # # SPDX-License-Identifier: Apache-2.0 +from typing import Callable + import jax import jax.lax as jlx import pytest from infra import run_op_test_with_random_inputs +from utils import record_unary_op_test_properties -@pytest.mark.parametrize("x_shape", [(32, 32), (64, 64)]) -def test_rsqrt(x_shape: tuple): +@pytest.mark.parametrize("x_shape", [(32, 32), (64, 64)], ids=lambda val: f"{val}") +def test_rsqrt(x_shape: tuple, record_tt_xla_property: Callable): def rsqrt(x: jax.Array) -> jax.Array: return jlx.rsqrt(x) + record_unary_op_test_properties( + record_tt_xla_property, + "jax.lax.rsqrt", + "stablehlo.rsqrt", + ) + # Input must be strictly positive because of sqrt(x). run_op_test_with_random_inputs(rsqrt, [x_shape], minval=0.1, maxval=10.0) diff --git a/tests/jax/ops/test_sign.py b/tests/jax/ops/test_sign.py index 2c0594ca..601f01d6 100644 --- a/tests/jax/ops/test_sign.py +++ b/tests/jax/ops/test_sign.py @@ -2,16 +2,25 @@ # # SPDX-License-Identifier: Apache-2.0 +from typing import Callable + import jax import jax.numpy as jnp import pytest from infra import run_op_test_with_random_inputs +from utils import record_unary_op_test_properties -@pytest.mark.parametrize("x_shape", [(32, 32), (64, 64)]) -def test_sign(x_shape: tuple): +@pytest.mark.parametrize("x_shape", [(32, 32), (64, 64)], ids=lambda val: f"{val}") +def test_sign(x_shape: tuple, record_tt_xla_property: Callable): def sign(x: jax.Array) -> jax.Array: return jnp.sign(x) + record_unary_op_test_properties( + record_tt_xla_property, + "jax.numpy.sign", + "stablehlo.sign", + ) + # Trying both negative and positive values. run_op_test_with_random_inputs(sign, [x_shape], minval=-5.0, maxval=5.0) diff --git a/tests/jax/ops/test_slice.py b/tests/jax/ops/test_slice.py index 3ede5e64..091fad6f 100644 --- a/tests/jax/ops/test_slice.py +++ b/tests/jax/ops/test_slice.py @@ -2,9 +2,12 @@ # # SPDX-License-Identifier: Apache-2.0 +from typing import Callable + import jax.numpy as jnp import pytest from infra import run_op_test_with_random_inputs +from utils import record_unary_op_test_properties dim0_cases = [] for begin in jnp.arange(10).tolist(): @@ -29,9 +32,11 @@ # TODO investigate if this test can be rewritten to make it easier for understanding. @pytest.mark.parametrize( - ["begin", "end", "dim"], [*dim2_cases, *dim3_cases, *dim0_cases, *dim1_cases] + ["begin", "end", "dim"], + [*dim2_cases, *dim3_cases, *dim0_cases, *dim1_cases], + ids=lambda val: f"{val}", ) -def test_slice(begin, end, dim): +def test_slice(begin: int, end: int, dim: int, record_tt_xla_property: Callable): def module_slice(a): if dim == 0: return a[begin:end, :, :, :] @@ -42,6 +47,12 @@ def module_slice(a): else: return a[:, :, :, begin:end] + record_unary_op_test_properties( + record_tt_xla_property, + "jax.lax.slice", + "stablehlo.slice", + ) + shape = [10, 10, 10, 10] shape[dim] = 128 diff --git a/tests/jax/ops/test_sqrt.py b/tests/jax/ops/test_sqrt.py index 9b8a1a2d..b5f445ea 100644 --- a/tests/jax/ops/test_sqrt.py +++ b/tests/jax/ops/test_sqrt.py @@ -2,16 +2,25 @@ # # SPDX-License-Identifier: Apache-2.0 +from typing import Callable + import jax import jax.numpy as jnp import pytest from infra import run_op_test_with_random_inputs +from utils import record_unary_op_test_properties -@pytest.mark.parametrize("x_shape", [(32, 32), (64, 64)]) -def test_sqrt(x_shape: tuple): +@pytest.mark.parametrize("x_shape", [(32, 32), (64, 64)], ids=lambda val: f"{val}") +def test_sqrt(x_shape: tuple, record_tt_xla_property: Callable): def sqrt(x: jax.Array) -> jax.Array: return jnp.sqrt(x) + record_unary_op_test_properties( + record_tt_xla_property, + "jax.numpy.sqrt", + "stablehlo.sqrt", + ) + # Input must be strictly positive because of sqrt(x). run_op_test_with_random_inputs(sqrt, [x_shape], minval=0.1, maxval=10.0) diff --git a/tests/jax/ops/test_subtract.py b/tests/jax/ops/test_subtract.py index 1c87919b..085740ce 100644 --- a/tests/jax/ops/test_subtract.py +++ b/tests/jax/ops/test_subtract.py @@ -2,10 +2,13 @@ # # SPDX-License-Identifier: Apache-2.0 +from typing import Callable + import jax import jax.numpy as jnp import pytest from infra import run_op_test_with_random_inputs +from utils import record_binary_op_test_properties @pytest.mark.parametrize( @@ -14,9 +17,16 @@ [(32, 32), (32, 32)], [(64, 64), (64, 64)], ], + ids=lambda val: f"{val}", ) -def test_subtract(x_shape: tuple, y_shape: tuple): +def test_subtract(x_shape: tuple, y_shape: tuple, record_tt_xla_property: Callable): def subtract(x: jax.Array, y: jax.Array) -> jax.Array: return jnp.subtract(x, y) + record_binary_op_test_properties( + record_tt_xla_property, + "jax.numpy.subtract", + "stablehlo.subtract", + ) + run_op_test_with_random_inputs(subtract, [x_shape, y_shape]) diff --git a/tests/jax/ops/test_transpose.py b/tests/jax/ops/test_transpose.py index c783703e..bda87b72 100644 --- a/tests/jax/ops/test_transpose.py +++ b/tests/jax/ops/test_transpose.py @@ -2,15 +2,24 @@ # # SPDX-License-Identifier: Apache-2.0 +from typing import Callable + import jax import jax.numpy as jnp import pytest from infra import run_op_test_with_random_inputs +from utils import record_unary_op_test_properties -@pytest.mark.parametrize("x_shape", [(32, 32), (64, 64)]) -def test_transpose(x_shape: tuple): +@pytest.mark.parametrize("x_shape", [(32, 32), (64, 64)], ids=lambda val: f"{val}") +def test_transpose(x_shape: tuple, record_tt_xla_property: Callable): def transpose(x: jax.Array) -> jax.Array: return jnp.transpose(x) + record_unary_op_test_properties( + record_tt_xla_property, + "jax.numpy.transpose", + "stablehlo.transpose", + ) + run_op_test_with_random_inputs(transpose, [x_shape]) diff --git a/tests/utils.py b/tests/utils.py new file mode 100644 index 00000000..e111f038 --- /dev/null +++ b/tests/utils.py @@ -0,0 +1,43 @@ +# SPDX-FileCopyrightText: (c) 2024 Tenstorrent AI ULC +# +# SPDX-License-Identifier: Apache-2.0 + +from typing import Callable + +from conftest import RecordProperties + + +def compile_fail(reason: str) -> str: + return f"Compile failed: {reason}" + + +def runtime_fail(reason: str) -> str: + return f"Runtime failed: {reason}" + + +def record_unary_op_test_properties( + record_property: Callable, framework_op_name: str, op_name: str +): + record_property(RecordProperties.OP_KIND.value, "Unary op") + record_property(RecordProperties.FRAMEWORK_OP_NAME.value, framework_op_name) + record_property(RecordProperties.OP_NAME.value, op_name) + + +def record_binary_op_test_properties( + record_property: Callable, framework_op_name: str, op_name: str +): + record_property(RecordProperties.OP_KIND.value, "Binary op") + record_property(RecordProperties.FRAMEWORK_OP_NAME.value, framework_op_name) + record_property(RecordProperties.OP_NAME.value, op_name) + + +def record_op_test_properties( + record_property: Callable, op_kind: str, framework_op_name: str, op_name: str +): + record_property(RecordProperties.OP_KIND.value, op_kind) + record_property(RecordProperties.FRAMEWORK_OP_NAME.value, framework_op_name) + record_property(RecordProperties.OP_NAME.value, op_name) + + +def record_model_test_properties(record_property: Callable, model_name: str): + record_property(RecordProperties.MODEL_NAME.value, model_name) diff --git a/third_party/CMakeLists.txt b/third_party/CMakeLists.txt index 59686167..ba6c2ae8 100644 --- a/third_party/CMakeLists.txt +++ b/third_party/CMakeLists.txt @@ -3,7 +3,7 @@ # SPDX-License-Identifier: Apache-2.0 # -set(TT_MLIR_VERSION "f6b5753630b6f2005f559d3ce223a72d94bf01d5") +set(TT_MLIR_VERSION "48ade5e60b133d9d5d0bf3c7c5c8ac4e85c649c3") set(LOGURU_VERSION "4adaa185883e3c04da25913579c451d3c32cfac1") if (TOOLCHAIN STREQUAL "ON")