Skip to content

Commit

Permalink
Implement convolution via indexing maps.
Browse files Browse the repository at this point in the history
Reusing the dot emitter code, which is (almost) the same as the convolution code.
Note that adding the tests uncovered an issue with convolution indexing analysis - it was using an incorrect divisor when `feature_group_count` attribute is used.

PiperOrigin-RevId: 621769303
  • Loading branch information
sergeykozub authored and copybara-github committed Apr 4, 2024
1 parent 55cdde9 commit ff468cb
Show file tree
Hide file tree
Showing 4 changed files with 371 additions and 28 deletions.
70 changes: 48 additions & 22 deletions xla/service/gpu/fusions/mlir/elemental_hlo_to_mlir.cc
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ limitations under the License.
==============================================================================*/
#include "xla/service/gpu/fusions/mlir/elemental_hlo_to_mlir.h"

#include <cstddef>
#include <cstdint>
#include <functional>
#include <iterator>
Expand Down Expand Up @@ -166,8 +167,8 @@ static auto& kUnsupportedOps =
HloOpcode::kStochasticConvert,
HloOpcode::kCall};

static auto& kUnimplementedOps = *new absl::flat_hash_set<HloOpcode>{
HloOpcode::kConvolution, HloOpcode::kMap};
static auto& kUnimplementedOps =
*new absl::flat_hash_set<HloOpcode>{HloOpcode::kMap};

bool IsUnsupportedConstant(const HloInstruction* instr) {
return instr->opcode() == HloOpcode::kConstant &&
Expand Down Expand Up @@ -593,27 +594,10 @@ absl::StatusOr<Value> EmitMulAdd(Value lhs, Value rhs, Value accumulator,
b.create<arith::MulIOp>(lhs, rhs));
}

absl::StatusOr<SmallVector<Value>> EmitDot(
absl::StatusOr<SmallVector<Value>> EmitDotLoop(
const HloInstruction* instr, mlir::Type result_element_type,
ValueRange indices, const OperandProvider& operand_provider,
ImplicitLocOpBuilder& b) {
VLOG(1) << "EmitDot: " << instr->ToString() << " "
<< llvm_ir::DumpToString(result_element_type);

if (!algorithm_util::IsSupportedByElementalIrEmitter(
instr->precision_config().algorithm())) {
return absl::InvalidArgumentError(
absl::StrFormat("Algorithm not supported by the ElementalIrEmitter: %s",
PrecisionConfig::Algorithm_Name(
instr->precision_config().algorithm())));
}
auto* dot = DynCast<HloDotInstruction>(instr);
TF_RET_CHECK(dot != nullptr);
if (dot->sparse_operands()) {
return absl::UnimplementedError(
"Sparse dot is supported by Triton emitter only.");
}

HloInstructionIndexing indexing =
ComputeOutputToInputIndexing(instr, /*output_id=*/0, b.getContext());
const IndexingMap& lhs_indexing_map = *indexing.indexing_maps.at(0).begin();
Expand All @@ -624,13 +608,18 @@ absl::StatusOr<SmallVector<Value>> EmitDot(
Value accum_init_value =
b.create<ConstantOp>(b.getZeroAttr(accumulator_type)).getResult();

// For convolutions with `batch_group_count` > 1, there is an additional
// symbol for LHS (group id) - ignore it for RHS.
size_t rhs_symbol_count = rhs_indexing_map.GetSymbolCount();

auto body =
[&](ValueRange iter_args, ValueRange dim_values,
ValueRange symbol_values) -> absl::StatusOr<SmallVector<Value>> {
llvm::SmallVector<Value> lhs_indices = ApplyAffineMap(
lhs_indexing_map.GetAffineMap(), dim_values, symbol_values, b);
llvm::SmallVector<Value> rhs_indices = ApplyAffineMap(
rhs_indexing_map.GetAffineMap(), dim_values, symbol_values, b);
llvm::SmallVector<Value> rhs_indices =
ApplyAffineMap(rhs_indexing_map.GetAffineMap(), dim_values,
symbol_values.take_front(rhs_symbol_count), b);

TF_ASSIGN_OR_RETURN(Value lhs_value, GetSingleOperandValue(
operand_provider, instr,
Expand All @@ -655,6 +644,40 @@ absl::StatusOr<SmallVector<Value>> EmitDot(
return results;
}

absl::StatusOr<SmallVector<Value>> EmitDot(
const HloInstruction* instr, mlir::Type result_element_type,
ValueRange indices, const OperandProvider& operand_provider,
ImplicitLocOpBuilder& b) {
VLOG(1) << "EmitDot: " << instr->ToString() << " "
<< llvm_ir::DumpToString(result_element_type);

if (!algorithm_util::IsSupportedByElementalIrEmitter(
instr->precision_config().algorithm())) {
return absl::InvalidArgumentError(
absl::StrFormat("Algorithm not supported by the ElementalIrEmitter: %s",
PrecisionConfig::Algorithm_Name(
instr->precision_config().algorithm())));
}
auto* dot = DynCast<HloDotInstruction>(instr);
TF_RET_CHECK(dot != nullptr);
if (dot->sparse_operands()) {
return absl::UnimplementedError(
"Sparse dot is supported by Triton emitter only.");
}

return EmitDotLoop(instr, result_element_type, indices, operand_provider, b);
}

absl::StatusOr<SmallVector<Value>> EmitConvolution(
const HloInstruction* instr, mlir::Type result_element_type,
ValueRange indices, const OperandProvider& operand_provider,
ImplicitLocOpBuilder& b) {
VLOG(1) << "EmitConvolution: " << instr->ToString() << " "
<< llvm_ir::DumpToString(result_element_type);

return EmitDotLoop(instr, result_element_type, indices, operand_provider, b);
}

absl::StatusOr<SmallVector<Value>> EmitParameter(const HloInstruction* instr,
mlir::func::FuncOp this_fn,
ValueRange indices,
Expand Down Expand Up @@ -786,6 +809,9 @@ absl::StatusOr<SmallVector<Value>> HloToMlir(
}
return absl::UnimplementedError(
absl::StrCat("Unimplemented: ", instr->ToShortString()));
case HloOpcode::kConvolution:
return EmitConvolution(instr, result_element_type, indices,
operand_provider, builder);
case HloOpcode::kDynamicSlice:
return EmitDynamicSlice(instr, indices, operand_provider, builder);
case HloOpcode::kDynamicUpdateSlice:
Expand Down
Loading

0 comments on commit ff468cb

Please sign in to comment.