From b10b497b1a75fc60c9e79a129c11df80b7386251 Mon Sep 17 00:00:00 2001 From: Wooseok Lee Date: Fri, 21 Feb 2025 10:56:18 -0600 Subject: [PATCH] Initial Shardy integration. (#2149) ### Ticket Shardy support (#2019) ### Problem description TT-xla with shardy partitioner requires shardy dialect conversion. ### What's changed - Added initial support of shardy dialect conversion - Parsed mesh info added to module attribute - Initial support of detecting pre-sharded input tensor from frontend - Allowed Conversion of variadic input/output pairs in stablehlo all_reduce op - Changed detecting mechanism of dim in stablehlo all_reduce op - Addressed pending code restructuring and return type change issues ### Checklist - [ o ] New/Existing tests provide coverage for changes --- .../StableHLOToTTIR/ShardingUtils.h | 80 ++++- .../Conversion/StableHLOToTTIR/ShardyToTTIR.h | 22 ++ include/ttmlir/Dialect/TT/IR/TTOpsTypes.td | 19 + include/ttmlir/Dialect/TT/Utils/Mesh.h | 66 ++++ include/ttmlir/Dialect/TTIR/IR/TTIROps.td | 5 +- lib/Conversion/StableHLOToTTIR/CMakeLists.txt | 1 + .../StableHLOToTTIR/ShardingUtils.cpp | 297 ++++++++++------ .../StableHLOToTTIR/ShardyToTTIRPatterns.cpp | 225 ++++++++++++ .../StableHLOToTTIR/StableHLOToTTIRPass.cpp | 21 +- .../StableHLOToTTIRPatterns.cpp | 275 ++++++++------- lib/Conversion/TTIRToTTNN/TTIRToTTNN.cpp | 8 +- lib/Dialect/TTIR/IR/TTIROps.cpp | 10 +- lib/Dialect/TTIR/Transforms/Utility.cpp | 8 +- lib/Dialect/TTNN/IR/TTNNOps.cpp | 7 +- .../lib/ttnn/operations/ccl/mesh_shard.cpp | 18 +- .../StableHLOToTTIR/ccl/all_reduce.mlir | 26 ++ .../{ccl_ops.mlir => ccl/ccl_ops_gspmd.mlir} | 66 +++- .../StableHLOToTTIR/ccl/ccl_ops_shardy.mlir | 332 ++++++++++++++++++ .../StableHLOToTTIR/ccl/e2e_dp_gspmd.mlir | 192 ++++++++++ .../StableHLOToTTIR/ccl/e2e_dp_shardy.mlir | 163 +++++++++ .../StableHLOToTTIR/ccl/e2e_fsdp_gspmd.mlir | 205 +++++++++++ .../StableHLOToTTIR/ccl/e2e_fsdp_shardy.mlir | 175 +++++++++ .../ccl/e2e_fsdp_tp_gspmd.mlir | 166 +++++++++ .../ccl/e2e_fsdp_tp_shardy.mlir | 136 +++++++ .../StableHLOToTTIR/ccl/e2e_tp_gspmd.mlir | 157 +++++++++ .../StableHLOToTTIR/ccl/e2e_tp_shardy.mlir | 103 ++++++ test/ttmlir/Dialect/TTNN/ccl/mesh_shard.mlir | 12 +- 27 files changed, 2513 insertions(+), 282 deletions(-) create mode 100644 include/ttmlir/Conversion/StableHLOToTTIR/ShardyToTTIR.h create mode 100644 include/ttmlir/Dialect/TT/Utils/Mesh.h create mode 100644 lib/Conversion/StableHLOToTTIR/ShardyToTTIRPatterns.cpp create mode 100644 test/ttmlir/Conversion/StableHLOToTTIR/ccl/all_reduce.mlir rename test/ttmlir/Conversion/StableHLOToTTIR/{ccl_ops.mlir => ccl/ccl_ops_gspmd.mlir} (99%) create mode 100644 test/ttmlir/Conversion/StableHLOToTTIR/ccl/ccl_ops_shardy.mlir create mode 100644 test/ttmlir/Conversion/StableHLOToTTIR/ccl/e2e_dp_gspmd.mlir create mode 100644 test/ttmlir/Conversion/StableHLOToTTIR/ccl/e2e_dp_shardy.mlir create mode 100644 test/ttmlir/Conversion/StableHLOToTTIR/ccl/e2e_fsdp_gspmd.mlir create mode 100644 test/ttmlir/Conversion/StableHLOToTTIR/ccl/e2e_fsdp_shardy.mlir create mode 100644 test/ttmlir/Conversion/StableHLOToTTIR/ccl/e2e_fsdp_tp_gspmd.mlir create mode 100644 test/ttmlir/Conversion/StableHLOToTTIR/ccl/e2e_fsdp_tp_shardy.mlir create mode 100644 test/ttmlir/Conversion/StableHLOToTTIR/ccl/e2e_tp_gspmd.mlir create mode 100644 test/ttmlir/Conversion/StableHLOToTTIR/ccl/e2e_tp_shardy.mlir diff --git a/include/ttmlir/Conversion/StableHLOToTTIR/ShardingUtils.h b/include/ttmlir/Conversion/StableHLOToTTIR/ShardingUtils.h index 665dbef03a..ddb3126dff 100644 --- a/include/ttmlir/Conversion/StableHLOToTTIR/ShardingUtils.h +++ b/include/ttmlir/Conversion/StableHLOToTTIR/ShardingUtils.h @@ -8,20 +8,82 @@ #include "ttmlir/Dialect/TT/IR/TT.h" #include "ttmlir/Dialect/TT/IR/TTOpsTypes.h" +#include "mlir/IR/BuiltinOps.h" +#include "shardy/dialect/sdy/ir/dialect.h" +#include "llvm/Support/Error.h" +#include "llvm/Support/ErrorHandling.h" + namespace mlir::tt::sharding_utils { #if TTMLIR_ENABLE_STABLEHLO -struct MeshSharding { - mlir::tt::MeshShardDirection shardDirection; - mlir::tt::MeshShardType shardType; - bool lastTileDimReplicate; - llvm::SmallVector shardShape; - llvm::SmallVector shardDims; - llvm::SmallVector meshShape; + +class MeshSharding { +public: + MeshSharding() {}; + ~MeshSharding() {}; + + // Convert mhlo.sharding string to meshSharding. + llvm::Expected + convertGSPMDShardingToMeshSharding(StringRef shardingStr); + + // Convert sdy.sharding to meshSharding. + llvm::Expected + convertSdyShardingToMeshSharding(mlir::sdy::TensorShardingAttr sdySharding, + mlir::sdy::MeshAttr mesh, + mlir::tt::MeshShardDirection direction); + + // Force dummy sharding op by setting shard_type to manual. The mesh_shard op + // will be ignored at runtime by simply copying input tensor to output. + void setDummyShardingOp() { shardType = mlir::tt::MeshShardType::Manual; } + + // Getter functions. + mlir::tt::MeshShardDirection getShardDirection() const { + return shardDirection; + } + mlir::tt::MeshShardType getShardType() const { return shardType; } + llvm::ArrayRef getShardShape() const { return shardShape; } + llvm::ArrayRef getShardDims() const { return shardDims; } + llvm::ArrayRef getMeshShape() const { return meshShape; } + +private: + // Parse GSPMD devices string and fill out MeshSharding info. + llvm::Expected parseGSPMDDevicesStr(StringRef devicesStr); + + // Based on current MeshSharding info, finalize sharding dimensions. + llvm::Expected determineGSPMDShardingDims(); + + // Set sharyType other than devices and reset values. + void setNonDevicesShardType(tt::MeshShardType targetShardType) { + assert(targetShardType != tt::MeshShardType::Devices); + shardType = targetShardType; + // Specific values are required to fill corresponding attributes in + // mesh_shard operation. + shardShape = llvm::SmallVector{1}; + shardDims = llvm::SmallVector{-1}; + meshShape = llvm::SmallVector{-1}; + } + +private: + mlir::tt::MeshShardDirection shardDirection = + mlir::tt::MeshShardDirection::ShardToFull; + mlir::tt::MeshShardType shardType = mlir::tt::MeshShardType::Manual; + llvm::SmallVector shardShape{-1}; + llvm::SmallVector shardDims{-1}; + llvm::SmallVector meshShape{-1}; + llvm::SmallVector deviceIds{-1}; + bool lastTileDimReplicate = false; }; -LogicalResult parseGSPMDShardingAttr(StringRef shardingStr, - MeshSharding &meshSharding); +// Sharding related string definitions from open-xla +// https://github.com/openxla/xla/blob/main/xla/service/spmd/shardy/constants.h + +inline constexpr llvm::StringRef kShardingCustomCallTargetName = "Sharding"; +inline constexpr llvm::StringRef kSPMDFullToShardShapeCallTargetName = + "SPMDFullToShardShape"; +inline constexpr llvm::StringRef kSPMDShardToFullShapeCallTargetName = + "SPMDShardToFullShape"; +inline constexpr llvm::StringRef kXlaShardingAttr = "mhlo.sharding"; + #endif } // namespace mlir::tt::sharding_utils diff --git a/include/ttmlir/Conversion/StableHLOToTTIR/ShardyToTTIR.h b/include/ttmlir/Conversion/StableHLOToTTIR/ShardyToTTIR.h new file mode 100644 index 0000000000..e449ee3b54 --- /dev/null +++ b/include/ttmlir/Conversion/StableHLOToTTIR/ShardyToTTIR.h @@ -0,0 +1,22 @@ +// SPDX-FileCopyrightText: (c) 2024 Tenstorrent AI ULC +// +// SPDX-License-Identifier: Apache-2.0 + +#ifndef TTMLIR_CONVERSION_STABLEHLOTOTTIR_SHARDYTOTTIR_H +#define TTMLIR_CONVERSION_STABLEHLOTOTTIR_SHARDYTOTTIR_H + +#include "mlir/Pass/Pass.h" +#include "mlir/Transforms/DialectConversion.h" + +namespace mlir::tt { + +#ifdef TTMLIR_ENABLE_STABLEHLO + +void populateShardyToTTIRPatterns(MLIRContext *ctx, RewritePatternSet &patterns, + TypeConverter &typeConverter); + +#endif + +} // namespace mlir::tt + +#endif // TTMLIR_CONVERSION_STABLEHLOTOTTIR_SHARDYTOTTIR_H diff --git a/include/ttmlir/Dialect/TT/IR/TTOpsTypes.td b/include/ttmlir/Dialect/TT/IR/TTOpsTypes.td index c0a0d5947d..224e8e1547 100644 --- a/include/ttmlir/Dialect/TT/IR/TTOpsTypes.td +++ b/include/ttmlir/Dialect/TT/IR/TTOpsTypes.td @@ -474,6 +474,25 @@ def TT_MeshShardTypeAttr : EnumAttr let assemblyFormat = "`<` $value `>`"; } +def TT_MeshAttr : TT_Attr<"Mesh", "mesh", []> { + let summary = "Mesh reference attribute in TT dialect."; + let description = [{ + Describes a mesh config including name and shape. + }]; + let parameters = (ins "StringAttr":$name, + ArrayRefParameter<"int64_t">:$shape); + let assemblyFormat = "`<` $name `=` custom($shape) `>`"; +} + +def TT_MeshesAttr : TT_Attr<"Meshes", "meshes"> { + let summary = "TT system meshes attribute"; + let description = [{ + TT system meshes attribute that can include multiple mesh configs used for networks. + }]; + let parameters = (ins ArrayRefParameter<"MeshAttr">:$meshes); + let assemblyFormat = "`<` `[` $meshes `]` `>`"; +} + //===----------------------------------------------------------------------===// // TT type definitions //===----------------------------------------------------------------------===// diff --git a/include/ttmlir/Dialect/TT/Utils/Mesh.h b/include/ttmlir/Dialect/TT/Utils/Mesh.h new file mode 100644 index 0000000000..c4148e66ef --- /dev/null +++ b/include/ttmlir/Dialect/TT/Utils/Mesh.h @@ -0,0 +1,66 @@ +// SPDX-FileCopyrightText: (c) 2024 Tenstorrent AI ULC +// +// SPDX-License-Identifier: Apache-2.0 + +#ifndef TTMLIR_DIALECT_TT_UTILS_MESH_H +#define TTMLIR_DIALECT_TT_UTILS_MESH_H + +#include "ttmlir/Dialect/TT/IR/TTOpsTypes.h" +#include "ttmlir/Dialect/TTIR/IR/TTIROps.h" + +#include "mlir/IR/BuiltinOps.h" +#include "llvm/Support/Error.h" +#include "llvm/Support/ErrorHandling.h" + +namespace mlir::tt::utils { + +// Add a new mesh info to module attribute. +inline void addMeshToModuleAttribute(PatternRewriter &rewriter, + mlir::ModuleOp module, StringAttr meshName, + llvm::ArrayRef meshShape) { + MLIRContext *context = rewriter.getContext(); + llvm::SmallVector meshes; + if (auto meshesAttr = + module->getAttrOfType(tt::MeshesAttr::name)) { + meshes = llvm::SmallVector(meshesAttr.getMeshes()); + } + // Avoid adding multiple meshes with the same name and shape as GSPMD may try + // to add the same meshes. + if (llvm::all_of(meshes, + [&](tt::MeshAttr m) { return m.getName() != meshName; })) { + meshes.push_back(mlir::tt::MeshAttr::get(context, meshName, meshShape)); + rewriter.modifyOpInPlace(module, [&]() { + module->setAttr(tt::MeshesAttr::name, + tt::MeshesAttr::get(context, meshes)); + }); + } +} + +// Determine hardware mesh config for DeviceAttr. +// If none exists, the empty meshShape leads to single device config. +// If either option.meshShape or meshes exists, use one of them. +// If both exist, compare mesh and throw error if they are different. +inline llvm::Expected> +determineMeshShape(mlir::ModuleOp module, llvm::ArrayRef meshShape) { + if (auto meshesAttr = + module->getAttrOfType(tt::MeshesAttr::name)) { + llvm::ArrayRef meshAttr = meshesAttr.getMeshes(); + if (meshAttr.empty()) { + return llvm::SmallVector(meshShape); + } + // For now, use the first meshShape. + llvm::ArrayRef meshFromMeshes = meshAttr[0].getShape(); + // If both meshes exist, they should be identical. Otherwise, throw error. + if (!meshShape.empty() && !llvm::equal(meshShape, meshFromMeshes)) { + return llvm::createStringError( + std::errc::invalid_argument, + "Option.meshShape and mesh info from graph should be identical."); + } + return llvm::SmallVector(meshFromMeshes); + } + return llvm::SmallVector(meshShape); +} + +} // namespace mlir::tt::utils + +#endif // TTMLIR_DIALECT_TT_UTILS_MESH_H diff --git a/include/ttmlir/Dialect/TTIR/IR/TTIROps.td b/include/ttmlir/Dialect/TTIR/IR/TTIROps.td index 46368ffc74..aa919829c2 100644 --- a/include/ttmlir/Dialect/TTIR/IR/TTIROps.td +++ b/include/ttmlir/Dialect/TTIR/IR/TTIROps.td @@ -14,6 +14,7 @@ include "mlir/Interfaces/InferTypeOpInterface.td" include "mlir/Interfaces/DestinationStyleOpInterface.td" include "mlir/Interfaces/ControlFlowInterfaces.td" include "mlir/Interfaces/SideEffectInterfaces.td" +include "mlir/IR/BuiltinAttributes.td" include "mlir/IR/CommonTypeConstraints.td" include "mlir/IR/CommonAttrConstraints.td" include "mlir/IR/OpBase.td" @@ -1922,7 +1923,7 @@ def TTIR_AllReduceOp : TTIR_DPSOp<"all_reduce"> { AllReduce op. }]; - let arguments = (ins Variadic:$inputs, + let arguments = (ins AnyRankedTensor:$input, AnyRankedTensor:$output, I64ElementsAttr:$replica_groups, SI32Attr:$dim, @@ -1931,7 +1932,7 @@ def TTIR_AllReduceOp : TTIR_DPSOp<"all_reduce"> { TT_ReduceTypeAttr:$reduce_type ); - let results = (outs Variadic:$results); + let results = (outs AnyRankedTensor:$result); let extraClassDeclaration = [{ MutableOperandRange getDpsInitsMutable() { return getOutputMutable(); } diff --git a/lib/Conversion/StableHLOToTTIR/CMakeLists.txt b/lib/Conversion/StableHLOToTTIR/CMakeLists.txt index 0a833684a3..ce154ccd1b 100644 --- a/lib/Conversion/StableHLOToTTIR/CMakeLists.txt +++ b/lib/Conversion/StableHLOToTTIR/CMakeLists.txt @@ -11,6 +11,7 @@ add_mlir_conversion_library(TTMLIRStableHLOToTTIR ShardingUtils.cpp StableHLOToTTIRPass.cpp StableHLOToTTIRPatterns.cpp + ShardyToTTIRPatterns.cpp ADDITIONAL_HEADER_DIRS ${PROJECT_SOURCE_DIR}/include/ttmlir/Conversion/StableHLOToTTIR diff --git a/lib/Conversion/StableHLOToTTIR/ShardingUtils.cpp b/lib/Conversion/StableHLOToTTIR/ShardingUtils.cpp index abfd991252..90b5273ef5 100644 --- a/lib/Conversion/StableHLOToTTIR/ShardingUtils.cpp +++ b/lib/Conversion/StableHLOToTTIR/ShardingUtils.cpp @@ -16,48 +16,86 @@ namespace mlir { namespace tt { namespace sharding_utils { +// Parse GSPMD devices string and fill out MeshSharding info. +llvm::Expected MeshSharding::parseGSPMDDevicesStr(StringRef devicesStr) { + // This function extract dimensions from targetDimsStr "[x,y,z]" and saves it + // to targetDims. + auto parseDimsFromDimensionStr = + [](StringRef targetDimsStr, SmallVector &targetDims) -> bool { + if (!targetDimsStr.consume_front("[") || !targetDimsStr.consume_back("]")) { + return false; + } + SmallVector dimsStr; + targetDimsStr.split(dimsStr, ","); + targetDims.clear(); + for (auto dim : dimsStr) { + int64_t d; + if (dim.getAsInteger(10, d)) { + return false; + } + targetDims.push_back(d); + } + return true; + }; + + // devicesStr is generated by splitting whole string using space " ". Thus, it + // is not supposed to include any trailing space. e.g., "[4,2,1]<=[2,4]T(1,0)" + auto [axesStr, restStr] = devicesStr.split("<="); + // Parse devices string before "<=" e.g., [4,2,1]. + if (!parseDimsFromDimensionStr(axesStr, shardShape)) { + return llvm::createStringError("Fail to parse GSPMD devices axes string: " + + axesStr); + } + // Parse devices string after "<=" e.g., [8] or [2,4]T(1,0). + auto [reshapeStr, unused] = restStr.split("T"); + // Parse reshape[0] string e.g., [8] or [2,4]. + if (!parseDimsFromDimensionStr(reshapeStr, meshShape)) { + return llvm::createStringError( + "Fail to parse GSPMD devices reshape string: " + reshapeStr); + } + return true; +} + // Based on current MeshSharding info, finalize sharding dimensions. -static LogicalResult -determineGSPMDShardingDims(MeshSharding &meshSharding, - const bool lastTileDimReplicate) { +llvm::Expected MeshSharding::determineGSPMDShardingDims() { // This code is based on following assumption. // 1. Hardware mesh is two dimenion such as 2x4, 1x2, ... // 2. Hardware mesh only supports either line or mesh config // e.g., t3k 1x8 or 2x4 - SmallVector shardShape = meshSharding.shardShape; - if (meshSharding.lastTileDimReplicate) { - meshSharding.shardShape.pop_back(); + SmallVector orgShardShape = shardShape; + if (lastTileDimReplicate) { + shardShape.pop_back(); } // Determine obvious properties first. - bool reverseOrder = meshSharding.meshShape.size() != 1; + bool reverseOrder = meshShape.size() != 1; // totalDevices is the total number of multi-chips such as 8 for t3k. Thus, no // overflow is expected with int64_t. - int64_t totalDevices = std::accumulate( - meshSharding.meshShape.begin(), meshSharding.meshShape.end(), int64_t{1}, - std::multiplies()); + int64_t totalDevices = + std::accumulate(meshShape.begin(), meshShape.end(), int64_t{1}, + std::multiplies()); // Detect line device config (1xN). bool isLineDeviceConfig = - llvm::any_of(shardShape, [&](int64_t s) { return s == totalDevices; }); + llvm::any_of(orgShardShape, [&](int64_t s) { return s == totalDevices; }); // Detect hardware mesh. For reverse order sharding, meshShape already // includes hardware mesh. For non reverse order case, extract hardware mesh // by traversing from front to back and picking none-zero values. if (!reverseOrder) { if (isLineDeviceConfig) { // Device with line config must be 1xN, not Nx1. - meshSharding.meshShape = {1, meshSharding.meshShape[0]}; + meshShape = {1, meshShape[0]}; } else { - // e.g., shardShape [1,2,4] or [2,1,4] leads to [2,4] - SmallVector meshShape(shardShape.size(), 0); - auto *it = llvm::copy_if(shardShape, meshShape.begin(), - [](int64_t s) { return !(s == int64_t{1}); }); - meshShape.resize(std::distance(meshShape.begin(), it)); - meshSharding.meshShape = meshShape; + meshShape.clear(); + // e.g., orgShardShape [1,2,4] or [2,1,4] leads to [2,4] + llvm::copy_if(orgShardShape, std::back_inserter(meshShape), + [](int64_t s) { return s != int64_t{1}; }); } } - if (meshSharding.meshShape.size() != 2) { + if (meshShape.size() != 2) { // Currently, we are only supporting 2d hardware mesh config. - return failure(); + return llvm::createStringError( + "Only support 2d hardware mesh config. mesh.size()=%d", + meshShape.size()); } // Determine shardDims based on the shardShape and meshShape. @@ -66,137 +104,178 @@ determineGSPMDShardingDims(MeshSharding &meshSharding, // on the sharding intention. // For example, if shardShape is [1,2,1,4], shard_dims is supposed to be [1, // 3] or if shardShape is [1,4,1,2], then shard_dims should be [3, 1]. - meshSharding.shardDims.assign(meshSharding.meshShape.size(), -1); + shardDims.assign(meshShape.size(), -1); // Skip the first 1 of 1xN hardware. uint64_t shardingCnt = isLineDeviceConfig; - for (uint64_t i = 0; i < meshSharding.shardShape.size(); ++i) { + for (uint64_t i = 0; i < shardShape.size(); ++i) { // Check sharding dimension only. - if (meshSharding.shardShape[i] != 1) { - auto shardDimIdx = (reverseOrder) - ? (meshSharding.meshShape.size() - 1 - shardingCnt) - : shardingCnt; - if (meshSharding.shardShape[i] != meshSharding.meshShape[shardDimIdx]) { - // shardShape[i] and meshShape[shardDimIdx] is supposed to be identical. - return failure(); + if (shardShape[i] != 1) { + auto shardDimIdx = + (reverseOrder) ? (meshShape.size() - 1 - shardingCnt) : shardingCnt; + // Positive shardShape[i] and meshShape[shardDimIdx] is supposed to be + // identical. + if (shardShape[i] > 0 && shardShape[i] != meshShape[shardDimIdx]) { + return llvm::createStringError( + "Fail to determine shardDims. shardShape[%d] (%d) != meshShape[%d] " + "(%d)", + i, shardShape[i], shardDimIdx, meshShape[shardDimIdx]); } - meshSharding.shardDims[shardDimIdx] = i; + shardDims[shardDimIdx] = i; shardingCnt++; } } - return success(); -} - -// Parse GSPMD devices string and fill out MeshSharding info. -static LogicalResult parseGSPMDDevicesStr(const StringRef devicesStr, - MeshSharding &meshSharding) { - // This function extract dimensions from targetDimsStr "[x,y,z]" and saves it - // to targetDims. - auto parseDimsFromDimensionStr = - [](StringRef targetDimsStr, - SmallVector &targetDims) -> LogicalResult { - if (!targetDimsStr.consume_front("[") || !targetDimsStr.consume_back("]")) { - return failure(); - } - SmallVector dimsStr; - targetDimsStr.split(dimsStr, ","); - for (auto dim : dimsStr) { - int64_t d; - if (dim.getAsInteger(10, d)) { - return failure(); - } - targetDims.push_back(d); - } - return success(); - }; - - // devicesStr is generated by splitting whole string using space " ". Thus, it - // is not supposed to include any trailing space. e.g., "[4,2,1]<=[2,4]T(1,0)" - auto [axesStr, restStr] = devicesStr.split("<="); - // Parse devices string before "<=" e.g., [4,2,1]. - if (failed(parseDimsFromDimensionStr(axesStr, meshSharding.shardShape))) { - return failure(); - } - // Parse devices string after "<=" e.g., [8] or [2,4]T(1,0). - SmallVector reshapeStr; - restStr.split(reshapeStr, "T"); - // Parse reshape[0] string e.g., [8] or [2,4]. - if (failed( - parseDimsFromDimensionStr(reshapeStr[0], meshSharding.meshShape))) { - return failure(); - } - return success(); + return true; } // OpenXLA has its own lexer, but we will use simple string-based parser here. // This parsing is mainly based on "Sharding Attribute" section in // https://github.com/sdasgup3/stablehlo/blob/80082431d1af0933e6202ecc8a6f8801e039235b/docs/spec.md#sharding-attribute -LogicalResult parseGSPMDShardingAttr(StringRef shardingStr, - MeshSharding &meshSharding) { - MeshShardType shardType = mlir::tt::MeshShardType::Manual; - bool lastTileDimReplicate = false; +llvm::Expected +MeshSharding::convertGSPMDShardingToMeshSharding(StringRef shardingStr) { + shardType = mlir::tt::MeshShardType::Manual; + lastTileDimReplicate = false; - // Parse sting and tokenize. + // Parse string and tokenize. if (!shardingStr.consume_front("{") || !shardingStr.consume_back("}")) { - return failure(); + return llvm::createStringError(std::errc::invalid_argument, + "Fail to parse GSPMD sharding."); } SmallVector shardingStrTokens; shardingStr.split(shardingStrTokens, " "); - // Parse tokens. + // Parse string tokens. for (auto str : shardingStrTokens) { if (str.contains("manual")) { - assert(shardType == mlir::tt::MeshShardType::Manual && - "Fail to parse sharding info."); // manual: already sharded, so no action is needed - meshSharding.shardShape.push_back(1); + if (shardType != tt::MeshShardType::Manual) { + return llvm::createStringError(std::errc::invalid_argument, + "Fail to parse GSPMD sharding."); + } + setNonDevicesShardType(tt::MeshShardType::Manual); } else if (str.contains("replicated")) { - assert(shardType == mlir::tt::MeshShardType::Manual && - "Fail to parse sharding info."); // replicated: all devices have whole data - shardType = mlir::tt::MeshShardType::Replicate; - meshSharding.shardShape.push_back(1); + if (shardType != tt::MeshShardType::Manual) { + return llvm::createStringError(std::errc::invalid_argument, + "Fail to parse GSPMD sharding."); + } + setNonDevicesShardType(tt::MeshShardType::Replicate); } else if (str.contains("maximal")) { - assert(shardType == mlir::tt::MeshShardType::Manual && - "Fail to parse sharding info."); // maximal: one device has whole data - shardType = mlir::tt::MeshShardType::Maximal; - meshSharding.shardShape.push_back(1); + if (shardType != tt::MeshShardType::Manual) { + return llvm::createStringError(std::errc::invalid_argument, + "Fail to parse GSPMD sharding."); + } + setNonDevicesShardType(tt::MeshShardType::Maximal); } else if (str.consume_front("device=")) { // maximal should followed by "device" to put data on - assert(shardType == mlir::tt::MeshShardType::Maximal && - "Fail to parse sharding info."); + if (shardType != tt::MeshShardType::Maximal) { + return llvm::createStringError(std::errc::invalid_argument, + "Fail to parse GSPMD sharding."); + } int64_t d; if (str.getAsInteger(10, d)) { - return failure(); + return llvm::createStringError(std::errc::invalid_argument, + "Fail to parse GSPMD sharding."); } - meshSharding.shardShape.push_back(d); + deviceIds.push_back(d); } else if (str.consume_front("devices=")) { // other: "devices" detail sharding plan - assert(shardType == mlir::tt::MeshShardType::Manual && - "Fail to parse sharding info."); - shardType = mlir::tt::MeshShardType::Devices; - if (failed(parseGSPMDDevicesStr(str, meshSharding))) { - return failure(); + if (shardType != tt::MeshShardType::Manual) { + return llvm::createStringError(std::errc::invalid_argument, + "Fail to parse GSPMD sharding."); + } + shardType = tt::MeshShardType::Devices; + auto error = parseGSPMDDevicesStr(str); + if (auto e = error.takeError()) { + return e; } } else if (str.contains("last_tile_dim_replicate")) { - assert(shardType == mlir::tt::MeshShardType::Devices && - "Fail to parse sharding info."); // other: replicate last tile dim + if (shardType != tt::MeshShardType::Devices) { + return llvm::createStringError( + std::errc::invalid_argument, + "Fail to parse GSPMD sharding in last_tile_dim_replicate."); + } lastTileDimReplicate = true; + } else { + return llvm::createStringError("Unknown GSPMD sharding: " + str); + } + } + + // Determine shard dims for devices. + if (shardType == tt::MeshShardType::Devices) { + auto error = determineGSPMDShardingDims(); + if (auto e = error.takeError()) { + return e; + } + } + + return true; +} + +// Convert sdy.sharding to meshSharding based on sdy::MeshAttr. +llvm::Expected MeshSharding::convertSdyShardingToMeshSharding( + sdy::TensorShardingAttr sdySharding, sdy::MeshAttr meshAttr, + tt::MeshShardDirection direction) { + + shardDirection = direction; + + if (meshAttr.getAxes().empty()) { + if (meshAttr.getDeviceIds().empty()) { + // replicated + setNonDevicesShardType(mlir::tt::MeshShardType::Replicate); + } else { + // maximal + setNonDevicesShardType(mlir::tt::MeshShardType::Maximal); + deviceIds = llvm::SmallVector(meshAttr.getDeviceIds()); + } + return true; + } + + shardType = tt::MeshShardType::Devices; + shardShape.assign(sdySharding.getRank(), 1); + shardDims.assign(meshAttr.getAxes().size(), -1); + + llvm::SmallDenseMap<::llvm::StringRef, int64_t> axisPosition; + for (auto [idx, meshAxisAttr] : llvm::enumerate(meshAttr.getAxes())) { + axisPosition[meshAxisAttr.getName()] = idx; + meshShape.push_back(meshAxisAttr.getSize()); + } + + if (!sdySharding.isFullyClosed()) { + return llvm::createStringError( + "Sharding with open dimension is currently not supported."); + } + + // Iterate each dimSharding in TensorShardingAttr + for (auto [dimIdx, dimSharding] : + llvm::enumerate(sdySharding.getDimShardings())) { + for (auto [axisIdx, axes] : llvm::enumerate(dimSharding.getAxes())) { + // Check if there is any subaxis sharding + if (auto subAxis = axes.getSubAxisInfo()) { + return llvm::createStringError( + "Sharding with subaxis partitioning is currently not supported."); + } + shardShape[dimIdx] *= axes.getSize(meshAttr); + // Sharding makes sense when it is higher than 1. + if (axes.getSize(meshAttr) > 1) { + shardDims[axisPosition[axes.getName()]] = dimIdx; + } } } - meshSharding.shardType = shardType; - meshSharding.lastTileDimReplicate = lastTileDimReplicate; - // Parse devices - if (meshSharding.shardType == mlir::tt::MeshShardType::Devices) { - return determineGSPMDShardingDims(meshSharding, lastTileDimReplicate); + // totalPartition is the total number of multi-chips such as 8 for t3k. Thus, + // no overflow is expected with int64_t. + int64_t totalPartition = + std::accumulate(shardShape.begin(), shardShape.end(), int64_t{1}, + std::multiplies()); + // No partition indicates replicate to all devices. + if (totalPartition == 1) { + setNonDevicesShardType(mlir::tt::MeshShardType::Replicate); } - meshSharding.shardDims.push_back(-1); - meshSharding.meshShape.push_back(-1); - return success(); + return true; } } // namespace sharding_utils diff --git a/lib/Conversion/StableHLOToTTIR/ShardyToTTIRPatterns.cpp b/lib/Conversion/StableHLOToTTIR/ShardyToTTIRPatterns.cpp new file mode 100644 index 0000000000..3ca86d29c2 --- /dev/null +++ b/lib/Conversion/StableHLOToTTIR/ShardyToTTIRPatterns.cpp @@ -0,0 +1,225 @@ +// SPDX-FileCopyrightText: (c) 2024 Tenstorrent AI ULC +// +// SPDX-License-Identifier: Apache-2.0 + +#include "ttmlir/Conversion/StableHLOToTTIR/ShardyToTTIR.h" + +#include "ttmlir/Conversion/StableHLOToTTIR/ShardingUtils.h" +#include "ttmlir/Dialect/TT/IR/TT.h" +#include "ttmlir/Dialect/TT/IR/TTOpsTypes.h" +#include "ttmlir/Dialect/TT/Utils/Mesh.h" +#include "ttmlir/Dialect/TTIR/IR/TTIR.h" +#include "ttmlir/Dialect/TTIR/IR/TTIROps.h" +#include "ttmlir/Utils.h" + +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/Func/Transforms/FuncConversions.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/Dialect.h" +#include "mlir/IR/MLIRContext.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Pass/Pass.h" +#include "shardy/dialect/sdy/ir/constants.h" +#include "shardy/dialect/sdy/ir/dialect.h" +#include "shardy/dialect/sdy/ir/utils.h" +#include "stablehlo/dialect/StablehloOps.h" +#include "llvm/ADT/ArrayRef.h" +#include "llvm/Support/ErrorHandling.h" + +namespace { + +class ShardyToTTIRManualComputationOpConversionPattern + : public mlir::OpConversionPattern { + using mlir::OpConversionPattern< + mlir::sdy::ManualComputationOp>::OpConversionPattern; + +public: + llvm::LogicalResult + matchAndRewrite(mlir::sdy::ManualComputationOp srcOp, + mlir::sdy::ManualComputationOp::Adaptor adaptor, + mlir::ConversionPatternRewriter &rewriter) const override { + auto module = srcOp->getParentOfType(); + if (!module) { + llvm_unreachable("mlir::sdy::ManualComputationOp requires module as one " + "of parent ops."); + } + mlir::SymbolTable symbolTable(module); + mlir::Location loc = srcOp.getLoc(); + + auto shardings = llvm::concat( + srcOp.getInShardings().getShardings(), + srcOp.getOutShardings().getShardings()); + if (shardings.begin() == shardings.end()) { + // Inline the body with no in/out shardings. + rewriter.eraseOp(getBodyTerminator(srcOp)); + rewriter.inlineBlockBefore(&srcOp.getBody().front(), srcOp, + srcOp.getOperands()); + rewriter.eraseOp(srcOp); + return llvm::success(); + } + + // ManualComputationOp include one mesh for all in/out shardings, so we can + // pick up first sharding and get mesh info. + mlir::sdy::TensorShardingAttr firstSharding = *shardings.begin(); + mlir::sdy::MeshAttr targetMesh = firstSharding.getMesh(symbolTable); + if (!targetMesh) { + llvm_unreachable( + "mlir::sdy::TensorShardingAttr requires mesh definition."); + } + + // Currently, sharding operation on device memory is not supported, so + // remove any sharding in body of manual computation op and compute + // with replicated tensor. + srcOp.getBody().front().walk( + [&](mlir::Operation *opInBody) { + if (mlir::isa(opInBody)) { + return mlir::WalkResult::skip(); + } + mlir::sdy::TensorShardingPerValueAttr shardingPerValue = + opInBody->getAttrOfType( + mlir::sdy::kShardingAttr); + if (!shardingPerValue) { + return mlir::WalkResult::advance(); + } + rewriter.modifyOpInPlace(opInBody, [&]() { + opInBody->removeAttr(mlir::sdy::kShardingAttr); + }); + return mlir::WalkResult::advance(); + }); + + auto funcOp = srcOp->getParentOfType(); + + // Add mesh_shard (FullToShardShape) for inputs. + llvm::SmallVector fullToShardResults; + for (auto [globalOperand, argSharding, localArgType] : llvm::zip_equal( + srcOp.getOperands(), srcOp.getInShardings().getShardings(), + srcOp.getBody().getArgumentTypes())) { + + mlir::tt::sharding_utils::MeshSharding meshSharding; + auto error = meshSharding.convertSdyShardingToMeshSharding( + argSharding, targetMesh, mlir::tt::MeshShardDirection::FullToShard); + if (auto e = error.takeError()) { + return rewriter.notifyMatchFailure(srcOp, llvm::toString(std::move(e))); + } + + // JAX automatic sharding pre-shards input tensors and provides multiple + // buffers. Thus, mesh sharding operations should not shard the tensors + // twice if they are function arguments and pre-sharded by frontend. + // Runtime ignores mesh sharding operation if it is set as manual + // sharding. + if (auto blockArg = mlir::dyn_cast(globalOperand)) { + auto argNum = blockArg.getArgNumber(); + if (mlir::sdy::TensorShardingAttr argShardingAttr = + funcOp.getArgAttrOfType( + argNum, mlir::sdy::kShardingAttr)) { + if (argShardingAttr == argSharding) { + meshSharding.setDummyShardingOp(); + rewriter.modifyOpInPlace(funcOp, [&]() { + funcOp.removeArgAttr(argNum, mlir::sdy::kShardingAttr); + }); + } else { + llvm_unreachable("Manual computation op and function argument " + "shardings are different."); + } + } + } + + auto outputType = mlir::cast( + getTypeConverter()->convertType(localArgType)); + + auto meshShardOp = + ttmlir::utils::createDPSOp( + rewriter, loc, outputType, globalOperand, + meshSharding.getShardType(), meshSharding.getShardDirection(), + meshSharding.getShardShape(), meshSharding.getShardDims()); + + fullToShardResults.push_back(meshShardOp.getResult()); + } + + // Add mesh_shard (ShardToFullShape) for outputs. + rewriter.setInsertionPointAfter(srcOp); + mlir::Operation *sdyReturn = getBodyTerminator(srcOp); + for (auto [returnOperand, outSharding, opResult] : llvm::zip_equal( + sdyReturn->getOpOperands(), srcOp.getOutShardings().getShardings(), + srcOp.getResults())) { + + mlir::tt::sharding_utils::MeshSharding meshSharding; + auto error = meshSharding.convertSdyShardingToMeshSharding( + outSharding, targetMesh, mlir::tt::MeshShardDirection::ShardToFull); + if (auto e = error.takeError()) { + return rewriter.notifyMatchFailure(srcOp, llvm::toString(std::move(e))); + } + + auto inputOperand = returnOperand.get(); + auto inputType = mlir::cast( + getTypeConverter()->convertType(inputOperand.getType())); + if (inputType != inputOperand.getType()) { + inputOperand.setType(inputType); + } + + auto outputType = mlir::cast( + getTypeConverter()->convertType(opResult.getType())); + + auto meshShardOp = + ttmlir::utils::createDPSOp( + rewriter, loc, outputType, inputOperand, + meshSharding.getShardType(), meshSharding.getShardDirection(), + meshSharding.getShardShape(), meshSharding.getShardDims()); + + rewriter.replaceAllUsesWith(opResult, meshShardOp.getResult()); + } + + // Inline inner block ops. + rewriter.inlineBlockBefore(&srcOp.getBody().front(), srcOp, + fullToShardResults); + rewriter.eraseOp(sdyReturn); + rewriter.eraseOp(srcOp); + + return llvm::success(); + } +}; + +class ShardyToTTIRMeshOpConversionPattern + : public mlir::OpConversionPattern { + using mlir::OpConversionPattern::OpConversionPattern; + +public: + llvm::LogicalResult + matchAndRewrite(mlir::sdy::MeshOp srcOp, mlir::sdy::MeshOp::Adaptor adaptor, + mlir::ConversionPatternRewriter &rewriter) const override { + // The goal of this conversion is to extract hardware mesh information from + // sdy.mesh op and store it as module attribute. + auto module = srcOp->getParentOfType(); + if (!module) { + llvm_unreachable( + "mlir::sdy::MeshOp requires module as one of parent ops."); + } + + mlir::StringAttr meshName = srcOp.getSymNameAttr(); + llvm::SmallVector meshShape; + mlir::sdy::MeshAttr sdyMesh = srcOp.getMesh(); + for (auto meshAxisAttr : sdyMesh.getAxes()) { + meshShape.push_back(meshAxisAttr.getSize()); + } + mlir::tt::utils::addMeshToModuleAttribute(rewriter, module, meshName, + meshShape); + rewriter.eraseOp(srcOp); + return llvm::success(); + } +}; + +} // namespace + +namespace mlir::tt { + +void populateShardyToTTIRPatterns(MLIRContext *ctx, RewritePatternSet &patterns, + TypeConverter &typeConverter) { + patterns.add(typeConverter, + ctx); + patterns.add(typeConverter, ctx); +} + +} // namespace mlir::tt diff --git a/lib/Conversion/StableHLOToTTIR/StableHLOToTTIRPass.cpp b/lib/Conversion/StableHLOToTTIR/StableHLOToTTIRPass.cpp index 996abd5f76..7d179f2cd4 100644 --- a/lib/Conversion/StableHLOToTTIR/StableHLOToTTIRPass.cpp +++ b/lib/Conversion/StableHLOToTTIR/StableHLOToTTIRPass.cpp @@ -20,15 +20,31 @@ #include #include "ttmlir/Conversion/StableHLOToTTIR/EmptyOpTypeConversion.h" +#include "ttmlir/Conversion/StableHLOToTTIR/ShardyToTTIR.h" #include "ttmlir/Dialect/TT/IR/TT.h" #include "ttmlir/Dialect/TTIR/IR/TTIR.h" +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/Func/Transforms/FuncConversions.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/Dialect.h" +#include "mlir/IR/MLIRContext.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Pass/Pass.h" +#include "shardy/dialect/sdy/ir/dialect.h" +#include "stablehlo/dialect/StablehloOps.h" +#include "llvm/ADT/ArrayRef.h" + using namespace mlir; using namespace mlir::tt; namespace mlir::tt::ttir { #define GEN_PASS_DEF_CONVERTSTABLEHLOTOTTIR +#define GEN_PASS_DEF_CONVERTSHARDYTOTTIR #include "ttmlir/Conversion/Passes.h.inc" } // namespace mlir::tt::ttir @@ -94,8 +110,8 @@ struct ConvertStableHLOToTTIRPass mlir::ConversionTarget target(getContext()); target.addIllegalDialect(); + target.addIllegalDialect(); - target.addLegalDialect(); target.addLegalDialect(); target.addLegalOp(); target.addLegalOp(); @@ -127,6 +143,8 @@ struct ConvertStableHLOToTTIRPass [&](tensor::EmptyOp op) { return typeConverter.isLegal(op); }); populateStableHLOToTTIRPatterns(&getContext(), patterns, typeConverter); + populateShardyToTTIRPatterns(&getContext(), patterns, typeConverter); + // Apply conversion. if (failed( applyFullConversion(getOperation(), target, std::move(patterns)))) { @@ -135,7 +153,6 @@ struct ConvertStableHLOToTTIRPass } } }; - } // namespace namespace mlir::tt { diff --git a/lib/Conversion/StableHLOToTTIR/StableHLOToTTIRPatterns.cpp b/lib/Conversion/StableHLOToTTIR/StableHLOToTTIRPatterns.cpp index 144d2564cc..e3506d4f49 100644 --- a/lib/Conversion/StableHLOToTTIR/StableHLOToTTIRPatterns.cpp +++ b/lib/Conversion/StableHLOToTTIR/StableHLOToTTIRPatterns.cpp @@ -12,17 +12,20 @@ #include "mlir/IR/Region.h" #include "mlir/IR/Value.h" #include "mlir/Support/LLVM.h" +#include "mlir/Transforms/DialectConversion.h" #include "ttmlir/Conversion/StableHLOToTTIR/ShardingUtils.h" #include "ttmlir/Conversion/StableHLOToTTIR/StableHLOToTTIR.h" #include "ttmlir/Dialect/TT/IR/TT.h" #include "ttmlir/Dialect/TT/IR/TTOpsTypes.h" +#include "ttmlir/Dialect/TT/Utils/Mesh.h" #include "ttmlir/Dialect/TTIR/IR/TTIR.h" #include "ttmlir/Dialect/TTIR/IR/TTIRGenericRegionOps.h" #include "ttmlir/Dialect/TTIR/IR/TTIROps.h" #include "ttmlir/Utils.h" #include +#include #include #include #include @@ -1526,67 +1529,24 @@ class StableHLOToTTIRAllReduceOpConversionPattern return err; } - // Create the output tensor type based on inputs - auto outputType = mlir::cast( - getTypeConverter()->convertType(srcOp.getResult(0).getType())); - - // Create an empty output tensor with the computed shape - tensor::EmptyOp outputTensor = rewriter.create( - srcOp.getLoc(), outputType.getShape(), outputType.getElementType()); - - SmallVector ttirTypes; - if (failed(this->getTypeConverter()->convertTypes(srcOp->getResultTypes(), - ttirTypes))) { - return failure(); - } - - auto ttirOperands = srcOp.getOperandsMutable(); - ttirOperands.append(ValueRange(outputTensor)); - - SmallVector srcAttrs = to_vector(srcOp->getAttrs()); - SmallVector ttirAttrs; - for (auto srcAttr : srcAttrs) { - StringAttr srcName = srcAttr.getName(); - if (srcName == "channel_handle") { - auto channelHandle = srcOp.getChannelHandle().value(); - - // channelType is supposed to be DEVICE_TO_DEVICE for CCL ops. - // Currently, we ensure if it is DEVICE_TO_DEVICE communication. - // Consider preserving this information in the future if the attribute - // is non-DEVICE_TO_DEVICE values. - auto channelType = static_cast(channelHandle.getType()); - if (channelType != kChannelTypeDeviceToDevice) { - return failure(); - } - - IntegerAttr channelHandleAttr = rewriter.getSI32IntegerAttr( - static_cast(channelHandle.getHandle())); - if (!channelHandleAttr) { - return failure(); - } - ttirAttrs.push_back({srcName, channelHandleAttr}); - } else { - ttirAttrs.push_back(srcAttr); + IntegerAttr channelHandleAttr; + if (auto srcChannelHandleAttr = adaptor.getChannelHandleAttr()) { + // channelType is supposed to be DEVICE_TO_DEVICE or Invalid for CCL ops. + // Currently, we ensure if it is DEVICE_TO_DEVICE commmuincaiton. + // Consider preserving this information in the future if the attribute + // is non-DEVICE_TO_DEVICE values. + auto channelType = static_cast(srcChannelHandleAttr.getType()); + if (channelType != kChannelTypeDeviceToDevice && + channelType != kChannelTypeInvalid) { + return failure(); } - } - // Algorithm: search for first non-one working dimension from back - auto replicaGroupsShape = adaptor.getReplicaGroups().getType().getShape(); - size_t dim = replicaGroupsShape.size() - 1; - for (auto s = replicaGroupsShape.rbegin(); s != replicaGroupsShape.rend(); - ++s, --dim) { - if (*s != 1) { - break; - } - } - if (dim < 0) { - // all one shape, then select the fastest dim - dim = replicaGroupsShape.size(); + channelHandleAttr = rewriter.getSI32IntegerAttr( + static_cast(srcChannelHandleAttr.getHandle())); } - StringAttr dimName = StringAttr::get(this->getContext(), "dim"); - IntegerAttr dimAttr = - rewriter.getSI32IntegerAttr(static_cast(dim)); - ttirAttrs.push_back({dimName, dimAttr}); + mlir::DenseIntElementsAttr replicaGroupsAttr = + adaptor.getReplicaGroupsAttr(); + bool useGlobalDeviceIds = adaptor.getUseGlobalDeviceIds(); // Parse computation in region and add it to ttirAttrs ReduceType reduceType; @@ -1594,16 +1554,43 @@ class StableHLOToTTIRAllReduceOpConversionPattern return rewriter.notifyMatchFailure( srcOp, "AllReduceOp cannot specify reduce type."); } - StringAttr reduceTypeAttrName = - StringAttr::get(this->getContext(), "reduce_type"); - Attribute reduceTypeAttr = rewriter.getAttr(reduceType); - ttirAttrs.push_back({reduceTypeAttrName, reduceTypeAttr}); - auto ttirAllReduceOp = rewriter.create( - srcOp.getLoc(), ttirTypes, ValueRange(ttirOperands.getAsOperandRange()), - ttirAttrs); + // stablehlo all_reduce op has no dimension defined in the op. Thus, we + // estimate possible all reduce dimension. Current algorithm is to search + // for first non-one dimension of input tensor from back. + auto estimateDim = [](mlir::RankedTensorType inputType) -> int32_t { + if (inputType.getRank() == 1) { + return 0; + } + auto inputShape = inputType.getShape(); + auto nonOneIt = std::find_if(inputShape.rbegin(), inputShape.rend(), + [](int64_t s) { return s != 1; }); + int32_t dim = inputType.getRank() - 1 - + std::distance(inputShape.rbegin(), nonOneIt); + // all one shape, then select the deepest dim + if (dim < 0) { + dim = inputType.getRank() - 1; + } + return dim; + }; + + // Handle variadic input/output pairs by creating mulitple AllReduceOps. + llvm::SmallVector allReduceOpResults; + for (auto [inputOperand, resultOperand] : + llvm::zip_equal(adaptor.getOperands(), srcOp->getResults())) { + auto inputType = mlir::cast(inputOperand.getType()); + auto outputType = mlir::cast( + getTypeConverter()->convertType(resultOperand.getType())); - rewriter.replaceOp(srcOp, ttirAllReduceOp); + auto allReduceOp = + ttmlir::utils::createDPSOp( + rewriter, srcOp.getLoc(), outputType, inputOperand, + replicaGroupsAttr, estimateDim(inputType), channelHandleAttr, + useGlobalDeviceIds, reduceType); + + allReduceOpResults.push_back(allReduceOp.getResult()); + } + rewriter.replaceOp(srcOp, allReduceOpResults); return success(); } @@ -1613,9 +1600,9 @@ class StableHLOToTTIRAllReduceOpConversionPattern checkBasicLegality(mlir::stablehlo::AllReduceOp &srcOp, mlir::stablehlo::AllReduceOp::Adaptor adaptor, ConversionPatternRewriter &rewriter) const { - if (srcOp.getOperands().empty() || srcOp.getOperands().size() > 1) { + if (srcOp.getOperands().empty()) { return rewriter.notifyMatchFailure( - srcOp, "AllReduceOp must have one input/output for now."); + srcOp, "AllReduceOp must have at least one input/output."); } return success(); @@ -1743,101 +1730,121 @@ class StableHLOToTTIRCustomCallOpConversionPattern return err; } - const std::string kShardingTarget = "Sharding"; - const std::string kSPMDFullToShardShapeTarget = "SPMDFullToShardShape"; - const std::string kSPMDShardToFullShapeTarget = "SPMDShardToFullShape"; - auto callTargetName = adaptor.getCallTargetNameAttr(); // Currently stablehlo.custom_call with following functions from // jax/openxla are supported - if (callTargetName != kShardingTarget && - callTargetName != kSPMDFullToShardShapeTarget && - callTargetName != kSPMDShardToFullShapeTarget) { + if (callTargetName != + mlir::tt::sharding_utils::kShardingCustomCallTargetName && + callTargetName != + mlir::tt::sharding_utils::kSPMDFullToShardShapeCallTargetName && + callTargetName != + mlir::tt::sharding_utils::kSPMDShardToFullShapeCallTargetName) { return failure(); } - auto shardingAttr = dyn_cast_or_null( - adaptor.getAttributes().get("mhlo.sharding")); + auto shardingAttr = + dyn_cast_if_present(adaptor.getAttributes().get( + mlir::tt::sharding_utils::kXlaShardingAttr)); if (!shardingAttr) { return failure(); } mlir::tt::sharding_utils::MeshSharding meshSharding; - if (failed(mlir::tt::sharding_utils::parseGSPMDShardingAttr( - shardingAttr.getValue(), meshSharding))) { - return failure(); + auto error = meshSharding.convertGSPMDShardingToMeshSharding( + shardingAttr.getValue()); + if (auto e = error.takeError()) { + return rewriter.notifyMatchFailure(srcOp, llvm::toString(std::move(e))); + } + + // For GSPMD, meshShape is extracted by the parser. Then, add it as module + // attribute such that the information is used by later pipeline stage. + auto meshShape = meshSharding.getMeshShape(); + if (meshShape.size() > 1) { + auto module = srcOp->getParentOfType(); + if (!module) { + llvm_unreachable("Require module as one of parent ops."); + } + mlir::tt::utils::addMeshToModuleAttribute( + rewriter, module, StringAttr::get(getContext(), "mesh_gspmd"), + meshShape); } - if (callTargetName == kSPMDFullToShardShapeTarget) { + if (callTargetName == + mlir::tt::sharding_utils::kSPMDFullToShardShapeCallTargetName) { + // @Sharding => @SPMDFullToShardShape pattern Operation *shardingOp = srcOp->getOperand(0).getDefiningOp(); if (!shardingOp) { return rewriter.notifyMatchFailure( - srcOp, "requires operand to be defined by an op"); + srcOp, "Requires operand to be defined by prior Sharding op."); } - - // TODO(wooseoklee): a bit rough approach here to match output dim - shardingOp->getResult(0).setType(srcOp->getResult(0).getType()); - srcOp.getResult(0).replaceAllUsesWith(shardingOp->getResult(0)); - rewriter.eraseOp(srcOp); - } else if (callTargetName == kSPMDShardToFullShapeTarget) { + rewriter.replaceOp(srcOp, shardingOp->getResult(0)); + } else if (callTargetName == + mlir::tt::sharding_utils::kSPMDShardToFullShapeCallTargetName) { + // @Sharding => @SPMDShardToFullShape pattern Operation *shardingOp = srcOp->getOperand(0).getDefiningOp(); if (!shardingOp) { return rewriter.notifyMatchFailure( - srcOp, "requires operand to be defined by an op"); + srcOp, "Requires operand to be defined by prior Sharding op."); } - // Create the output tensor type based on inputs auto outputType = mlir::cast( getTypeConverter()->convertType(srcOp->getResult(0).getType())); - // Create an empty output tensor with the computed shape - tensor::EmptyOp outputTensor = rewriter.create( - srcOp.getLoc(), outputType.getShape(), outputType.getElementType()); - - SmallVector outputTypes; - if (failed(this->getTypeConverter()->convertTypes(srcOp->getResultTypes(), - outputTypes))) { - return failure(); - } - - meshSharding.shardDirection = mlir::tt::MeshShardDirection::ShardToFull; - rewriter.replaceOpWithNewOp( - srcOp, outputTypes, srcOp.getInputs().front(), outputTensor, - meshSharding.shardType, meshSharding.shardDirection, - meshSharding.shardShape, meshSharding.shardDims); - } else if (callTargetName == kShardingTarget) { - if (meshSharding.shardType == mlir::tt::MeshShardType::Manual) { - // "manual" sharding indicates match between input/output tensor shape - // and no sharding is required. - srcOp.getResult(0).replaceAllUsesWith(srcOp->getOperand(0)); - rewriter.eraseOp(srcOp); + ttmlir::utils::replaceOpWithNewDPSOp( + rewriter, srcOp, outputType, adaptor.getInputs().front(), + meshSharding.getShardType(), + mlir::tt::MeshShardDirection::ShardToFull, + meshSharding.getShardShape(), meshSharding.getShardDims()); + + } else if (callTargetName == + mlir::tt::sharding_utils::kShardingCustomCallTargetName) { + if (meshSharding.getShardType() == mlir::tt::MeshShardType::Manual) { + // @Sharding => @SPMDShardToFullShape pattern + // "manual" sharding indicates no sharding is required. + rewriter.replaceOp(srcOp, srcOp->getOperand(0)); } else { - auto *user = *srcOp.getResult(0).user_begin(); - auto userOp = dyn_cast_or_null(user); - if (!userOp) { + // @Sharding => @SPMDFullToShardShape pattern + auto fullToShardCustomCall = + mlir::dyn_cast_if_present( + *srcOp->user_begin()); + if (!fullToShardCustomCall || !fullToShardCustomCall->hasOneUse()) { return failure(); } - // Create the output tensor type based on inputs - auto outputType = mlir::cast( - getTypeConverter()->convertType(userOp->getResult(0).getType())); - - // Create an empty output tensor with the computed shape - tensor::EmptyOp outputTensor = rewriter.create( - srcOp.getLoc(), outputType.getShape(), outputType.getElementType()); - - SmallVector outputTypes; - if (failed(this->getTypeConverter()->convertTypes( - userOp->getResultTypes(), outputTypes))) { - return failure(); + // JAX automatic sharding pre-shards input tensors and provides multiple + // buffers. Thus, mesh sharding operations should not shard the tensors + // twice if they are function arguments and pre-sharded by frontend. + // Runtime ignores mesh sharding operation if it is set as manual + // sharding. + auto inputOperand = adaptor.getInputs().front(); + auto funcOp = srcOp->getParentOfType(); + if (auto blockArg = mlir::dyn_cast(inputOperand)) { + auto argNum = blockArg.getArgNumber(); + if (auto argShardingAttr = funcOp.getArgAttrOfType( + argNum, mlir::tt::sharding_utils::kXlaShardingAttr)) { + if (argShardingAttr == shardingAttr) { + meshSharding.setDummyShardingOp(); + rewriter.modifyOpInPlace(funcOp, [&]() { + funcOp.removeArgAttr( + argNum, mlir::tt::sharding_utils::kXlaShardingAttr); + }); + } else { + llvm_unreachable("GSPMD customCallOp and function argument " + "shardings are different."); + } + } } - meshSharding.shardDirection = mlir::tt::MeshShardDirection::FullToShard; - rewriter.replaceOpWithNewOp( - srcOp, outputTypes, srcOp.getInputs().front(), outputTensor, - meshSharding.shardType, meshSharding.shardDirection, - meshSharding.shardShape, meshSharding.shardDims); + auto outputType = + mlir::cast(getTypeConverter()->convertType( + fullToShardCustomCall->getResult(0).getType())); + + ttmlir::utils::replaceOpWithNewDPSOp( + rewriter, srcOp, outputType, inputOperand, + meshSharding.getShardType(), + mlir::tt::MeshShardDirection::FullToShard, + meshSharding.getShardShape(), meshSharding.getShardDims()); } } return success(); @@ -1848,9 +1855,9 @@ class StableHLOToTTIRCustomCallOpConversionPattern checkBasicLegality(mlir::stablehlo::CustomCallOp &srcOp, mlir::stablehlo::CustomCallOp::Adaptor adaptor, ConversionPatternRewriter &rewriter) const { - - // Expect single input/output, otherwise do not convert - if (adaptor.getInputs().size() != 1 && srcOp->getResults().size() != 1) { + // Expect single input/output and at least one use of result. + if (srcOp->getNumOperands() != 1 || srcOp->getNumResults() != 1 || + srcOp->getResult(0).use_empty()) { return failure(); } diff --git a/lib/Conversion/TTIRToTTNN/TTIRToTTNN.cpp b/lib/Conversion/TTIRToTTNN/TTIRToTTNN.cpp index 8f420a5fb7..e59a7efc13 100644 --- a/lib/Conversion/TTIRToTTNN/TTIRToTTNN.cpp +++ b/lib/Conversion/TTIRToTTNN/TTIRToTTNN.cpp @@ -1299,7 +1299,7 @@ class AllReduceOpConversionPattern using OpConversionPattern::OpConversionPattern; LogicalResult - matchAndRewrite(ttir::AllReduceOp op, OpAdaptor adaptor, + matchAndRewrite(ttir::AllReduceOp srcOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { auto replicaGroupsShape = adaptor.getReplicaGroups().getType().getShape(); @@ -1308,10 +1308,10 @@ class AllReduceOpConversionPattern // pass of reduce_scatter output and all_gather input int32_t scatter_num = replicaGroupsShape[scatter_dim % replicaGroupsShape.size()]; - auto device = ::ttnn::utils::getOrInsertDevice(rewriter, op); + auto device = ::ttnn::utils::getOrInsertDevice(rewriter, srcOp); rewriter.replaceOpWithNewOp( - op, this->getTypeConverter()->convertType(op.getType(0)), - adaptor.getInputs().front(), device, scatter_dim, scatter_num, + srcOp, this->getTypeConverter()->convertType(srcOp.getType()), + adaptor.getInput(), device, scatter_dim, scatter_num, adaptor.getReduceType()); return success(); diff --git a/lib/Dialect/TTIR/IR/TTIROps.cpp b/lib/Dialect/TTIR/IR/TTIROps.cpp index 04591a9166..79e630df8c 100644 --- a/lib/Dialect/TTIR/IR/TTIROps.cpp +++ b/lib/Dialect/TTIR/IR/TTIROps.cpp @@ -1883,8 +1883,7 @@ ::mlir::LogicalResult mlir::tt::ttir::AllGatherOp::verify() { // AllReduceOp verification ::mlir::LogicalResult mlir::tt::ttir::AllReduceOp::verify() { - ::mlir::RankedTensorType inputType = - mlir::cast(getInputs().front().getType()); + ::mlir::RankedTensorType inputType = getInput().getType(); int32_t dim = getDim(); if (dim >= inputType.getRank()) { @@ -1902,10 +1901,9 @@ ::mlir::LogicalResult mlir::tt::ttir::AllReduceOp::verify() { ::mlir::LogicalResult mlir::tt::ttir::MeshShardOp::verify() { auto shardType = getShardType(); - // Currently, we are only supporting replicate or devices from StableHLO. - if (shardType != mlir::tt::MeshShardType::Replicate && - shardType != mlir::tt::MeshShardType::Devices) { - return emitOpError("Invalid shard_type for mesh_shard op."); + // Currently, we are not supporting maximal from StableHLO. + if (shardType == mlir::tt::MeshShardType::Maximal) { + return emitOpError("Invalid shard_type (maximal) for mesh_shard op."); } return success(); diff --git a/lib/Dialect/TTIR/Transforms/Utility.cpp b/lib/Dialect/TTIR/Transforms/Utility.cpp index df0e5a6f36..41ff8a2e76 100644 --- a/lib/Dialect/TTIR/Transforms/Utility.cpp +++ b/lib/Dialect/TTIR/Transforms/Utility.cpp @@ -1,6 +1,8 @@ // SPDX-FileCopyrightText: (c) 2024 Tenstorrent AI ULC // // SPDX-License-Identifier: Apache-2.0 + +#include "ttmlir/Dialect/TT/Utils/Mesh.h" #include "ttmlir/Dialect/TTIR/Transforms/Passes.h" namespace mlir::tt::ttir { @@ -23,11 +25,15 @@ class TTIRImplicitDevice if (not module->hasAttr(tt::DeviceAttr::name)) { assert(module->hasAttr(tt::SystemDescAttr::name)); auto systemDesc = module->getAttr(tt::SystemDescAttr::name); + auto finalMeshShape = tt::utils::determineMeshShape(module, *meshShape); + if (auto err = finalMeshShape.takeError()) { + return; + } module->setAttr( tt::DeviceAttr::name, tt::DeviceAttr::get(&getContext(), mlir::cast(systemDesc), - meshShape)); + *finalMeshShape)); } } diff --git a/lib/Dialect/TTNN/IR/TTNNOps.cpp b/lib/Dialect/TTNN/IR/TTNNOps.cpp index 250432cd33..34f6812907 100644 --- a/lib/Dialect/TTNN/IR/TTNNOps.cpp +++ b/lib/Dialect/TTNN/IR/TTNNOps.cpp @@ -1402,10 +1402,9 @@ ::mlir::LogicalResult MeshShardOp::verify() { llvm::ArrayRef shardShape = getShardShape(); ::mlir::tt::MeshShardType shardType = getShardType(); - // Check sharding is one of replicate or devices. - if (shardType != ::mlir::tt::MeshShardType::Replicate && - shardType != ::mlir::tt::MeshShardType::Devices) { - return emitOpError("Invalid shard_type for mesh_shard op."); + // Check shard_type is not maximal. + if (shardType == ::mlir::tt::MeshShardType::Maximal) { + return emitOpError("Invalid shard_type (maximal) for mesh_shard op."); } if (shardType == ::mlir::tt::MeshShardType::Devices) { diff --git a/runtime/lib/ttnn/operations/ccl/mesh_shard.cpp b/runtime/lib/ttnn/operations/ccl/mesh_shard.cpp index d4f5bb41f7..35499e8b82 100644 --- a/runtime/lib/ttnn/operations/ccl/mesh_shard.cpp +++ b/runtime/lib/ttnn/operations/ccl/mesh_shard.cpp @@ -21,10 +21,8 @@ void FullToShardShape(const ::ttnn::Tensor &input, ::ttnn::Tensor &out, input, *::ttnn::distributed::replicate_tensor_to_mesh_mapper(meshDevice)); } else { - DEBUG_ASSERT( - input.get_logical_shape().rank() > 1, - "Sharding requires higher than 2 dimensional tensor. Tensor rank=", - input.get_logical_shape().rank()); + DEBUG_ASSERT(input.get_logical_shape().rank() > 1, + "Sharding requires higher than one dimensional tensor."); ::ttnn::distributed::Shard2dConfig shard2dConfig{std::nullopt, std::nullopt}; if (shardDims[0] >= 0) { @@ -96,6 +94,18 @@ void run(const ::tt::target::ttnn::MeshShardOp *op, ProgramContext &context) { DEBUG_ASSERT(::tt::runtime::ttnn::utils::isOnHost(input.storage_type()), "Input of ttnn::mesh_shard should be host tensor"); + // Regards manual sharding as no op assuming that the input tensor is + // pre-sharded by frontend. Thus, no sharding is required, but need to makes + // sure if the tensor is multi-device host tensor. + if (shardType == ::tt::target::ttnn::MeshShardType::Manual) { + LOG_ASSERT( + input.storage_type() == ::tt::tt_metal::StorageType::MULTI_DEVICE, + "Input of mesh_shard with manual sharding must be MULTIDEVICE. id:", + op->in()->global_id()); + tensorPool.insert_or_assign(op->out()->global_id(), input); + return; + } + if (shardDirection != ::tt::target::ttnn::MeshShardDirection::FullToShardShape && shardDirection != diff --git a/test/ttmlir/Conversion/StableHLOToTTIR/ccl/all_reduce.mlir b/test/ttmlir/Conversion/StableHLOToTTIR/ccl/all_reduce.mlir new file mode 100644 index 0000000000..6b9619e35f --- /dev/null +++ b/test/ttmlir/Conversion/StableHLOToTTIR/ccl/all_reduce.mlir @@ -0,0 +1,26 @@ +// REQUIRES: stablehlo +// RUN: ttmlir-opt -split-input-file --stablehlo-to-ttir-pipeline %s | FileCheck %s + +module { + func.func @all_reduce_variadic(%arg0: tensor, %arg1: tensor) -> (tensor, tensor) { + %cst_0 = stablehlo.constant dense<1.600000e+01> : tensor + %cst_1 = stablehlo.constant dense<0.000000e+00> : tensor + %0 = stablehlo.add %arg0, %cst_0 : tensor + %1 = stablehlo.add %arg1, %cst_1 : tensor + %2:2 = "stablehlo.all_reduce"(%0, %1) <{channel_handle = #stablehlo.channel_handle, replica_groups = dense<[[0, 1]]> : tensor<1x2xi64>, use_global_device_ids}> ({ + ^bb0(%arg2: tensor, %arg3: tensor): + %3 = stablehlo.add %arg2, %arg3 : tensor + stablehlo.return %3 : tensor + }) : (tensor, tensor) -> (tensor, tensor) + return %2#0, %2#1 : tensor, tensor + } +} + +// CHECK: "ttir.all_reduce" +// CHECK-SAME: channel_handle = 1 +// CHECK-SAME: dim = 0 +// CHECK-SAME: reduce_type = #tt.reduce_type +// CHECK: "ttir.all_reduce" +// CHECK-SAME: channel_handle = 1 +// CHECK-SAME: dim = 0 +// CHECK-SAME: reduce_type = #tt.reduce_type diff --git a/test/ttmlir/Conversion/StableHLOToTTIR/ccl_ops.mlir b/test/ttmlir/Conversion/StableHLOToTTIR/ccl/ccl_ops_gspmd.mlir similarity index 99% rename from test/ttmlir/Conversion/StableHLOToTTIR/ccl_ops.mlir rename to test/ttmlir/Conversion/StableHLOToTTIR/ccl/ccl_ops_gspmd.mlir index 5dc6f3cb54..ea1be54cab 100644 --- a/test/ttmlir/Conversion/StableHLOToTTIR/ccl_ops.mlir +++ b/test/ttmlir/Conversion/StableHLOToTTIR/ccl/ccl_ops_gspmd.mlir @@ -1,5 +1,5 @@ // REQUIRES: stablehlo -// RUN: ttmlir-opt --stablehlo-to-ttir-pipeline %s | FileCheck %s +// RUN: ttmlir-opt -split-input-file --stablehlo-to-ttir-pipeline %s | FileCheck %s // jax/pjrt sharding target 1x2 for n300 all_reduce module @all_reduce_1x2 attributes {mhlo.num_partitions = 2 : i32, mhlo.num_replicas = 1 : i32} { @@ -40,6 +40,8 @@ module @all_reduce_1x2 attributes {mhlo.num_partitions = 2 : i32, mhlo.num_repli } } +// ----- + // jax/pjrt sharding target 2x4 for t3k all_reduce module @all_reduce_2x4 attributes {mhlo.num_partitions = 8 : i32, mhlo.num_replicas = 1 : i32} { func.func public @main(%arg0: tensor<8192x784xf32>, %arg1: tensor<784x16384xf32>) -> (tensor<8192x16384xf32> {jax.result_info = ""}) { @@ -79,6 +81,8 @@ module @all_reduce_2x4 attributes {mhlo.num_partitions = 8 : i32, mhlo.num_repli } } +// ----- + // jax/pjrt sharding target 1x8 for t3k all_reduce module @all_reduce_1x8 attributes {mhlo.num_partitions = 8 : i32, mhlo.num_replicas = 1 : i32} { func.func public @main(%arg0: tensor<8192x784xf32> {mhlo.layout_mode = "default"}, %arg1: tensor<784x16384xf32> {mhlo.layout_mode = "default"}) -> (tensor<8192x16384xf32> {jax.result_info = "", mhlo.layout_mode = "default"}) { @@ -118,6 +122,8 @@ module @all_reduce_1x8 attributes {mhlo.num_partitions = 8 : i32, mhlo.num_repli } } +// ----- + // jax/pjrt sharding target 8x4 for tg all_reduce module @all_reduce_8x4 attributes {mhlo.num_partitions = 32 : i32, mhlo.num_replicas = 1 : i32} { func.func public @main(%arg0: tensor<8192x784xf32>, %arg1: tensor<784x16384xf32>) -> (tensor<8192x16384xf32> {jax.result_info = ""}) { @@ -157,6 +163,8 @@ module @all_reduce_8x4 attributes {mhlo.num_partitions = 32 : i32, mhlo.num_repl } } +// ----- + // jax/pjrt sharding target 1x32 for tg all_reduce module @all_reduce_1x32 attributes {mhlo.num_partitions = 32 : i32, mhlo.num_replicas = 1 : i32} { func.func public @main(%arg0: tensor<8192x800xf32>, %arg1: tensor<800x16384xf32>) -> (tensor<8192x16384xf32> {jax.result_info = ""}) { @@ -196,6 +204,8 @@ module @all_reduce_1x32 attributes {mhlo.num_partitions = 32 : i32, mhlo.num_rep } } +// ----- + // jax/pjrt sharding target 1x2 for n300 all_gather cluster_axis=0 rank=2 module @all_gather_1x2_rank_2_cluster_0 attributes {mhlo.num_partitions = 2 : i32, mhlo.num_replicas = 1 : i32} { func.func public @main(%arg0: tensor<8192x800xf32>) -> (tensor<8192x800xf32> {jax.result_info = ""}) { @@ -224,6 +234,8 @@ module @all_gather_1x2_rank_2_cluster_0 attributes {mhlo.num_partitions = 2 : i3 } } +// ----- + // jax/pjrt sharding target 1x2 for n300 all_gather rank=2 module @all_gather_1x2_rank_2_cluster_1 attributes {mhlo.num_partitions = 2 : i32, mhlo.num_replicas = 1 : i32} { func.func public @main(%arg0: tensor<8192x800xf32>) -> (tensor<16384x800xf32> {jax.result_info = ""}) { @@ -252,6 +264,8 @@ module @all_gather_1x2_rank_2_cluster_1 attributes {mhlo.num_partitions = 2 : i3 } } +// ----- + // jax/pjrt sharding target 1x2 for n300 all_gather cluster_axis=0 rank=4 module @all_gather_1x2_rank_4_cluster_0 attributes {mhlo.num_partitions = 2 : i32, mhlo.num_replicas = 1 : i32} { func.func public @main(%arg0: tensor<1x1x8192x512xf32>) -> (tensor<1x1x8192x512xf32> {jax.result_info = ""}) { @@ -280,6 +294,8 @@ module @all_gather_1x2_rank_4_cluster_0 attributes {mhlo.num_partitions = 2 : i3 } } +// ----- + // jax/pjrt sharding target 1x2 for n300 all_gather rank=4 module @all_gather_1x2_rank_4_cluster_1 attributes {mhlo.num_partitions = 2 : i32, mhlo.num_replicas = 1 : i32} { func.func public @main(%arg0: tensor<1x1x8192x784xf32>) -> (tensor<1x1x16384x784xf32> {jax.result_info = ""}) { @@ -308,6 +324,8 @@ module @all_gather_1x2_rank_4_cluster_1 attributes {mhlo.num_partitions = 2 : i3 } } +// ----- + // jax/pjrt sharding target 1x8 for t3k all_gather cluster_axis=0 rank=2 module @all_gather_1x8_rank_2_cluster_0 attributes {mhlo.num_partitions = 8 : i32, mhlo.num_replicas = 1 : i32} { func.func public @main(%arg0: tensor<8192x800xf32>) -> (tensor<8192x800xf32> {jax.result_info = ""}) { @@ -336,6 +354,8 @@ module @all_gather_1x8_rank_2_cluster_0 attributes {mhlo.num_partitions = 8 : i3 } } +// ----- + // jax/pjrt sharding target 1x8 for t3k all_gather rank=2 module @all_gather_1x8_rank_2_cluster_1 attributes {mhlo.num_partitions = 8 : i32, mhlo.num_replicas = 1 : i32} { func.func public @main(%arg0: tensor<8192x800xf32>) -> (tensor<65536x800xf32> {jax.result_info = ""}) { @@ -364,6 +384,8 @@ module @all_gather_1x8_rank_2_cluster_1 attributes {mhlo.num_partitions = 8 : i3 } } +// ----- + // jax/pjrt sharding target 1x8 for t3k all_gather cluster_axis=0 rank=4 module @all_gather_1x8_rank_4_cluster_0 attributes {mhlo.num_partitions = 8 : i32, mhlo.num_replicas = 1 : i32} { func.func public @main(%arg0: tensor<1x1x8192x512xf32>) -> (tensor<1x1x8192x512xf32> {jax.result_info = ""}) { @@ -392,6 +414,8 @@ module @all_gather_1x8_rank_4_cluster_0 attributes {mhlo.num_partitions = 8 : i3 } } +// ----- + // jax/pjrt sharding target 1x8 for t3k all_gather rank=4 module @all_gather_1x8_rank_4_cluster_1 attributes {mhlo.num_partitions = 8 : i32, mhlo.num_replicas = 1 : i32} { func.func public @main(%arg0: tensor<1x1x8192x784xf32>) -> (tensor<1x1x65536x784xf32> {jax.result_info = ""}) { @@ -420,6 +444,8 @@ module @all_gather_1x8_rank_4_cluster_1 attributes {mhlo.num_partitions = 8 : i3 } } +// ----- + // jax/pjrt sharding target 2x4 for t3k all_gather cluster_axis=0 rank=2 module @all_gather_2x4_rank_2_cluster_0 attributes {mhlo.num_partitions = 8 : i32, mhlo.num_replicas = 1 : i32} { func.func public @main(%arg0: tensor<8192x800xf32>) -> (tensor<8192x1600xf32> {jax.result_info = ""}) { @@ -448,6 +474,8 @@ module @all_gather_2x4_rank_2_cluster_0 attributes {mhlo.num_partitions = 8 : i3 } } +// ----- + // jax/pjrt sharding target 2x4 for t3k all_gather rank=2 module @all_gather_2x4_rank_2_cluster_1 attributes {mhlo.num_partitions = 8 : i32, mhlo.num_replicas = 1 : i32} { func.func public @main(%arg0: tensor<8192x800xf32>) -> (tensor<32768x800xf32> {jax.result_info = ""}) { @@ -476,6 +504,8 @@ module @all_gather_2x4_rank_2_cluster_1 attributes {mhlo.num_partitions = 8 : i3 } } +// ----- + // jax/pjrt sharding target 2x4 for t3k all_gather cluster_axis=0 rank=4 module @all_gather_2x4_rank_4_cluster_0 attributes {mhlo.num_partitions = 8 : i32, mhlo.num_replicas = 1 : i32} { func.func public @main(%arg0: tensor<1x1x8192x512xf32>) -> (tensor<1x1x16384x512xf32> {jax.result_info = ""}) { @@ -504,6 +534,8 @@ module @all_gather_2x4_rank_4_cluster_0 attributes {mhlo.num_partitions = 8 : i3 } } +// ----- + // jax/pjrt sharding target 2x4 for t3k all_gather rank=4 module @all_gather_2x4_rank_4_cluster_1 attributes {mhlo.num_partitions = 8 : i32, mhlo.num_replicas = 1 : i32} { func.func public @main(%arg0: tensor<1x1x8192x784xf32>) -> (tensor<1x1x32768x784xf32> {jax.result_info = ""}) { @@ -532,6 +564,8 @@ module @all_gather_2x4_rank_4_cluster_1 attributes {mhlo.num_partitions = 8 : i3 } } +// ----- + // jax/pjrt sharding target 1x32 for tg all_gather cluster_axis=0 rank=2 module @all_gather_1x32_rank_2_cluster_0 attributes {mhlo.num_partitions = 32 : i32, mhlo.num_replicas = 1 : i32} { func.func public @main(%arg0: tensor<8192x800xf32>) -> (tensor<8192x800xf32> {jax.result_info = ""}) { @@ -560,6 +594,8 @@ module @all_gather_1x32_rank_2_cluster_0 attributes {mhlo.num_partitions = 32 : } } +// ----- + // jax/pjrt sharding target 1x32 for tg all_gather rank=2 module @all_gather_1x32_rank_2_cluster_1 attributes {mhlo.num_partitions = 32 : i32, mhlo.num_replicas = 1 : i32} { func.func public @main(%arg0: tensor<8192x800xf32>) -> (tensor<262144x800xf32> {jax.result_info = ""}) { @@ -588,6 +624,8 @@ module @all_gather_1x32_rank_2_cluster_1 attributes {mhlo.num_partitions = 32 : } } +// ----- + // jax/pjrt sharding target 1x32 for tg all_gather cluster_axis=0 rank=4 module @all_gather_1x32_rank_4_cluster_0 attributes {mhlo.num_partitions = 32 : i32, mhlo.num_replicas = 1 : i32} { func.func public @main(%arg0: tensor<1x1x8192x512xf32>) -> (tensor<1x1x8192x512xf32> {jax.result_info = ""}) { @@ -616,6 +654,8 @@ module @all_gather_1x32_rank_4_cluster_0 attributes {mhlo.num_partitions = 32 : } } +// ----- + // jax/pjrt sharding target 1x32 for tg all_gather rank4 module @all_gather_1x32_rank_4_cluster_1 attributes {mhlo.num_partitions = 32 : i32, mhlo.num_replicas = 1 : i32} { func.func public @main(%arg0: tensor<1x1x8192x784xf32>) -> (tensor<1x1x262144x784xf32> {jax.result_info = ""}) { @@ -644,6 +684,8 @@ module @all_gather_1x32_rank_4_cluster_1 attributes {mhlo.num_partitions = 32 : } } +// ----- + // jax/pjrt sharding target 8x4 for tg all_gather cluster_axis=0 rank=2 module @all_gather_8x4_rank_2_cluster_0 attributes {mhlo.num_partitions = 32 : i32, mhlo.num_replicas = 1 : i32} { func.func public @main(%arg0: tensor<8192x800xf32>) -> (tensor<8192x6400xf32> {jax.result_info = ""}) { @@ -672,6 +714,8 @@ module @all_gather_8x4_rank_2_cluster_0 attributes {mhlo.num_partitions = 32 : i } } +// ----- + // jax/pjrt sharding target 8x4 for tg all_gather rank=2 module @all_gather_4x8_rank_2_cluster_1 attributes {mhlo.num_partitions = 32 : i32, mhlo.num_replicas = 1 : i32} { func.func public @main(%arg0: tensor<8192x800xf32>) -> (tensor<65536x800xf32> {jax.result_info = ""}) { @@ -700,6 +744,8 @@ module @all_gather_4x8_rank_2_cluster_1 attributes {mhlo.num_partitions = 32 : i } } +// ----- + // jax/pjrt sharding target 8x4 for tg all_gather cluster_axis=0 rank=4 module @all_gather_8x4_rank_4_cluster_0 attributes {mhlo.num_partitions = 32 : i32, mhlo.num_replicas = 1 : i32} { func.func public @main(%arg0: tensor<1x1x8192x512xf32>) -> (tensor<1x1x65536x512xf32> {jax.result_info = ""}) { @@ -728,6 +774,8 @@ module @all_gather_8x4_rank_4_cluster_0 attributes {mhlo.num_partitions = 32 : i } } +// ----- + // jax/pjrt sharding target 8x4 for tg all_gather rank4 module @all_gather_8x4_rank_4_cluster_1 attributes {mhlo.num_partitions = 32 : i32, mhlo.num_replicas = 1 : i32} { func.func public @main(%arg0: tensor<1x1x8192x784xf32>) -> (tensor<1x1x32768x784xf32> {jax.result_info = ""}) { @@ -756,6 +804,8 @@ module @all_gather_8x4_rank_4_cluster_1 attributes {mhlo.num_partitions = 32 : i } } +// ----- + // jax/pjrt sharding target 2x4 for t3k - GSPMD negative, sharding [None, "x", None, "y"] module @jit_neg_basic0 attributes {mhlo.num_partitions = 8 : i32, mhlo.num_replicas = 1 : i32} { func.func public @main(%arg0: tensor<1x1024x128x1024xf32>) -> (tensor<1x1024x128x1024xf32> {jax.result_info = ""}) { @@ -782,6 +832,8 @@ module @jit_neg_basic0 attributes {mhlo.num_partitions = 8 : i32, mhlo.num_repli } } +// ----- + // jax/pjrt sharding target 2x4 for t3k - GSPMD negative, sharding [None, "x", None, None] module @jit_neg_basic1 attributes {mhlo.num_partitions = 8 : i32, mhlo.num_replicas = 1 : i32} { func.func public @main(%arg0: tensor<1x1024x128x1024xf32>) -> (tensor<1x1024x128x1024xf32> {jax.result_info = ""}) { @@ -808,6 +860,8 @@ module @jit_neg_basic1 attributes {mhlo.num_partitions = 8 : i32, mhlo.num_repli } } +// ----- + // jax/pjrt sharding target 2x4 for t3k - GSPMD negative, sharding [None, None, None, "y"] module @jit_neg_basic2 attributes {mhlo.num_partitions = 8 : i32, mhlo.num_replicas = 1 : i32} { func.func public @main(%arg0: tensor<1x1024x128x1024xf32>) -> (tensor<1x1024x128x1024xf32> {jax.result_info = ""}) { @@ -834,6 +888,8 @@ module @jit_neg_basic2 attributes {mhlo.num_partitions = 8 : i32, mhlo.num_repli } } +// ----- + // jax/pjrt sharding target 2x4 for t3k - GSPMD negative, sharding [None, "y", None, "x"] module @jit_neg_basic3 attributes {mhlo.num_partitions = 8 : i32, mhlo.num_replicas = 1 : i32} { func.func public @main(%arg0: tensor<1x1024x128x1024xf32>) -> (tensor<1x1024x128x1024xf32> {jax.result_info = ""}) { @@ -860,6 +916,8 @@ module @jit_neg_basic3 attributes {mhlo.num_partitions = 8 : i32, mhlo.num_repli } } +// ----- + // jax/pjrt sharding target 2x4 for t3k - GSPMD negative, sharding [None, "y", None, None] module @jit_neg_basic4 attributes {mhlo.num_partitions = 8 : i32, mhlo.num_replicas = 1 : i32} { func.func public @main(%arg0: tensor<1x1024x128x1024xf32>) -> (tensor<1x1024x128x1024xf32> {jax.result_info = ""}) { @@ -886,6 +944,8 @@ module @jit_neg_basic4 attributes {mhlo.num_partitions = 8 : i32, mhlo.num_repli } } +// ----- + // jax/pjrt sharding target 2x4 for t3k - GSPMD negative, sharding [None, None, None, "x"] module @jit_neg_basic5 attributes {mhlo.num_partitions = 8 : i32, mhlo.num_replicas = 1 : i32} { func.func public @main(%arg0: tensor<1x1x1024x1024xf32>) -> (tensor<1x1x1024x1024xf32> {jax.result_info = ""}) { @@ -912,6 +972,8 @@ module @jit_neg_basic5 attributes {mhlo.num_partitions = 8 : i32, mhlo.num_repli } } +// ----- + // jax/pjrt sharding target 1x8 for t3k - GSPMD negative, sharding [None, None, None, "y"] module @jit_neg_basic6 attributes {mhlo.num_partitions = 8 : i32, mhlo.num_replicas = 1 : i32} { func.func public @main(%arg0: tensor<1x1024x128x1024xf32>) -> (tensor<1x1024x128x1024xf32> {jax.result_info = ""}) { @@ -938,6 +1000,8 @@ module @jit_neg_basic6 attributes {mhlo.num_partitions = 8 : i32, mhlo.num_repli } } +// ----- + // jax/pjrt sharding target 1x8 for t3k - GSPMD negative, sharding [None, "y", None, None] module @jit_neg_basic7 attributes {mhlo.num_partitions = 8 : i32, mhlo.num_replicas = 1 : i32} { func.func public @main(%arg0: tensor<1x1024x128x1024xf32>) -> (tensor<1x1024x128x1024xf32> {jax.result_info = ""}) { diff --git a/test/ttmlir/Conversion/StableHLOToTTIR/ccl/ccl_ops_shardy.mlir b/test/ttmlir/Conversion/StableHLOToTTIR/ccl/ccl_ops_shardy.mlir new file mode 100644 index 0000000000..af902ad920 --- /dev/null +++ b/test/ttmlir/Conversion/StableHLOToTTIR/ccl/ccl_ops_shardy.mlir @@ -0,0 +1,332 @@ +// REQUIRES: stablehlo +// RUN: ttmlir-opt -split-input-file --stablehlo-to-ttir-pipeline %s | FileCheck %s + +// jax/pjrt sharding target 1x2 for t3k - Shardy all_reduce +module @jit_matmul_shardy0 attributes {mhlo.num_partitions = 2 : i32, mhlo.num_replicas = 1 : i32} { + sdy.mesh @mesh = <["x"=1, "y"=2]> + func.func public @main(%arg0: tensor<8192x784xf32>, %arg1: tensor<784x16384xf32>) -> (tensor<8192x16384xf32> {jax.result_info = ""}) { + %0 = sdy.manual_computation(%arg0, %arg1) in_shardings=[<@mesh, [{"x"}, {"y"}]>, <@mesh, [{"y"}, {}]>] out_shardings=[<@mesh, [{"x"}, {}]>] manual_axes={"y", "x"} (%arg2: tensor<8192x392xf32>, %arg3: tensor<392x16384xf32>) { + %1 = stablehlo.dot_general %arg2, %arg3, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<8192x392xf32>, tensor<392x16384xf32>) -> tensor<8192x16384xf32> + %2 = "stablehlo.all_reduce"(%1) <{channel_handle = #stablehlo.channel_handle, replica_groups = dense<[[0, 1]]> : tensor<1x2xi64>, use_global_device_ids}> ({ + ^bb0(%arg4: tensor, %arg5: tensor): + %3 = stablehlo.add %arg4, %arg5 : tensor + stablehlo.return %3 : tensor + }) : (tensor<8192x16384xf32>) -> tensor<8192x16384xf32> + sdy.return %2 : tensor<8192x16384xf32> + } : (tensor<8192x784xf32>, tensor<784x16384xf32>) -> tensor<8192x16384xf32> + return %0 : tensor<8192x16384xf32> + } +} +// CHECK: "ttir.mesh_shard" +// CHECK-SAME: shard_dims = array +// CHECK-SAME: shard_direction = #tt.shard_direction +// CHECK-SAME: shard_shape = array +// CHECK-SAME: shard_type = #tt.shard_type +// CHECK: "ttir.mesh_shard" +// CHECK-SAME: shard_dims = array +// CHECK-SAME: shard_direction = #tt.shard_direction +// CHECK-SAME: shard_shape = array +// CHECK-SAME: shard_type = #tt.shard_type +// CHECK: %[[C:.*]] = "ttir.all_reduce"[[C:.*]] +// CHECK: "ttir.mesh_shard" +// CHECK-SAME: shard_dims = array +// CHECK-SAME: shard_direction = #tt.shard_direction +// CHECK-SAME: shard_shape = array +// CHECK-SAME: shard_type = #tt.shard_type + +// ----- + +// jax/pjrt sharding target 2x4 for t3k - Shardy all_reduce +module @jit_matmul_shardy1 attributes {mhlo.num_partitions = 8 : i32, mhlo.num_replicas = 1 : i32} { + sdy.mesh @mesh = <["x"=2, "y"=4]> + func.func public @main(%arg0: tensor<8192x784xf32>, %arg1: tensor<784x16384xf32>) -> (tensor<8192x16384xf32> {jax.result_info = ""}) { + %0 = sdy.manual_computation(%arg0, %arg1) in_shardings=[<@mesh, [{"x"}, {"y"}]>, <@mesh, [{"y"}, {}]>] out_shardings=[<@mesh, [{"x"}, {}]>] manual_axes={"x", "y"} (%arg2: tensor<4096x196xf32>, %arg3: tensor<196x16384xf32>) { + %1 = stablehlo.dot_general %arg2, %arg3, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<4096x196xf32>, tensor<196x16384xf32>) -> tensor<4096x16384xf32> + %2 = "stablehlo.all_reduce"(%1) <{channel_handle = #stablehlo.channel_handle, replica_groups = dense<[[0, 1, 2, 3], [4, 5, 6, 7]]> : tensor<2x4xi64>, use_global_device_ids}> ({ + ^bb0(%arg4: tensor, %arg5: tensor): + %3 = stablehlo.add %arg4, %arg5 : tensor + stablehlo.return %3 : tensor + }) : (tensor<4096x16384xf32>) -> tensor<4096x16384xf32> + sdy.return %2 : tensor<4096x16384xf32> + } : (tensor<8192x784xf32>, tensor<784x16384xf32>) -> tensor<8192x16384xf32> + return %0 : tensor<8192x16384xf32> + } +} +// CHECK: "ttir.mesh_shard" +// CHECK-SAME: shard_dims = array +// CHECK-SAME: shard_direction = #tt.shard_direction +// CHECK-SAME: shard_shape = array +// CHECK-SAME: shard_type = #tt.shard_type +// CHECK: "ttir.mesh_shard" +// CHECK-SAME: shard_dims = array +// CHECK-SAME: shard_direction = #tt.shard_direction +// CHECK-SAME: shard_shape = array +// CHECK-SAME: shard_type = #tt.shard_type +// CHECK: %[[C:.*]] = "ttir.all_reduce"[[C:.*]] +// CHECK: "ttir.mesh_shard" +// CHECK-SAME: shard_dims = array +// CHECK-SAME: shard_direction = #tt.shard_direction +// CHECK-SAME: shard_shape = array +// CHECK-SAME: shard_type = #tt.shard_type + +// ----- + +// jax/pjrt sharding target 1x8 for t3k - Shardy all_reduce +module @jit_matmul_shardy2 attributes {mhlo.num_partitions = 8 : i32, mhlo.num_replicas = 1 : i32} { + sdy.mesh @mesh = <["x"=1, "y"=8]> + func.func public @main(%arg0: tensor<8192x784xf32>, %arg1: tensor<784x16384xf32>) -> (tensor<8192x16384xf32> {jax.result_info = ""}) { + %0 = sdy.manual_computation(%arg0, %arg1) in_shardings=[<@mesh, [{"x"}, {"y"}]>, <@mesh, [{"y"}, {}]>] out_shardings=[<@mesh, [{"x"}, {}]>] manual_axes={"y", "x"} (%arg2: tensor<8192x98xf32>, %arg3: tensor<98x16384xf32>) { + %1 = stablehlo.dot_general %arg2, %arg3, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<8192x98xf32>, tensor<98x16384xf32>) -> tensor<8192x16384xf32> + %2 = "stablehlo.all_reduce"(%1) <{channel_handle = #stablehlo.channel_handle, replica_groups = dense<[[0, 1, 2, 3, 4, 5, 6, 7]]> : tensor<1x8xi64>, use_global_device_ids}> ({ + ^bb0(%arg4: tensor, %arg5: tensor): + %3 = stablehlo.add %arg4, %arg5 : tensor + stablehlo.return %3 : tensor + }) : (tensor<8192x16384xf32>) -> tensor<8192x16384xf32> + sdy.return %2 : tensor<8192x16384xf32> + } : (tensor<8192x784xf32>, tensor<784x16384xf32>) -> tensor<8192x16384xf32> + return %0 : tensor<8192x16384xf32> + } +} +// CHECK: "ttir.mesh_shard" +// CHECK-SAME: shard_dims = array +// CHECK-SAME: shard_direction = #tt.shard_direction +// CHECK-SAME: shard_shape = array +// CHECK-SAME: shard_type = #tt.shard_type +// CHECK: "ttir.mesh_shard" +// CHECK-SAME: shard_dims = array +// CHECK-SAME: shard_direction = #tt.shard_direction +// CHECK-SAME: shard_shape = array +// CHECK-SAME: shard_type = #tt.shard_type +// CHECK: %[[C:.*]] = "ttir.all_reduce"[[C:.*]] +// CHECK: "ttir.mesh_shard" +// CHECK-SAME: shard_dims = array +// CHECK-SAME: shard_direction = #tt.shard_direction +// CHECK-SAME: shard_shape = array +// CHECK-SAME: shard_type = #tt.shard_type + +// ----- + +// jax/pjrt sharding target 2x4 for t3k - Shardy negative, sharding [None, "x", None, "y"] +module @jit_neg_shardy0 attributes {mhlo.num_partitions = 8 : i32, mhlo.num_replicas = 1 : i32} { + sdy.mesh @mesh = <["x"=2, "y"=4]> + func.func public @main(%arg0: tensor<1x1024x128x1024xf32>) -> (tensor<1x1024x128x1024xf32> {jax.result_info = ""}) { + %0 = sdy.manual_computation(%arg0) in_shardings=[<@mesh, [{}, {"x"}, {}, {"y"}]>] out_shardings=[<@mesh, [{}, {"x"}, {}, {"y"}]>] manual_axes={"y", "x"} (%arg1: tensor<1x512x128x256xf32>) { + %1 = stablehlo.negate %arg1 : tensor<1x512x128x256xf32> + sdy.return %1 : tensor<1x512x128x256xf32> + } : (tensor<1x1024x128x1024xf32>) -> tensor<1x1024x128x1024xf32> + return %0 : tensor<1x1024x128x1024xf32> + } +} +// CHECK: "ttir.mesh_shard" +// CHECK-SAME: shard_dims = array +// CHECK-SAME: shard_direction = #tt.shard_direction +// CHECK-SAME: shard_shape = array +// CHECK-SAME: shard_type = #tt.shard_type +// CHECK: "ttir.mesh_shard" +// CHECK-SAME: shard_dims = array +// CHECK-SAME: shard_direction = #tt.shard_direction +// CHECK-SAME: shard_shape = array +// CHECK-SAME: shard_type = #tt.shard_type + +// ----- + +// jax/pjrt sharding target 2x4 for t3k - Shardy negative, sharding [None, "x", None, None] +module @jit_neg_shardy1 attributes {mhlo.num_partitions = 8 : i32, mhlo.num_replicas = 1 : i32} { + sdy.mesh @mesh = <["x"=2, "y"=4]> + func.func public @main(%arg0: tensor<1x1024x128x1024xf32>) -> (tensor<1x1024x128x1024xf32> {jax.result_info = ""}) { + %0 = sdy.manual_computation(%arg0) in_shardings=[<@mesh, [{}, {"x"}, {}, {}]>] out_shardings=[<@mesh, [{}, {"x"}, {}, {}]>] manual_axes={"x", "y"} (%arg1: tensor<1x512x128x1024xf32>) { + %1 = stablehlo.negate %arg1 : tensor<1x512x128x1024xf32> + sdy.return %1 : tensor<1x512x128x1024xf32> + } : (tensor<1x1024x128x1024xf32>) -> tensor<1x1024x128x1024xf32> + return %0 : tensor<1x1024x128x1024xf32> + } +} +// CHECK: "ttir.mesh_shard" +// CHECK-SAME: shard_dims = array +// CHECK-SAME: shard_direction = #tt.shard_direction +// CHECK-SAME: shard_shape = array +// CHECK-SAME: shard_type = #tt.shard_type +// CHECK: "ttir.mesh_shard" +// CHECK-SAME: shard_dims = array +// CHECK-SAME: shard_direction = #tt.shard_direction +// CHECK-SAME: shard_shape = array +// CHECK-SAME: shard_type = #tt.shard_type + +// ----- + +// jax/pjrt sharding target 2x4 for t3k - Shardy negative, sharding [None, None, None, "y"] +module @jit_neg_shardy2 attributes {mhlo.num_partitions = 8 : i32, mhlo.num_replicas = 1 : i32} { + sdy.mesh @mesh = <["x"=2, "y"=4]> + func.func public @main(%arg0: tensor<1x1024x128x1024xf32>) -> (tensor<1x1024x128x1024xf32> {jax.result_info = ""}) { + %0 = sdy.manual_computation(%arg0) in_shardings=[<@mesh, [{}, {}, {}, {"y"}]>] out_shardings=[<@mesh, [{}, {}, {}, {"y"}]>] manual_axes={"y", "x"} (%arg1: tensor<1x1024x128x256xf32>) { + %1 = stablehlo.negate %arg1 : tensor<1x1024x128x256xf32> + sdy.return %1 : tensor<1x1024x128x256xf32> + } : (tensor<1x1024x128x1024xf32>) -> tensor<1x1024x128x1024xf32> + return %0 : tensor<1x1024x128x1024xf32> + } +} +// CHECK: "ttir.mesh_shard" +// CHECK-SAME: shard_dims = array +// CHECK-SAME: shard_direction = #tt.shard_direction +// CHECK-SAME: shard_shape = array +// CHECK-SAME: shard_type = #tt.shard_type +// CHECK: "ttir.mesh_shard" +// CHECK-SAME: shard_dims = array +// CHECK-SAME: shard_direction = #tt.shard_direction +// CHECK-SAME: shard_shape = array +// CHECK-SAME: shard_type = #tt.shard_type + +// ----- + +// jax/pjrt sharding target 2x4 for t3k - Shardy negative, sharding [None, "y", None, "x"] +module @jit_neg_shardy3 attributes {mhlo.num_partitions = 8 : i32, mhlo.num_replicas = 1 : i32} { + sdy.mesh @mesh = <["x"=2, "y"=4]> + func.func public @main(%arg0: tensor<1x1024x128x1024xf32>) -> (tensor<1x1024x128x1024xf32> {jax.result_info = ""}) { + %0 = sdy.manual_computation(%arg0) in_shardings=[<@mesh, [{}, {"y"}, {}, {"x"}]>] out_shardings=[<@mesh, [{}, {"y"}, {}, {"x"}]>] manual_axes={"x", "y"} (%arg1: tensor<1x256x128x512xf32>) { + %1 = stablehlo.negate %arg1 : tensor<1x256x128x512xf32> + sdy.return %1 : tensor<1x256x128x512xf32> + } : (tensor<1x1024x128x1024xf32>) -> tensor<1x1024x128x1024xf32> + return %0 : tensor<1x1024x128x1024xf32> + } +} +// CHECK: "ttir.mesh_shard" +// CHECK-SAME: shard_dims = array +// CHECK-SAME: shard_direction = #tt.shard_direction +// CHECK-SAME: shard_shape = array +// CHECK-SAME: shard_type = #tt.shard_type +// CHECK: "ttir.mesh_shard" +// CHECK-SAME: shard_dims = array +// CHECK-SAME: shard_direction = #tt.shard_direction +// CHECK-SAME: shard_shape = array +// CHECK-SAME: shard_type = #tt.shard_type + +// ----- + +// jax/pjrt sharding target 2x4 for t3k - Shardy negative, sharding [None, "y", None, None] +module @jit_neg_shardy4 attributes {mhlo.num_partitions = 8 : i32, mhlo.num_replicas = 1 : i32} { + sdy.mesh @mesh = <["x"=2, "y"=4]> + func.func public @main(%arg0: tensor<1x1024x128x1024xf32>) -> (tensor<1x1024x128x1024xf32> {jax.result_info = ""}) { + %0 = sdy.manual_computation(%arg0) in_shardings=[<@mesh, [{}, {"y"}, {}, {}]>] out_shardings=[<@mesh, [{}, {"y"}, {}, {}]>] manual_axes={"x", "y"} (%arg1: tensor<1x256x128x1024xf32>) { + %1 = stablehlo.negate %arg1 : tensor<1x256x128x1024xf32> + sdy.return %1 : tensor<1x256x128x1024xf32> + } : (tensor<1x1024x128x1024xf32>) -> tensor<1x1024x128x1024xf32> + return %0 : tensor<1x1024x128x1024xf32> + } +} +// CHECK: "ttir.mesh_shard" +// CHECK-SAME: shard_dims = array +// CHECK-SAME: shard_direction = #tt.shard_direction +// CHECK-SAME: shard_shape = array +// CHECK-SAME: shard_type = #tt.shard_type +// CHECK: "ttir.mesh_shard" +// CHECK-SAME: shard_dims = array +// CHECK-SAME: shard_direction = #tt.shard_direction +// CHECK-SAME: shard_shape = array +// CHECK-SAME: shard_type = #tt.shard_type + +// ----- + +// jax/pjrt sharding target 2x4 for t3k - Shardy negative, sharding [None, None, None, "x"] +module @jit_neg_shardy5 attributes {mhlo.num_partitions = 8 : i32, mhlo.num_replicas = 1 : i32} { + sdy.mesh @mesh = <["x"=2, "y"=4]> + func.func public @main(%arg0: tensor<1x1024x128x1024xf32>) -> (tensor<1x1024x128x1024xf32> {jax.result_info = ""}) { + %0 = sdy.manual_computation(%arg0) in_shardings=[<@mesh, [{}, {}, {}, {"x"}]>] out_shardings=[<@mesh, [{}, {}, {}, {"x"}]>] manual_axes={"x", "y"} (%arg1: tensor<1x1024x128x512xf32>) { + %1 = stablehlo.negate %arg1 : tensor<1x1024x128x512xf32> + sdy.return %1 : tensor<1x1024x128x512xf32> + } : (tensor<1x1024x128x1024xf32>) -> tensor<1x1024x128x1024xf32> + return %0 : tensor<1x1024x128x1024xf32> + } +} +// CHECK: "ttir.mesh_shard" +// CHECK-SAME: shard_dims = array +// CHECK-SAME: shard_direction = #tt.shard_direction +// CHECK-SAME: shard_shape = array +// CHECK-SAME: shard_type = #tt.shard_type +// CHECK: "ttir.mesh_shard" +// CHECK-SAME: shard_dims = array +// CHECK-SAME: shard_direction = #tt.shard_direction +// CHECK-SAME: shard_shape = array +// CHECK-SAME: shard_type = #tt.shard_type + +// ----- + +// jax/pjrt sharding target 1x8 for t3k - Shardy negative, sharding [None, None, None, "y"] +module @jit_neg_shardy6 attributes {mhlo.num_partitions = 8 : i32, mhlo.num_replicas = 1 : i32} { + sdy.mesh @mesh = <["x"=1, "y"=8]> + func.func public @main(%arg0: tensor<1x1024x128x1024xf32>) -> (tensor<1x1024x128x1024xf32> {jax.result_info = ""}) { + %0 = sdy.manual_computation(%arg0) in_shardings=[<@mesh, [{}, {}, {}, {"y"}]>] out_shardings=[<@mesh, [{}, {}, {}, {"y"}]>] manual_axes={"y", "x"} (%arg1: tensor<1x1024x128x128xf32>) { + %1 = stablehlo.negate %arg1 : tensor<1x1024x128x128xf32> + sdy.return %1 : tensor<1x1024x128x128xf32> + } : (tensor<1x1024x128x1024xf32>) -> tensor<1x1024x128x1024xf32> + return %0 : tensor<1x1024x128x1024xf32> + } +} +// CHECK: "ttir.mesh_shard" +// CHECK-SAME: shard_dims = array +// CHECK-SAME: shard_direction = #tt.shard_direction +// CHECK-SAME: shard_shape = array +// CHECK-SAME: shard_type = #tt.shard_type +// CHECK: "ttir.mesh_shard" +// CHECK-SAME: shard_dims = array +// CHECK-SAME: shard_direction = #tt.shard_direction +// CHECK-SAME: shard_shape = array +// CHECK-SAME: shard_type = #tt.shard_type + +// ----- + +// jax/pjrt sharding target 1x8 for t3k - Shardy negative, sharding [None, "y", None, None] +module @jit_neg_shardy7 attributes {mhlo.num_partitions = 8 : i32, mhlo.num_replicas = 1 : i32} { + sdy.mesh @mesh = <["x"=1, "y"=8]> + func.func public @main(%arg0: tensor<1x1024x128x1024xf32>) -> (tensor<1x1024x128x1024xf32> {jax.result_info = ""}) { + %0 = sdy.manual_computation(%arg0) in_shardings=[<@mesh, [{}, {"y"}, {}, {}]>] out_shardings=[<@mesh, [{}, {"y"}, {}, {}]>] manual_axes={"y", "x"} (%arg1: tensor<1x128x128x1024xf32>) { + %1 = stablehlo.negate %arg1 : tensor<1x128x128x1024xf32> + sdy.return %1 : tensor<1x128x128x1024xf32> + } : (tensor<1x1024x128x1024xf32>) -> tensor<1x1024x128x1024xf32> + return %0 : tensor<1x1024x128x1024xf32> + } +} +// CHECK: "ttir.mesh_shard" +// CHECK-SAME: shard_dims = array +// CHECK-SAME: shard_direction = #tt.shard_direction +// CHECK-SAME: shard_shape = array +// CHECK-SAME: shard_type = #tt.shard_type +// CHECK: "ttir.mesh_shard" +// CHECK-SAME: shard_dims = array +// CHECK-SAME: shard_direction = #tt.shard_direction +// CHECK-SAME: shard_shape = array +// CHECK-SAME: shard_type = #tt.shard_type + +// ----- + +// jax/pjrt sharding target 2x4 for t3k - Shardy all_reduce with automatic input sharding +module @jit_matmul_shardy_automatic attributes {mhlo.num_partitions = 8 : i32, mhlo.num_replicas = 1 : i32} { + sdy.mesh @mesh = <["x"=2, "y"=4]> + func.func public @main(%arg0: tensor<8192x784xf32> {sdy.sharding = #sdy.sharding<@mesh, [{"x"}, {"y"}]>}, %arg1: tensor<784x16384xf32> {sdy.sharding = #sdy.sharding<@mesh, [{"y"}, {}]>}) -> (tensor<8192x16384xf32> {jax.result_info = ""}) { + %0 = sdy.manual_computation(%arg0, %arg1) in_shardings=[<@mesh, [{"x"}, {"y"}]>, <@mesh, [{"y"}, {}]>] out_shardings=[<@mesh, [{"x"}, {}]>] manual_axes={"x", "y"} (%arg2: tensor<4096x196xf32>, %arg3: tensor<196x16384xf32>) { + %1 = stablehlo.dot_general %arg2, %arg3, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<4096x196xf32>, tensor<196x16384xf32>) -> tensor<4096x16384xf32> + %2 = "stablehlo.all_reduce"(%1) <{channel_handle = #stablehlo.channel_handle, replica_groups = dense<[[0, 1, 2, 3], [4, 5, 6, 7]]> : tensor<2x4xi64>, use_global_device_ids}> ({ + ^bb0(%arg4: tensor, %arg5: tensor): + %3 = stablehlo.add %arg4, %arg5 : tensor + stablehlo.return %3 : tensor + }) : (tensor<4096x16384xf32>) -> tensor<4096x16384xf32> + sdy.return %2 : tensor<4096x16384xf32> + } : (tensor<8192x784xf32>, tensor<784x16384xf32>) -> tensor<8192x16384xf32> + return %0 : tensor<8192x16384xf32> + } +} +// CHECK: "ttir.mesh_shard" +// CHECK-SAME: shard_dims = array +// CHECK-SAME: shard_direction = #tt.shard_direction +// CHECK-SAME: shard_shape = array +// CHECK-SAME: shard_type = #tt.shard_type +// CHECK: "ttir.mesh_shard" +// CHECK-SAME: shard_dims = array +// CHECK-SAME: shard_direction = #tt.shard_direction +// CHECK-SAME: shard_shape = array +// CHECK-SAME: shard_type = #tt.shard_type +// CHECK: %[[C:.*]] = "ttir.all_reduce"[[C:.*]] +// CHECK: "ttir.mesh_shard" +// CHECK-SAME: shard_dims = array +// CHECK-SAME: shard_direction = #tt.shard_direction +// CHECK-SAME: shard_shape = array +// CHECK-SAME: shard_type = #tt.shard_type diff --git a/test/ttmlir/Conversion/StableHLOToTTIR/ccl/e2e_dp_gspmd.mlir b/test/ttmlir/Conversion/StableHLOToTTIR/ccl/e2e_dp_gspmd.mlir new file mode 100644 index 0000000000..2acb787afe --- /dev/null +++ b/test/ttmlir/Conversion/StableHLOToTTIR/ccl/e2e_dp_gspmd.mlir @@ -0,0 +1,192 @@ +// REQUIRES: stablehlo +// RUN: ttmlir-opt -split-input-file --stablehlo-to-ttir-pipeline %s | FileCheck %s + +module @jit_loss_dp attributes {mhlo.num_partitions = 8 : i32, mhlo.num_replicas = 1 : i32} { + func.func public @main(%arg0: tensor<784x128xf32> {mhlo.sharding = "{replicated}"}, %arg1: tensor<128xf32> {mhlo.sharding = "{replicated}"}, %arg2: tensor<128x128xf32> {mhlo.sharding = "{replicated}"}, %arg3: tensor<128xf32> {mhlo.sharding = "{replicated}"}, %arg4: tensor<128x128xf32> {mhlo.sharding = "{replicated}"}, %arg5: tensor<128xf32> {mhlo.sharding = "{replicated}"}, %arg6: tensor<128x128xf32> {mhlo.sharding = "{replicated}"}, %arg7: tensor<128xf32> {mhlo.sharding = "{replicated}"}, %arg8: tensor<128x128xf32> {mhlo.sharding = "{replicated}"}, %arg9: tensor<128xf32> {mhlo.sharding = "{replicated}"}, %arg10: tensor<128x8xf32> {mhlo.sharding = "{replicated}"}, %arg11: tensor<8xf32> {mhlo.sharding = "{replicated}"}, %arg12: tensor<32x784xf32> {mhlo.sharding = "{devices=[8,1]<=[8]}"}, %arg13: tensor<32x8xf32> {mhlo.sharding = "{devices=[8,1]<=[8]}"}) -> (tensor {jax.result_info = ""}) { + %0 = stablehlo.custom_call @Sharding(%arg0) {mhlo.sharding = "{replicated}"} : (tensor<784x128xf32>) -> tensor<784x128xf32> + %1 = stablehlo.custom_call @SPMDFullToShardShape(%0) {mhlo.sharding = "{manual}"} : (tensor<784x128xf32>) -> tensor<784x128xf32> + %2 = stablehlo.custom_call @Sharding(%arg1) {mhlo.sharding = "{replicated}"} : (tensor<128xf32>) -> tensor<128xf32> + %3 = stablehlo.custom_call @SPMDFullToShardShape(%2) {mhlo.sharding = "{manual}"} : (tensor<128xf32>) -> tensor<128xf32> + %4 = stablehlo.custom_call @Sharding(%arg2) {mhlo.sharding = "{replicated}"} : (tensor<128x128xf32>) -> tensor<128x128xf32> + %5 = stablehlo.custom_call @SPMDFullToShardShape(%4) {mhlo.sharding = "{manual}"} : (tensor<128x128xf32>) -> tensor<128x128xf32> + %6 = stablehlo.custom_call @Sharding(%arg3) {mhlo.sharding = "{replicated}"} : (tensor<128xf32>) -> tensor<128xf32> + %7 = stablehlo.custom_call @SPMDFullToShardShape(%6) {mhlo.sharding = "{manual}"} : (tensor<128xf32>) -> tensor<128xf32> + %8 = stablehlo.custom_call @Sharding(%arg4) {mhlo.sharding = "{replicated}"} : (tensor<128x128xf32>) -> tensor<128x128xf32> + %9 = stablehlo.custom_call @SPMDFullToShardShape(%8) {mhlo.sharding = "{manual}"} : (tensor<128x128xf32>) -> tensor<128x128xf32> + %10 = stablehlo.custom_call @Sharding(%arg5) {mhlo.sharding = "{replicated}"} : (tensor<128xf32>) -> tensor<128xf32> + %11 = stablehlo.custom_call @SPMDFullToShardShape(%10) {mhlo.sharding = "{manual}"} : (tensor<128xf32>) -> tensor<128xf32> + %12 = stablehlo.custom_call @Sharding(%arg6) {mhlo.sharding = "{replicated}"} : (tensor<128x128xf32>) -> tensor<128x128xf32> + %13 = stablehlo.custom_call @SPMDFullToShardShape(%12) {mhlo.sharding = "{manual}"} : (tensor<128x128xf32>) -> tensor<128x128xf32> + %14 = stablehlo.custom_call @Sharding(%arg7) {mhlo.sharding = "{replicated}"} : (tensor<128xf32>) -> tensor<128xf32> + %15 = stablehlo.custom_call @SPMDFullToShardShape(%14) {mhlo.sharding = "{manual}"} : (tensor<128xf32>) -> tensor<128xf32> + %16 = stablehlo.custom_call @Sharding(%arg8) {mhlo.sharding = "{replicated}"} : (tensor<128x128xf32>) -> tensor<128x128xf32> + %17 = stablehlo.custom_call @SPMDFullToShardShape(%16) {mhlo.sharding = "{manual}"} : (tensor<128x128xf32>) -> tensor<128x128xf32> + %18 = stablehlo.custom_call @Sharding(%arg9) {mhlo.sharding = "{replicated}"} : (tensor<128xf32>) -> tensor<128xf32> + %19 = stablehlo.custom_call @SPMDFullToShardShape(%18) {mhlo.sharding = "{manual}"} : (tensor<128xf32>) -> tensor<128xf32> + %20 = stablehlo.custom_call @Sharding(%arg10) {mhlo.sharding = "{replicated}"} : (tensor<128x8xf32>) -> tensor<128x8xf32> + %21 = stablehlo.custom_call @SPMDFullToShardShape(%20) {mhlo.sharding = "{manual}"} : (tensor<128x8xf32>) -> tensor<128x8xf32> + %22 = stablehlo.custom_call @Sharding(%arg11) {mhlo.sharding = "{replicated}"} : (tensor<8xf32>) -> tensor<8xf32> + %23 = stablehlo.custom_call @SPMDFullToShardShape(%22) {mhlo.sharding = "{manual}"} : (tensor<8xf32>) -> tensor<8xf32> + %24 = stablehlo.custom_call @Sharding(%arg12) {mhlo.sharding = "{devices=[8,1]<=[8]}"} : (tensor<32x784xf32>) -> tensor<32x784xf32> + %25 = stablehlo.custom_call @SPMDFullToShardShape(%24) {mhlo.sharding = "{manual}"} : (tensor<32x784xf32>) -> tensor<4x784xf32> + %26 = stablehlo.custom_call @Sharding(%arg13) {mhlo.sharding = "{devices=[8,1]<=[8]}"} : (tensor<32x8xf32>) -> tensor<32x8xf32> + %27 = stablehlo.custom_call @SPMDFullToShardShape(%26) {mhlo.sharding = "{manual}"} : (tensor<32x8xf32>) -> tensor<4x8xf32> + %28 = call @shmap_body(%1, %3, %5, %7, %9, %11, %13, %15, %17, %19, %21, %23, %25, %27) : (tensor<784x128xf32>, tensor<128xf32>, tensor<128x128xf32>, tensor<128xf32>, tensor<128x128xf32>, tensor<128xf32>, tensor<128x128xf32>, tensor<128xf32>, tensor<128x128xf32>, tensor<128xf32>, tensor<128x8xf32>, tensor<8xf32>, tensor<4x784xf32>, tensor<4x8xf32>) -> tensor + %29 = stablehlo.custom_call @Sharding(%28) {mhlo.sharding = "{manual}"} : (tensor) -> tensor + %30 = stablehlo.custom_call @SPMDShardToFullShape(%29) {mhlo.sharding = "{replicated}"} : (tensor) -> tensor + return %30 : tensor + } + func.func private @shmap_body(%arg0: tensor<784x128xf32>, %arg1: tensor<128xf32>, %arg2: tensor<128x128xf32>, %arg3: tensor<128xf32>, %arg4: tensor<128x128xf32>, %arg5: tensor<128xf32>, %arg6: tensor<128x128xf32>, %arg7: tensor<128xf32>, %arg8: tensor<128x128xf32>, %arg9: tensor<128xf32>, %arg10: tensor<128x8xf32>, %arg11: tensor<8xf32>, %arg12: tensor<4x784xf32>, %arg13: tensor<4x8xf32>) -> (tensor {jax.result_info = "[]"}) { + %cst = stablehlo.constant dense<8.000000e+00> : tensor + %cst_0 = stablehlo.constant dense<4.000000e+00> : tensor + %cst_1 = stablehlo.constant dense<0.000000e+00> : tensor + %0 = stablehlo.dot_general %arg12, %arg0, contracting_dims = [1] x [0] : (tensor<4x784xf32>, tensor<784x128xf32>) -> tensor<4x128xf32> + %1 = stablehlo.broadcast_in_dim %arg1, dims = [1] : (tensor<128xf32>) -> tensor<1x128xf32> + %2 = stablehlo.broadcast_in_dim %1, dims = [0, 1] : (tensor<1x128xf32>) -> tensor<4x128xf32> + %3 = stablehlo.add %0, %2 : tensor<4x128xf32> + %4 = call @relu(%3) : (tensor<4x128xf32>) -> tensor<4x128xf32> + %5 = stablehlo.dot_general %4, %arg2, contracting_dims = [1] x [0] : (tensor<4x128xf32>, tensor<128x128xf32>) -> tensor<4x128xf32> + %6 = stablehlo.broadcast_in_dim %arg3, dims = [1] : (tensor<128xf32>) -> tensor<1x128xf32> + %7 = stablehlo.broadcast_in_dim %6, dims = [0, 1] : (tensor<1x128xf32>) -> tensor<4x128xf32> + %8 = stablehlo.add %5, %7 : tensor<4x128xf32> + %9 = call @relu_0(%8) : (tensor<4x128xf32>) -> tensor<4x128xf32> + %10 = stablehlo.dot_general %9, %arg4, contracting_dims = [1] x [0] : (tensor<4x128xf32>, tensor<128x128xf32>) -> tensor<4x128xf32> + %11 = stablehlo.broadcast_in_dim %arg5, dims = [1] : (tensor<128xf32>) -> tensor<1x128xf32> + %12 = stablehlo.broadcast_in_dim %11, dims = [0, 1] : (tensor<1x128xf32>) -> tensor<4x128xf32> + %13 = stablehlo.add %10, %12 : tensor<4x128xf32> + %14 = call @relu_1(%13) : (tensor<4x128xf32>) -> tensor<4x128xf32> + %15 = stablehlo.dot_general %14, %arg6, contracting_dims = [1] x [0] : (tensor<4x128xf32>, tensor<128x128xf32>) -> tensor<4x128xf32> + %16 = stablehlo.broadcast_in_dim %arg7, dims = [1] : (tensor<128xf32>) -> tensor<1x128xf32> + %17 = stablehlo.broadcast_in_dim %16, dims = [0, 1] : (tensor<1x128xf32>) -> tensor<4x128xf32> + %18 = stablehlo.add %15, %17 : tensor<4x128xf32> + %19 = call @relu_2(%18) : (tensor<4x128xf32>) -> tensor<4x128xf32> + %20 = stablehlo.dot_general %19, %arg8, contracting_dims = [1] x [0] : (tensor<4x128xf32>, tensor<128x128xf32>) -> tensor<4x128xf32> + %21 = stablehlo.broadcast_in_dim %arg9, dims = [1] : (tensor<128xf32>) -> tensor<1x128xf32> + %22 = stablehlo.broadcast_in_dim %21, dims = [0, 1] : (tensor<1x128xf32>) -> tensor<4x128xf32> + %23 = stablehlo.add %20, %22 : tensor<4x128xf32> + %24 = call @relu_3(%23) : (tensor<4x128xf32>) -> tensor<4x128xf32> + %25 = stablehlo.dot_general %24, %arg10, contracting_dims = [1] x [0] : (tensor<4x128xf32>, tensor<128x8xf32>) -> tensor<4x8xf32> + %26 = stablehlo.broadcast_in_dim %arg11, dims = [1] : (tensor<8xf32>) -> tensor<1x8xf32> + %27 = stablehlo.broadcast_in_dim %26, dims = [0, 1] : (tensor<1x8xf32>) -> tensor<4x8xf32> + %28 = stablehlo.add %25, %27 : tensor<4x8xf32> + %29 = stablehlo.subtract %28, %arg13 : tensor<4x8xf32> + %30 = stablehlo.multiply %29, %29 : tensor<4x8xf32> + %31 = stablehlo.reduce(%30 init: %cst_1) applies stablehlo.add across dimensions = [1] : (tensor<4x8xf32>, tensor) -> tensor<4xf32> + %32 = stablehlo.reduce(%31 init: %cst_1) applies stablehlo.add across dimensions = [0] : (tensor<4xf32>, tensor) -> tensor + %33 = stablehlo.divide %32, %cst_0 : tensor + %34 = "stablehlo.all_reduce"(%33) <{channel_handle = #stablehlo.channel_handle, replica_groups = dense<[[0, 1, 2, 3, 4, 5, 6, 7]]> : tensor<1x8xi64>, use_global_device_ids}> ({ + ^bb0(%arg14: tensor, %arg15: tensor): + %36 = stablehlo.add %arg14, %arg15 : tensor + stablehlo.return %36 : tensor + }) : (tensor) -> tensor + %35 = stablehlo.divide %34, %cst : tensor + return %35 : tensor + } + func.func private @relu(%arg0: tensor<4x128xf32>) -> tensor<4x128xf32> { + %cst = stablehlo.constant dense<0.000000e+00> : tensor + %0 = stablehlo.broadcast_in_dim %cst, dims = [] : (tensor) -> tensor<4x128xf32> + %1 = stablehlo.maximum %arg0, %0 : tensor<4x128xf32> + return %1 : tensor<4x128xf32> + } + func.func private @relu_0(%arg0: tensor<4x128xf32>) -> tensor<4x128xf32> { + %cst = stablehlo.constant dense<0.000000e+00> : tensor + %0 = stablehlo.broadcast_in_dim %cst, dims = [] : (tensor) -> tensor<4x128xf32> + %1 = stablehlo.maximum %arg0, %0 : tensor<4x128xf32> + return %1 : tensor<4x128xf32> + } + func.func private @relu_1(%arg0: tensor<4x128xf32>) -> tensor<4x128xf32> { + %cst = stablehlo.constant dense<0.000000e+00> : tensor + %0 = stablehlo.broadcast_in_dim %cst, dims = [] : (tensor) -> tensor<4x128xf32> + %1 = stablehlo.maximum %arg0, %0 : tensor<4x128xf32> + return %1 : tensor<4x128xf32> + } + func.func private @relu_2(%arg0: tensor<4x128xf32>) -> tensor<4x128xf32> { + %cst = stablehlo.constant dense<0.000000e+00> : tensor + %0 = stablehlo.broadcast_in_dim %cst, dims = [] : (tensor) -> tensor<4x128xf32> + %1 = stablehlo.maximum %arg0, %0 : tensor<4x128xf32> + return %1 : tensor<4x128xf32> + } + func.func private @relu_3(%arg0: tensor<4x128xf32>) -> tensor<4x128xf32> { + %cst = stablehlo.constant dense<0.000000e+00> : tensor + %0 = stablehlo.broadcast_in_dim %cst, dims = [] : (tensor) -> tensor<4x128xf32> + %1 = stablehlo.maximum %arg0, %0 : tensor<4x128xf32> + return %1 : tensor<4x128xf32> + } +} + +// CHECK-LABEL @main +// CHECK: "ttir.mesh_shard" +// CHECK-SAME: shard_dims = array +// CHECK-SAME: shard_direction = #tt.shard_direction +// CHECK-SAME: shard_shape = array +// CHECK-SAME: shard_type = #tt.shard_type +// CHECK: "ttir.mesh_shard" +// CHECK-SAME: shard_dims = array +// CHECK-SAME: shard_direction = #tt.shard_direction +// CHECK-SAME: shard_shape = array +// CHECK-SAME: shard_type = #tt.shard_type +// CHECK: "ttir.mesh_shard" +// CHECK-SAME: shard_dims = array +// CHECK-SAME: shard_direction = #tt.shard_direction +// CHECK-SAME: shard_shape = array +// CHECK-SAME: shard_type = #tt.shard_type +// CHECK: "ttir.mesh_shard" +// CHECK-SAME: shard_dims = array +// CHECK-SAME: shard_direction = #tt.shard_direction +// CHECK-SAME: shard_shape = array +// CHECK-SAME: shard_type = #tt.shard_type +// CHECK: "ttir.mesh_shard" +// CHECK-SAME: shard_dims = array +// CHECK-SAME: shard_direction = #tt.shard_direction +// CHECK-SAME: shard_shape = array +// CHECK-SAME: shard_type = #tt.shard_type +// CHECK: "ttir.mesh_shard" +// CHECK-SAME: shard_dims = array +// CHECK-SAME: shard_direction = #tt.shard_direction +// CHECK-SAME: shard_shape = array +// CHECK-SAME: shard_type = #tt.shard_type +// CHECK: "ttir.mesh_shard" +// CHECK-SAME: shard_dims = array +// CHECK-SAME: shard_direction = #tt.shard_direction +// CHECK-SAME: shard_shape = array +// CHECK-SAME: shard_type = #tt.shard_type +// CHECK: "ttir.mesh_shard" +// CHECK-SAME: shard_dims = array +// CHECK-SAME: shard_direction = #tt.shard_direction +// CHECK-SAME: shard_shape = array +// CHECK-SAME: shard_type = #tt.shard_type +// CHECK: "ttir.mesh_shard" +// CHECK-SAME: shard_dims = array +// CHECK-SAME: shard_direction = #tt.shard_direction +// CHECK-SAME: shard_shape = array +// CHECK-SAME: shard_type = #tt.shard_type +// CHECK: "ttir.mesh_shard" +// CHECK-SAME: shard_dims = array +// CHECK-SAME: shard_direction = #tt.shard_direction +// CHECK-SAME: shard_shape = array +// CHECK-SAME: shard_type = #tt.shard_type +// CHECK: "ttir.mesh_shard" +// CHECK-SAME: shard_dims = array +// CHECK-SAME: shard_direction = #tt.shard_direction +// CHECK-SAME: shard_shape = array +// CHECK-SAME: shard_type = #tt.shard_type +// CHECK: "ttir.mesh_shard" +// CHECK-SAME: shard_dims = array +// CHECK-SAME: shard_direction = #tt.shard_direction +// CHECK-SAME: shard_shape = array +// CHECK-SAME: shard_type = #tt.shard_type +// CHECK: "ttir.mesh_shard" +// CHECK-SAME: shard_dims = array +// CHECK-SAME: shard_direction = #tt.shard_direction +// CHECK-SAME: shard_shape = array +// CHECK-SAME: shard_type = #tt.shard_type +// CHECK: "ttir.mesh_shard" +// CHECK-SAME: shard_dims = array +// CHECK-SAME: shard_direction = #tt.shard_direction +// CHECK-SAME: shard_shape = array +// CHECK-SAME: shard_type = #tt.shard_type +// CHECK: "ttir.mesh_shard" +// CHECK-SAME: shard_dims = array +// CHECK-SAME: shard_direction = #tt.shard_direction +// CHECK-SAME: shard_shape = array +// CHECK-SAME: shard_type = #tt.shard_type diff --git a/test/ttmlir/Conversion/StableHLOToTTIR/ccl/e2e_dp_shardy.mlir b/test/ttmlir/Conversion/StableHLOToTTIR/ccl/e2e_dp_shardy.mlir new file mode 100644 index 0000000000..facbd69166 --- /dev/null +++ b/test/ttmlir/Conversion/StableHLOToTTIR/ccl/e2e_dp_shardy.mlir @@ -0,0 +1,163 @@ +// REQUIRES: stablehlo +// RUN: ttmlir-opt -split-input-file --stablehlo-to-ttir-pipeline %s | FileCheck %s + +module @jit_loss_dp attributes {mhlo.num_partitions = 8 : i32, mhlo.num_replicas = 1 : i32} { + sdy.mesh @mesh = <["x"=1, "y"=8]> + func.func public @main(%arg0: tensor<784x128xf32> {sdy.sharding = #sdy.sharding<@mesh, [{}, {}]>}, %arg1: tensor<128xf32> {sdy.sharding = #sdy.sharding<@mesh, [{}]>}, %arg2: tensor<128x128xf32> {sdy.sharding = #sdy.sharding<@mesh, [{}, {}]>}, %arg3: tensor<128xf32> {sdy.sharding = #sdy.sharding<@mesh, [{}]>}, %arg4: tensor<128x128xf32> {sdy.sharding = #sdy.sharding<@mesh, [{}, {}]>}, %arg5: tensor<128xf32> {sdy.sharding = #sdy.sharding<@mesh, [{}]>}, %arg6: tensor<128x128xf32> {sdy.sharding = #sdy.sharding<@mesh, [{}, {}]>}, %arg7: tensor<128xf32> {sdy.sharding = #sdy.sharding<@mesh, [{}]>}, %arg8: tensor<128x128xf32> {sdy.sharding = #sdy.sharding<@mesh, [{}, {}]>}, %arg9: tensor<128xf32> {sdy.sharding = #sdy.sharding<@mesh, [{}]>}, %arg10: tensor<128x8xf32> {sdy.sharding = #sdy.sharding<@mesh, [{}, {}]>}, %arg11: tensor<8xf32> {sdy.sharding = #sdy.sharding<@mesh, [{}]>}, %arg12: tensor<32x784xf32> {sdy.sharding = #sdy.sharding<@mesh, [{"y"}, {}]>}, %arg13: tensor<32x8xf32> {sdy.sharding = #sdy.sharding<@mesh, [{"y"}, {}]>}) -> (tensor {jax.result_info = ""}) { + %0 = sdy.manual_computation(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8, %arg9, %arg10, %arg11, %arg12, %arg13) in_shardings=[<@mesh, [{}, {}]>, <@mesh, [{}]>, <@mesh, [{}, {}]>, <@mesh, [{}]>, <@mesh, [{}, {}]>, <@mesh, [{}]>, <@mesh, [{}, {}]>, <@mesh, [{}]>, <@mesh, [{}, {}]>, <@mesh, [{}]>, <@mesh, [{}, {}]>, <@mesh, [{}]>, <@mesh, [{"y"}, {}]>, <@mesh, [{"y"}, {}]>] out_shardings=[<@mesh, []>] manual_axes={"y", "x"} (%arg14: tensor<784x128xf32>, %arg15: tensor<128xf32>, %arg16: tensor<128x128xf32>, %arg17: tensor<128xf32>, %arg18: tensor<128x128xf32>, %arg19: tensor<128xf32>, %arg20: tensor<128x128xf32>, %arg21: tensor<128xf32>, %arg22: tensor<128x128xf32>, %arg23: tensor<128xf32>, %arg24: tensor<128x8xf32>, %arg25: tensor<8xf32>, %arg26: tensor<4x784xf32>, %arg27: tensor<4x8xf32>) { + %1 = stablehlo.dot_general %arg26, %arg14, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<4x784xf32>, tensor<784x128xf32>) -> tensor<4x128xf32> + %2 = stablehlo.broadcast_in_dim %arg15, dims = [1] : (tensor<128xf32>) -> tensor<1x128xf32> + %3 = stablehlo.broadcast_in_dim %2, dims = [0, 1] : (tensor<1x128xf32>) -> tensor<4x128xf32> + %4 = stablehlo.add %1, %3 : tensor<4x128xf32> + %5 = func.call @relu(%4) : (tensor<4x128xf32>) -> tensor<4x128xf32> + %6 = stablehlo.dot_general %5, %arg16, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<4x128xf32>, tensor<128x128xf32>) -> tensor<4x128xf32> + %7 = stablehlo.broadcast_in_dim %arg17, dims = [1] : (tensor<128xf32>) -> tensor<1x128xf32> + %8 = stablehlo.broadcast_in_dim %7, dims = [0, 1] : (tensor<1x128xf32>) -> tensor<4x128xf32> + %9 = stablehlo.add %6, %8 : tensor<4x128xf32> + %10 = func.call @relu_0(%9) : (tensor<4x128xf32>) -> tensor<4x128xf32> + %11 = stablehlo.dot_general %10, %arg18, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<4x128xf32>, tensor<128x128xf32>) -> tensor<4x128xf32> + %12 = stablehlo.broadcast_in_dim %arg19, dims = [1] : (tensor<128xf32>) -> tensor<1x128xf32> + %13 = stablehlo.broadcast_in_dim %12, dims = [0, 1] : (tensor<1x128xf32>) -> tensor<4x128xf32> + %14 = stablehlo.add %11, %13 : tensor<4x128xf32> + %15 = func.call @relu_1(%14) : (tensor<4x128xf32>) -> tensor<4x128xf32> + %16 = stablehlo.dot_general %15, %arg20, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<4x128xf32>, tensor<128x128xf32>) -> tensor<4x128xf32> + %17 = stablehlo.broadcast_in_dim %arg21, dims = [1] : (tensor<128xf32>) -> tensor<1x128xf32> + %18 = stablehlo.broadcast_in_dim %17, dims = [0, 1] : (tensor<1x128xf32>) -> tensor<4x128xf32> + %19 = stablehlo.add %16, %18 : tensor<4x128xf32> + %20 = func.call @relu_2(%19) : (tensor<4x128xf32>) -> tensor<4x128xf32> + %21 = stablehlo.dot_general %20, %arg22, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<4x128xf32>, tensor<128x128xf32>) -> tensor<4x128xf32> + %22 = stablehlo.broadcast_in_dim %arg23, dims = [1] : (tensor<128xf32>) -> tensor<1x128xf32> + %23 = stablehlo.broadcast_in_dim %22, dims = [0, 1] : (tensor<1x128xf32>) -> tensor<4x128xf32> + %24 = stablehlo.add %21, %23 : tensor<4x128xf32> + %25 = func.call @relu_3(%24) : (tensor<4x128xf32>) -> tensor<4x128xf32> + %26 = stablehlo.dot_general %25, %arg24, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<4x128xf32>, tensor<128x8xf32>) -> tensor<4x8xf32> + %27 = stablehlo.broadcast_in_dim %arg25, dims = [1] : (tensor<8xf32>) -> tensor<1x8xf32> + %28 = stablehlo.broadcast_in_dim %27, dims = [0, 1] : (tensor<1x8xf32>) -> tensor<4x8xf32> + %29 = stablehlo.add %26, %28 : tensor<4x8xf32> + %30 = stablehlo.subtract %29, %arg27 : tensor<4x8xf32> + %31 = stablehlo.multiply %30, %30 : tensor<4x8xf32> + %cst = stablehlo.constant dense<0.000000e+00> : tensor + %32 = stablehlo.reduce(%31 init: %cst) applies stablehlo.add across dimensions = [1] : (tensor<4x8xf32>, tensor) -> tensor<4xf32> + %cst_0 = stablehlo.constant dense<0.000000e+00> : tensor + %33 = stablehlo.reduce(%32 init: %cst_0) applies stablehlo.add across dimensions = [0] : (tensor<4xf32>, tensor) -> tensor + %cst_1 = stablehlo.constant dense<4.000000e+00> : tensor + %34 = stablehlo.divide %33, %cst_1 : tensor + %35 = "stablehlo.all_reduce"(%34) <{channel_handle = #stablehlo.channel_handle, replica_groups = dense<[[0, 1, 2, 3, 4, 5, 6, 7]]> : tensor<1x8xi64>, use_global_device_ids}> ({ + ^bb0(%arg28: tensor, %arg29: tensor): + %37 = stablehlo.add %arg28, %arg29 : tensor + stablehlo.return %37 : tensor + }) : (tensor) -> tensor + %cst_2 = stablehlo.constant dense<8.000000e+00> : tensor + %36 = stablehlo.divide %35, %cst_2 : tensor + sdy.return %36 : tensor + } : (tensor<784x128xf32>, tensor<128xf32>, tensor<128x128xf32>, tensor<128xf32>, tensor<128x128xf32>, tensor<128xf32>, tensor<128x128xf32>, tensor<128xf32>, tensor<128x128xf32>, tensor<128xf32>, tensor<128x8xf32>, tensor<8xf32>, tensor<32x784xf32>, tensor<32x8xf32>) -> tensor + return %0 : tensor + } + func.func private @relu(%arg0: tensor<4x128xf32>) -> tensor<4x128xf32> { + %cst = stablehlo.constant dense<0.000000e+00> : tensor + %0 = stablehlo.broadcast_in_dim %cst, dims = [] : (tensor) -> tensor<4x128xf32> + %1 = stablehlo.maximum %arg0, %0 : tensor<4x128xf32> + return %1 : tensor<4x128xf32> + } + func.func private @relu_0(%arg0: tensor<4x128xf32>) -> tensor<4x128xf32> { + %cst = stablehlo.constant dense<0.000000e+00> : tensor + %0 = stablehlo.broadcast_in_dim %cst, dims = [] : (tensor) -> tensor<4x128xf32> + %1 = stablehlo.maximum %arg0, %0 : tensor<4x128xf32> + return %1 : tensor<4x128xf32> + } + func.func private @relu_1(%arg0: tensor<4x128xf32>) -> tensor<4x128xf32> { + %cst = stablehlo.constant dense<0.000000e+00> : tensor + %0 = stablehlo.broadcast_in_dim %cst, dims = [] : (tensor) -> tensor<4x128xf32> + %1 = stablehlo.maximum %arg0, %0 : tensor<4x128xf32> + return %1 : tensor<4x128xf32> + } + func.func private @relu_2(%arg0: tensor<4x128xf32>) -> tensor<4x128xf32> { + %cst = stablehlo.constant dense<0.000000e+00> : tensor + %0 = stablehlo.broadcast_in_dim %cst, dims = [] : (tensor) -> tensor<4x128xf32> + %1 = stablehlo.maximum %arg0, %0 : tensor<4x128xf32> + return %1 : tensor<4x128xf32> + } + func.func private @relu_3(%arg0: tensor<4x128xf32>) -> tensor<4x128xf32> { + %cst = stablehlo.constant dense<0.000000e+00> : tensor + %0 = stablehlo.broadcast_in_dim %cst, dims = [] : (tensor) -> tensor<4x128xf32> + %1 = stablehlo.maximum %arg0, %0 : tensor<4x128xf32> + return %1 : tensor<4x128xf32> + } +} + +// CHECK-LABEL @main +// CHECK: "ttir.mesh_shard" +// CHECK-SAME: shard_dims = array +// CHECK-SAME: shard_direction = #tt.shard_direction +// CHECK-SAME: shard_shape = array +// CHECK-SAME: shard_type = #tt.shard_type +// CHECK: "ttir.mesh_shard" +// CHECK-SAME: shard_dims = array +// CHECK-SAME: shard_direction = #tt.shard_direction +// CHECK-SAME: shard_shape = array +// CHECK-SAME: shard_type = #tt.shard_type +// CHECK: "ttir.mesh_shard" +// CHECK-SAME: shard_dims = array +// CHECK-SAME: shard_direction = #tt.shard_direction +// CHECK-SAME: shard_shape = array +// CHECK-SAME: shard_type = #tt.shard_type +// CHECK: "ttir.mesh_shard" +// CHECK-SAME: shard_dims = array +// CHECK-SAME: shard_direction = #tt.shard_direction +// CHECK-SAME: shard_shape = array +// CHECK-SAME: shard_type = #tt.shard_type +// CHECK: "ttir.mesh_shard" +// CHECK-SAME: shard_dims = array +// CHECK-SAME: shard_direction = #tt.shard_direction +// CHECK-SAME: shard_shape = array +// CHECK-SAME: shard_type = #tt.shard_type +// CHECK: "ttir.mesh_shard" +// CHECK-SAME: shard_dims = array +// CHECK-SAME: shard_direction = #tt.shard_direction +// CHECK-SAME: shard_shape = array +// CHECK-SAME: shard_type = #tt.shard_type +// CHECK: "ttir.mesh_shard" +// CHECK-SAME: shard_dims = array +// CHECK-SAME: shard_direction = #tt.shard_direction +// CHECK-SAME: shard_shape = array +// CHECK-SAME: shard_type = #tt.shard_type +// CHECK: "ttir.mesh_shard" +// CHECK-SAME: shard_dims = array +// CHECK-SAME: shard_direction = #tt.shard_direction +// CHECK-SAME: shard_shape = array +// CHECK-SAME: shard_type = #tt.shard_type +// CHECK: "ttir.mesh_shard" +// CHECK-SAME: shard_dims = array +// CHECK-SAME: shard_direction = #tt.shard_direction +// CHECK-SAME: shard_shape = array +// CHECK-SAME: shard_type = #tt.shard_type +// CHECK: "ttir.mesh_shard" +// CHECK-SAME: shard_dims = array +// CHECK-SAME: shard_direction = #tt.shard_direction +// CHECK-SAME: shard_shape = array +// CHECK-SAME: shard_type = #tt.shard_type +// CHECK: "ttir.mesh_shard" +// CHECK-SAME: shard_dims = array +// CHECK-SAME: shard_direction = #tt.shard_direction +// CHECK-SAME: shard_shape = array +// CHECK-SAME: shard_type = #tt.shard_type +// CHECK: "ttir.mesh_shard" +// CHECK-SAME: shard_dims = array +// CHECK-SAME: shard_direction = #tt.shard_direction +// CHECK-SAME: shard_shape = array +// CHECK-SAME: shard_type = #tt.shard_type +// CHECK: "ttir.mesh_shard" +// CHECK-SAME: shard_dims = array +// CHECK-SAME: shard_direction = #tt.shard_direction +// CHECK-SAME: shard_shape = array +// CHECK-SAME: shard_type = #tt.shard_type +// CHECK: "ttir.mesh_shard" +// CHECK-SAME: shard_dims = array +// CHECK-SAME: shard_direction = #tt.shard_direction +// CHECK-SAME: shard_shape = array +// CHECK-SAME: shard_type = #tt.shard_type +// CHECK: "ttir.mesh_shard" +// CHECK-SAME: shard_dims = array +// CHECK-SAME: shard_direction = #tt.shard_direction +// CHECK-SAME: shard_shape = array +// CHECK-SAME: shard_type = #tt.shard_type diff --git a/test/ttmlir/Conversion/StableHLOToTTIR/ccl/e2e_fsdp_gspmd.mlir b/test/ttmlir/Conversion/StableHLOToTTIR/ccl/e2e_fsdp_gspmd.mlir new file mode 100644 index 0000000000..94071c41b6 --- /dev/null +++ b/test/ttmlir/Conversion/StableHLOToTTIR/ccl/e2e_fsdp_gspmd.mlir @@ -0,0 +1,205 @@ +// REQUIRES: stablehlo +// RUN: ttmlir-opt -split-input-file --stablehlo-to-ttir-pipeline %s | FileCheck %s + +module @jit_loss_fsdp attributes {mhlo.num_partitions = 8 : i32, mhlo.num_replicas = 1 : i32} { + func.func public @main(%arg0: tensor<784x128xf32> {mhlo.sharding = "{devices=[8,1]<=[8]}"}, %arg1: tensor<128xf32> {mhlo.sharding = "{devices=[8]<=[8]}"}, %arg2: tensor<128x128xf32> {mhlo.sharding = "{devices=[8,1]<=[8]}"}, %arg3: tensor<128xf32> {mhlo.sharding = "{devices=[8]<=[8]}"}, %arg4: tensor<128x128xf32> {mhlo.sharding = "{devices=[8,1]<=[8]}"}, %arg5: tensor<128xf32> {mhlo.sharding = "{devices=[8]<=[8]}"}, %arg6: tensor<128x128xf32> {mhlo.sharding = "{devices=[8,1]<=[8]}"}, %arg7: tensor<128xf32> {mhlo.sharding = "{devices=[8]<=[8]}"}, %arg8: tensor<128x128xf32> {mhlo.sharding = "{devices=[8,1]<=[8]}"}, %arg9: tensor<128xf32> {mhlo.sharding = "{devices=[8]<=[8]}"}, %arg10: tensor<128x8xf32> {mhlo.sharding = "{devices=[8,1]<=[8]}"}, %arg11: tensor<8xf32> {mhlo.sharding = "{devices=[8]<=[8]}"}, %arg12: tensor<32x784xf32> {mhlo.sharding = "{devices=[8,1]<=[8]}"}, %arg13: tensor<32x8xf32> {mhlo.sharding = "{devices=[8,1]<=[8]}"}) -> (tensor {jax.result_info = ""}) { + %0 = stablehlo.custom_call @Sharding(%arg0) {backend_config = "", mhlo.sharding = "{devices=[8,1]<=[8]}"} : (tensor<784x128xf32>) -> tensor<784x128xf32> + %1 = stablehlo.custom_call @SPMDFullToShardShape(%0) {backend_config = "", mhlo.sharding = "{manual}"} : (tensor<784x128xf32>) -> tensor<98x128xf32> + %2 = stablehlo.custom_call @Sharding(%arg1) {backend_config = "", mhlo.sharding = "{devices=[8]<=[8]}"} : (tensor<128xf32>) -> tensor<128xf32> + %3 = stablehlo.custom_call @SPMDFullToShardShape(%2) {backend_config = "", mhlo.sharding = "{manual}"} : (tensor<128xf32>) -> tensor<16xf32> + %4 = stablehlo.custom_call @Sharding(%arg2) {backend_config = "", mhlo.sharding = "{devices=[8,1]<=[8]}"} : (tensor<128x128xf32>) -> tensor<128x128xf32> + %5 = stablehlo.custom_call @SPMDFullToShardShape(%4) {backend_config = "", mhlo.sharding = "{manual}"} : (tensor<128x128xf32>) -> tensor<16x128xf32> + %6 = stablehlo.custom_call @Sharding(%arg3) {backend_config = "", mhlo.sharding = "{devices=[8]<=[8]}"} : (tensor<128xf32>) -> tensor<128xf32> + %7 = stablehlo.custom_call @SPMDFullToShardShape(%6) {backend_config = "", mhlo.sharding = "{manual}"} : (tensor<128xf32>) -> tensor<16xf32> + %8 = stablehlo.custom_call @Sharding(%arg4) {backend_config = "", mhlo.sharding = "{devices=[8,1]<=[8]}"} : (tensor<128x128xf32>) -> tensor<128x128xf32> + %9 = stablehlo.custom_call @SPMDFullToShardShape(%8) {backend_config = "", mhlo.sharding = "{manual}"} : (tensor<128x128xf32>) -> tensor<16x128xf32> + %10 = stablehlo.custom_call @Sharding(%arg5) {backend_config = "", mhlo.sharding = "{devices=[8]<=[8]}"} : (tensor<128xf32>) -> tensor<128xf32> + %11 = stablehlo.custom_call @SPMDFullToShardShape(%10) {backend_config = "", mhlo.sharding = "{manual}"} : (tensor<128xf32>) -> tensor<16xf32> + %12 = stablehlo.custom_call @Sharding(%arg6) {backend_config = "", mhlo.sharding = "{devices=[8,1]<=[8]}"} : (tensor<128x128xf32>) -> tensor<128x128xf32> + %13 = stablehlo.custom_call @SPMDFullToShardShape(%12) {backend_config = "", mhlo.sharding = "{manual}"} : (tensor<128x128xf32>) -> tensor<16x128xf32> + %14 = stablehlo.custom_call @Sharding(%arg7) {backend_config = "", mhlo.sharding = "{devices=[8]<=[8]}"} : (tensor<128xf32>) -> tensor<128xf32> + %15 = stablehlo.custom_call @SPMDFullToShardShape(%14) {backend_config = "", mhlo.sharding = "{manual}"} : (tensor<128xf32>) -> tensor<16xf32> + %16 = stablehlo.custom_call @Sharding(%arg8) {backend_config = "", mhlo.sharding = "{devices=[8,1]<=[8]}"} : (tensor<128x128xf32>) -> tensor<128x128xf32> + %17 = stablehlo.custom_call @SPMDFullToShardShape(%16) {backend_config = "", mhlo.sharding = "{manual}"} : (tensor<128x128xf32>) -> tensor<16x128xf32> + %18 = stablehlo.custom_call @Sharding(%arg9) {backend_config = "", mhlo.sharding = "{devices=[8]<=[8]}"} : (tensor<128xf32>) -> tensor<128xf32> + %19 = stablehlo.custom_call @SPMDFullToShardShape(%18) {backend_config = "", mhlo.sharding = "{manual}"} : (tensor<128xf32>) -> tensor<16xf32> + %20 = stablehlo.custom_call @Sharding(%arg10) {backend_config = "", mhlo.sharding = "{devices=[8,1]<=[8]}"} : (tensor<128x8xf32>) -> tensor<128x8xf32> + %21 = stablehlo.custom_call @SPMDFullToShardShape(%20) {backend_config = "", mhlo.sharding = "{manual}"} : (tensor<128x8xf32>) -> tensor<16x8xf32> + %22 = stablehlo.custom_call @Sharding(%arg11) {backend_config = "", mhlo.sharding = "{devices=[8]<=[8]}"} : (tensor<8xf32>) -> tensor<8xf32> + %23 = stablehlo.custom_call @SPMDFullToShardShape(%22) {backend_config = "", mhlo.sharding = "{manual}"} : (tensor<8xf32>) -> tensor<1xf32> + %24 = stablehlo.custom_call @Sharding(%arg12) {backend_config = "", mhlo.sharding = "{devices=[8,1]<=[8]}"} : (tensor<32x784xf32>) -> tensor<32x784xf32> + %25 = stablehlo.custom_call @SPMDFullToShardShape(%24) {backend_config = "", mhlo.sharding = "{manual}"} : (tensor<32x784xf32>) -> tensor<4x784xf32> + %26 = stablehlo.custom_call @Sharding(%arg13) {backend_config = "", mhlo.sharding = "{devices=[8,1]<=[8]}"} : (tensor<32x8xf32>) -> tensor<32x8xf32> + %27 = stablehlo.custom_call @SPMDFullToShardShape(%26) {backend_config = "", mhlo.sharding = "{manual}"} : (tensor<32x8xf32>) -> tensor<4x8xf32> + %28 = call @shmap_body(%1, %3, %5, %7, %9, %11, %13, %15, %17, %19, %21, %23, %25, %27) : (tensor<98x128xf32>, tensor<16xf32>, tensor<16x128xf32>, tensor<16xf32>, tensor<16x128xf32>, tensor<16xf32>, tensor<16x128xf32>, tensor<16xf32>, tensor<16x128xf32>, tensor<16xf32>, tensor<16x8xf32>, tensor<1xf32>, tensor<4x784xf32>, tensor<4x8xf32>) -> tensor + %29 = stablehlo.custom_call @Sharding(%28) {backend_config = "", mhlo.sharding = "{manual}"} : (tensor) -> tensor + %30 = stablehlo.custom_call @SPMDShardToFullShape(%29) {backend_config = "", mhlo.sharding = "{replicated}"} : (tensor) -> tensor + return %30 : tensor + } + func.func private @shmap_body(%arg0: tensor<98x128xf32>, %arg1: tensor<16xf32>, %arg2: tensor<16x128xf32>, %arg3: tensor<16xf32>, %arg4: tensor<16x128xf32>, %arg5: tensor<16xf32>, %arg6: tensor<16x128xf32>, %arg7: tensor<16xf32>, %arg8: tensor<16x128xf32>, %arg9: tensor<16xf32>, %arg10: tensor<16x8xf32>, %arg11: tensor<1xf32>, %arg12: tensor<4x784xf32>, %arg13: tensor<4x8xf32>) -> (tensor {jax.result_info = "[]"}) { + %0 = "stablehlo.all_gather"(%arg0) <{all_gather_dim = 0 : i64, channel_handle = #stablehlo.channel_handle, replica_groups = dense<[[0, 1, 2, 3, 4, 5, 6, 7]]> : tensor<1x8xi64>, use_global_device_ids}> : (tensor<98x128xf32>) -> tensor<784x128xf32> + %1 = "stablehlo.all_gather"(%arg1) <{all_gather_dim = 0 : i64, channel_handle = #stablehlo.channel_handle, replica_groups = dense<[[0, 1, 2, 3, 4, 5, 6, 7]]> : tensor<1x8xi64>, use_global_device_ids}> : (tensor<16xf32>) -> tensor<128xf32> + %2 = stablehlo.dot_general %arg12, %0, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<4x784xf32>, tensor<784x128xf32>) -> tensor<4x128xf32> + %3 = stablehlo.broadcast_in_dim %1, dims = [1] : (tensor<128xf32>) -> tensor<1x128xf32> + %4 = stablehlo.broadcast_in_dim %3, dims = [0, 1] : (tensor<1x128xf32>) -> tensor<4x128xf32> + %5 = stablehlo.add %2, %4 : tensor<4x128xf32> + %6 = call @relu(%5) : (tensor<4x128xf32>) -> tensor<4x128xf32> + %7 = "stablehlo.all_gather"(%arg2) <{all_gather_dim = 0 : i64, channel_handle = #stablehlo.channel_handle, replica_groups = dense<[[0, 1, 2, 3, 4, 5, 6, 7]]> : tensor<1x8xi64>, use_global_device_ids}> : (tensor<16x128xf32>) -> tensor<128x128xf32> + %8 = "stablehlo.all_gather"(%arg3) <{all_gather_dim = 0 : i64, channel_handle = #stablehlo.channel_handle, replica_groups = dense<[[0, 1, 2, 3, 4, 5, 6, 7]]> : tensor<1x8xi64>, use_global_device_ids}> : (tensor<16xf32>) -> tensor<128xf32> + %9 = stablehlo.dot_general %6, %7, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<4x128xf32>, tensor<128x128xf32>) -> tensor<4x128xf32> + %10 = stablehlo.broadcast_in_dim %8, dims = [1] : (tensor<128xf32>) -> tensor<1x128xf32> + %11 = stablehlo.broadcast_in_dim %10, dims = [0, 1] : (tensor<1x128xf32>) -> tensor<4x128xf32> + %12 = stablehlo.add %9, %11 : tensor<4x128xf32> + %13 = call @relu_0(%12) : (tensor<4x128xf32>) -> tensor<4x128xf32> + %14 = "stablehlo.all_gather"(%arg4) <{all_gather_dim = 0 : i64, channel_handle = #stablehlo.channel_handle, replica_groups = dense<[[0, 1, 2, 3, 4, 5, 6, 7]]> : tensor<1x8xi64>, use_global_device_ids}> : (tensor<16x128xf32>) -> tensor<128x128xf32> + %15 = "stablehlo.all_gather"(%arg5) <{all_gather_dim = 0 : i64, channel_handle = #stablehlo.channel_handle, replica_groups = dense<[[0, 1, 2, 3, 4, 5, 6, 7]]> : tensor<1x8xi64>, use_global_device_ids}> : (tensor<16xf32>) -> tensor<128xf32> + %16 = stablehlo.dot_general %13, %14, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<4x128xf32>, tensor<128x128xf32>) -> tensor<4x128xf32> + %17 = stablehlo.broadcast_in_dim %15, dims = [1] : (tensor<128xf32>) -> tensor<1x128xf32> + %18 = stablehlo.broadcast_in_dim %17, dims = [0, 1] : (tensor<1x128xf32>) -> tensor<4x128xf32> + %19 = stablehlo.add %16, %18 : tensor<4x128xf32> + %20 = call @relu_1(%19) : (tensor<4x128xf32>) -> tensor<4x128xf32> + %21 = "stablehlo.all_gather"(%arg6) <{all_gather_dim = 0 : i64, channel_handle = #stablehlo.channel_handle, replica_groups = dense<[[0, 1, 2, 3, 4, 5, 6, 7]]> : tensor<1x8xi64>, use_global_device_ids}> : (tensor<16x128xf32>) -> tensor<128x128xf32> + %22 = "stablehlo.all_gather"(%arg7) <{all_gather_dim = 0 : i64, channel_handle = #stablehlo.channel_handle, replica_groups = dense<[[0, 1, 2, 3, 4, 5, 6, 7]]> : tensor<1x8xi64>, use_global_device_ids}> : (tensor<16xf32>) -> tensor<128xf32> + %23 = stablehlo.dot_general %20, %21, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<4x128xf32>, tensor<128x128xf32>) -> tensor<4x128xf32> + %24 = stablehlo.broadcast_in_dim %22, dims = [1] : (tensor<128xf32>) -> tensor<1x128xf32> + %25 = stablehlo.broadcast_in_dim %24, dims = [0, 1] : (tensor<1x128xf32>) -> tensor<4x128xf32> + %26 = stablehlo.add %23, %25 : tensor<4x128xf32> + %27 = call @relu_2(%26) : (tensor<4x128xf32>) -> tensor<4x128xf32> + %28 = "stablehlo.all_gather"(%arg8) <{all_gather_dim = 0 : i64, channel_handle = #stablehlo.channel_handle, replica_groups = dense<[[0, 1, 2, 3, 4, 5, 6, 7]]> : tensor<1x8xi64>, use_global_device_ids}> : (tensor<16x128xf32>) -> tensor<128x128xf32> + %29 = "stablehlo.all_gather"(%arg9) <{all_gather_dim = 0 : i64, channel_handle = #stablehlo.channel_handle, replica_groups = dense<[[0, 1, 2, 3, 4, 5, 6, 7]]> : tensor<1x8xi64>, use_global_device_ids}> : (tensor<16xf32>) -> tensor<128xf32> + %30 = stablehlo.dot_general %27, %28, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<4x128xf32>, tensor<128x128xf32>) -> tensor<4x128xf32> + %31 = stablehlo.broadcast_in_dim %29, dims = [1] : (tensor<128xf32>) -> tensor<1x128xf32> + %32 = stablehlo.broadcast_in_dim %31, dims = [0, 1] : (tensor<1x128xf32>) -> tensor<4x128xf32> + %33 = stablehlo.add %30, %32 : tensor<4x128xf32> + %34 = call @relu_3(%33) : (tensor<4x128xf32>) -> tensor<4x128xf32> + %35 = "stablehlo.all_gather"(%arg10) <{all_gather_dim = 0 : i64, channel_handle = #stablehlo.channel_handle, replica_groups = dense<[[0, 1, 2, 3, 4, 5, 6, 7]]> : tensor<1x8xi64>, use_global_device_ids}> : (tensor<16x8xf32>) -> tensor<128x8xf32> + %36 = "stablehlo.all_gather"(%arg11) <{all_gather_dim = 0 : i64, channel_handle = #stablehlo.channel_handle, replica_groups = dense<[[0, 1, 2, 3, 4, 5, 6, 7]]> : tensor<1x8xi64>, use_global_device_ids}> : (tensor<1xf32>) -> tensor<8xf32> + %37 = stablehlo.dot_general %34, %35, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<4x128xf32>, tensor<128x8xf32>) -> tensor<4x8xf32> + %38 = stablehlo.broadcast_in_dim %36, dims = [1] : (tensor<8xf32>) -> tensor<1x8xf32> + %39 = stablehlo.broadcast_in_dim %38, dims = [0, 1] : (tensor<1x8xf32>) -> tensor<4x8xf32> + %40 = stablehlo.add %37, %39 : tensor<4x8xf32> + %41 = stablehlo.subtract %40, %arg13 : tensor<4x8xf32> + %42 = stablehlo.multiply %41, %41 : tensor<4x8xf32> + %cst = stablehlo.constant dense<0.000000e+00> : tensor + %43 = stablehlo.reduce(%42 init: %cst) applies stablehlo.add across dimensions = [1] : (tensor<4x8xf32>, tensor) -> tensor<4xf32> + %cst_0 = stablehlo.constant dense<0.000000e+00> : tensor + %44 = stablehlo.reduce(%43 init: %cst_0) applies stablehlo.add across dimensions = [0] : (tensor<4xf32>, tensor) -> tensor + %cst_1 = stablehlo.constant dense<4.000000e+00> : tensor + %45 = stablehlo.divide %44, %cst_1 : tensor + %46 = "stablehlo.all_reduce"(%45) <{channel_handle = #stablehlo.channel_handle, replica_groups = dense<[[0, 1, 2, 3, 4, 5, 6, 7]]> : tensor<1x8xi64>, use_global_device_ids}> ({ + ^bb0(%arg14: tensor, %arg15: tensor): + %48 = stablehlo.add %arg14, %arg15 : tensor + stablehlo.return %48 : tensor + }) : (tensor) -> tensor + %cst_2 = stablehlo.constant dense<8.000000e+00> : tensor + %47 = stablehlo.divide %46, %cst_2 : tensor + return %47 : tensor + } + func.func private @relu(%arg0: tensor<4x128xf32>) -> tensor<4x128xf32> { + %cst = stablehlo.constant dense<0.000000e+00> : tensor + %0 = stablehlo.broadcast_in_dim %cst, dims = [] : (tensor) -> tensor<4x128xf32> + %1 = stablehlo.maximum %arg0, %0 : tensor<4x128xf32> + return %1 : tensor<4x128xf32> + } + func.func private @relu_0(%arg0: tensor<4x128xf32>) -> tensor<4x128xf32> { + %cst = stablehlo.constant dense<0.000000e+00> : tensor + %0 = stablehlo.broadcast_in_dim %cst, dims = [] : (tensor) -> tensor<4x128xf32> + %1 = stablehlo.maximum %arg0, %0 : tensor<4x128xf32> + return %1 : tensor<4x128xf32> + } + func.func private @relu_1(%arg0: tensor<4x128xf32>) -> tensor<4x128xf32> { + %cst = stablehlo.constant dense<0.000000e+00> : tensor + %0 = stablehlo.broadcast_in_dim %cst, dims = [] : (tensor) -> tensor<4x128xf32> + %1 = stablehlo.maximum %arg0, %0 : tensor<4x128xf32> + return %1 : tensor<4x128xf32> + } + func.func private @relu_2(%arg0: tensor<4x128xf32>) -> tensor<4x128xf32> { + %cst = stablehlo.constant dense<0.000000e+00> : tensor + %0 = stablehlo.broadcast_in_dim %cst, dims = [] : (tensor) -> tensor<4x128xf32> + %1 = stablehlo.maximum %arg0, %0 : tensor<4x128xf32> + return %1 : tensor<4x128xf32> + } + func.func private @relu_3(%arg0: tensor<4x128xf32>) -> tensor<4x128xf32> { + %cst = stablehlo.constant dense<0.000000e+00> : tensor + %0 = stablehlo.broadcast_in_dim %cst, dims = [] : (tensor) -> tensor<4x128xf32> + %1 = stablehlo.maximum %arg0, %0 : tensor<4x128xf32> + return %1 : tensor<4x128xf32> + } +} + +// CHECK-LABEL @main +// CHECK: "ttir.mesh_shard" +// CHECK-SAME: shard_dims = array +// CHECK-SAME: shard_direction = #tt.shard_direction +// CHECK-SAME: shard_shape = array +// CHECK-SAME: shard_type = #tt.shard_type +// CHECK: "ttir.mesh_shard" +// CHECK-SAME: shard_dims = array +// CHECK-SAME: shard_direction = #tt.shard_direction +// CHECK-SAME: shard_shape = array +// CHECK-SAME: shard_type = #tt.shard_type +// CHECK: "ttir.mesh_shard" +// CHECK-SAME: shard_dims = array +// CHECK-SAME: shard_direction = #tt.shard_direction +// CHECK-SAME: shard_shape = array +// CHECK-SAME: shard_type = #tt.shard_type +// CHECK: "ttir.mesh_shard" +// CHECK-SAME: shard_dims = array +// CHECK-SAME: shard_direction = #tt.shard_direction +// CHECK-SAME: shard_shape = array +// CHECK-SAME: shard_type = #tt.shard_type +// CHECK: "ttir.mesh_shard" +// CHECK-SAME: shard_dims = array +// CHECK-SAME: shard_direction = #tt.shard_direction +// CHECK-SAME: shard_shape = array +// CHECK-SAME: shard_type = #tt.shard_type +// CHECK: "ttir.mesh_shard" +// CHECK-SAME: shard_dims = array +// CHECK-SAME: shard_direction = #tt.shard_direction +// CHECK-SAME: shard_shape = array +// CHECK-SAME: shard_type = #tt.shard_type +// CHECK: "ttir.mesh_shard" +// CHECK-SAME: shard_dims = array +// CHECK-SAME: shard_direction = #tt.shard_direction +// CHECK-SAME: shard_shape = array +// CHECK-SAME: shard_type = #tt.shard_type +// CHECK: "ttir.mesh_shard" +// CHECK-SAME: shard_dims = array +// CHECK-SAME: shard_direction = #tt.shard_direction +// CHECK-SAME: shard_shape = array +// CHECK-SAME: shard_type = #tt.shard_type +// CHECK: "ttir.mesh_shard" +// CHECK-SAME: shard_dims = array +// CHECK-SAME: shard_direction = #tt.shard_direction +// CHECK-SAME: shard_shape = array +// CHECK-SAME: shard_type = #tt.shard_type +// CHECK: "ttir.mesh_shard" +// CHECK-SAME: shard_dims = array +// CHECK-SAME: shard_direction = #tt.shard_direction +// CHECK-SAME: shard_shape = array +// CHECK-SAME: shard_type = #tt.shard_type +// CHECK: "ttir.mesh_shard" +// CHECK-SAME: shard_dims = array +// CHECK-SAME: shard_direction = #tt.shard_direction +// CHECK-SAME: shard_shape = array +// CHECK-SAME: shard_type = #tt.shard_type +// CHECK: "ttir.mesh_shard" +// CHECK-SAME: shard_dims = array +// CHECK-SAME: shard_direction = #tt.shard_direction +// CHECK-SAME: shard_shape = array +// CHECK-SAME: shard_type = #tt.shard_type +// CHECK: "ttir.mesh_shard" +// CHECK-SAME: shard_dims = array +// CHECK-SAME: shard_direction = #tt.shard_direction +// CHECK-SAME: shard_shape = array +// CHECK-SAME: shard_type = #tt.shard_type +// CHECK: "ttir.mesh_shard" +// CHECK-SAME: shard_dims = array +// CHECK-SAME: shard_direction = #tt.shard_direction +// CHECK-SAME: shard_shape = array +// CHECK-SAME: shard_type = #tt.shard_type +// CHECK: "ttir.mesh_shard" +// CHECK-SAME: shard_dims = array +// CHECK-SAME: shard_direction = #tt.shard_direction +// CHECK-SAME: shard_shape = array +// CHECK-SAME: shard_type = #tt.shard_type diff --git a/test/ttmlir/Conversion/StableHLOToTTIR/ccl/e2e_fsdp_shardy.mlir b/test/ttmlir/Conversion/StableHLOToTTIR/ccl/e2e_fsdp_shardy.mlir new file mode 100644 index 0000000000..8bbe2417ee --- /dev/null +++ b/test/ttmlir/Conversion/StableHLOToTTIR/ccl/e2e_fsdp_shardy.mlir @@ -0,0 +1,175 @@ +// REQUIRES: stablehlo +// RUN: ttmlir-opt -split-input-file --stablehlo-to-ttir-pipeline %s | FileCheck %s + +module @jit_loss_fsdp attributes {mhlo.num_partitions = 8 : i32, mhlo.num_replicas = 1 : i32} { + sdy.mesh @mesh = <["x"=1, "y"=8]> + func.func public @main(%arg0: tensor<784x128xf32> {sdy.sharding = #sdy.sharding<@mesh, [{"y"}, {}]>}, %arg1: tensor<128xf32> {sdy.sharding = #sdy.sharding<@mesh, [{"y"}]>}, %arg2: tensor<128x128xf32> {sdy.sharding = #sdy.sharding<@mesh, [{"y"}, {}]>}, %arg3: tensor<128xf32> {sdy.sharding = #sdy.sharding<@mesh, [{"y"}]>}, %arg4: tensor<128x128xf32> {sdy.sharding = #sdy.sharding<@mesh, [{"y"}, {}]>}, %arg5: tensor<128xf32> {sdy.sharding = #sdy.sharding<@mesh, [{"y"}]>}, %arg6: tensor<128x128xf32> {sdy.sharding = #sdy.sharding<@mesh, [{"y"}, {}]>}, %arg7: tensor<128xf32> {sdy.sharding = #sdy.sharding<@mesh, [{"y"}]>}, %arg8: tensor<128x128xf32> {sdy.sharding = #sdy.sharding<@mesh, [{"y"}, {}]>}, %arg9: tensor<128xf32> {sdy.sharding = #sdy.sharding<@mesh, [{"y"}]>}, %arg10: tensor<128x8xf32> {sdy.sharding = #sdy.sharding<@mesh, [{"y"}, {}]>}, %arg11: tensor<8xf32> {sdy.sharding = #sdy.sharding<@mesh, [{"y"}]>}, %arg12: tensor<32x784xf32> {sdy.sharding = #sdy.sharding<@mesh, [{"y"}, {}]>}, %arg13: tensor<32x8xf32> {sdy.sharding = #sdy.sharding<@mesh, [{"y"}, {}]>}) -> (tensor {jax.result_info = ""}) { + %0 = sdy.manual_computation(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8, %arg9, %arg10, %arg11, %arg12, %arg13) in_shardings=[<@mesh, [{"y"}, {}]>, <@mesh, [{"y"}]>, <@mesh, [{"y"}, {}]>, <@mesh, [{"y"}]>, <@mesh, [{"y"}, {}]>, <@mesh, [{"y"}]>, <@mesh, [{"y"}, {}]>, <@mesh, [{"y"}]>, <@mesh, [{"y"}, {}]>, <@mesh, [{"y"}]>, <@mesh, [{"y"}, {}]>, <@mesh, [{"y"}]>, <@mesh, [{"y"}, {}]>, <@mesh, [{"y"}, {}]>] out_shardings=[<@mesh, []>] manual_axes={"x", "y"} (%arg14: tensor<98x128xf32>, %arg15: tensor<16xf32>, %arg16: tensor<16x128xf32>, %arg17: tensor<16xf32>, %arg18: tensor<16x128xf32>, %arg19: tensor<16xf32>, %arg20: tensor<16x128xf32>, %arg21: tensor<16xf32>, %arg22: tensor<16x128xf32>, %arg23: tensor<16xf32>, %arg24: tensor<16x8xf32>, %arg25: tensor<1xf32>, %arg26: tensor<4x784xf32>, %arg27: tensor<4x8xf32>) { + %1 = "stablehlo.all_gather"(%arg14) <{all_gather_dim = 0 : i64, channel_handle = #stablehlo.channel_handle, replica_groups = dense<[[0, 1, 2, 3, 4, 5, 6, 7]]> : tensor<1x8xi64>, use_global_device_ids}> : (tensor<98x128xf32>) -> tensor<784x128xf32> + %2 = "stablehlo.all_gather"(%arg15) <{all_gather_dim = 0 : i64, channel_handle = #stablehlo.channel_handle, replica_groups = dense<[[0, 1, 2, 3, 4, 5, 6, 7]]> : tensor<1x8xi64>, use_global_device_ids}> : (tensor<16xf32>) -> tensor<128xf32> + %3 = stablehlo.dot_general %arg26, %1, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<4x784xf32>, tensor<784x128xf32>) -> tensor<4x128xf32> + %4 = stablehlo.broadcast_in_dim %2, dims = [1] : (tensor<128xf32>) -> tensor<1x128xf32> + %5 = stablehlo.broadcast_in_dim %4, dims = [0, 1] : (tensor<1x128xf32>) -> tensor<4x128xf32> + %6 = stablehlo.add %3, %5 : tensor<4x128xf32> + %7 = func.call @relu(%6) : (tensor<4x128xf32>) -> tensor<4x128xf32> + %8 = "stablehlo.all_gather"(%arg16) <{all_gather_dim = 0 : i64, channel_handle = #stablehlo.channel_handle, replica_groups = dense<[[0, 1, 2, 3, 4, 5, 6, 7]]> : tensor<1x8xi64>, use_global_device_ids}> : (tensor<16x128xf32>) -> tensor<128x128xf32> + %9 = "stablehlo.all_gather"(%arg17) <{all_gather_dim = 0 : i64, channel_handle = #stablehlo.channel_handle, replica_groups = dense<[[0, 1, 2, 3, 4, 5, 6, 7]]> : tensor<1x8xi64>, use_global_device_ids}> : (tensor<16xf32>) -> tensor<128xf32> + %10 = stablehlo.dot_general %7, %8, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<4x128xf32>, tensor<128x128xf32>) -> tensor<4x128xf32> + %11 = stablehlo.broadcast_in_dim %9, dims = [1] : (tensor<128xf32>) -> tensor<1x128xf32> + %12 = stablehlo.broadcast_in_dim %11, dims = [0, 1] : (tensor<1x128xf32>) -> tensor<4x128xf32> + %13 = stablehlo.add %10, %12 : tensor<4x128xf32> + %14 = func.call @relu_0(%13) : (tensor<4x128xf32>) -> tensor<4x128xf32> + %15 = "stablehlo.all_gather"(%arg18) <{all_gather_dim = 0 : i64, channel_handle = #stablehlo.channel_handle, replica_groups = dense<[[0, 1, 2, 3, 4, 5, 6, 7]]> : tensor<1x8xi64>, use_global_device_ids}> : (tensor<16x128xf32>) -> tensor<128x128xf32> + %16 = "stablehlo.all_gather"(%arg19) <{all_gather_dim = 0 : i64, channel_handle = #stablehlo.channel_handle, replica_groups = dense<[[0, 1, 2, 3, 4, 5, 6, 7]]> : tensor<1x8xi64>, use_global_device_ids}> : (tensor<16xf32>) -> tensor<128xf32> + %17 = stablehlo.dot_general %14, %15, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<4x128xf32>, tensor<128x128xf32>) -> tensor<4x128xf32> + %18 = stablehlo.broadcast_in_dim %16, dims = [1] : (tensor<128xf32>) -> tensor<1x128xf32> + %19 = stablehlo.broadcast_in_dim %18, dims = [0, 1] : (tensor<1x128xf32>) -> tensor<4x128xf32> + %20 = stablehlo.add %17, %19 : tensor<4x128xf32> + %21 = func.call @relu_1(%20) : (tensor<4x128xf32>) -> tensor<4x128xf32> + %22 = "stablehlo.all_gather"(%arg20) <{all_gather_dim = 0 : i64, channel_handle = #stablehlo.channel_handle, replica_groups = dense<[[0, 1, 2, 3, 4, 5, 6, 7]]> : tensor<1x8xi64>, use_global_device_ids}> : (tensor<16x128xf32>) -> tensor<128x128xf32> + %23 = "stablehlo.all_gather"(%arg21) <{all_gather_dim = 0 : i64, channel_handle = #stablehlo.channel_handle, replica_groups = dense<[[0, 1, 2, 3, 4, 5, 6, 7]]> : tensor<1x8xi64>, use_global_device_ids}> : (tensor<16xf32>) -> tensor<128xf32> + %24 = stablehlo.dot_general %21, %22, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<4x128xf32>, tensor<128x128xf32>) -> tensor<4x128xf32> + %25 = stablehlo.broadcast_in_dim %23, dims = [1] : (tensor<128xf32>) -> tensor<1x128xf32> + %26 = stablehlo.broadcast_in_dim %25, dims = [0, 1] : (tensor<1x128xf32>) -> tensor<4x128xf32> + %27 = stablehlo.add %24, %26 : tensor<4x128xf32> + %28 = func.call @relu_2(%27) : (tensor<4x128xf32>) -> tensor<4x128xf32> + %29 = "stablehlo.all_gather"(%arg22) <{all_gather_dim = 0 : i64, channel_handle = #stablehlo.channel_handle, replica_groups = dense<[[0, 1, 2, 3, 4, 5, 6, 7]]> : tensor<1x8xi64>, use_global_device_ids}> : (tensor<16x128xf32>) -> tensor<128x128xf32> + %30 = "stablehlo.all_gather"(%arg23) <{all_gather_dim = 0 : i64, channel_handle = #stablehlo.channel_handle, replica_groups = dense<[[0, 1, 2, 3, 4, 5, 6, 7]]> : tensor<1x8xi64>, use_global_device_ids}> : (tensor<16xf32>) -> tensor<128xf32> + %31 = stablehlo.dot_general %28, %29, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<4x128xf32>, tensor<128x128xf32>) -> tensor<4x128xf32> + %32 = stablehlo.broadcast_in_dim %30, dims = [1] : (tensor<128xf32>) -> tensor<1x128xf32> + %33 = stablehlo.broadcast_in_dim %32, dims = [0, 1] : (tensor<1x128xf32>) -> tensor<4x128xf32> + %34 = stablehlo.add %31, %33 : tensor<4x128xf32> + %35 = func.call @relu_3(%34) : (tensor<4x128xf32>) -> tensor<4x128xf32> + %36 = "stablehlo.all_gather"(%arg24) <{all_gather_dim = 0 : i64, channel_handle = #stablehlo.channel_handle, replica_groups = dense<[[0, 1, 2, 3, 4, 5, 6, 7]]> : tensor<1x8xi64>, use_global_device_ids}> : (tensor<16x8xf32>) -> tensor<128x8xf32> + %37 = "stablehlo.all_gather"(%arg25) <{all_gather_dim = 0 : i64, channel_handle = #stablehlo.channel_handle, replica_groups = dense<[[0, 1, 2, 3, 4, 5, 6, 7]]> : tensor<1x8xi64>, use_global_device_ids}> : (tensor<1xf32>) -> tensor<8xf32> + %38 = stablehlo.dot_general %35, %36, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<4x128xf32>, tensor<128x8xf32>) -> tensor<4x8xf32> + %39 = stablehlo.broadcast_in_dim %37, dims = [1] : (tensor<8xf32>) -> tensor<1x8xf32> + %40 = stablehlo.broadcast_in_dim %39, dims = [0, 1] : (tensor<1x8xf32>) -> tensor<4x8xf32> + %41 = stablehlo.add %38, %40 : tensor<4x8xf32> + %42 = stablehlo.subtract %41, %arg27 : tensor<4x8xf32> + %43 = stablehlo.multiply %42, %42 : tensor<4x8xf32> + %cst = stablehlo.constant dense<0.000000e+00> : tensor + %44 = stablehlo.reduce(%43 init: %cst) applies stablehlo.add across dimensions = [1] : (tensor<4x8xf32>, tensor) -> tensor<4xf32> + %cst_0 = stablehlo.constant dense<0.000000e+00> : tensor + %45 = stablehlo.reduce(%44 init: %cst_0) applies stablehlo.add across dimensions = [0] : (tensor<4xf32>, tensor) -> tensor + %cst_1 = stablehlo.constant dense<4.000000e+00> : tensor + %46 = stablehlo.divide %45, %cst_1 : tensor + %47 = "stablehlo.all_reduce"(%46) <{channel_handle = #stablehlo.channel_handle, replica_groups = dense<[[0, 1, 2, 3, 4, 5, 6, 7]]> : tensor<1x8xi64>, use_global_device_ids}> ({ + ^bb0(%arg28: tensor, %arg29: tensor): + %49 = stablehlo.add %arg28, %arg29 : tensor + stablehlo.return %49 : tensor + }) : (tensor) -> tensor + %cst_2 = stablehlo.constant dense<8.000000e+00> : tensor + %48 = stablehlo.divide %47, %cst_2 : tensor + sdy.return %48 : tensor + } : (tensor<784x128xf32>, tensor<128xf32>, tensor<128x128xf32>, tensor<128xf32>, tensor<128x128xf32>, tensor<128xf32>, tensor<128x128xf32>, tensor<128xf32>, tensor<128x128xf32>, tensor<128xf32>, tensor<128x8xf32>, tensor<8xf32>, tensor<32x784xf32>, tensor<32x8xf32>) -> tensor + return %0 : tensor + } + func.func private @relu(%arg0: tensor<4x128xf32>) -> tensor<4x128xf32> { + %cst = stablehlo.constant dense<0.000000e+00> : tensor + %0 = stablehlo.broadcast_in_dim %cst, dims = [] : (tensor) -> tensor<4x128xf32> + %1 = stablehlo.maximum %arg0, %0 : tensor<4x128xf32> + return %1 : tensor<4x128xf32> + } + func.func private @relu_0(%arg0: tensor<4x128xf32>) -> tensor<4x128xf32> { + %cst = stablehlo.constant dense<0.000000e+00> : tensor + %0 = stablehlo.broadcast_in_dim %cst, dims = [] : (tensor) -> tensor<4x128xf32> + %1 = stablehlo.maximum %arg0, %0 : tensor<4x128xf32> + return %1 : tensor<4x128xf32> + } + func.func private @relu_1(%arg0: tensor<4x128xf32>) -> tensor<4x128xf32> { + %cst = stablehlo.constant dense<0.000000e+00> : tensor + %0 = stablehlo.broadcast_in_dim %cst, dims = [] : (tensor) -> tensor<4x128xf32> + %1 = stablehlo.maximum %arg0, %0 : tensor<4x128xf32> + return %1 : tensor<4x128xf32> + } + func.func private @relu_2(%arg0: tensor<4x128xf32>) -> tensor<4x128xf32> { + %cst = stablehlo.constant dense<0.000000e+00> : tensor + %0 = stablehlo.broadcast_in_dim %cst, dims = [] : (tensor) -> tensor<4x128xf32> + %1 = stablehlo.maximum %arg0, %0 : tensor<4x128xf32> + return %1 : tensor<4x128xf32> + } + func.func private @relu_3(%arg0: tensor<4x128xf32>) -> tensor<4x128xf32> { + %cst = stablehlo.constant dense<0.000000e+00> : tensor + %0 = stablehlo.broadcast_in_dim %cst, dims = [] : (tensor) -> tensor<4x128xf32> + %1 = stablehlo.maximum %arg0, %0 : tensor<4x128xf32> + return %1 : tensor<4x128xf32> + } +} + +// CHECK-LABEL @main +// CHECK: "ttir.mesh_shard" +// CHECK-SAME: shard_dims = array +// CHECK-SAME: shard_direction = #tt.shard_direction +// CHECK-SAME: shard_shape = array +// CHECK-SAME: shard_type = #tt.shard_type +// CHECK: "ttir.mesh_shard" +// CHECK-SAME: shard_dims = array +// CHECK-SAME: shard_direction = #tt.shard_direction +// CHECK-SAME: shard_shape = array +// CHECK-SAME: shard_type = #tt.shard_type +// CHECK: "ttir.mesh_shard" +// CHECK-SAME: shard_dims = array +// CHECK-SAME: shard_direction = #tt.shard_direction +// CHECK-SAME: shard_shape = array +// CHECK-SAME: shard_type = #tt.shard_type +// CHECK: "ttir.mesh_shard" +// CHECK-SAME: shard_dims = array +// CHECK-SAME: shard_direction = #tt.shard_direction +// CHECK-SAME: shard_shape = array +// CHECK-SAME: shard_type = #tt.shard_type +// CHECK: "ttir.mesh_shard" +// CHECK-SAME: shard_dims = array +// CHECK-SAME: shard_direction = #tt.shard_direction +// CHECK-SAME: shard_shape = array +// CHECK-SAME: shard_type = #tt.shard_type +// CHECK: "ttir.mesh_shard" +// CHECK-SAME: shard_dims = array +// CHECK-SAME: shard_direction = #tt.shard_direction +// CHECK-SAME: shard_shape = array +// CHECK-SAME: shard_type = #tt.shard_type +// CHECK: "ttir.mesh_shard" +// CHECK-SAME: shard_dims = array +// CHECK-SAME: shard_direction = #tt.shard_direction +// CHECK-SAME: shard_shape = array +// CHECK-SAME: shard_type = #tt.shard_type +// CHECK: "ttir.mesh_shard" +// CHECK-SAME: shard_dims = array +// CHECK-SAME: shard_direction = #tt.shard_direction +// CHECK-SAME: shard_shape = array +// CHECK-SAME: shard_type = #tt.shard_type +// CHECK: "ttir.mesh_shard" +// CHECK-SAME: shard_dims = array +// CHECK-SAME: shard_direction = #tt.shard_direction +// CHECK-SAME: shard_shape = array +// CHECK-SAME: shard_type = #tt.shard_type +// CHECK: "ttir.mesh_shard" +// CHECK-SAME: shard_dims = array +// CHECK-SAME: shard_direction = #tt.shard_direction +// CHECK-SAME: shard_shape = array +// CHECK-SAME: shard_type = #tt.shard_type +// CHECK: "ttir.mesh_shard" +// CHECK-SAME: shard_dims = array +// CHECK-SAME: shard_direction = #tt.shard_direction +// CHECK-SAME: shard_shape = array +// CHECK-SAME: shard_type = #tt.shard_type +// CHECK: "ttir.mesh_shard" +// CHECK-SAME: shard_dims = array +// CHECK-SAME: shard_direction = #tt.shard_direction +// CHECK-SAME: shard_shape = array +// CHECK-SAME: shard_type = #tt.shard_type +// CHECK: "ttir.mesh_shard" +// CHECK-SAME: shard_dims = array +// CHECK-SAME: shard_direction = #tt.shard_direction +// CHECK-SAME: shard_shape = array +// CHECK-SAME: shard_type = #tt.shard_type +// CHECK: "ttir.mesh_shard" +// CHECK-SAME: shard_dims = array +// CHECK-SAME: shard_direction = #tt.shard_direction +// CHECK-SAME: shard_shape = array +// CHECK-SAME: shard_type = #tt.shard_type +// CHECK: "ttir.mesh_shard" +// CHECK-SAME: shard_dims = array +// CHECK-SAME: shard_direction = #tt.shard_direction +// CHECK-SAME: shard_shape = array +// CHECK-SAME: shard_type = #tt.shard_type diff --git a/test/ttmlir/Conversion/StableHLOToTTIR/ccl/e2e_fsdp_tp_gspmd.mlir b/test/ttmlir/Conversion/StableHLOToTTIR/ccl/e2e_fsdp_tp_gspmd.mlir new file mode 100644 index 0000000000..f696f4b268 --- /dev/null +++ b/test/ttmlir/Conversion/StableHLOToTTIR/ccl/e2e_fsdp_tp_gspmd.mlir @@ -0,0 +1,166 @@ +// REQUIRES: stablehlo +// RUN: ttmlir-opt -split-input-file --stablehlo-to-ttir-pipeline %s | FileCheck %s +// UNSUPPORTED: true + +module @jit_loss_fsdp_tp attributes {mhlo.num_partitions = 8 : i32, mhlo.num_replicas = 1 : i32} { + func.func public @main(%arg0: tensor<784x128xf32>, %arg1: tensor<128xf32>, %arg2: tensor<128x128xf32>, %arg3: tensor<128xf32>, %arg4: tensor<128x128xf32>, %arg5: tensor<128xf32>, %arg6: tensor<128x128xf32>, %arg7: tensor<128xf32>, %arg8: tensor<128x128xf32>, %arg9: tensor<128xf32>, %arg10: tensor<128x8xf32>, %arg11: tensor<8xf32>, %arg12: tensor<32x784xf32>, %arg13: tensor<32x8xf32>) -> (tensor {jax.result_info = ""}) { + %0 = stablehlo.custom_call @Sharding(%arg0) {backend_config = "", mhlo.sharding = "{devices=[8,1]<=[2,4]T(1,0)}"} : (tensor<784x128xf32>) -> tensor<784x128xf32> + %1 = stablehlo.custom_call @SPMDFullToShardShape(%0) {backend_config = "", mhlo.sharding = "{manual}"} : (tensor<784x128xf32>) -> tensor<98x128xf32> + %2 = stablehlo.custom_call @Sharding(%arg1) {backend_config = "", mhlo.sharding = "{devices=[8]<=[2,4]T(1,0)}"} : (tensor<128xf32>) -> tensor<128xf32> + %3 = stablehlo.custom_call @SPMDFullToShardShape(%2) {backend_config = "", mhlo.sharding = "{manual}"} : (tensor<128xf32>) -> tensor<16xf32> + %4 = stablehlo.custom_call @Sharding(%arg2) {backend_config = "", mhlo.sharding = "{devices=[8,1]<=[2,4]T(1,0)}"} : (tensor<128x128xf32>) -> tensor<128x128xf32> + %5 = stablehlo.custom_call @SPMDFullToShardShape(%4) {backend_config = "", mhlo.sharding = "{manual}"} : (tensor<128x128xf32>) -> tensor<16x128xf32> + %6 = stablehlo.custom_call @Sharding(%arg3) {backend_config = "", mhlo.sharding = "{devices=[8]<=[2,4]T(1,0)}"} : (tensor<128xf32>) -> tensor<128xf32> + %7 = stablehlo.custom_call @SPMDFullToShardShape(%6) {backend_config = "", mhlo.sharding = "{manual}"} : (tensor<128xf32>) -> tensor<16xf32> + %8 = stablehlo.custom_call @Sharding(%arg4) {backend_config = "", mhlo.sharding = "{devices=[8,1]<=[2,4]T(1,0)}"} : (tensor<128x128xf32>) -> tensor<128x128xf32> + %9 = stablehlo.custom_call @SPMDFullToShardShape(%8) {backend_config = "", mhlo.sharding = "{manual}"} : (tensor<128x128xf32>) -> tensor<16x128xf32> + %10 = stablehlo.custom_call @Sharding(%arg5) {backend_config = "", mhlo.sharding = "{devices=[8]<=[2,4]T(1,0)}"} : (tensor<128xf32>) -> tensor<128xf32> + %11 = stablehlo.custom_call @SPMDFullToShardShape(%10) {backend_config = "", mhlo.sharding = "{manual}"} : (tensor<128xf32>) -> tensor<16xf32> + %12 = stablehlo.custom_call @Sharding(%arg6) {backend_config = "", mhlo.sharding = "{devices=[8,1]<=[2,4]T(1,0)}"} : (tensor<128x128xf32>) -> tensor<128x128xf32> + %13 = stablehlo.custom_call @SPMDFullToShardShape(%12) {backend_config = "", mhlo.sharding = "{manual}"} : (tensor<128x128xf32>) -> tensor<16x128xf32> + %14 = stablehlo.custom_call @Sharding(%arg7) {backend_config = "", mhlo.sharding = "{devices=[8]<=[2,4]T(1,0)}"} : (tensor<128xf32>) -> tensor<128xf32> + %15 = stablehlo.custom_call @SPMDFullToShardShape(%14) {backend_config = "", mhlo.sharding = "{manual}"} : (tensor<128xf32>) -> tensor<16xf32> + %16 = stablehlo.custom_call @Sharding(%arg8) {backend_config = "", mhlo.sharding = "{devices=[8,1]<=[2,4]T(1,0)}"} : (tensor<128x128xf32>) -> tensor<128x128xf32> + %17 = stablehlo.custom_call @SPMDFullToShardShape(%16) {backend_config = "", mhlo.sharding = "{manual}"} : (tensor<128x128xf32>) -> tensor<16x128xf32> + %18 = stablehlo.custom_call @Sharding(%arg9) {backend_config = "", mhlo.sharding = "{devices=[8]<=[2,4]T(1,0)}"} : (tensor<128xf32>) -> tensor<128xf32> + %19 = stablehlo.custom_call @SPMDFullToShardShape(%18) {backend_config = "", mhlo.sharding = "{manual}"} : (tensor<128xf32>) -> tensor<16xf32> + %20 = stablehlo.custom_call @Sharding(%arg10) {backend_config = "", mhlo.sharding = "{devices=[8,1]<=[2,4]T(1,0)}"} : (tensor<128x8xf32>) -> tensor<128x8xf32> + %21 = stablehlo.custom_call @SPMDFullToShardShape(%20) {backend_config = "", mhlo.sharding = "{manual}"} : (tensor<128x8xf32>) -> tensor<16x8xf32> + %22 = stablehlo.custom_call @Sharding(%arg11) {backend_config = "", mhlo.sharding = "{devices=[8]<=[2,4]T(1,0)}"} : (tensor<8xf32>) -> tensor<8xf32> + %23 = stablehlo.custom_call @SPMDFullToShardShape(%22) {backend_config = "", mhlo.sharding = "{manual}"} : (tensor<8xf32>) -> tensor<1xf32> + %24 = stablehlo.custom_call @Sharding(%arg12) {backend_config = "", mhlo.sharding = "{devices=[2,4]<=[8]}"} : (tensor<32x784xf32>) -> tensor<32x784xf32> + %25 = stablehlo.custom_call @SPMDFullToShardShape(%24) {backend_config = "", mhlo.sharding = "{manual}"} : (tensor<32x784xf32>) -> tensor<16x196xf32> + %26 = stablehlo.custom_call @Sharding(%arg13) {backend_config = "", mhlo.sharding = "{devices=[2,4]<=[8]}"} : (tensor<32x8xf32>) -> tensor<32x8xf32> + %27 = stablehlo.custom_call @SPMDFullToShardShape(%26) {backend_config = "", mhlo.sharding = "{manual}"} : (tensor<32x8xf32>) -> tensor<16x2xf32> + %28 = call @shmap_body(%1, %3, %5, %7, %9, %11, %13, %15, %17, %19, %21, %23, %25, %27) : (tensor<98x128xf32>, tensor<16xf32>, tensor<16x128xf32>, tensor<16xf32>, tensor<16x128xf32>, tensor<16xf32>, tensor<16x128xf32>, tensor<16xf32>, tensor<16x128xf32>, tensor<16xf32>, tensor<16x8xf32>, tensor<1xf32>, tensor<16x196xf32>, tensor<16x2xf32>) -> tensor + %29 = stablehlo.custom_call @Sharding(%28) {backend_config = "", mhlo.sharding = "{manual}"} : (tensor) -> tensor + %30 = stablehlo.custom_call @SPMDShardToFullShape(%29) {backend_config = "", mhlo.sharding = "{replicated}"} : (tensor) -> tensor + return %30 : tensor + } + func.func private @shmap_body(%arg0: tensor<98x128xf32>, %arg1: tensor<16xf32>, %arg2: tensor<16x128xf32>, %arg3: tensor<16xf32>, %arg4: tensor<16x128xf32>, %arg5: tensor<16xf32>, %arg6: tensor<16x128xf32>, %arg7: tensor<16xf32>, %arg8: tensor<16x128xf32>, %arg9: tensor<16xf32>, %arg10: tensor<16x8xf32>, %arg11: tensor<1xf32>, %arg12: tensor<16x196xf32>, %arg13: tensor<16x2xf32>) -> (tensor {jax.result_info = "[]"}) { + %0 = "stablehlo.all_gather"(%arg0) <{all_gather_dim = 0 : i64, channel_handle = #stablehlo.channel_handle, replica_groups = dense<[[0, 4], [1, 5], [2, 6], [3, 7]]> : tensor<4x2xi64>, use_global_device_ids}> : (tensor<98x128xf32>) -> tensor<196x128xf32> + %1 = "stablehlo.all_gather"(%arg1) <{all_gather_dim = 0 : i64, channel_handle = #stablehlo.channel_handle, replica_groups = dense<[[0, 4], [1, 5], [2, 6], [3, 7]]> : tensor<4x2xi64>, use_global_device_ids}> : (tensor<16xf32>) -> tensor<32xf32> + %2 = stablehlo.dot_general %arg12, %0, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<16x196xf32>, tensor<196x128xf32>) -> tensor<16x128xf32> + %3 = "stablehlo.reduce_scatter"(%2) <{channel_handle = #stablehlo.channel_handle, replica_groups = dense<[[0, 1, 2, 3], [4, 5, 6, 7]]> : tensor<2x4xi64>, scatter_dimension = 1 : i64, use_global_device_ids}> ({ + ^bb0(%arg14: tensor, %arg15: tensor): + %55 = stablehlo.add %arg14, %arg15 : tensor + stablehlo.return %55 : tensor + }) : (tensor<16x128xf32>) -> tensor<16x32xf32> + %4 = stablehlo.broadcast_in_dim %1, dims = [1] : (tensor<32xf32>) -> tensor<1x32xf32> + %5 = stablehlo.broadcast_in_dim %4, dims = [0, 1] : (tensor<1x32xf32>) -> tensor<16x32xf32> + %6 = stablehlo.add %3, %5 : tensor<16x32xf32> + %7 = call @relu(%6) : (tensor<16x32xf32>) -> tensor<16x32xf32> + %8 = "stablehlo.all_gather"(%arg2) <{all_gather_dim = 0 : i64, channel_handle = #stablehlo.channel_handle, replica_groups = dense<[[0, 4], [1, 5], [2, 6], [3, 7]]> : tensor<4x2xi64>, use_global_device_ids}> : (tensor<16x128xf32>) -> tensor<32x128xf32> + %9 = "stablehlo.all_gather"(%arg3) <{all_gather_dim = 0 : i64, channel_handle = #stablehlo.channel_handle, replica_groups = dense<[[0, 4], [1, 5], [2, 6], [3, 7]]> : tensor<4x2xi64>, use_global_device_ids}> : (tensor<16xf32>) -> tensor<32xf32> + %10 = stablehlo.dot_general %7, %8, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<16x32xf32>, tensor<32x128xf32>) -> tensor<16x128xf32> + %11 = "stablehlo.reduce_scatter"(%10) <{channel_handle = #stablehlo.channel_handle, replica_groups = dense<[[0, 1, 2, 3], [4, 5, 6, 7]]> : tensor<2x4xi64>, scatter_dimension = 1 : i64, use_global_device_ids}> ({ + ^bb0(%arg14: tensor, %arg15: tensor): + %55 = stablehlo.add %arg14, %arg15 : tensor + stablehlo.return %55 : tensor + }) : (tensor<16x128xf32>) -> tensor<16x32xf32> + %12 = stablehlo.broadcast_in_dim %9, dims = [1] : (tensor<32xf32>) -> tensor<1x32xf32> + %13 = stablehlo.broadcast_in_dim %12, dims = [0, 1] : (tensor<1x32xf32>) -> tensor<16x32xf32> + %14 = stablehlo.add %11, %13 : tensor<16x32xf32> + %15 = call @relu_0(%14) : (tensor<16x32xf32>) -> tensor<16x32xf32> + %16 = "stablehlo.all_gather"(%arg4) <{all_gather_dim = 0 : i64, channel_handle = #stablehlo.channel_handle, replica_groups = dense<[[0, 4], [1, 5], [2, 6], [3, 7]]> : tensor<4x2xi64>, use_global_device_ids}> : (tensor<16x128xf32>) -> tensor<32x128xf32> + %17 = "stablehlo.all_gather"(%arg5) <{all_gather_dim = 0 : i64, channel_handle = #stablehlo.channel_handle, replica_groups = dense<[[0, 4], [1, 5], [2, 6], [3, 7]]> : tensor<4x2xi64>, use_global_device_ids}> : (tensor<16xf32>) -> tensor<32xf32> + %18 = stablehlo.dot_general %15, %16, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<16x32xf32>, tensor<32x128xf32>) -> tensor<16x128xf32> + %19 = "stablehlo.reduce_scatter"(%18) <{channel_handle = #stablehlo.channel_handle, replica_groups = dense<[[0, 1, 2, 3], [4, 5, 6, 7]]> : tensor<2x4xi64>, scatter_dimension = 1 : i64, use_global_device_ids}> ({ + ^bb0(%arg14: tensor, %arg15: tensor): + %55 = stablehlo.add %arg14, %arg15 : tensor + stablehlo.return %55 : tensor + }) : (tensor<16x128xf32>) -> tensor<16x32xf32> + %20 = stablehlo.broadcast_in_dim %17, dims = [1] : (tensor<32xf32>) -> tensor<1x32xf32> + %21 = stablehlo.broadcast_in_dim %20, dims = [0, 1] : (tensor<1x32xf32>) -> tensor<16x32xf32> + %22 = stablehlo.add %19, %21 : tensor<16x32xf32> + %23 = call @relu_1(%22) : (tensor<16x32xf32>) -> tensor<16x32xf32> + %24 = "stablehlo.all_gather"(%arg6) <{all_gather_dim = 0 : i64, channel_handle = #stablehlo.channel_handle, replica_groups = dense<[[0, 4], [1, 5], [2, 6], [3, 7]]> : tensor<4x2xi64>, use_global_device_ids}> : (tensor<16x128xf32>) -> tensor<32x128xf32> + %25 = "stablehlo.all_gather"(%arg7) <{all_gather_dim = 0 : i64, channel_handle = #stablehlo.channel_handle, replica_groups = dense<[[0, 4], [1, 5], [2, 6], [3, 7]]> : tensor<4x2xi64>, use_global_device_ids}> : (tensor<16xf32>) -> tensor<32xf32> + %26 = stablehlo.dot_general %23, %24, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<16x32xf32>, tensor<32x128xf32>) -> tensor<16x128xf32> + %27 = "stablehlo.reduce_scatter"(%26) <{channel_handle = #stablehlo.channel_handle, replica_groups = dense<[[0, 1, 2, 3], [4, 5, 6, 7]]> : tensor<2x4xi64>, scatter_dimension = 1 : i64, use_global_device_ids}> ({ + ^bb0(%arg14: tensor, %arg15: tensor): + %55 = stablehlo.add %arg14, %arg15 : tensor + stablehlo.return %55 : tensor + }) : (tensor<16x128xf32>) -> tensor<16x32xf32> + %28 = stablehlo.broadcast_in_dim %25, dims = [1] : (tensor<32xf32>) -> tensor<1x32xf32> + %29 = stablehlo.broadcast_in_dim %28, dims = [0, 1] : (tensor<1x32xf32>) -> tensor<16x32xf32> + %30 = stablehlo.add %27, %29 : tensor<16x32xf32> + %31 = call @relu_2(%30) : (tensor<16x32xf32>) -> tensor<16x32xf32> + %32 = "stablehlo.all_gather"(%arg8) <{all_gather_dim = 0 : i64, channel_handle = #stablehlo.channel_handle, replica_groups = dense<[[0, 4], [1, 5], [2, 6], [3, 7]]> : tensor<4x2xi64>, use_global_device_ids}> : (tensor<16x128xf32>) -> tensor<32x128xf32> + %33 = "stablehlo.all_gather"(%arg9) <{all_gather_dim = 0 : i64, channel_handle = #stablehlo.channel_handle, replica_groups = dense<[[0, 4], [1, 5], [2, 6], [3, 7]]> : tensor<4x2xi64>, use_global_device_ids}> : (tensor<16xf32>) -> tensor<32xf32> + %34 = stablehlo.dot_general %31, %32, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<16x32xf32>, tensor<32x128xf32>) -> tensor<16x128xf32> + %35 = "stablehlo.reduce_scatter"(%34) <{channel_handle = #stablehlo.channel_handle, replica_groups = dense<[[0, 1, 2, 3], [4, 5, 6, 7]]> : tensor<2x4xi64>, scatter_dimension = 1 : i64, use_global_device_ids}> ({ + ^bb0(%arg14: tensor, %arg15: tensor): + %55 = stablehlo.add %arg14, %arg15 : tensor + stablehlo.return %55 : tensor + }) : (tensor<16x128xf32>) -> tensor<16x32xf32> + %36 = stablehlo.broadcast_in_dim %33, dims = [1] : (tensor<32xf32>) -> tensor<1x32xf32> + %37 = stablehlo.broadcast_in_dim %36, dims = [0, 1] : (tensor<1x32xf32>) -> tensor<16x32xf32> + %38 = stablehlo.add %35, %37 : tensor<16x32xf32> + %39 = call @relu_3(%38) : (tensor<16x32xf32>) -> tensor<16x32xf32> + %40 = "stablehlo.all_gather"(%arg10) <{all_gather_dim = 0 : i64, channel_handle = #stablehlo.channel_handle, replica_groups = dense<[[0, 4], [1, 5], [2, 6], [3, 7]]> : tensor<4x2xi64>, use_global_device_ids}> : (tensor<16x8xf32>) -> tensor<32x8xf32> + %41 = "stablehlo.all_gather"(%arg11) <{all_gather_dim = 0 : i64, channel_handle = #stablehlo.channel_handle, replica_groups = dense<[[0, 4], [1, 5], [2, 6], [3, 7]]> : tensor<4x2xi64>, use_global_device_ids}> : (tensor<1xf32>) -> tensor<2xf32> + %42 = stablehlo.dot_general %39, %40, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<16x32xf32>, tensor<32x8xf32>) -> tensor<16x8xf32> + %43 = "stablehlo.reduce_scatter"(%42) <{channel_handle = #stablehlo.channel_handle, replica_groups = dense<[[0, 1, 2, 3], [4, 5, 6, 7]]> : tensor<2x4xi64>, scatter_dimension = 1 : i64, use_global_device_ids}> ({ + ^bb0(%arg14: tensor, %arg15: tensor): + %55 = stablehlo.add %arg14, %arg15 : tensor + stablehlo.return %55 : tensor + }) : (tensor<16x8xf32>) -> tensor<16x2xf32> + %44 = stablehlo.broadcast_in_dim %41, dims = [1] : (tensor<2xf32>) -> tensor<1x2xf32> + %45 = stablehlo.broadcast_in_dim %44, dims = [0, 1] : (tensor<1x2xf32>) -> tensor<16x2xf32> + %46 = stablehlo.add %43, %45 : tensor<16x2xf32> + %47 = stablehlo.subtract %46, %arg13 : tensor<16x2xf32> + %48 = stablehlo.multiply %47, %47 : tensor<16x2xf32> + %cst = stablehlo.constant dense<0.000000e+00> : tensor + %49 = stablehlo.reduce(%48 init: %cst) applies stablehlo.add across dimensions = [1] : (tensor<16x2xf32>, tensor) -> tensor<16xf32> + %50 = "stablehlo.all_reduce"(%49) <{channel_handle = #stablehlo.channel_handle, replica_groups = dense<[[0, 1, 2, 3], [4, 5, 6, 7]]> : tensor<2x4xi64>, use_global_device_ids}> ({ + ^bb0(%arg14: tensor, %arg15: tensor): + %55 = stablehlo.add %arg14, %arg15 : tensor + stablehlo.return %55 : tensor + }) : (tensor<16xf32>) -> tensor<16xf32> + %cst_0 = stablehlo.constant dense<0.000000e+00> : tensor + %51 = stablehlo.reduce(%50 init: %cst_0) applies stablehlo.add across dimensions = [0] : (tensor<16xf32>, tensor) -> tensor + %cst_1 = stablehlo.constant dense<1.600000e+01> : tensor + %52 = stablehlo.divide %51, %cst_1 : tensor + %53 = "stablehlo.all_reduce"(%52) <{channel_handle = #stablehlo.channel_handle, replica_groups = dense<[[0, 4], [1, 5], [2, 6], [3, 7]]> : tensor<4x2xi64>, use_global_device_ids}> ({ + ^bb0(%arg14: tensor, %arg15: tensor): + %55 = stablehlo.add %arg14, %arg15 : tensor + stablehlo.return %55 : tensor + }) : (tensor) -> tensor + %cst_2 = stablehlo.constant dense<2.000000e+00> : tensor + %54 = stablehlo.divide %53, %cst_2 : tensor + return %54 : tensor + } + func.func private @relu(%arg0: tensor<16x32xf32>) -> tensor<16x32xf32> { + %cst = stablehlo.constant dense<0.000000e+00> : tensor + %0 = stablehlo.broadcast_in_dim %cst, dims = [] : (tensor) -> tensor<16x32xf32> + %1 = stablehlo.maximum %arg0, %0 : tensor<16x32xf32> + return %1 : tensor<16x32xf32> + } + func.func private @relu_0(%arg0: tensor<16x32xf32>) -> tensor<16x32xf32> { + %cst = stablehlo.constant dense<0.000000e+00> : tensor + %0 = stablehlo.broadcast_in_dim %cst, dims = [] : (tensor) -> tensor<16x32xf32> + %1 = stablehlo.maximum %arg0, %0 : tensor<16x32xf32> + return %1 : tensor<16x32xf32> + } + func.func private @relu_1(%arg0: tensor<16x32xf32>) -> tensor<16x32xf32> { + %cst = stablehlo.constant dense<0.000000e+00> : tensor + %0 = stablehlo.broadcast_in_dim %cst, dims = [] : (tensor) -> tensor<16x32xf32> + %1 = stablehlo.maximum %arg0, %0 : tensor<16x32xf32> + return %1 : tensor<16x32xf32> + } + func.func private @relu_2(%arg0: tensor<16x32xf32>) -> tensor<16x32xf32> { + %cst = stablehlo.constant dense<0.000000e+00> : tensor + %0 = stablehlo.broadcast_in_dim %cst, dims = [] : (tensor) -> tensor<16x32xf32> + %1 = stablehlo.maximum %arg0, %0 : tensor<16x32xf32> + return %1 : tensor<16x32xf32> + } + func.func private @relu_3(%arg0: tensor<16x32xf32>) -> tensor<16x32xf32> { + %cst = stablehlo.constant dense<0.000000e+00> : tensor + %0 = stablehlo.broadcast_in_dim %cst, dims = [] : (tensor) -> tensor<16x32xf32> + %1 = stablehlo.maximum %arg0, %0 : tensor<16x32xf32> + return %1 : tensor<16x32xf32> + } +} + +// CHECK-LABEL @main diff --git a/test/ttmlir/Conversion/StableHLOToTTIR/ccl/e2e_fsdp_tp_shardy.mlir b/test/ttmlir/Conversion/StableHLOToTTIR/ccl/e2e_fsdp_tp_shardy.mlir new file mode 100644 index 0000000000..5ddd3fb180 --- /dev/null +++ b/test/ttmlir/Conversion/StableHLOToTTIR/ccl/e2e_fsdp_tp_shardy.mlir @@ -0,0 +1,136 @@ +// REQUIRES: stablehlo +// RUN: ttmlir-opt -split-input-file --stablehlo-to-ttir-pipeline %s | FileCheck %s +// UNSUPPORTED: true + +module @jit_loss_fsdp_tp attributes {mhlo.num_partitions = 8 : i32, mhlo.num_replicas = 1 : i32} { + sdy.mesh @mesh = <["x"=2, "y"=4]> + func.func public @main(%arg0: tensor<784x128xf32>, %arg1: tensor<128xf32>, %arg2: tensor<128x128xf32>, %arg3: tensor<128xf32>, %arg4: tensor<128x128xf32>, %arg5: tensor<128xf32>, %arg6: tensor<128x128xf32>, %arg7: tensor<128xf32>, %arg8: tensor<128x128xf32>, %arg9: tensor<128xf32>, %arg10: tensor<128x8xf32>, %arg11: tensor<8xf32>, %arg12: tensor<32x784xf32>, %arg13: tensor<32x8xf32>) -> (tensor {jax.result_info = ""}) { + %0 = sdy.manual_computation(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8, %arg9, %arg10, %arg11, %arg12, %arg13) in_shardings=[<@mesh, [{"y", "x"}, {}]>, <@mesh, [{"y", "x"}]>, <@mesh, [{"y", "x"}, {}]>, <@mesh, [{"y", "x"}]>, <@mesh, [{"y", "x"}, {}]>, <@mesh, [{"y", "x"}]>, <@mesh, [{"y", "x"}, {}]>, <@mesh, [{"y", "x"}]>, <@mesh, [{"y", "x"}, {}]>, <@mesh, [{"y", "x"}]>, <@mesh, [{"y", "x"}, {}]>, <@mesh, [{"y", "x"}]>, <@mesh, [{"x"}, {"y"}]>, <@mesh, [{"x"}, {"y"}]>] out_shardings=[<@mesh, []>] manual_axes={"x", "y"} (%arg14: tensor<98x128xf32>, %arg15: tensor<16xf32>, %arg16: tensor<16x128xf32>, %arg17: tensor<16xf32>, %arg18: tensor<16x128xf32>, %arg19: tensor<16xf32>, %arg20: tensor<16x128xf32>, %arg21: tensor<16xf32>, %arg22: tensor<16x128xf32>, %arg23: tensor<16xf32>, %arg24: tensor<16x8xf32>, %arg25: tensor<1xf32>, %arg26: tensor<16x196xf32>, %arg27: tensor<16x2xf32>) { + %1 = "stablehlo.all_gather"(%arg14) <{all_gather_dim = 0 : i64, channel_handle = #stablehlo.channel_handle, replica_groups = dense<[[0, 4], [1, 5], [2, 6], [3, 7]]> : tensor<4x2xi64>, use_global_device_ids}> : (tensor<98x128xf32>) -> tensor<196x128xf32> + %2 = "stablehlo.all_gather"(%arg15) <{all_gather_dim = 0 : i64, channel_handle = #stablehlo.channel_handle, replica_groups = dense<[[0, 4], [1, 5], [2, 6], [3, 7]]> : tensor<4x2xi64>, use_global_device_ids}> : (tensor<16xf32>) -> tensor<32xf32> + %3 = stablehlo.dot_general %arg26, %1, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<16x196xf32>, tensor<196x128xf32>) -> tensor<16x128xf32> + %4 = "stablehlo.reduce_scatter"(%3) <{channel_handle = #stablehlo.channel_handle, replica_groups = dense<[[0, 1, 2, 3], [4, 5, 6, 7]]> : tensor<2x4xi64>, scatter_dimension = 1 : i64, use_global_device_ids}> ({ + ^bb0(%arg28: tensor, %arg29: tensor): + %56 = stablehlo.add %arg28, %arg29 : tensor + stablehlo.return %56 : tensor + }) : (tensor<16x128xf32>) -> tensor<16x32xf32> + %5 = stablehlo.broadcast_in_dim %2, dims = [1] : (tensor<32xf32>) -> tensor<1x32xf32> + %6 = stablehlo.broadcast_in_dim %5, dims = [0, 1] : (tensor<1x32xf32>) -> tensor<16x32xf32> + %7 = stablehlo.add %4, %6 : tensor<16x32xf32> + %8 = func.call @relu(%7) : (tensor<16x32xf32>) -> tensor<16x32xf32> + %9 = "stablehlo.all_gather"(%arg16) <{all_gather_dim = 0 : i64, channel_handle = #stablehlo.channel_handle, replica_groups = dense<[[0, 4], [1, 5], [2, 6], [3, 7]]> : tensor<4x2xi64>, use_global_device_ids}> : (tensor<16x128xf32>) -> tensor<32x128xf32> + %10 = "stablehlo.all_gather"(%arg17) <{all_gather_dim = 0 : i64, channel_handle = #stablehlo.channel_handle, replica_groups = dense<[[0, 4], [1, 5], [2, 6], [3, 7]]> : tensor<4x2xi64>, use_global_device_ids}> : (tensor<16xf32>) -> tensor<32xf32> + %11 = stablehlo.dot_general %8, %9, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<16x32xf32>, tensor<32x128xf32>) -> tensor<16x128xf32> + %12 = "stablehlo.reduce_scatter"(%11) <{channel_handle = #stablehlo.channel_handle, replica_groups = dense<[[0, 1, 2, 3], [4, 5, 6, 7]]> : tensor<2x4xi64>, scatter_dimension = 1 : i64, use_global_device_ids}> ({ + ^bb0(%arg28: tensor, %arg29: tensor): + %56 = stablehlo.add %arg28, %arg29 : tensor + stablehlo.return %56 : tensor + }) : (tensor<16x128xf32>) -> tensor<16x32xf32> + %13 = stablehlo.broadcast_in_dim %10, dims = [1] : (tensor<32xf32>) -> tensor<1x32xf32> + %14 = stablehlo.broadcast_in_dim %13, dims = [0, 1] : (tensor<1x32xf32>) -> tensor<16x32xf32> + %15 = stablehlo.add %12, %14 : tensor<16x32xf32> + %16 = func.call @relu_0(%15) : (tensor<16x32xf32>) -> tensor<16x32xf32> + %17 = "stablehlo.all_gather"(%arg18) <{all_gather_dim = 0 : i64, channel_handle = #stablehlo.channel_handle, replica_groups = dense<[[0, 4], [1, 5], [2, 6], [3, 7]]> : tensor<4x2xi64>, use_global_device_ids}> : (tensor<16x128xf32>) -> tensor<32x128xf32> + %18 = "stablehlo.all_gather"(%arg19) <{all_gather_dim = 0 : i64, channel_handle = #stablehlo.channel_handle, replica_groups = dense<[[0, 4], [1, 5], [2, 6], [3, 7]]> : tensor<4x2xi64>, use_global_device_ids}> : (tensor<16xf32>) -> tensor<32xf32> + %19 = stablehlo.dot_general %16, %17, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<16x32xf32>, tensor<32x128xf32>) -> tensor<16x128xf32> + %20 = "stablehlo.reduce_scatter"(%19) <{channel_handle = #stablehlo.channel_handle, replica_groups = dense<[[0, 1, 2, 3], [4, 5, 6, 7]]> : tensor<2x4xi64>, scatter_dimension = 1 : i64, use_global_device_ids}> ({ + ^bb0(%arg28: tensor, %arg29: tensor): + %56 = stablehlo.add %arg28, %arg29 : tensor + stablehlo.return %56 : tensor + }) : (tensor<16x128xf32>) -> tensor<16x32xf32> + %21 = stablehlo.broadcast_in_dim %18, dims = [1] : (tensor<32xf32>) -> tensor<1x32xf32> + %22 = stablehlo.broadcast_in_dim %21, dims = [0, 1] : (tensor<1x32xf32>) -> tensor<16x32xf32> + %23 = stablehlo.add %20, %22 : tensor<16x32xf32> + %24 = func.call @relu_1(%23) : (tensor<16x32xf32>) -> tensor<16x32xf32> + %25 = "stablehlo.all_gather"(%arg20) <{all_gather_dim = 0 : i64, channel_handle = #stablehlo.channel_handle, replica_groups = dense<[[0, 4], [1, 5], [2, 6], [3, 7]]> : tensor<4x2xi64>, use_global_device_ids}> : (tensor<16x128xf32>) -> tensor<32x128xf32> + %26 = "stablehlo.all_gather"(%arg21) <{all_gather_dim = 0 : i64, channel_handle = #stablehlo.channel_handle, replica_groups = dense<[[0, 4], [1, 5], [2, 6], [3, 7]]> : tensor<4x2xi64>, use_global_device_ids}> : (tensor<16xf32>) -> tensor<32xf32> + %27 = stablehlo.dot_general %24, %25, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<16x32xf32>, tensor<32x128xf32>) -> tensor<16x128xf32> + %28 = "stablehlo.reduce_scatter"(%27) <{channel_handle = #stablehlo.channel_handle, replica_groups = dense<[[0, 1, 2, 3], [4, 5, 6, 7]]> : tensor<2x4xi64>, scatter_dimension = 1 : i64, use_global_device_ids}> ({ + ^bb0(%arg28: tensor, %arg29: tensor): + %56 = stablehlo.add %arg28, %arg29 : tensor + stablehlo.return %56 : tensor + }) : (tensor<16x128xf32>) -> tensor<16x32xf32> + %29 = stablehlo.broadcast_in_dim %26, dims = [1] : (tensor<32xf32>) -> tensor<1x32xf32> + %30 = stablehlo.broadcast_in_dim %29, dims = [0, 1] : (tensor<1x32xf32>) -> tensor<16x32xf32> + %31 = stablehlo.add %28, %30 : tensor<16x32xf32> + %32 = func.call @relu_2(%31) : (tensor<16x32xf32>) -> tensor<16x32xf32> + %33 = "stablehlo.all_gather"(%arg22) <{all_gather_dim = 0 : i64, channel_handle = #stablehlo.channel_handle, replica_groups = dense<[[0, 4], [1, 5], [2, 6], [3, 7]]> : tensor<4x2xi64>, use_global_device_ids}> : (tensor<16x128xf32>) -> tensor<32x128xf32> + %34 = "stablehlo.all_gather"(%arg23) <{all_gather_dim = 0 : i64, channel_handle = #stablehlo.channel_handle, replica_groups = dense<[[0, 4], [1, 5], [2, 6], [3, 7]]> : tensor<4x2xi64>, use_global_device_ids}> : (tensor<16xf32>) -> tensor<32xf32> + %35 = stablehlo.dot_general %32, %33, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<16x32xf32>, tensor<32x128xf32>) -> tensor<16x128xf32> + %36 = "stablehlo.reduce_scatter"(%35) <{channel_handle = #stablehlo.channel_handle, replica_groups = dense<[[0, 1, 2, 3], [4, 5, 6, 7]]> : tensor<2x4xi64>, scatter_dimension = 1 : i64, use_global_device_ids}> ({ + ^bb0(%arg28: tensor, %arg29: tensor): + %56 = stablehlo.add %arg28, %arg29 : tensor + stablehlo.return %56 : tensor + }) : (tensor<16x128xf32>) -> tensor<16x32xf32> + %37 = stablehlo.broadcast_in_dim %34, dims = [1] : (tensor<32xf32>) -> tensor<1x32xf32> + %38 = stablehlo.broadcast_in_dim %37, dims = [0, 1] : (tensor<1x32xf32>) -> tensor<16x32xf32> + %39 = stablehlo.add %36, %38 : tensor<16x32xf32> + %40 = func.call @relu_3(%39) : (tensor<16x32xf32>) -> tensor<16x32xf32> + %41 = "stablehlo.all_gather"(%arg24) <{all_gather_dim = 0 : i64, channel_handle = #stablehlo.channel_handle, replica_groups = dense<[[0, 4], [1, 5], [2, 6], [3, 7]]> : tensor<4x2xi64>, use_global_device_ids}> : (tensor<16x8xf32>) -> tensor<32x8xf32> + %42 = "stablehlo.all_gather"(%arg25) <{all_gather_dim = 0 : i64, channel_handle = #stablehlo.channel_handle, replica_groups = dense<[[0, 4], [1, 5], [2, 6], [3, 7]]> : tensor<4x2xi64>, use_global_device_ids}> : (tensor<1xf32>) -> tensor<2xf32> + %43 = stablehlo.dot_general %40, %41, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<16x32xf32>, tensor<32x8xf32>) -> tensor<16x8xf32> + %44 = "stablehlo.reduce_scatter"(%43) <{channel_handle = #stablehlo.channel_handle, replica_groups = dense<[[0, 1, 2, 3], [4, 5, 6, 7]]> : tensor<2x4xi64>, scatter_dimension = 1 : i64, use_global_device_ids}> ({ + ^bb0(%arg28: tensor, %arg29: tensor): + %56 = stablehlo.add %arg28, %arg29 : tensor + stablehlo.return %56 : tensor + }) : (tensor<16x8xf32>) -> tensor<16x2xf32> + %45 = stablehlo.broadcast_in_dim %42, dims = [1] : (tensor<2xf32>) -> tensor<1x2xf32> + %46 = stablehlo.broadcast_in_dim %45, dims = [0, 1] : (tensor<1x2xf32>) -> tensor<16x2xf32> + %47 = stablehlo.add %44, %46 : tensor<16x2xf32> + %48 = stablehlo.subtract %47, %arg27 : tensor<16x2xf32> + %49 = stablehlo.multiply %48, %48 : tensor<16x2xf32> + %cst = stablehlo.constant dense<0.000000e+00> : tensor + %50 = stablehlo.reduce(%49 init: %cst) applies stablehlo.add across dimensions = [1] : (tensor<16x2xf32>, tensor) -> tensor<16xf32> + %51 = "stablehlo.all_reduce"(%50) <{channel_handle = #stablehlo.channel_handle, replica_groups = dense<[[0, 1, 2, 3], [4, 5, 6, 7]]> : tensor<2x4xi64>, use_global_device_ids}> ({ + ^bb0(%arg28: tensor, %arg29: tensor): + %56 = stablehlo.add %arg28, %arg29 : tensor + stablehlo.return %56 : tensor + }) : (tensor<16xf32>) -> tensor<16xf32> + %cst_0 = stablehlo.constant dense<0.000000e+00> : tensor + %52 = stablehlo.reduce(%51 init: %cst_0) applies stablehlo.add across dimensions = [0] : (tensor<16xf32>, tensor) -> tensor + %cst_1 = stablehlo.constant dense<1.600000e+01> : tensor + %53 = stablehlo.divide %52, %cst_1 : tensor + %54 = "stablehlo.all_reduce"(%53) <{channel_handle = #stablehlo.channel_handle, replica_groups = dense<[[0, 4], [1, 5], [2, 6], [3, 7]]> : tensor<4x2xi64>, use_global_device_ids}> ({ + ^bb0(%arg28: tensor, %arg29: tensor): + %56 = stablehlo.add %arg28, %arg29 : tensor + stablehlo.return %56 : tensor + }) : (tensor) -> tensor + %cst_2 = stablehlo.constant dense<2.000000e+00> : tensor + %55 = stablehlo.divide %54, %cst_2 : tensor + sdy.return %55 : tensor + } : (tensor<784x128xf32>, tensor<128xf32>, tensor<128x128xf32>, tensor<128xf32>, tensor<128x128xf32>, tensor<128xf32>, tensor<128x128xf32>, tensor<128xf32>, tensor<128x128xf32>, tensor<128xf32>, tensor<128x8xf32>, tensor<8xf32>, tensor<32x784xf32>, tensor<32x8xf32>) -> tensor + return %0 : tensor + } + func.func private @relu(%arg0: tensor<16x32xf32>) -> tensor<16x32xf32> { + %cst = stablehlo.constant dense<0.000000e+00> : tensor + %0 = stablehlo.broadcast_in_dim %cst, dims = [] : (tensor) -> tensor<16x32xf32> + %1 = stablehlo.maximum %arg0, %0 : tensor<16x32xf32> + return %1 : tensor<16x32xf32> + } + func.func private @relu_0(%arg0: tensor<16x32xf32>) -> tensor<16x32xf32> { + %cst = stablehlo.constant dense<0.000000e+00> : tensor + %0 = stablehlo.broadcast_in_dim %cst, dims = [] : (tensor) -> tensor<16x32xf32> + %1 = stablehlo.maximum %arg0, %0 : tensor<16x32xf32> + return %1 : tensor<16x32xf32> + } + func.func private @relu_1(%arg0: tensor<16x32xf32>) -> tensor<16x32xf32> { + %cst = stablehlo.constant dense<0.000000e+00> : tensor + %0 = stablehlo.broadcast_in_dim %cst, dims = [] : (tensor) -> tensor<16x32xf32> + %1 = stablehlo.maximum %arg0, %0 : tensor<16x32xf32> + return %1 : tensor<16x32xf32> + } + func.func private @relu_2(%arg0: tensor<16x32xf32>) -> tensor<16x32xf32> { + %cst = stablehlo.constant dense<0.000000e+00> : tensor + %0 = stablehlo.broadcast_in_dim %cst, dims = [] : (tensor) -> tensor<16x32xf32> + %1 = stablehlo.maximum %arg0, %0 : tensor<16x32xf32> + return %1 : tensor<16x32xf32> + } + func.func private @relu_3(%arg0: tensor<16x32xf32>) -> tensor<16x32xf32> { + %cst = stablehlo.constant dense<0.000000e+00> : tensor + %0 = stablehlo.broadcast_in_dim %cst, dims = [] : (tensor) -> tensor<16x32xf32> + %1 = stablehlo.maximum %arg0, %0 : tensor<16x32xf32> + return %1 : tensor<16x32xf32> + } +} + +// CHECK-LABEL @main diff --git a/test/ttmlir/Conversion/StableHLOToTTIR/ccl/e2e_tp_gspmd.mlir b/test/ttmlir/Conversion/StableHLOToTTIR/ccl/e2e_tp_gspmd.mlir new file mode 100644 index 0000000000..68675eca89 --- /dev/null +++ b/test/ttmlir/Conversion/StableHLOToTTIR/ccl/e2e_tp_gspmd.mlir @@ -0,0 +1,157 @@ +// REQUIRES: stablehlo +// RUN: ttmlir-opt -split-input-file --stablehlo-to-ttir-pipeline %s | FileCheck %s +// UNSUPPORTED: true + +module @jit_loss_tp attributes {mhlo.num_partitions = 8 : i32, mhlo.num_replicas = 1 : i32} { + func.func public @main(%arg0: tensor<784x128xf32> {mhlo.sharding = "{devices=[8,1]<=[8]}"}, %arg1: tensor<128xf32> {mhlo.sharding = "{devices=[8]<=[8]}"}, %arg2: tensor<128x128xf32> {mhlo.sharding = "{devices=[8,1]<=[8]}"}, %arg3: tensor<128xf32> {mhlo.sharding = "{devices=[8]<=[8]}"}, %arg4: tensor<128x128xf32> {mhlo.sharding = "{devices=[8,1]<=[8]}"}, %arg5: tensor<128xf32> {mhlo.sharding = "{devices=[8]<=[8]}"}, %arg6: tensor<128x128xf32> {mhlo.sharding = "{devices=[8,1]<=[8]}"}, %arg7: tensor<128xf32> {mhlo.sharding = "{devices=[8]<=[8]}"}, %arg8: tensor<128x128xf32> {mhlo.sharding = "{devices=[8,1]<=[8]}"}, %arg9: tensor<128xf32> {mhlo.sharding = "{devices=[8]<=[8]}"}, %arg10: tensor<128x8xf32> {mhlo.sharding = "{devices=[8,1]<=[8]}"}, %arg11: tensor<8xf32> {mhlo.sharding = "{devices=[8]<=[8]}"}, %arg12: tensor<32x784xf32> {mhlo.sharding = "{devices=[1,8]<=[8]}"}, %arg13: tensor<32x8xf32> {mhlo.sharding = "{devices=[1,8]<=[8]}"}) -> (tensor {jax.result_info = ""}) { + %0 = stablehlo.custom_call @Sharding(%arg12) {backend_config = "", mhlo.sharding = "{devices=[1,8]<=[8]}"} : (tensor<32x784xf32>) -> tensor<32x784xf32> + %1 = stablehlo.custom_call @SPMDFullToShardShape(%0) {backend_config = "", mhlo.sharding = "{manual}"} : (tensor<32x784xf32>) -> tensor<32x98xf32> + %2 = stablehlo.custom_call @Sharding(%arg0) {backend_config = "", mhlo.sharding = "{devices=[8,1]<=[8]}"} : (tensor<784x128xf32>) -> tensor<784x128xf32> + %3 = stablehlo.custom_call @SPMDFullToShardShape(%2) {backend_config = "", mhlo.sharding = "{manual}"} : (tensor<784x128xf32>) -> tensor<98x128xf32> + %4 = stablehlo.custom_call @Sharding(%arg1) {backend_config = "", mhlo.sharding = "{devices=[8]<=[8]}"} : (tensor<128xf32>) -> tensor<128xf32> + %5 = stablehlo.custom_call @SPMDFullToShardShape(%4) {backend_config = "", mhlo.sharding = "{manual}"} : (tensor<128xf32>) -> tensor<16xf32> + %6 = call @shmap_body(%1, %3, %5) : (tensor<32x98xf32>, tensor<98x128xf32>, tensor<16xf32>) -> tensor<32x16xf32> + %7 = stablehlo.custom_call @Sharding(%6) {backend_config = "", mhlo.sharding = "{manual}"} : (tensor<32x16xf32>) -> tensor<32x16xf32> + %8 = stablehlo.custom_call @SPMDShardToFullShape(%7) {backend_config = "", mhlo.sharding = "{devices=[1,8]<=[8]}"} : (tensor<32x16xf32>) -> tensor<32x128xf32> + %9 = call @relu(%8) : (tensor<32x128xf32>) -> tensor<32x128xf32> + %10 = stablehlo.custom_call @Sharding(%9) {backend_config = "", mhlo.sharding = "{devices=[1,8]<=[8]}"} : (tensor<32x128xf32>) -> tensor<32x128xf32> + %11 = stablehlo.custom_call @SPMDFullToShardShape(%10) {backend_config = "", mhlo.sharding = "{manual}"} : (tensor<32x128xf32>) -> tensor<32x16xf32> + %12 = stablehlo.custom_call @Sharding(%arg2) {backend_config = "", mhlo.sharding = "{devices=[8,1]<=[8]}"} : (tensor<128x128xf32>) -> tensor<128x128xf32> + %13 = stablehlo.custom_call @SPMDFullToShardShape(%12) {backend_config = "", mhlo.sharding = "{manual}"} : (tensor<128x128xf32>) -> tensor<16x128xf32> + %14 = stablehlo.custom_call @Sharding(%arg3) {backend_config = "", mhlo.sharding = "{devices=[8]<=[8]}"} : (tensor<128xf32>) -> tensor<128xf32> + %15 = stablehlo.custom_call @SPMDFullToShardShape(%14) {backend_config = "", mhlo.sharding = "{manual}"} : (tensor<128xf32>) -> tensor<16xf32> + %16 = call @shmap_body_0(%11, %13, %15) : (tensor<32x16xf32>, tensor<16x128xf32>, tensor<16xf32>) -> tensor<32x16xf32> + %17 = stablehlo.custom_call @Sharding(%16) {backend_config = "", mhlo.sharding = "{manual}"} : (tensor<32x16xf32>) -> tensor<32x16xf32> + %18 = stablehlo.custom_call @SPMDShardToFullShape(%17) {backend_config = "", mhlo.sharding = "{devices=[1,8]<=[8]}"} : (tensor<32x16xf32>) -> tensor<32x128xf32> + %19 = call @relu(%18) : (tensor<32x128xf32>) -> tensor<32x128xf32> + %20 = stablehlo.custom_call @Sharding(%19) {backend_config = "", mhlo.sharding = "{devices=[1,8]<=[8]}"} : (tensor<32x128xf32>) -> tensor<32x128xf32> + %21 = stablehlo.custom_call @SPMDFullToShardShape(%20) {backend_config = "", mhlo.sharding = "{manual}"} : (tensor<32x128xf32>) -> tensor<32x16xf32> + %22 = stablehlo.custom_call @Sharding(%arg4) {backend_config = "", mhlo.sharding = "{devices=[8,1]<=[8]}"} : (tensor<128x128xf32>) -> tensor<128x128xf32> + %23 = stablehlo.custom_call @SPMDFullToShardShape(%22) {backend_config = "", mhlo.sharding = "{manual}"} : (tensor<128x128xf32>) -> tensor<16x128xf32> + %24 = stablehlo.custom_call @Sharding(%arg5) {backend_config = "", mhlo.sharding = "{devices=[8]<=[8]}"} : (tensor<128xf32>) -> tensor<128xf32> + %25 = stablehlo.custom_call @SPMDFullToShardShape(%24) {backend_config = "", mhlo.sharding = "{manual}"} : (tensor<128xf32>) -> tensor<16xf32> + %26 = call @shmap_body_1(%21, %23, %25) : (tensor<32x16xf32>, tensor<16x128xf32>, tensor<16xf32>) -> tensor<32x16xf32> + %27 = stablehlo.custom_call @Sharding(%26) {backend_config = "", mhlo.sharding = "{manual}"} : (tensor<32x16xf32>) -> tensor<32x16xf32> + %28 = stablehlo.custom_call @SPMDShardToFullShape(%27) {backend_config = "", mhlo.sharding = "{devices=[1,8]<=[8]}"} : (tensor<32x16xf32>) -> tensor<32x128xf32> + %29 = call @relu(%28) : (tensor<32x128xf32>) -> tensor<32x128xf32> + %30 = stablehlo.custom_call @Sharding(%29) {backend_config = "", mhlo.sharding = "{devices=[1,8]<=[8]}"} : (tensor<32x128xf32>) -> tensor<32x128xf32> + %31 = stablehlo.custom_call @SPMDFullToShardShape(%30) {backend_config = "", mhlo.sharding = "{manual}"} : (tensor<32x128xf32>) -> tensor<32x16xf32> + %32 = stablehlo.custom_call @Sharding(%arg6) {backend_config = "", mhlo.sharding = "{devices=[8,1]<=[8]}"} : (tensor<128x128xf32>) -> tensor<128x128xf32> + %33 = stablehlo.custom_call @SPMDFullToShardShape(%32) {backend_config = "", mhlo.sharding = "{manual}"} : (tensor<128x128xf32>) -> tensor<16x128xf32> + %34 = stablehlo.custom_call @Sharding(%arg7) {backend_config = "", mhlo.sharding = "{devices=[8]<=[8]}"} : (tensor<128xf32>) -> tensor<128xf32> + %35 = stablehlo.custom_call @SPMDFullToShardShape(%34) {backend_config = "", mhlo.sharding = "{manual}"} : (tensor<128xf32>) -> tensor<16xf32> + %36 = call @shmap_body_2(%31, %33, %35) : (tensor<32x16xf32>, tensor<16x128xf32>, tensor<16xf32>) -> tensor<32x16xf32> + %37 = stablehlo.custom_call @Sharding(%36) {backend_config = "", mhlo.sharding = "{manual}"} : (tensor<32x16xf32>) -> tensor<32x16xf32> + %38 = stablehlo.custom_call @SPMDShardToFullShape(%37) {backend_config = "", mhlo.sharding = "{devices=[1,8]<=[8]}"} : (tensor<32x16xf32>) -> tensor<32x128xf32> + %39 = call @relu(%38) : (tensor<32x128xf32>) -> tensor<32x128xf32> + %40 = stablehlo.custom_call @Sharding(%39) {backend_config = "", mhlo.sharding = "{devices=[1,8]<=[8]}"} : (tensor<32x128xf32>) -> tensor<32x128xf32> + %41 = stablehlo.custom_call @SPMDFullToShardShape(%40) {backend_config = "", mhlo.sharding = "{manual}"} : (tensor<32x128xf32>) -> tensor<32x16xf32> + %42 = stablehlo.custom_call @Sharding(%arg8) {backend_config = "", mhlo.sharding = "{devices=[8,1]<=[8]}"} : (tensor<128x128xf32>) -> tensor<128x128xf32> + %43 = stablehlo.custom_call @SPMDFullToShardShape(%42) {backend_config = "", mhlo.sharding = "{manual}"} : (tensor<128x128xf32>) -> tensor<16x128xf32> + %44 = stablehlo.custom_call @Sharding(%arg9) {backend_config = "", mhlo.sharding = "{devices=[8]<=[8]}"} : (tensor<128xf32>) -> tensor<128xf32> + %45 = stablehlo.custom_call @SPMDFullToShardShape(%44) {backend_config = "", mhlo.sharding = "{manual}"} : (tensor<128xf32>) -> tensor<16xf32> + %46 = call @shmap_body_3(%41, %43, %45) : (tensor<32x16xf32>, tensor<16x128xf32>, tensor<16xf32>) -> tensor<32x16xf32> + %47 = stablehlo.custom_call @Sharding(%46) {backend_config = "", mhlo.sharding = "{manual}"} : (tensor<32x16xf32>) -> tensor<32x16xf32> + %48 = stablehlo.custom_call @SPMDShardToFullShape(%47) {backend_config = "", mhlo.sharding = "{devices=[1,8]<=[8]}"} : (tensor<32x16xf32>) -> tensor<32x128xf32> + %49 = call @relu(%48) : (tensor<32x128xf32>) -> tensor<32x128xf32> + %50 = stablehlo.custom_call @Sharding(%49) {backend_config = "", mhlo.sharding = "{devices=[1,8]<=[8]}"} : (tensor<32x128xf32>) -> tensor<32x128xf32> + %51 = stablehlo.custom_call @SPMDFullToShardShape(%50) {backend_config = "", mhlo.sharding = "{manual}"} : (tensor<32x128xf32>) -> tensor<32x16xf32> + %52 = stablehlo.custom_call @Sharding(%arg10) {backend_config = "", mhlo.sharding = "{devices=[8,1]<=[8]}"} : (tensor<128x8xf32>) -> tensor<128x8xf32> + %53 = stablehlo.custom_call @SPMDFullToShardShape(%52) {backend_config = "", mhlo.sharding = "{manual}"} : (tensor<128x8xf32>) -> tensor<16x8xf32> + %54 = stablehlo.custom_call @Sharding(%arg11) {backend_config = "", mhlo.sharding = "{devices=[8]<=[8]}"} : (tensor<8xf32>) -> tensor<8xf32> + %55 = stablehlo.custom_call @SPMDFullToShardShape(%54) {backend_config = "", mhlo.sharding = "{manual}"} : (tensor<8xf32>) -> tensor<1xf32> + %56 = call @shmap_body_4(%51, %53, %55) : (tensor<32x16xf32>, tensor<16x8xf32>, tensor<1xf32>) -> tensor<32x1xf32> + %57 = stablehlo.custom_call @Sharding(%56) {backend_config = "", mhlo.sharding = "{manual}"} : (tensor<32x1xf32>) -> tensor<32x1xf32> + %58 = stablehlo.custom_call @SPMDShardToFullShape(%57) {backend_config = "", mhlo.sharding = "{devices=[1,8]<=[8]}"} : (tensor<32x1xf32>) -> tensor<32x8xf32> + %59 = stablehlo.subtract %58, %arg13 : tensor<32x8xf32> + %60 = stablehlo.multiply %59, %59 : tensor<32x8xf32> + %cst = stablehlo.constant dense<0.000000e+00> : tensor + %61 = stablehlo.reduce(%60 init: %cst) applies stablehlo.add across dimensions = [1] : (tensor<32x8xf32>, tensor) -> tensor<32xf32> + %cst_0 = stablehlo.constant dense<0.000000e+00> : tensor + %62 = stablehlo.reduce(%61 init: %cst_0) applies stablehlo.add across dimensions = [0] : (tensor<32xf32>, tensor) -> tensor + %cst_1 = stablehlo.constant dense<3.200000e+01> : tensor + %63 = stablehlo.divide %62, %cst_1 : tensor + return %63 : tensor + } + func.func private @shmap_body(%arg0: tensor<32x98xf32>, %arg1: tensor<98x128xf32>, %arg2: tensor<16xf32>) -> (tensor<32x16xf32> {jax.result_info = "[None, ('y',)]"}) { + %0 = stablehlo.dot_general %arg0, %arg1, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<32x98xf32>, tensor<98x128xf32>) -> tensor<32x128xf32> + %1 = "stablehlo.reduce_scatter"(%0) <{channel_handle = #stablehlo.channel_handle, replica_groups = dense<[[0, 1, 2, 3, 4, 5, 6, 7]]> : tensor<1x8xi64>, scatter_dimension = 1 : i64, use_global_device_ids}> ({ + ^bb0(%arg3: tensor, %arg4: tensor): + %5 = stablehlo.add %arg3, %arg4 : tensor + stablehlo.return %5 : tensor + }) : (tensor<32x128xf32>) -> tensor<32x16xf32> + %2 = stablehlo.broadcast_in_dim %arg2, dims = [1] : (tensor<16xf32>) -> tensor<1x16xf32> + %3 = stablehlo.broadcast_in_dim %2, dims = [0, 1] : (tensor<1x16xf32>) -> tensor<32x16xf32> + %4 = stablehlo.add %1, %3 : tensor<32x16xf32> + return %4 : tensor<32x16xf32> + } + func.func private @relu(%arg0: tensor<32x128xf32>) -> tensor<32x128xf32> { + %cst = stablehlo.constant dense<0.000000e+00> : tensor + %0 = stablehlo.broadcast_in_dim %cst, dims = [] : (tensor) -> tensor<32x128xf32> + %1 = stablehlo.maximum %arg0, %0 : tensor<32x128xf32> + return %1 : tensor<32x128xf32> + } + func.func private @shmap_body_0(%arg0: tensor<32x16xf32>, %arg1: tensor<16x128xf32>, %arg2: tensor<16xf32>) -> (tensor<32x16xf32> {jax.result_info = "[None, ('y',)]"}) { + %0 = stablehlo.dot_general %arg0, %arg1, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<32x16xf32>, tensor<16x128xf32>) -> tensor<32x128xf32> + %1 = "stablehlo.reduce_scatter"(%0) <{channel_handle = #stablehlo.channel_handle, replica_groups = dense<[[0, 1, 2, 3, 4, 5, 6, 7]]> : tensor<1x8xi64>, scatter_dimension = 1 : i64, use_global_device_ids}> ({ + ^bb0(%arg3: tensor, %arg4: tensor): + %5 = stablehlo.add %arg3, %arg4 : tensor + stablehlo.return %5 : tensor + }) : (tensor<32x128xf32>) -> tensor<32x16xf32> + %2 = stablehlo.broadcast_in_dim %arg2, dims = [1] : (tensor<16xf32>) -> tensor<1x16xf32> + %3 = stablehlo.broadcast_in_dim %2, dims = [0, 1] : (tensor<1x16xf32>) -> tensor<32x16xf32> + %4 = stablehlo.add %1, %3 : tensor<32x16xf32> + return %4 : tensor<32x16xf32> + } + func.func private @shmap_body_1(%arg0: tensor<32x16xf32>, %arg1: tensor<16x128xf32>, %arg2: tensor<16xf32>) -> (tensor<32x16xf32> {jax.result_info = "[None, ('y',)]"}) { + %0 = stablehlo.dot_general %arg0, %arg1, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<32x16xf32>, tensor<16x128xf32>) -> tensor<32x128xf32> + %1 = "stablehlo.reduce_scatter"(%0) <{channel_handle = #stablehlo.channel_handle, replica_groups = dense<[[0, 1, 2, 3, 4, 5, 6, 7]]> : tensor<1x8xi64>, scatter_dimension = 1 : i64, use_global_device_ids}> ({ + ^bb0(%arg3: tensor, %arg4: tensor): + %5 = stablehlo.add %arg3, %arg4 : tensor + stablehlo.return %5 : tensor + }) : (tensor<32x128xf32>) -> tensor<32x16xf32> + %2 = stablehlo.broadcast_in_dim %arg2, dims = [1] : (tensor<16xf32>) -> tensor<1x16xf32> + %3 = stablehlo.broadcast_in_dim %2, dims = [0, 1] : (tensor<1x16xf32>) -> tensor<32x16xf32> + %4 = stablehlo.add %1, %3 : tensor<32x16xf32> + return %4 : tensor<32x16xf32> + } + func.func private @shmap_body_2(%arg0: tensor<32x16xf32>, %arg1: tensor<16x128xf32>, %arg2: tensor<16xf32>) -> (tensor<32x16xf32> {jax.result_info = "[None, ('y',)]"}) { + %0 = stablehlo.dot_general %arg0, %arg1, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<32x16xf32>, tensor<16x128xf32>) -> tensor<32x128xf32> + %1 = "stablehlo.reduce_scatter"(%0) <{channel_handle = #stablehlo.channel_handle, replica_groups = dense<[[0, 1, 2, 3, 4, 5, 6, 7]]> : tensor<1x8xi64>, scatter_dimension = 1 : i64, use_global_device_ids}> ({ + ^bb0(%arg3: tensor, %arg4: tensor): + %5 = stablehlo.add %arg3, %arg4 : tensor + stablehlo.return %5 : tensor + }) : (tensor<32x128xf32>) -> tensor<32x16xf32> + %2 = stablehlo.broadcast_in_dim %arg2, dims = [1] : (tensor<16xf32>) -> tensor<1x16xf32> + %3 = stablehlo.broadcast_in_dim %2, dims = [0, 1] : (tensor<1x16xf32>) -> tensor<32x16xf32> + %4 = stablehlo.add %1, %3 : tensor<32x16xf32> + return %4 : tensor<32x16xf32> + } + func.func private @shmap_body_3(%arg0: tensor<32x16xf32>, %arg1: tensor<16x128xf32>, %arg2: tensor<16xf32>) -> (tensor<32x16xf32> {jax.result_info = "[None, ('y',)]"}) { + %0 = stablehlo.dot_general %arg0, %arg1, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<32x16xf32>, tensor<16x128xf32>) -> tensor<32x128xf32> + %1 = "stablehlo.reduce_scatter"(%0) <{channel_handle = #stablehlo.channel_handle, replica_groups = dense<[[0, 1, 2, 3, 4, 5, 6, 7]]> : tensor<1x8xi64>, scatter_dimension = 1 : i64, use_global_device_ids}> ({ + ^bb0(%arg3: tensor, %arg4: tensor): + %5 = stablehlo.add %arg3, %arg4 : tensor + stablehlo.return %5 : tensor + }) : (tensor<32x128xf32>) -> tensor<32x16xf32> + %2 = stablehlo.broadcast_in_dim %arg2, dims = [1] : (tensor<16xf32>) -> tensor<1x16xf32> + %3 = stablehlo.broadcast_in_dim %2, dims = [0, 1] : (tensor<1x16xf32>) -> tensor<32x16xf32> + %4 = stablehlo.add %1, %3 : tensor<32x16xf32> + return %4 : tensor<32x16xf32> + } + func.func private @shmap_body_4(%arg0: tensor<32x16xf32>, %arg1: tensor<16x8xf32>, %arg2: tensor<1xf32>) -> (tensor<32x1xf32> {jax.result_info = "[None, ('y',)]"}) { + %0 = stablehlo.dot_general %arg0, %arg1, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<32x16xf32>, tensor<16x8xf32>) -> tensor<32x8xf32> + %1 = "stablehlo.reduce_scatter"(%0) <{channel_handle = #stablehlo.channel_handle, replica_groups = dense<[[0, 1, 2, 3, 4, 5, 6, 7]]> : tensor<1x8xi64>, scatter_dimension = 1 : i64, use_global_device_ids}> ({ + ^bb0(%arg3: tensor, %arg4: tensor): + %5 = stablehlo.add %arg3, %arg4 : tensor + stablehlo.return %5 : tensor + }) : (tensor<32x8xf32>) -> tensor<32x1xf32> + %2 = stablehlo.broadcast_in_dim %arg2, dims = [1] : (tensor<1xf32>) -> tensor<1x1xf32> + %3 = stablehlo.broadcast_in_dim %2, dims = [0, 1] : (tensor<1x1xf32>) -> tensor<32x1xf32> + %4 = stablehlo.add %1, %3 : tensor<32x1xf32> + return %4 : tensor<32x1xf32> + } +} + + +// CHECK-LABEL @main diff --git a/test/ttmlir/Conversion/StableHLOToTTIR/ccl/e2e_tp_shardy.mlir b/test/ttmlir/Conversion/StableHLOToTTIR/ccl/e2e_tp_shardy.mlir new file mode 100644 index 0000000000..f2e3b07be8 --- /dev/null +++ b/test/ttmlir/Conversion/StableHLOToTTIR/ccl/e2e_tp_shardy.mlir @@ -0,0 +1,103 @@ +// REQUIRES: stablehlo +// RUN: ttmlir-opt -split-input-file --stablehlo-to-ttir-pipeline %s | FileCheck %s +// UNSUPPORTED: true + +module @jit_loss_tp attributes {mhlo.num_partitions = 8 : i32, mhlo.num_replicas = 1 : i32} { + sdy.mesh @mesh = <["x"=1, "y"=8]> + func.func public @main(%arg0: tensor<784x128xf32> {sdy.sharding = #sdy.sharding<@mesh, [{"y"}, {}]>}, %arg1: tensor<128xf32> {sdy.sharding = #sdy.sharding<@mesh, [{"y"}]>}, %arg2: tensor<128x128xf32> {sdy.sharding = #sdy.sharding<@mesh, [{"y"}, {}]>}, %arg3: tensor<128xf32> {sdy.sharding = #sdy.sharding<@mesh, [{"y"}]>}, %arg4: tensor<128x128xf32> {sdy.sharding = #sdy.sharding<@mesh, [{"y"}, {}]>}, %arg5: tensor<128xf32> {sdy.sharding = #sdy.sharding<@mesh, [{"y"}]>}, %arg6: tensor<128x128xf32> {sdy.sharding = #sdy.sharding<@mesh, [{"y"}, {}]>}, %arg7: tensor<128xf32> {sdy.sharding = #sdy.sharding<@mesh, [{"y"}]>}, %arg8: tensor<128x128xf32> {sdy.sharding = #sdy.sharding<@mesh, [{"y"}, {}]>}, %arg9: tensor<128xf32> {sdy.sharding = #sdy.sharding<@mesh, [{"y"}]>}, %arg10: tensor<128x8xf32> {sdy.sharding = #sdy.sharding<@mesh, [{"y"}, {}]>}, %arg11: tensor<8xf32> {sdy.sharding = #sdy.sharding<@mesh, [{"y"}]>}, %arg12: tensor<32x784xf32> {sdy.sharding = #sdy.sharding<@mesh, [{}, {"y"}]>}, %arg13: tensor<32x8xf32> {sdy.sharding = #sdy.sharding<@mesh, [{}, {"y"}]>}) -> (tensor {jax.result_info = ""}) { + %0 = sdy.manual_computation(%arg12, %arg0, %arg1) in_shardings=[<@mesh, [{}, {"y"}]>, <@mesh, [{"y"}, {}]>, <@mesh, [{"y"}]>] out_shardings=[<@mesh, [{}, {"y"}]>] manual_axes={"x", "y"} (%arg14: tensor<32x98xf32>, %arg15: tensor<98x128xf32>, %arg16: tensor<16xf32>) { + %16 = stablehlo.dot_general %arg14, %arg15, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<32x98xf32>, tensor<98x128xf32>) -> tensor<32x128xf32> + %17 = "stablehlo.reduce_scatter"(%16) <{channel_handle = #stablehlo.channel_handle, replica_groups = dense<[[0, 1, 2, 3, 4, 5, 6, 7]]> : tensor<1x8xi64>, scatter_dimension = 1 : i64, use_global_device_ids}> ({ + ^bb0(%arg17: tensor, %arg18: tensor): + %21 = stablehlo.add %arg17, %arg18 : tensor + stablehlo.return %21 : tensor + }) : (tensor<32x128xf32>) -> tensor<32x16xf32> + %18 = stablehlo.broadcast_in_dim %arg16, dims = [1] : (tensor<16xf32>) -> tensor<1x16xf32> + %19 = stablehlo.broadcast_in_dim %18, dims = [0, 1] : (tensor<1x16xf32>) -> tensor<32x16xf32> + %20 = stablehlo.add %17, %19 : tensor<32x16xf32> + sdy.return %20 : tensor<32x16xf32> + } : (tensor<32x784xf32>, tensor<784x128xf32>, tensor<128xf32>) -> tensor<32x128xf32> + %1 = call @relu(%0) : (tensor<32x128xf32>) -> tensor<32x128xf32> + %2 = sdy.manual_computation(%1, %arg2, %arg3) in_shardings=[<@mesh, [{}, {"y"}]>, <@mesh, [{"y"}, {}]>, <@mesh, [{"y"}]>] out_shardings=[<@mesh, [{}, {"y"}]>] manual_axes={"x", "y"} (%arg14: tensor<32x16xf32>, %arg15: tensor<16x128xf32>, %arg16: tensor<16xf32>) { + %16 = stablehlo.dot_general %arg14, %arg15, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<32x16xf32>, tensor<16x128xf32>) -> tensor<32x128xf32> + %17 = "stablehlo.reduce_scatter"(%16) <{channel_handle = #stablehlo.channel_handle, replica_groups = dense<[[0, 1, 2, 3, 4, 5, 6, 7]]> : tensor<1x8xi64>, scatter_dimension = 1 : i64, use_global_device_ids}> ({ + ^bb0(%arg17: tensor, %arg18: tensor): + %21 = stablehlo.add %arg17, %arg18 : tensor + stablehlo.return %21 : tensor + }) : (tensor<32x128xf32>) -> tensor<32x16xf32> + %18 = stablehlo.broadcast_in_dim %arg16, dims = [1] : (tensor<16xf32>) -> tensor<1x16xf32> + %19 = stablehlo.broadcast_in_dim %18, dims = [0, 1] : (tensor<1x16xf32>) -> tensor<32x16xf32> + %20 = stablehlo.add %17, %19 : tensor<32x16xf32> + sdy.return %20 : tensor<32x16xf32> + } : (tensor<32x128xf32>, tensor<128x128xf32>, tensor<128xf32>) -> tensor<32x128xf32> + %3 = call @relu(%2) : (tensor<32x128xf32>) -> tensor<32x128xf32> + %4 = sdy.manual_computation(%3, %arg4, %arg5) in_shardings=[<@mesh, [{}, {"y"}]>, <@mesh, [{"y"}, {}]>, <@mesh, [{"y"}]>] out_shardings=[<@mesh, [{}, {"y"}]>] manual_axes={"x", "y"} (%arg14: tensor<32x16xf32>, %arg15: tensor<16x128xf32>, %arg16: tensor<16xf32>) { + %16 = stablehlo.dot_general %arg14, %arg15, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<32x16xf32>, tensor<16x128xf32>) -> tensor<32x128xf32> + %17 = "stablehlo.reduce_scatter"(%16) <{channel_handle = #stablehlo.channel_handle, replica_groups = dense<[[0, 1, 2, 3, 4, 5, 6, 7]]> : tensor<1x8xi64>, scatter_dimension = 1 : i64, use_global_device_ids}> ({ + ^bb0(%arg17: tensor, %arg18: tensor): + %21 = stablehlo.add %arg17, %arg18 : tensor + stablehlo.return %21 : tensor + }) : (tensor<32x128xf32>) -> tensor<32x16xf32> + %18 = stablehlo.broadcast_in_dim %arg16, dims = [1] : (tensor<16xf32>) -> tensor<1x16xf32> + %19 = stablehlo.broadcast_in_dim %18, dims = [0, 1] : (tensor<1x16xf32>) -> tensor<32x16xf32> + %20 = stablehlo.add %17, %19 : tensor<32x16xf32> + sdy.return %20 : tensor<32x16xf32> + } : (tensor<32x128xf32>, tensor<128x128xf32>, tensor<128xf32>) -> tensor<32x128xf32> + %5 = call @relu(%4) : (tensor<32x128xf32>) -> tensor<32x128xf32> + %6 = sdy.manual_computation(%5, %arg6, %arg7) in_shardings=[<@mesh, [{}, {"y"}]>, <@mesh, [{"y"}, {}]>, <@mesh, [{"y"}]>] out_shardings=[<@mesh, [{}, {"y"}]>] manual_axes={"x", "y"} (%arg14: tensor<32x16xf32>, %arg15: tensor<16x128xf32>, %arg16: tensor<16xf32>) { + %16 = stablehlo.dot_general %arg14, %arg15, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<32x16xf32>, tensor<16x128xf32>) -> tensor<32x128xf32> + %17 = "stablehlo.reduce_scatter"(%16) <{channel_handle = #stablehlo.channel_handle, replica_groups = dense<[[0, 1, 2, 3, 4, 5, 6, 7]]> : tensor<1x8xi64>, scatter_dimension = 1 : i64, use_global_device_ids}> ({ + ^bb0(%arg17: tensor, %arg18: tensor): + %21 = stablehlo.add %arg17, %arg18 : tensor + stablehlo.return %21 : tensor + }) : (tensor<32x128xf32>) -> tensor<32x16xf32> + %18 = stablehlo.broadcast_in_dim %arg16, dims = [1] : (tensor<16xf32>) -> tensor<1x16xf32> + %19 = stablehlo.broadcast_in_dim %18, dims = [0, 1] : (tensor<1x16xf32>) -> tensor<32x16xf32> + %20 = stablehlo.add %17, %19 : tensor<32x16xf32> + sdy.return %20 : tensor<32x16xf32> + } : (tensor<32x128xf32>, tensor<128x128xf32>, tensor<128xf32>) -> tensor<32x128xf32> + %7 = call @relu(%6) : (tensor<32x128xf32>) -> tensor<32x128xf32> + %8 = sdy.manual_computation(%7, %arg8, %arg9) in_shardings=[<@mesh, [{}, {"y"}]>, <@mesh, [{"y"}, {}]>, <@mesh, [{"y"}]>] out_shardings=[<@mesh, [{}, {"y"}]>] manual_axes={"x", "y"} (%arg14: tensor<32x16xf32>, %arg15: tensor<16x128xf32>, %arg16: tensor<16xf32>) { + %16 = stablehlo.dot_general %arg14, %arg15, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<32x16xf32>, tensor<16x128xf32>) -> tensor<32x128xf32> + %17 = "stablehlo.reduce_scatter"(%16) <{channel_handle = #stablehlo.channel_handle, replica_groups = dense<[[0, 1, 2, 3, 4, 5, 6, 7]]> : tensor<1x8xi64>, scatter_dimension = 1 : i64, use_global_device_ids}> ({ + ^bb0(%arg17: tensor, %arg18: tensor): + %21 = stablehlo.add %arg17, %arg18 : tensor + stablehlo.return %21 : tensor + }) : (tensor<32x128xf32>) -> tensor<32x16xf32> + %18 = stablehlo.broadcast_in_dim %arg16, dims = [1] : (tensor<16xf32>) -> tensor<1x16xf32> + %19 = stablehlo.broadcast_in_dim %18, dims = [0, 1] : (tensor<1x16xf32>) -> tensor<32x16xf32> + %20 = stablehlo.add %17, %19 : tensor<32x16xf32> + sdy.return %20 : tensor<32x16xf32> + } : (tensor<32x128xf32>, tensor<128x128xf32>, tensor<128xf32>) -> tensor<32x128xf32> + %9 = call @relu(%8) : (tensor<32x128xf32>) -> tensor<32x128xf32> + %10 = sdy.manual_computation(%9, %arg10, %arg11) in_shardings=[<@mesh, [{}, {"y"}]>, <@mesh, [{"y"}, {}]>, <@mesh, [{"y"}]>] out_shardings=[<@mesh, [{}, {"y"}]>] manual_axes={"x", "y"} (%arg14: tensor<32x16xf32>, %arg15: tensor<16x8xf32>, %arg16: tensor<1xf32>) { + %16 = stablehlo.dot_general %arg14, %arg15, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<32x16xf32>, tensor<16x8xf32>) -> tensor<32x8xf32> + %17 = "stablehlo.reduce_scatter"(%16) <{channel_handle = #stablehlo.channel_handle, replica_groups = dense<[[0, 1, 2, 3, 4, 5, 6, 7]]> : tensor<1x8xi64>, scatter_dimension = 1 : i64, use_global_device_ids}> ({ + ^bb0(%arg17: tensor, %arg18: tensor): + %21 = stablehlo.add %arg17, %arg18 : tensor + stablehlo.return %21 : tensor + }) : (tensor<32x8xf32>) -> tensor<32x1xf32> + %18 = stablehlo.broadcast_in_dim %arg16, dims = [1] : (tensor<1xf32>) -> tensor<1x1xf32> + %19 = stablehlo.broadcast_in_dim %18, dims = [0, 1] : (tensor<1x1xf32>) -> tensor<32x1xf32> + %20 = stablehlo.add %17, %19 : tensor<32x1xf32> + sdy.return %20 : tensor<32x1xf32> + } : (tensor<32x128xf32>, tensor<128x8xf32>, tensor<8xf32>) -> tensor<32x8xf32> + %11 = stablehlo.subtract %10, %arg13 : tensor<32x8xf32> + %12 = stablehlo.multiply %11, %11 : tensor<32x8xf32> + %cst = stablehlo.constant dense<0.000000e+00> : tensor + %13 = stablehlo.reduce(%12 init: %cst) applies stablehlo.add across dimensions = [1] : (tensor<32x8xf32>, tensor) -> tensor<32xf32> + %cst_0 = stablehlo.constant dense<0.000000e+00> : tensor + %14 = stablehlo.reduce(%13 init: %cst_0) applies stablehlo.add across dimensions = [0] : (tensor<32xf32>, tensor) -> tensor + %cst_1 = stablehlo.constant dense<3.200000e+01> : tensor + %15 = stablehlo.divide %14, %cst_1 : tensor + return %15 : tensor + } + func.func private @relu(%arg0: tensor<32x128xf32>) -> tensor<32x128xf32> { + %cst = stablehlo.constant dense<0.000000e+00> : tensor + %0 = stablehlo.broadcast_in_dim %cst, dims = [] : (tensor) -> tensor<32x128xf32> + %1 = stablehlo.maximum %arg0, %0 : tensor<32x128xf32> + return %1 : tensor<32x128xf32> + } +} + +// CHECK-LABEL @main diff --git a/test/ttmlir/Dialect/TTNN/ccl/mesh_shard.mlir b/test/ttmlir/Dialect/TTNN/ccl/mesh_shard.mlir index ab9fab7ab5..9f688b1ec5 100644 --- a/test/ttmlir/Dialect/TTNN/ccl/mesh_shard.mlir +++ b/test/ttmlir/Dialect/TTNN/ccl/mesh_shard.mlir @@ -1,10 +1,10 @@ // RUN: ttmlir-opt --ttir-to-ttnn-backend-pipeline %s | FileCheck %s -module attributes {} { - func.func @forward(%arg0: tensor<8192x784xf32>) -> tensor<4096x196xf32> { - %0 = tensor.empty() : tensor<4096x196xf32> - %1 = "ttir.mesh_shard"(%arg0, %0) <{shard_dims = array, shard_direction = #tt.shard_direction, shard_shape = array, shard_type = #tt.shard_type}> : (tensor<8192x784xf32>, tensor<4096x196xf32>) -> tensor<4096x196xf32> - return %1 : tensor<4096x196xf32> +module @mesh_shard_test attributes {tt.meshes = #tt.meshes<[<"mesh" = 1x1>]>} { + func.func @forward(%arg0: tensor<8192x784xf32>) -> tensor<8192x392xf32> { + %0 = tensor.empty() : tensor<8192x392xf32> + %1 = "ttir.mesh_shard"(%arg0, %0) <{shard_dims = array, shard_direction = #tt.shard_direction, shard_shape = array, shard_type = #tt.shard_type}> : (tensor<8192x784xf32>, tensor<8192x392xf32>) -> tensor<8192x392xf32> + return %1 : tensor<8192x392xf32> } } @@ -12,5 +12,5 @@ module attributes {} { // CHECK-NEXT: [[REG:.*]] = "ttnn.mesh_shard"([[ARG:.*]], [[DEVICE]]) // CHECK-SAME: shard_dims = array // CHECK-SAME: shard_direction = #tt.shard_direction -// CHECK-SAME: shard_shape = array +// CHECK-SAME: shard_shape = array // CHECK-SAME: shard_type = #tt.shard_type