Skip to content

Commit

Permalink
#2065: Updated all reduce code to handle 0 or 1 cluster axis and clea…
Browse files Browse the repository at this point in the history
…ned up dialect representations of all reduce in ttir and ttnn. Update algorithms for calculating gather and scatter dimensions
  • Loading branch information
tapspatel committed Feb 26, 2025
1 parent f27fb78 commit ca39af3
Show file tree
Hide file tree
Showing 30 changed files with 791 additions and 290 deletions.
8 changes: 2 additions & 6 deletions include/ttmlir/Dialect/TTIR/IR/TTIROps.td
Original file line number Diff line number Diff line change
Expand Up @@ -1985,12 +1985,8 @@ def TTIR_AllReduceOp : TTIR_DPSOp<"all_reduce"> {

let arguments = (ins AnyRankedTensor:$input,
AnyRankedTensor:$output,
I64ElementsAttr:$replica_groups,
SI32Attr:$dim,
OptionalAttr<SI32Attr>:$channel_handle,
UnitAttr:$use_global_device_ids,
TT_ReduceTypeAttr:$reduce_type
);
TT_ReduceTypeAttr:$reduce_type,
UI32Attr:$cluster_axis);

let results = (outs AnyRankedTensor:$result);

Expand Down
14 changes: 7 additions & 7 deletions include/ttmlir/Dialect/TTNN/IR/TTNNOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -1424,9 +1424,10 @@ def TTNN_ReduceScatterOp: TTNN_Op<"reduce_scatter"> {

let arguments = (ins AnyRankedTensor:$input,
TT_Device:$device,
SI32Attr:$scatter_split_dim,
TT_ReduceTypeAttr:$math_op,
DefaultValuedAttr<SI32Attr, "1">:$num_links);
TT_ReduceTypeAttr:$reduce_type,
SI32Attr:$scatter_dim,
UI32Attr:$cluster_axis,
DefaultValuedAttr<UI32Attr, "1">:$num_links);

let results = (outs AnyRankedTensor:$result);

Expand All @@ -1441,10 +1442,9 @@ def TTNN_AllReduceOp: TTNN_Op<"all_reduce"> {

let arguments = (ins AnyRankedTensor:$input,
TT_Device:$device,
SI32Attr:$scatter_dim,
SI32Attr:$scatter_num,
TT_ReduceTypeAttr:$math_op,
DefaultValuedAttr<SI32Attr, "1">:$num_links);
TT_ReduceTypeAttr:$reduce_type,
UI32Attr:$cluster_axis,
DefaultValuedAttr<UI32Attr, "1">:$num_links);

let results = (outs AnyRankedTensor:$result);

Expand Down
5 changes: 3 additions & 2 deletions include/ttmlir/Target/TTNN/program.fbs
Original file line number Diff line number Diff line change
Expand Up @@ -394,8 +394,9 @@ table ReduceScatterOp {
in: tt.target.ttnn.TensorRef;
out: tt.target.ttnn.TensorRef;
device: tt.target.DeviceRef;
scatter_split_dim: uint32;
math_op: uint32;
scatter_dim: int32;
reduce_type: uint32;
cluster_axis: uint32;
num_links: uint32;
}

Expand Down
176 changes: 66 additions & 110 deletions lib/Conversion/StableHLOToTTIR/StableHLOToTTIRPatterns.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1506,21 +1506,59 @@ LogicalResult getReduceType(SrcOpTy &srcOp, ReduceType &reduceType) {
return failure();
}

// StalbeHLO spec.md defines following channel type for ccl ops
enum StableHLOChannelType {
// CHANNEL_TYPE_INVALID = 0 : Invalid primitive type to serve as
// default.
kChannelTypeInvalid = 0,
// DEVICE_TO_DEVICE = 1 : A channel for sending data between
// devices.
kChannelTypeDeviceToDevice = 1,
// DEVICE_TO_HOST = 2 : A channel for sending data from the
// device to the host. Can only be used with a Send operation.
kChannelTypeDeviceToHost = 2,
// HOST_TO_DEVICE = 3 : A channel for sending data from the host to
// the device. Can only be used with a Recv operation.
kChannelTypeHostToDevice = 3,
};
static LogicalResult
determineClusterAxis(::mlir::DenseIntElementsAttr replicaGroups,
uint32_t &clusterAxis) {
/*
We need to figure out what the cluster axis is based on replica_groups.
Replica groups define which device axis we are performing the collective
communication operation on. It is a 2D vector. Each element in replica_groups
contains a list of devices that will perform the collective communication
operation with each other. Currently we only support 2D meshes, but this
algorithm can be expanded for ND.
ex.
mesh = [2, 4]
replica_groups = [[0, 1, 2, 3], [4, 5, 6, 7]]
0 1 2 3
4 5 6 7
collective communication operation happens on (0, 1, 2, 3) and (4, 5, 6, 7) so
cluster_axis = 1 (mesh[1])
mesh = [2, 4]
replica_groups = [[0, 4], [1, 5], [2, 6], [3, 7]]
0 1 2 3
4 5 6 7
collective communication operation happens on (0, 4), (1, 5), (2, 6), (3, 7)
so cluster_axis = 0 (mesh[0])
*/
auto replicaGroupsShape = replicaGroups.getType().getShape();

if (replicaGroupsShape.size() == 0) {
// Cannot have replicas of size 0, this means we are not performing the
// collective communication operation across any device.
return failure();
}

// Case where we have single devices in each replica_group (ie perform
// collective communication operation against itself which should be optimized
// away). We also assume we are only using our constrained mesh types (ie 1x8,
// 1x32 etc) and cannot have (32x1, 8x1).
if (replicaGroupsShape[1] != 1) {
auto firstElementIt = replicaGroups.begin();
auto secondElementIt = firstElementIt + 1;

clusterAxis = (((*firstElementIt) + 1) == *secondElementIt);
return success();
}

// Default to cluster axis 0
clusterAxis = 0;
return success();
}

namespace {
class StableHLOToTTIRAllReduceOpConversionPattern
Expand All @@ -1540,64 +1578,31 @@ class StableHLOToTTIRAllReduceOpConversionPattern
return err;
}

IntegerAttr channelHandleAttr;
if (auto srcChannelHandleAttr = adaptor.getChannelHandleAttr()) {
// channelType is supposed to be DEVICE_TO_DEVICE or Invalid for CCL ops.
// Currently, we ensure if it is DEVICE_TO_DEVICE commmuincaiton.
// Consider preserving this information in the future if the attribute
// is non-DEVICE_TO_DEVICE values.
auto channelType = static_cast<int32_t>(srcChannelHandleAttr.getType());
if (channelType != kChannelTypeDeviceToDevice &&
channelType != kChannelTypeInvalid) {
return failure();
}

channelHandleAttr = rewriter.getSI32IntegerAttr(
static_cast<int32_t>(srcChannelHandleAttr.getHandle()));
// Determine cluster axis based on replica groups
uint32_t clusterAxis;
if (failed(determineClusterAxis(adaptor.getReplicaGroups(), clusterAxis))) {
return rewriter.notifyMatchFailure(
srcOp, "AllReduceOp cannot specify cluster axis.");
}
mlir::DenseIntElementsAttr replicaGroupsAttr =
adaptor.getReplicaGroupsAttr();
bool useGlobalDeviceIds = adaptor.getUseGlobalDeviceIds();

// Parse computation in region and add it to ttirAttrs
// Convert reduceType shlo attribute into ttir attribute
ReduceType reduceType;
if (failed(getReduceType(srcOp, reduceType))) {
return rewriter.notifyMatchFailure(
srcOp, "AllReduceOp cannot specify reduce type.");
}

// stablehlo all_reduce op has no dimension defined in the op. Thus, we
// estimate possible all reduce dimension. Current algorithm is to search
// for first non-one dimension of input tensor from back.
auto estimateDim = [](mlir::RankedTensorType inputType) -> int32_t {
if (inputType.getRank() == 1) {
return 0;
}
auto inputShape = inputType.getShape();
auto nonOneIt = std::find_if(inputShape.rbegin(), inputShape.rend(),
[](int64_t s) { return s != 1; });
int32_t dim = inputType.getRank() - 1 -
std::distance(inputShape.rbegin(), nonOneIt);
// all one shape, then select the deepest dim
if (dim < 0) {
dim = inputType.getRank() - 1;
}
return dim;
};

// Handle variadic input/output pairs by creating mulitple AllReduceOps.
llvm::SmallVector<mlir::Value> allReduceOpResults;
for (auto [inputOperand, resultOperand] :
llvm::zip_equal(adaptor.getOperands(), srcOp->getResults())) {
auto inputType = mlir::cast<RankedTensorType>(inputOperand.getType());
auto outputType = mlir::cast<RankedTensorType>(
getTypeConverter()->convertType(resultOperand.getType()));

auto allReduceOp =
ttmlir::utils::createDPSOp<mlir::tt::ttir::AllReduceOp>(
rewriter, srcOp.getLoc(), outputType, inputOperand,
replicaGroupsAttr, estimateDim(inputType), channelHandleAttr,
useGlobalDeviceIds, reduceType);
rewriter, srcOp.getLoc(), outputType, inputOperand, reduceType,
clusterAxis);

allReduceOpResults.push_back(allReduceOp.getResult());
}
Expand Down Expand Up @@ -1645,60 +1650,11 @@ class StableHLOToTTIRAllGatherOpConversionPattern
tensor::EmptyOp outputTensor = rewriter.create<tensor::EmptyOp>(
srcOp.getLoc(), outputType.getShape(), outputType.getElementType());

SmallVector<Type> ttirTypes;
if (failed(this->getTypeConverter()->convertTypes(srcOp->getResultTypes(),
ttirTypes))) {
return failure();
}

auto ttirOperands = srcOp.getOperandsMutable();
ttirOperands.append(ValueRange(outputTensor));

/*
We need to figure out what the cluster axis is based on replica_groups.
Replica groups define which device axis we are performing all_gather on.
It is a 2D vector. Each element in replica_groups contains a list of devices
that will perform all_gather with each other. Currently we only support 2D
meshes, but this algorithm can be expanded for ND.
ex.
mesh = [2, 4]
replica_groups = [[0, 1, 2, 3], [4, 5, 6, 7]]
0 1 2 3
4 5 6 7
all_gather happens on (0, 1, 2, 3) and (4, 5, 6, 7) so cluster_axis = 1
(mesh[1])
mesh = [2, 4]
replica_groups = [[0, 4], [1, 5], [2, 6], [3, 7]]
0 1 2 3
4 5 6 7
all_gather happens on (0, 4), (1, 5), (2, 6), (3, 7) so cluster_axis = 0
(mesh[0])
*/

uint32_t clusterAxis = 0;
auto replicaGroups = adaptor.getReplicaGroups();
auto replicaGroupsShape = adaptor.getReplicaGroups().getType().getShape();

if (replicaGroupsShape.size() == 0) {
// Cannot have replicas of size 0, this means we are not performing the
// all_gather across any device.
return failure();
}

// Case where we have single devices in each replica_group (ie perform
// all_gather against itself which should be optimized away).
// We also assume we are only using our constrained mesh types (ie 1x8, 1x32
// etc) and cannot have (32x1, 8x1).
if (replicaGroupsShape[1] != 1) {
auto firstElementIt = replicaGroups.begin();
auto secondElementIt = firstElementIt + 1;

clusterAxis = (((*firstElementIt) + 1) == *secondElementIt);
// Determine cluster axis based on replica groups
uint32_t clusterAxis;
if (failed(determineClusterAxis(adaptor.getReplicaGroups(), clusterAxis))) {
return rewriter.notifyMatchFailure(
srcOp, "AllGather cannot specify cluster axis.");
}

rewriter.replaceOpWithNewOp<mlir::tt::ttir::AllGatherOp>(
Expand Down
12 changes: 3 additions & 9 deletions lib/Conversion/TTIRToTTNN/TTIRToTTNN.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1340,18 +1340,12 @@ class AllReduceOpConversionPattern
LogicalResult
matchAndRewrite(ttir::AllReduceOp srcOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {

auto replicaGroupsShape = adaptor.getReplicaGroups().getType().getShape();
size_t scatter_dim = adaptor.getDim();
// scatter_num is needed when determining the output shape of workaround
// pass of reduce_scatter output and all_gather input
int32_t scatter_num =
replicaGroupsShape[scatter_dim % replicaGroupsShape.size()];
auto device = ::ttnn::utils::getOrInsertDevice(rewriter, srcOp);

rewriter.replaceOpWithNewOp<ttnn::AllReduceOp>(
srcOp, this->getTypeConverter()->convertType(srcOp.getType()),
adaptor.getInput(), device, scatter_dim, scatter_num,
adaptor.getReduceType());
adaptor.getInput(), device, adaptor.getReduceType(),
adaptor.getClusterAxis());

return success();
}
Expand Down
15 changes: 10 additions & 5 deletions lib/Dialect/TTIR/IR/TTIROps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2102,11 +2102,16 @@ ::mlir::LogicalResult mlir::tt::ttir::AllGatherOp::verify() {

// AllReduceOp verification
::mlir::LogicalResult mlir::tt::ttir::AllReduceOp::verify() {
::mlir::RankedTensorType inputType = getInput().getType();
int32_t dim = getDim();

if (dim >= inputType.getRank()) {
return emitOpError("Invalid dimension for all_reduce op.");
::mlir::tt::ReduceType reduceType = getReduceType();

// Currently TTIR only supports the following reduce types.
if (reduceType != ::mlir::tt::ReduceType::Sum &&
reduceType != ::mlir::tt::ReduceType::Mean &&
reduceType != ::mlir::tt::ReduceType::Max &&
reduceType != ::mlir::tt::ReduceType::Min &&
reduceType != ::mlir::tt::ReduceType::Std &&
reduceType != ::mlir::tt::ReduceType::Var) {
return emitOpError("Invalid reduction op for all reduce op.");
}

return success();
Expand Down
55 changes: 29 additions & 26 deletions lib/Dialect/TTNN/IR/TTNNOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1507,7 +1507,10 @@ ::mlir::LogicalResult AllGatherOp::verify() {
int32_t gatherDim = getAllGatherDim();

if (gatherDim >= inputType.getRank() || gatherDim < -inputType.getRank()) {
return emitOpError("Invalid dimension for all gather op.");
return emitOpError("Invalid gather dimension for all reduce op. Gather "
"dimension must be >= to input tensor rank or < -input "
"tensor rank, got gather_dim = ")
<< gatherDim;
}

return success();
Expand All @@ -1519,19 +1522,23 @@ ::mlir::LogicalResult AllGatherOp::verify() {

::mlir::LogicalResult ReduceScatterOp::verify() {
::mlir::RankedTensorType inputType = getInput().getType();
int32_t scatterSplitDim = getScatterSplitDim();
auto mathOp = getMathOp();

if (scatterSplitDim >= inputType.getRank() ||
scatterSplitDim < -inputType.getRank()) {
return emitOpError("Invalid dimension for reduce scatter op.");
}

// Check reduction op that we currently support in tt_nn
if (mathOp != ::mlir::tt::ReduceType::Sum &&
mathOp != ::mlir::tt::ReduceType::Max &&
mathOp != ::mlir::tt::ReduceType::Min) {
return emitOpError("Invalid reduction op for reduce scatter op.");
int32_t scatterDim = getScatterDim();
::mlir::tt::ReduceType reduceType = getReduceType();

if (scatterDim >= inputType.getRank() || scatterDim < -inputType.getRank()) {
return emitOpError("Invalid scatter dimension for all reduce op. Scatter "
"dimension must be >= to input tensor rank or < -input "
"tensor rank, got scatter_dim = ")
<< scatterDim;
}

// Currently TTNN only supports the following reduce types. Compiler is able
// to model the full ReduceType list but only the following can be lowered
// into TTNN.
if (reduceType != ::mlir::tt::ReduceType::Sum &&
reduceType != ::mlir::tt::ReduceType::Max &&
reduceType != ::mlir::tt::ReduceType::Min) {
return emitOpError("Invalid reduction op for all reduce op.");
}

return success();
Expand All @@ -1542,18 +1549,14 @@ ::mlir::LogicalResult ReduceScatterOp::verify() {
//===----------------------------------------------------------------------===//

::mlir::LogicalResult AllReduceOp::verify() {
::mlir::RankedTensorType inputType = getInput().getType();
int32_t dim = getScatterDim();
auto mathOp = getMathOp();

if (dim >= inputType.getRank() || dim < -inputType.getRank()) {
return emitOpError("Invalid dimension for all reduce op.");
}

// Check reduction op that we currently support in tt_nn
if (mathOp != ::mlir::tt::ReduceType::Sum &&
mathOp != ::mlir::tt::ReduceType::Max &&
mathOp != ::mlir::tt::ReduceType::Min) {
::mlir::tt::ReduceType reduceType = getReduceType();

// Currently TTNN only supports the following reduce types. Compiler is able
// to model the full ReduceType list but only the following can be lowered
// into TTNN.
if (reduceType != ::mlir::tt::ReduceType::Sum &&
reduceType != ::mlir::tt::ReduceType::Max &&
reduceType != ::mlir::tt::ReduceType::Min) {
return emitOpError("Invalid reduction op for all reduce op.");
}

Expand Down
Loading

0 comments on commit ca39af3

Please sign in to comment.