Skip to content

Commit

Permalink
Rewriting elementwise binary ops in TTNN dialect to be non-dps (#2233)
Browse files Browse the repository at this point in the history
### Ticket
Closes #2231 

### Problem description
We have decided not to use DPS in the TTNN dialect until we can model it
properly.

### What's changed
This PR rewrites eltwise ops in the TTNN dialect in non-DPS style.

Note: Most of the changes are rewrites of existing tests; hence, don't
fear the number of changed files. :)

### Checklist
- [x] New/Existing tests provide coverage for changes
  • Loading branch information
sdjordjevicTT authored Feb 24, 2025
1 parent a575fee commit 91dd3d6
Show file tree
Hide file tree
Showing 167 changed files with 640 additions and 745 deletions.
42 changes: 25 additions & 17 deletions include/ttmlir/Dialect/TTNN/IR/TTNNOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -127,57 +127,68 @@ class TTNN_NamedDPSOp<string mnemonic, list<Trait> traits = []> :
}

class TTNN_ElementwiseOp<string mnemonic, list<Trait> traits = []> :
TTNN_NamedDPSOp<mnemonic, !listconcat(traits, [AttrSizedOperandSegments])> {
TTNN_Op<mnemonic, traits> {

let arguments = (ins Variadic<AnyRankedTensor>:$inputs,
Variadic<AnyRankedTensor>:$outputs);
let arguments = (ins Variadic<AnyRankedTensor>:$inputs);
let results = (outs Variadic<AnyRankedTensor>:$results);
}

class TTNN_ElementwiseUnaryOp<string mnemonic, list<Trait> traits = []> :
TTNN_ElementwiseOp<mnemonic, traits> {
TTNN_ElementwiseOp<mnemonic, !listconcat([OneOperand], traits)> {
let summary = "Eltwise unary op.";
let description = [{
Eltwise unary op.
}];

let builders =
[
OpBuilder<(ins "Value": $in, "Value": $out),
OpBuilder<(ins "Value": $in, "Type": $outputType),
[{
build($_builder, $_state, {outputType}, in);
}]>,
OpBuilder<(ins "Value": $in),
[{
build($_builder, $_state, {out.getType()}, in, out);
build($_builder, $_state, in, in.getType());
}]>
];
}

class TTNN_ElementwiseBinaryOp<string mnemonic, list<Trait> traits = []> :
TTNN_ElementwiseOp<mnemonic, traits> {
TTNN_ElementwiseOp<mnemonic, !listconcat([TwoOperands], traits)> {
let summary = "Eltwise binary op.";
let description = [{
Eltwise binary op.
}];

let builders =
[
OpBuilder<(ins "Value": $lhs, "Value": $rhs, "Value": $out),
OpBuilder<(ins "Value": $lhs, "Value": $rhs, "Type": $outputType),
[{
build($_builder, $_state, {outputType}, {lhs, rhs});
}]>,
OpBuilder<(ins "Value": $lhs, "Value": $rhs),
[{
build($_builder, $_state, {out.getType()}, {lhs, rhs}, out);
build($_builder, $_state, lhs, rhs, lhs.getType());
}]>
];
}

class TTNN_ElementwiseTernaryOp<string mnemonic, list<Trait> traits = []> :
TTNN_ElementwiseOp<mnemonic, traits> {
TTNN_ElementwiseOp<mnemonic, !listconcat([ThreeOperands], traits)> {
let summary = "Eltwise ternary op.";
let description = [{
Eltwise ternary op.
}];

let builders =
[
OpBuilder<(ins "Value": $first, "Value": $second, "Value": $third, "Value": $out),
OpBuilder<(ins "Value": $first, "Value": $second, "Value": $third, "Type": $outputType),
[{
build($_builder, $_state, {out.getType()}, {first, second, third}, out);
build($_builder, $_state, {outputType}, {first, second, third});
}]>,
OpBuilder<(ins "Value": $first, "Value": $second, "Value": $third),
[{
build($_builder, $_state, first, second, third, first.getType());
}]>
];
}
Expand All @@ -189,8 +200,6 @@ def TTNN_WhereOp : TTNN_ElementwiseTernaryOp<"where"> {
}];

let extraClassDeclaration = [{
MutableOperandRange getDpsInitsMutable() { return getOutputsMutable(); }

wa::TTNNOperandsWorkarounds getOperandsWorkarounds() {
::mlir::Operation::operand_range inputs = getInputs();
return
Expand Down Expand Up @@ -391,14 +400,13 @@ class TTNN_ElementwiseUnaryWithFloatParameterOp<string mnemonic, list<Trait> tra
}];

let arguments = (ins Variadic<AnyRankedTensor>:$inputs,
Variadic<AnyRankedTensor>:$outputs,
F32Attr:$parameter);

let builders =
[
OpBuilder<(ins "Value": $in, "Value": $out, "FloatAttr":$parameter),
OpBuilder<(ins "Value": $in, "FloatAttr":$parameter),
[{
build($_builder, $_state, {out.getType()}, {in}, {out}, parameter);
build($_builder, $_state, {in.getType()}, {in}, parameter);
}]>
];
}
Expand Down
18 changes: 18 additions & 0 deletions include/ttmlir/Dialect/TTNN/IR/TTNNTraits.h
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,24 @@ class HasMemoryConfigTrait
return mlir::success();
}
};

// Trait to verify that operations have exactly N operands.
template <unsigned N>
class NOperandTTNN {
public:
template <typename ConcreteType>
class Impl
: public mlir::OpTrait::TraitBase<ConcreteType, NOperandTTNN<N>::Impl> {
static LogicalResult verifyTrait(Operation *op) {
if (op->getNumOperands() != N) {
return op->emitOpError() << "Operation " << op->getName()
<< " must have exactly " << N << " operands.";
}
return success();
}
};
};

} // namespace mlir::tt::ttnn

#endif
22 changes: 22 additions & 0 deletions include/ttmlir/Dialect/TTNN/IR/TTNNTraits.td
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,32 @@

include "mlir/IR/OpBase.td"

//===----------------------------------------------------------------------===//
// TTNN traits definition.
//===----------------------------------------------------------------------===//

// Trait for ops that have memory config attribute.
def HasMemoryConfigTrait : NativeOpTrait<"HasMemoryConfigTrait">
{
let cppNamespace = "mlir::tt::ttnn";
}

// Trait for ops with variadic operands to specify the number of operands.
//
// Trait for ops with one operand.
def OneOperand : ParamNativeOpTrait<"NOperandTTNN", "1">
{
let cppNamespace = "mlir::tt::ttnn";
}
// Trait for ops with two operands.
def TwoOperands : ParamNativeOpTrait<"NOperandTTNN", "2">
{
let cppNamespace = "mlir::tt::ttnn";
}
// Trait for ops with three operands.
def ThreeOperands : ParamNativeOpTrait<"NOperandTTNN", "3">
{
let cppNamespace = "mlir::tt::ttnn";
}

#endif // TTMLIR_TTMLIR_DIALECT_TTNN_TTNNTRAITS_TD
18 changes: 8 additions & 10 deletions lib/Conversion/TTIRToTTNN/TTIRToTTNN.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -302,8 +302,7 @@ class ElementwiseOpConversionPattern : public OpConversionPattern<TTIROpTy> {
return failure();
}

rewriter.replaceOpWithNewOp<TTNNOpTy>(op, resultTypes, adaptor.getInputs(),
adaptor.getOutputs());
rewriter.replaceOpWithNewOp<TTNNOpTy>(op, resultTypes, adaptor.getInputs());
return success();
}
};
Expand Down Expand Up @@ -633,7 +632,7 @@ class ElementwiseUnaryWithFloatParameterOpConversionPattern
ConversionPatternRewriter &rewriter) const override {
rewriter.replaceOpWithNewOp<TTNNOpTy>(
op, this->getTypeConverter()->convertType(op.getType(0)),
adaptor.getInputs(), adaptor.getOutputs(), adaptor.getParameter());
adaptor.getInputs(), adaptor.getParameter());
return success();
}
};
Expand Down Expand Up @@ -1272,8 +1271,7 @@ class SubtractOpConversionPattern

if (lhsType.getShape() == rhsType.getShape()) {
rewriter.replaceOpWithNewOp<ttnn::SubtractOp>(
srcOp, adaptor.getInputs().front(), adaptor.getInputs().back(),
adaptor.getOutputs().front());
srcOp, adaptor.getInputs().front(), adaptor.getInputs().back());

// Broadcast for rhs operand require the operation to be commutative to
// allow switching the order of operands. To allow this conversion, the
Expand All @@ -1282,11 +1280,10 @@ class SubtractOpConversionPattern

} else {
ttnn::NegOp negOp = ttmlir::utils::createDPSOp<ttnn::NegOp>(
rewriter, srcOp.getLoc(), rhsType, adaptor.getInputs().back());
rewriter, srcOp.getLoc(), rhsType);

rewriter.replaceOpWithNewOp<ttnn::AddOp>(
srcOp, adaptor.getInputs().front(), negOp.getResults().front(),
adaptor.getOutputs().front());
srcOp, adaptor.getInputs().front(), negOp.getResults().front());
}

return success();
Expand Down Expand Up @@ -1415,8 +1412,9 @@ class ScatterOpConversionPattern : public OpConversionPattern<ttir::ScatterOp> {
ConversionPatternRewriter &rewriter) const override {
// The ttnn interface has the inverse inputs of the TTIR dialect op (which
// matches torch ops).
rewriter.replaceOpWithNewOp<ttnn::ScatterOp>(
op, adaptor.getUpdate(), adaptor.getInput(), adaptor.getOutput());
rewriter.replaceOpWithNewOp<ttnn::ScatterOp>(op, adaptor.getUpdate(),
adaptor.getInput(),
adaptor.getOutput().getType());

return success();
}
Expand Down
5 changes: 1 addition & 4 deletions lib/Conversion/TTNNToEmitC/TTNNToEmitC.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,6 @@ class EltwiseUnaryOpConversionPattern
llvm::SmallVector<Attribute, 5> attrs;
attrs.push_back(mlir::IntegerAttr::get(rewriter.getIndexType(), 0));
attrs.push_back(ttnn_to_emitc::utils::createStdNullopt(rewriter));
attrs.push_back(mlir::IntegerAttr::get(rewriter.getIndexType(), 1));

ArrayAttr arrayAttrs = ArrayAttr::get(srcOp->getContext(), attrs);

Expand Down Expand Up @@ -173,8 +172,7 @@ class EltwiseUnaryWithFastAndApproximateModeOpConversionPattern
{mlir::IntegerAttr::get(rewriter.getIndexType(), 0),
ttnn_to_emitc::utils::convertBoolAttr(
rewriter, BoolAttr::get(rewriter.getContext(), false)),
ttnn_to_emitc::utils::createStdNullopt(rewriter),
mlir::IntegerAttr::get(rewriter.getIndexType(), 1)});
ttnn_to_emitc::utils::createStdNullopt(rewriter)});

rewriter.replaceOpWithNewOp<emitc::CallOpaqueOp>(
srcOp, this->getTypeConverter()->convertType(srcOp.getType(0)),
Expand Down Expand Up @@ -243,7 +241,6 @@ class EltwiseBinaryOpConversionPattern
attrs.push_back(mlir::IntegerAttr::get(rewriter.getIndexType(), 1));
attrs.push_back(ttnn_to_emitc::utils::createStdNullopt(rewriter));
attrs.push_back(ttnn_to_emitc::utils::createStdNullopt(rewriter));
attrs.push_back(mlir::IntegerAttr::get(rewriter.getIndexType(), 2));

ArrayAttr arrayAttrs = ArrayAttr::get(srcOp->getContext(), attrs);

Expand Down
10 changes: 6 additions & 4 deletions lib/Dialect/TTNN/Analysis/ShardSolver.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -531,10 +531,12 @@ bool ShardSolver::checkShardCompatible(

uint32_t numOperands = consumerOp->getNumOperands();

// Some ops have multiple operands; and some ops have output also an
// operand. TBD if there is a more robust way to get real number of inputs.
// TODO(odjuricic): cast to DPSop?
numOperands = (numOperands > 1) ? numOperands - 1 : numOperands;
// DPS ops have an additional operand for the destination style, hence
// we need to subtract it from the total number of operands.
if (llvm::isa<DestinationStyleOpInterface>(consumerOp)) {
numOperands = numOperands - 1;
}

std::vector<TTNNLayoutAttr> inputLayouts;

auto inputUnderCheck =
Expand Down
6 changes: 2 additions & 4 deletions lib/Dialect/TTNN/IR/TTNNOpModelInterface.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,7 @@ ReluOp::getOpConstraints(const std::vector<TTNNLayoutAttr> &inputs,
assert(inputs.size() == 1);

const auto inputShape =
mlir::cast<RankedTensorType>(getDpsInputOperand(0)->get().getType())
.getShape();
mlir::cast<RankedTensorType>(getOperand(0).getType()).getShape();

const auto outputShape =
mlir::cast<RankedTensorType>(getResults().front().getType()).getShape();
Expand All @@ -58,8 +57,7 @@ ReluOp::getOpRuntime(const std::vector<TTNNLayoutAttr> &inputs,
assert(inputs.size() == 1);

const auto inputShape =
mlir::cast<RankedTensorType>(getDpsInputOperand(0)->get().getType())
.getShape();
mlir::cast<RankedTensorType>(getOperand(0).getType()).getShape();

const auto outputShape =
mlir::cast<RankedTensorType>(getResults().front().getType()).getShape();
Expand Down
1 change: 0 additions & 1 deletion lib/Dialect/TTNN/IR/TTNNWorkarounds.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -355,7 +355,6 @@ TTNNOperandsWorkaroundsFactory::createWhereOpOperandsWorkarounds(
.addInputOperandWorkaround(typeWorkaround)
.addInputOperandWorkaround(typeWorkaround)
.addInputOperandWorkaround(typeWorkaround)
.addInputOperandWorkaround(typeWorkaround)
.addOutputOperandWorkaround(typeWorkaround);
}
} // namespace mlir::tt::ttnn::wa
27 changes: 20 additions & 7 deletions lib/Dialect/TTNN/Transforms/Optimizer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
#include "ttmlir/Dialect/TTNN/Analysis/LegalLayoutAnalysis.h"
#include "ttmlir/Dialect/TTNN/Analysis/MemoryLayoutAnalysis.h"
#include "ttmlir/Dialect/TTNN/Analysis/OpConfigAnalysis.h"
#include "ttmlir/Dialect/TTNN/IR/TTNNOps.h"
#include "ttmlir/Dialect/TTNN/IR/TTNNOpsTypes.h"
#include "ttmlir/Dialect/TTNN/Transforms/Passes.h"
#include "ttmlir/Dialect/TTNN/Utils/Utils.h"
Expand Down Expand Up @@ -377,7 +378,7 @@ class TTNNOptimizer : public impl::TTNNOptimizerBase<TTNNOptimizer> {
void extractReshardEdges(ModuleOp &moduleOp,
std::unordered_set<Edge> &overrideReshardEdges) {
moduleOp->walk([&](Operation *op) {
if (!isa<DestinationStyleOpInterface>(op)) {
if (isa<ToLayoutOp>(op)) {
return;
}

Expand Down Expand Up @@ -409,16 +410,28 @@ class TTNNOptimizer : public impl::TTNNOptimizerBase<TTNNOptimizer> {
}

mlir::TypedValue<mlir::tt::DeviceType>
getDeviceOpValue(Operation *contextOp) {
getOrCreateDeviceOpValue(Operation *contextOp, OpBuilder &builder) {
Block *block = contextOp->getBlock();
mlir::TypedValue<mlir::tt::DeviceType> deviceOpResult;
for (auto &op : block->getOperations()) {
if (GetDeviceOp deviceOp = dyn_cast<GetDeviceOp>(op)) {
deviceOpResult = deviceOp.getResult();
break;
return deviceOp.getResult();
}
}
return deviceOpResult;

// Device op does not exist in the block, hence we need to create it.
DeviceAttr deviceAttr = getCurrentScopeDevice(contextOp);
auto currentInsertionPoint = builder.saveInsertionPoint();
builder.setInsertionPoint(block, block->begin());
llvm::SmallVector<int64_t> meshShape{deviceAttr.getMeshShape()};
if (meshShape.empty()) {
meshShape = llvm::SmallVector<int64_t, 2>{1, 1};
}
auto deviceOp = builder.create<ttnn::GetDeviceOp>(
contextOp->getLoc(), builder.getType<DeviceType>(deviceAttr),
ttnn::MeshShapeAttr::get(contextOp->getContext(), meshShape[0],
meshShape[1]));
builder.restoreInsertionPoint(currentInsertionPoint);
return deviceOp;
}

void
Expand Down Expand Up @@ -492,7 +505,7 @@ class TTNNOptimizer : public impl::TTNNOptimizerBase<TTNNOptimizer> {
Operation *memoryReconfigOp = builder.create<ToLayoutOp>(
consumerOp->getLoc(), newTensorType, producerOp->getResult(0),
outputLayout, outputDataType, outputMemConfigAttr,
getDeviceOpValue(consumerOp));
getOrCreateDeviceOpValue(consumerOp, builder));

consumerOp->setOperand(edge.operandIndex,
memoryReconfigOp->getResult(0));
Expand Down
Loading

0 comments on commit 91dd3d6

Please sign in to comment.