Skip to content

Commit

Permalink
Initial Shardy integration. (#2149)
Browse files Browse the repository at this point in the history
### Ticket
Shardy support (#2019)

### Problem description
TT-xla with shardy partitioner requires shardy dialect conversion. 

### What's changed
- Added initial support of shardy dialect conversion
- Parsed mesh info added to module attribute
- Initial support of detecting pre-sharded input tensor from frontend
- Allowed Conversion of variadic input/output pairs in stablehlo
all_reduce op
- Changed detecting mechanism of dim in stablehlo all_reduce op
- Addressed pending code restructuring and return type change issues


### Checklist
- [ o ] New/Existing tests provide coverage for changes
  • Loading branch information
wooseokTT authored Feb 21, 2025
1 parent 115db72 commit b10b497
Show file tree
Hide file tree
Showing 27 changed files with 2,513 additions and 282 deletions.
80 changes: 71 additions & 9 deletions include/ttmlir/Conversion/StableHLOToTTIR/ShardingUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -8,20 +8,82 @@
#include "ttmlir/Dialect/TT/IR/TT.h"
#include "ttmlir/Dialect/TT/IR/TTOpsTypes.h"

#include "mlir/IR/BuiltinOps.h"
#include "shardy/dialect/sdy/ir/dialect.h"
#include "llvm/Support/Error.h"
#include "llvm/Support/ErrorHandling.h"

namespace mlir::tt::sharding_utils {

#if TTMLIR_ENABLE_STABLEHLO
struct MeshSharding {
mlir::tt::MeshShardDirection shardDirection;
mlir::tt::MeshShardType shardType;
bool lastTileDimReplicate;
llvm::SmallVector<int64_t> shardShape;
llvm::SmallVector<int64_t> shardDims;
llvm::SmallVector<int64_t> meshShape;

class MeshSharding {
public:
MeshSharding() {};
~MeshSharding() {};

// Convert mhlo.sharding string to meshSharding.
llvm::Expected<bool>
convertGSPMDShardingToMeshSharding(StringRef shardingStr);

// Convert sdy.sharding to meshSharding.
llvm::Expected<bool>
convertSdyShardingToMeshSharding(mlir::sdy::TensorShardingAttr sdySharding,
mlir::sdy::MeshAttr mesh,
mlir::tt::MeshShardDirection direction);

// Force dummy sharding op by setting shard_type to manual. The mesh_shard op
// will be ignored at runtime by simply copying input tensor to output.
void setDummyShardingOp() { shardType = mlir::tt::MeshShardType::Manual; }

// Getter functions.
mlir::tt::MeshShardDirection getShardDirection() const {
return shardDirection;
}
mlir::tt::MeshShardType getShardType() const { return shardType; }
llvm::ArrayRef<int64_t> getShardShape() const { return shardShape; }
llvm::ArrayRef<int64_t> getShardDims() const { return shardDims; }
llvm::ArrayRef<int64_t> getMeshShape() const { return meshShape; }

private:
// Parse GSPMD devices string and fill out MeshSharding info.
llvm::Expected<bool> parseGSPMDDevicesStr(StringRef devicesStr);

// Based on current MeshSharding info, finalize sharding dimensions.
llvm::Expected<bool> determineGSPMDShardingDims();

// Set sharyType other than devices and reset values.
void setNonDevicesShardType(tt::MeshShardType targetShardType) {
assert(targetShardType != tt::MeshShardType::Devices);
shardType = targetShardType;
// Specific values are required to fill corresponding attributes in
// mesh_shard operation.
shardShape = llvm::SmallVector<int64_t>{1};
shardDims = llvm::SmallVector<int64_t>{-1};
meshShape = llvm::SmallVector<int64_t>{-1};
}

private:
mlir::tt::MeshShardDirection shardDirection =
mlir::tt::MeshShardDirection::ShardToFull;
mlir::tt::MeshShardType shardType = mlir::tt::MeshShardType::Manual;
llvm::SmallVector<int64_t> shardShape{-1};
llvm::SmallVector<int64_t> shardDims{-1};
llvm::SmallVector<int64_t> meshShape{-1};
llvm::SmallVector<int64_t> deviceIds{-1};
bool lastTileDimReplicate = false;
};

LogicalResult parseGSPMDShardingAttr(StringRef shardingStr,
MeshSharding &meshSharding);
// Sharding related string definitions from open-xla
// https://github.com/openxla/xla/blob/main/xla/service/spmd/shardy/constants.h

inline constexpr llvm::StringRef kShardingCustomCallTargetName = "Sharding";
inline constexpr llvm::StringRef kSPMDFullToShardShapeCallTargetName =
"SPMDFullToShardShape";
inline constexpr llvm::StringRef kSPMDShardToFullShapeCallTargetName =
"SPMDShardToFullShape";
inline constexpr llvm::StringRef kXlaShardingAttr = "mhlo.sharding";

#endif

} // namespace mlir::tt::sharding_utils
Expand Down
22 changes: 22 additions & 0 deletions include/ttmlir/Conversion/StableHLOToTTIR/ShardyToTTIR.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
// SPDX-FileCopyrightText: (c) 2024 Tenstorrent AI ULC
//
// SPDX-License-Identifier: Apache-2.0

#ifndef TTMLIR_CONVERSION_STABLEHLOTOTTIR_SHARDYTOTTIR_H
#define TTMLIR_CONVERSION_STABLEHLOTOTTIR_SHARDYTOTTIR_H

#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/DialectConversion.h"

namespace mlir::tt {

#ifdef TTMLIR_ENABLE_STABLEHLO

void populateShardyToTTIRPatterns(MLIRContext *ctx, RewritePatternSet &patterns,
TypeConverter &typeConverter);

#endif

} // namespace mlir::tt

#endif // TTMLIR_CONVERSION_STABLEHLOTOTTIR_SHARDYTOTTIR_H
19 changes: 19 additions & 0 deletions include/ttmlir/Dialect/TT/IR/TTOpsTypes.td
Original file line number Diff line number Diff line change
Expand Up @@ -474,6 +474,25 @@ def TT_MeshShardTypeAttr : EnumAttr<TT_Dialect, TT_MeshShardType, "shard_type">
let assemblyFormat = "`<` $value `>`";
}

def TT_MeshAttr : TT_Attr<"Mesh", "mesh", []> {
let summary = "Mesh reference attribute in TT dialect.";
let description = [{
Describes a mesh config including name and shape.
}];
let parameters = (ins "StringAttr":$name,
ArrayRefParameter<"int64_t">:$shape);
let assemblyFormat = "`<` $name `=` custom<DimensionList>($shape) `>`";
}

def TT_MeshesAttr : TT_Attr<"Meshes", "meshes"> {
let summary = "TT system meshes attribute";
let description = [{
TT system meshes attribute that can include multiple mesh configs used for networks.
}];
let parameters = (ins ArrayRefParameter<"MeshAttr">:$meshes);
let assemblyFormat = "`<` `[` $meshes `]` `>`";
}

//===----------------------------------------------------------------------===//
// TT type definitions
//===----------------------------------------------------------------------===//
Expand Down
66 changes: 66 additions & 0 deletions include/ttmlir/Dialect/TT/Utils/Mesh.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
// SPDX-FileCopyrightText: (c) 2024 Tenstorrent AI ULC
//
// SPDX-License-Identifier: Apache-2.0

#ifndef TTMLIR_DIALECT_TT_UTILS_MESH_H
#define TTMLIR_DIALECT_TT_UTILS_MESH_H

#include "ttmlir/Dialect/TT/IR/TTOpsTypes.h"
#include "ttmlir/Dialect/TTIR/IR/TTIROps.h"

#include "mlir/IR/BuiltinOps.h"
#include "llvm/Support/Error.h"
#include "llvm/Support/ErrorHandling.h"

namespace mlir::tt::utils {

// Add a new mesh info to module attribute.
inline void addMeshToModuleAttribute(PatternRewriter &rewriter,
mlir::ModuleOp module, StringAttr meshName,
llvm::ArrayRef<int64_t> meshShape) {
MLIRContext *context = rewriter.getContext();
llvm::SmallVector<tt::MeshAttr> meshes;
if (auto meshesAttr =
module->getAttrOfType<tt::MeshesAttr>(tt::MeshesAttr::name)) {
meshes = llvm::SmallVector<tt::MeshAttr>(meshesAttr.getMeshes());
}
// Avoid adding multiple meshes with the same name and shape as GSPMD may try
// to add the same meshes.
if (llvm::all_of(meshes,
[&](tt::MeshAttr m) { return m.getName() != meshName; })) {
meshes.push_back(mlir::tt::MeshAttr::get(context, meshName, meshShape));
rewriter.modifyOpInPlace(module, [&]() {
module->setAttr(tt::MeshesAttr::name,
tt::MeshesAttr::get(context, meshes));
});
}
}

// Determine hardware mesh config for DeviceAttr.
// If none exists, the empty meshShape leads to single device config.
// If either option.meshShape or meshes exists, use one of them.
// If both exist, compare mesh and throw error if they are different.
inline llvm::Expected<llvm::SmallVector<int64_t>>
determineMeshShape(mlir::ModuleOp module, llvm::ArrayRef<int64_t> meshShape) {
if (auto meshesAttr =
module->getAttrOfType<tt::MeshesAttr>(tt::MeshesAttr::name)) {
llvm::ArrayRef<MeshAttr> meshAttr = meshesAttr.getMeshes();
if (meshAttr.empty()) {
return llvm::SmallVector<int64_t>(meshShape);
}
// For now, use the first meshShape.
llvm::ArrayRef<int64_t> meshFromMeshes = meshAttr[0].getShape();
// If both meshes exist, they should be identical. Otherwise, throw error.
if (!meshShape.empty() && !llvm::equal(meshShape, meshFromMeshes)) {
return llvm::createStringError(
std::errc::invalid_argument,
"Option.meshShape and mesh info from graph should be identical.");
}
return llvm::SmallVector<int64_t>(meshFromMeshes);
}
return llvm::SmallVector<int64_t>(meshShape);
}

} // namespace mlir::tt::utils

#endif // TTMLIR_DIALECT_TT_UTILS_MESH_H
5 changes: 3 additions & 2 deletions include/ttmlir/Dialect/TTIR/IR/TTIROps.td
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ include "mlir/Interfaces/InferTypeOpInterface.td"
include "mlir/Interfaces/DestinationStyleOpInterface.td"
include "mlir/Interfaces/ControlFlowInterfaces.td"
include "mlir/Interfaces/SideEffectInterfaces.td"
include "mlir/IR/BuiltinAttributes.td"
include "mlir/IR/CommonTypeConstraints.td"
include "mlir/IR/CommonAttrConstraints.td"
include "mlir/IR/OpBase.td"
Expand Down Expand Up @@ -1922,7 +1923,7 @@ def TTIR_AllReduceOp : TTIR_DPSOp<"all_reduce"> {
AllReduce op.
}];

let arguments = (ins Variadic<AnyRankedTensor>:$inputs,
let arguments = (ins AnyRankedTensor:$input,
AnyRankedTensor:$output,
I64ElementsAttr:$replica_groups,
SI32Attr:$dim,
Expand All @@ -1931,7 +1932,7 @@ def TTIR_AllReduceOp : TTIR_DPSOp<"all_reduce"> {
TT_ReduceTypeAttr:$reduce_type
);

let results = (outs Variadic<AnyRankedTensor>:$results);
let results = (outs AnyRankedTensor:$result);

let extraClassDeclaration = [{
MutableOperandRange getDpsInitsMutable() { return getOutputMutable(); }
Expand Down
1 change: 1 addition & 0 deletions lib/Conversion/StableHLOToTTIR/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ add_mlir_conversion_library(TTMLIRStableHLOToTTIR
ShardingUtils.cpp
StableHLOToTTIRPass.cpp
StableHLOToTTIRPatterns.cpp
ShardyToTTIRPatterns.cpp

ADDITIONAL_HEADER_DIRS
${PROJECT_SOURCE_DIR}/include/ttmlir/Conversion/StableHLOToTTIR
Expand Down
Loading

0 comments on commit b10b497

Please sign in to comment.