-
Notifications
You must be signed in to change notification settings - Fork 16
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
### Ticket Shardy support (#2019) ### Problem description TT-xla with shardy partitioner requires shardy dialect conversion. ### What's changed - Added initial support of shardy dialect conversion - Parsed mesh info added to module attribute - Initial support of detecting pre-sharded input tensor from frontend - Allowed Conversion of variadic input/output pairs in stablehlo all_reduce op - Changed detecting mechanism of dim in stablehlo all_reduce op - Addressed pending code restructuring and return type change issues ### Checklist - [ o ] New/Existing tests provide coverage for changes
- Loading branch information
Showing
27 changed files
with
2,513 additions
and
282 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,22 @@ | ||
// SPDX-FileCopyrightText: (c) 2024 Tenstorrent AI ULC | ||
// | ||
// SPDX-License-Identifier: Apache-2.0 | ||
|
||
#ifndef TTMLIR_CONVERSION_STABLEHLOTOTTIR_SHARDYTOTTIR_H | ||
#define TTMLIR_CONVERSION_STABLEHLOTOTTIR_SHARDYTOTTIR_H | ||
|
||
#include "mlir/Pass/Pass.h" | ||
#include "mlir/Transforms/DialectConversion.h" | ||
|
||
namespace mlir::tt { | ||
|
||
#ifdef TTMLIR_ENABLE_STABLEHLO | ||
|
||
void populateShardyToTTIRPatterns(MLIRContext *ctx, RewritePatternSet &patterns, | ||
TypeConverter &typeConverter); | ||
|
||
#endif | ||
|
||
} // namespace mlir::tt | ||
|
||
#endif // TTMLIR_CONVERSION_STABLEHLOTOTTIR_SHARDYTOTTIR_H |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,66 @@ | ||
// SPDX-FileCopyrightText: (c) 2024 Tenstorrent AI ULC | ||
// | ||
// SPDX-License-Identifier: Apache-2.0 | ||
|
||
#ifndef TTMLIR_DIALECT_TT_UTILS_MESH_H | ||
#define TTMLIR_DIALECT_TT_UTILS_MESH_H | ||
|
||
#include "ttmlir/Dialect/TT/IR/TTOpsTypes.h" | ||
#include "ttmlir/Dialect/TTIR/IR/TTIROps.h" | ||
|
||
#include "mlir/IR/BuiltinOps.h" | ||
#include "llvm/Support/Error.h" | ||
#include "llvm/Support/ErrorHandling.h" | ||
|
||
namespace mlir::tt::utils { | ||
|
||
// Add a new mesh info to module attribute. | ||
inline void addMeshToModuleAttribute(PatternRewriter &rewriter, | ||
mlir::ModuleOp module, StringAttr meshName, | ||
llvm::ArrayRef<int64_t> meshShape) { | ||
MLIRContext *context = rewriter.getContext(); | ||
llvm::SmallVector<tt::MeshAttr> meshes; | ||
if (auto meshesAttr = | ||
module->getAttrOfType<tt::MeshesAttr>(tt::MeshesAttr::name)) { | ||
meshes = llvm::SmallVector<tt::MeshAttr>(meshesAttr.getMeshes()); | ||
} | ||
// Avoid adding multiple meshes with the same name and shape as GSPMD may try | ||
// to add the same meshes. | ||
if (llvm::all_of(meshes, | ||
[&](tt::MeshAttr m) { return m.getName() != meshName; })) { | ||
meshes.push_back(mlir::tt::MeshAttr::get(context, meshName, meshShape)); | ||
rewriter.modifyOpInPlace(module, [&]() { | ||
module->setAttr(tt::MeshesAttr::name, | ||
tt::MeshesAttr::get(context, meshes)); | ||
}); | ||
} | ||
} | ||
|
||
// Determine hardware mesh config for DeviceAttr. | ||
// If none exists, the empty meshShape leads to single device config. | ||
// If either option.meshShape or meshes exists, use one of them. | ||
// If both exist, compare mesh and throw error if they are different. | ||
inline llvm::Expected<llvm::SmallVector<int64_t>> | ||
determineMeshShape(mlir::ModuleOp module, llvm::ArrayRef<int64_t> meshShape) { | ||
if (auto meshesAttr = | ||
module->getAttrOfType<tt::MeshesAttr>(tt::MeshesAttr::name)) { | ||
llvm::ArrayRef<MeshAttr> meshAttr = meshesAttr.getMeshes(); | ||
if (meshAttr.empty()) { | ||
return llvm::SmallVector<int64_t>(meshShape); | ||
} | ||
// For now, use the first meshShape. | ||
llvm::ArrayRef<int64_t> meshFromMeshes = meshAttr[0].getShape(); | ||
// If both meshes exist, they should be identical. Otherwise, throw error. | ||
if (!meshShape.empty() && !llvm::equal(meshShape, meshFromMeshes)) { | ||
return llvm::createStringError( | ||
std::errc::invalid_argument, | ||
"Option.meshShape and mesh info from graph should be identical."); | ||
} | ||
return llvm::SmallVector<int64_t>(meshFromMeshes); | ||
} | ||
return llvm::SmallVector<int64_t>(meshShape); | ||
} | ||
|
||
} // namespace mlir::tt::utils | ||
|
||
#endif // TTMLIR_DIALECT_TT_UTILS_MESH_H |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.