Skip to content

Commit

Permalink
[Optimizer] Refactor legal gird analysis (#1363)
Browse files Browse the repository at this point in the history
* Rename to LegalLayoutAnalysis
* Add generation of both TILE and ROW_MAJOR layouts
* Keep ROW_MAJOR disabled by default and add flag to enable when needed
* Explicitly generate all layout related parameters (except for data format).
* Adjust overrides to support partial output layout instead of every param
* Clean up optimizer mlir tests
  • Loading branch information
odjuricicTT authored Dec 10, 2024
1 parent 0c7c2ca commit f2f2e97
Show file tree
Hide file tree
Showing 38 changed files with 772 additions and 508 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@
//
// SPDX-License-Identifier: Apache-2.0

#ifndef TTMLIR_DIALECT_TTNN_ANALYSIS_LEGALGRIDANALYSIS_H
#define TTMLIR_DIALECT_TTNN_ANALYSIS_LEGALGRIDANALYSIS_H
#ifndef TTMLIR_DIALECT_TTNN_ANALYSIS_LEGALLAYOUTANALYSIS_H
#define TTMLIR_DIALECT_TTNN_ANALYSIS_LEGALLAYOUTANALYSIS_H

#include "ttmlir/Dialect/TT/IR/TTOpsTypes.h"
#include "ttmlir/Dialect/TTNN/Analysis/TTNNAnalysis.h"
Expand All @@ -12,46 +12,49 @@

namespace mlir::tt::ttnn {

struct LegalGridAnalysisInput {
struct LegalLayoutAnalysisInput {
ChipDescAttr chipDesc;
GridAttr maxGrid;
RankedTensorType tensorType;
int64_t maxShardedGrids;
llvm::StringMap<OutputLayoutOverrideParams> *outputLayoutOverrides;
bool rowMajorEnabled;

LegalGridAnalysisInput()
LegalLayoutAnalysisInput()
: chipDesc(nullptr), maxGrid(nullptr), tensorType(nullptr),
outputLayoutOverrides(nullptr) {}

LegalGridAnalysisInput(
LegalLayoutAnalysisInput(
ChipDescAttr chipDesc, GridAttr maxGrid, RankedTensorType tensorType,
int64_t maxShardedGrids,
llvm::StringMap<OutputLayoutOverrideParams> *outputLayoutOverrides)
llvm::StringMap<OutputLayoutOverrideParams> *outputLayoutOverrides,
bool rowMajorEnabled)
: chipDesc(chipDesc), maxGrid(maxGrid), tensorType(tensorType),
maxShardedGrids(maxShardedGrids),
outputLayoutOverrides(outputLayoutOverrides) {}
outputLayoutOverrides(outputLayoutOverrides),
rowMajorEnabled(rowMajorEnabled) {}

bool operator==(const LegalGridAnalysisInput &rhs) const {
bool operator==(const LegalLayoutAnalysisInput &rhs) const {
return chipDesc == rhs.chipDesc && maxGrid == rhs.maxGrid &&
tensorType == rhs.tensorType &&
outputLayoutOverrides == rhs.outputLayoutOverrides;
}

bool operator!=(const LegalGridAnalysisInput &rhs) const {
bool operator!=(const LegalLayoutAnalysisInput &rhs) const {
return !(*this == rhs);
}
};

class LegalGridAnalysis
: public TTNNAnalysis<LegalGridAnalysisInput, std::vector<TTNNLayoutAttr>> {
class LegalLayoutAnalysis : public TTNNAnalysis<LegalLayoutAnalysisInput,
std::vector<TTNNLayoutAttr>> {
private:
void analysisImplementation() override;
bool applyOverrides() override;

public:
LegalGridAnalysis(Operation *op) : TTNNAnalysis(op) {}
LegalLayoutAnalysis(Operation *op) : TTNNAnalysis(op) {}
};

} // namespace mlir::tt::ttnn

#endif // TTMLIR_DIALECT_TTNN_ANALYSIS_LEGALGRIDANALYSIS_H
#endif // TTMLIR_DIALECT_TTNN_ANALYSIS_LEGALLAYOUTANALYSIS_H
14 changes: 7 additions & 7 deletions include/ttmlir/Dialect/TTNN/Analysis/OpConfigAnalysis.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,22 +11,22 @@
namespace mlir::tt::ttnn {

struct OpConfigAnalysisInput {
llvm::DenseMap<Operation *, std::vector<TTNNLayoutAttr>> legalGrids;
llvm::DenseMap<Operation *, std::vector<TTNNLayoutAttr>> legalLayouts;

OpConfigAnalysisInput() : legalGrids() {}
OpConfigAnalysisInput() : legalLayouts() {}

OpConfigAnalysisInput(
const llvm::DenseMap<Operation *, std::vector<TTNNLayoutAttr>>
&&legalGrids)
: legalGrids(std::move(legalGrids)) {}
&&legalLayouts)
: legalLayouts(std::move(legalLayouts)) {}

OpConfigAnalysisInput(
const llvm::DenseMap<Operation *, std::vector<TTNNLayoutAttr>>
&legalGrids)
: legalGrids(legalGrids) {}
&legalLayouts)
: legalLayouts(legalLayouts) {}

bool operator==(const OpConfigAnalysisInput &rhs) const {
return legalGrids == rhs.legalGrids;
return legalLayouts == rhs.legalLayouts;
}

bool operator!=(const OpConfigAnalysisInput &rhs) const {
Expand Down
1 change: 1 addition & 0 deletions include/ttmlir/Dialect/TTNN/IR/TTNNOpsAttrs.td
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,7 @@ def TTNN_TTNNLayoutAttr: TTNN_Attr<"TTNNLayout", "ttnn_layout"> {
Layout getLayout() const;
std::optional<TensorMemoryLayout> getMemLayoutOpt() const;
Type getElementType() const;
Type getScalarElementType() const;
uint64_t getShardSizeInBytes() const;
BufferType getBufferType() const;
DataType getDataType() const;
Expand Down
31 changes: 22 additions & 9 deletions include/ttmlir/Dialect/TTNN/Pipelines/TTNNPipelines.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ struct TTIRToTTNNBackendPipelineOptions
llvm::cl::desc("Determine and set max valid grid for Op execution."),
llvm::cl::init(false)};

// Option to manually insert TTIR_ToLayoutOp for specific op's operand.
// Option to manually insert TTNN_ToLayoutOp for specific op's operand.
// The format is a comma separated list of op names and operand index
// separated by ':' separator.
//
Expand All @@ -43,9 +43,13 @@ struct TTIRToTTNNBackendPipelineOptions
"Manually insert memory reconfig op for specific op's operand."),
llvm::cl::init(llvm::StringMap<InputLayoutOverrideParams>())};

// Option to override output layout for specific ops.
// The format is a comma separated list of op names equal to the output layout
// params separated by ":"
// Option to override output layout for specific operations. You can
// override any number or combination of layout parameters. If not all are
// overridden, the remaining ones will be inferred with all possible
// combinations generated in LegalLayoutAnalysis. The format is a
// comma-separated list of operation names followed by the output layout
// parameters, separated by :. The order of parameters does not matter; the
// parser will deduce which one is being overridden based on its value.
//
// op_name=grid_size:memory_space:tensor_memory_layout:memory_layout:data_type
//
Expand All @@ -58,7 +62,9 @@ struct TTIRToTTNNBackendPipelineOptions
// bfp_bf2, u32, u16, u8
//
// Full Example:
// "op1=2x2:dram:interleaved:tile:fp32,op2=4x4:l1:block_sharded:row_major:fp16"
// "op1=2x2:dram:interleaved:tile:fp32,op2=4x4:l1:block_sharded:row_major:f16"
// Partial Example:
// "op1=2x2:block_sharded"
//
//
// Note: This option is only valid if optimizerPassEnabled is true.
Expand Down Expand Up @@ -101,19 +107,26 @@ struct TTIRToTTNNBackendPipelineOptions
"Pass in a system descriptor flatbuffer to compile against."),
llvm::cl::init("")};

// Option to override maximum number of legal layouts for grid analysis
// Option to override maximum number of sharded layouts to be generated in
// legal layout analysis.
//
Option<int64_t> maxLegalLayouts{
*this, OptionNames::maxLegalLayouts,
llvm::cl::desc(
"Override maximum number of legal layouts for grid analysis."),
llvm::cl::desc("Override maximum number of sharded layouts for legal "
"layout analysis."),
llvm::cl::init(64)};

ListOption<int64_t> meshShape{
*this, OptionNames::meshShape,
llvm::cl::desc("Set the multi-device mesh shape.")};

// Options to enable/disable the workaround pass.
Option<bool> rowMajorEnabled{
*this, "row-major-enabled",
llvm::cl::desc(
"Enable row major layout generation in legal layout analysis."),
llvm::cl::init(false)};

// Option to enable/disable the workaround pass.
//
Option<bool> layouotWorkaroundsEnabled{
*this, "enable-layout-workaround-pass",
Expand Down
1 change: 1 addition & 0 deletions include/ttmlir/Dialect/TTNN/Transforms/Optimizer.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ struct TTNNOptimizerOptions {
MemoryLayoutAnalysisPolicyType::DFSharding;
bool memReconfigEnabled = false;
int64_t maxLegalLayouts = 64;
bool rowMajorEnabled = false;
};

std::unique_ptr<::mlir::Pass> createTTNNOptimizer();
Expand Down
79 changes: 69 additions & 10 deletions include/ttmlir/Dialect/TTNN/Utils/PassOverrides.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,18 +31,77 @@ struct OptionNames {
};

struct OutputLayoutOverrideParams {

SmallVector<int64_t, 2> grid;
BufferType bufferType;
TensorMemoryLayout tensorMemoryLayout; // INTERLEAVED / SHARDED etc...
Layout memoryLayout; // ROW_MAJOR / TILE
mlir::tt::DataType dataType;
std::optional<SmallVector<int64_t, 2>> grid;
std::optional<BufferType> bufferType;
std::optional<TensorMemoryLayout>
tensorMemoryLayout; // INTERLEAVED / SHARDED etc...
std::optional<Layout> memoryLayout; // ROW_MAJOR / TILE
std::optional<tt::DataType> dataType;

// Check if all layout parameters that are generated in LegalLayoutAnalysis
// are overridden. DataType is the only that is not.
bool fullLayoutOverride() const {
return grid.has_value() && bufferType.has_value() &&
tensorMemoryLayout.has_value() && memoryLayout.has_value();
}

bool operator==(const OutputLayoutOverrideParams rhs) const {
return grid[0] == rhs.grid[0] && grid[1] == rhs.grid[1] &&
bufferType == rhs.bufferType &&
tensorMemoryLayout == rhs.tensorMemoryLayout &&
memoryLayout == rhs.memoryLayout && dataType == rhs.dataType;
if (grid.has_value() != rhs.grid.has_value()) {
return false;
}

if (grid.has_value() && rhs.grid.has_value()) {
if (grid.value().size() != rhs.grid.value().size()) {
return false;
}
for (std::size_t i = 0; i < grid.value().size(); i++) {
if (grid.value()[i] != rhs.grid.value()[i]) {
return false;
}
}
}

if (bufferType.has_value() != rhs.bufferType.has_value()) {
return false;
}

if (bufferType.has_value() && rhs.bufferType.has_value()) {
if (bufferType.value() != rhs.bufferType.value()) {
return false;
}
}

if (tensorMemoryLayout.has_value() != rhs.tensorMemoryLayout.has_value()) {
return false;
}

if (tensorMemoryLayout.has_value() && rhs.tensorMemoryLayout.has_value()) {
if (tensorMemoryLayout.value() != rhs.tensorMemoryLayout.value()) {
return false;
}
}

if (memoryLayout.has_value() != rhs.memoryLayout.has_value()) {
return false;
}

if (memoryLayout.has_value() && rhs.memoryLayout.has_value()) {
if (memoryLayout.value() != rhs.memoryLayout.value()) {
return false;
}
}

if (dataType.has_value() != rhs.dataType.has_value()) {
return false;
}

if (dataType.has_value() && rhs.dataType.has_value()) {
if (dataType.value() != rhs.dataType.value()) {
return false;
}
}

return true;
}

bool operator!=(const OutputLayoutOverrideParams &rhs) const {
Expand Down
2 changes: 1 addition & 1 deletion lib/Dialect/TTNN/Analysis/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
add_mlir_dialect_library(MLIRTTNNAnalysis
LegalGridAnalysis.cpp
LegalLayoutAnalysis.cpp
OpConfigAnalysis.cpp
MemoryLayoutAnalysis.cpp
L1ChainConfig.cpp
Expand Down
2 changes: 1 addition & 1 deletion lib/Dialect/TTNN/Analysis/L1InterleavedPolicy.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -309,7 +309,7 @@ void L1InterleavedPolicy::run() {
}

bool L1InterleavedPolicy::isAnalyzable(Operation *op) {
// Skip operations that are not analyzed by the LegalGridAnalysis.
// Skip operations that are not analyzed by the LegalLayoutAnalysis.
//
if (legalLayouts.count(op) > 0) {
// Skip operations that are filterd out by the MemoryLayoutAnalysis.
Expand Down
Loading

0 comments on commit f2f2e97

Please sign in to comment.