Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Preprocessing] Fix bug in TD dag matching op #19936

Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion compiler/plugins/target/CUDA/CUDATarget.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -545,7 +545,7 @@ class CUDATargetBackend final : public TargetBackend {
};

// Mark the entry point as a kernel.
setMetadataValueI32("kernel", 1);
llvmFunc->setCallingConv(llvm::CallingConv::PTX_Kernel);

// Set the maximum number of threads in the thread block (CTA).
auto exportOp = exportOpMap[funcOp.getName()];
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -269,7 +269,8 @@ static LogicalResult convertStreamableCall(StreamableFunc &streamableFunc,
auto calculateCallOp = builder.create<IREE::Util::CallOp>(
callOp.getLoc(), resultDimTypes,
streamableFunc.resultDimsFunc.getLeafReference().getValue(),
callOp.getOperands(), ArrayAttr{});
callOp.getOperands(), /*tied_operands=*/ArrayAttr{},
callOp.getArgAttrsAttr(), callOp.getResAttrsAttr());
llvm::append_range(resultDims, calculateCallOp.getResults());
} else {
// Get the shape dimensions from existing call arguments or tied operands.
Expand Down
1 change: 1 addition & 0 deletions compiler/src/iree/compiler/Codegen/Common/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,7 @@ iree_compiler_cc_library(
"LowerUKernelsToCalls.cpp",
"MaterializeEncoding.cpp",
"MaterializeEncodingIntoNop.cpp",
"MaterializeEncodingIntoPadding.cpp",
"MaterializeEncodingPatterns.cpp",
"MaterializeTuningSpecsPass.cpp",
"MemrefCopyToLinalg.cpp",
Expand Down
1 change: 1 addition & 0 deletions compiler/src/iree/compiler/Codegen/Common/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,7 @@ iree_cc_library(
"LowerUKernelsToCalls.cpp"
"MaterializeEncoding.cpp"
"MaterializeEncodingIntoNop.cpp"
"MaterializeEncodingIntoPadding.cpp"
"MaterializeEncodingPatterns.cpp"
"MaterializeTuningSpecsPass.cpp"
"MemrefCopyToLinalg.cpp"
Expand Down
20 changes: 15 additions & 5 deletions compiler/src/iree/compiler/Codegen/Common/EncodingUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,11 @@
#include "iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenTypes.h"
#include "iree/compiler/Codegen/Dialect/Codegen/Utils/Utils.h"
#include "iree/compiler/Dialect/Encoding/IR/EncodingTypes.h"
#include "mlir/Dialect/Linalg/IR/LinalgInterfaces.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Utils/IndexingUtils.h"
#include "mlir/IR/BuiltinAttributes.h"

#include <numeric>
#include <optional>

namespace mlir::iree_compiler {

Expand Down Expand Up @@ -101,14 +100,25 @@ getEncodingInfoFromLayouts(RankedTensorType type) {
if (!encodingAttr) {
return std::nullopt;
}
auto layoutsAttr = encodingAttr.getLayouts();
ArrayAttr layoutsAttr = encodingAttr.getLayouts();
if (!layoutsAttr) {
return std::nullopt;
}
ArrayRef<Attribute> layouts = layoutsAttr.getValue();
assert(layouts.size() == 1 && "only single layout is supported");
return cast<IREE::Codegen::LayoutAttrInterface>(layouts[0])
.getEncodingInfo(type);
if (auto layout = dyn_cast<IREE::Codegen::LayoutAttrInterface>(layouts[0])) {
return layout.getEncodingInfo(type);
}
return std::nullopt;
}

bool isNonZeroPadding(IREE::Encoding::PadEncodingLayoutAttr padLayout) {
if (!padLayout) {
return false;
}

return !llvm::all_of(padLayout.getPadding().asArrayRef(),
[](int32_t padValue) { return padValue == 0; });
}

} // namespace mlir::iree_compiler
6 changes: 5 additions & 1 deletion compiler/src/iree/compiler/Codegen/Common/EncodingUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -97,13 +97,17 @@ FailureOr<Value> lowerUnsetEncodingToUnpackOp(
Value packedValue, const MaterializeEncodingTypeConverter &typeConverter,
MaterializeEncodingValueFn materializeEncodingValueFn);

/// Pouplates the set of patterns that lowers operations with encoding types to
/// Populates the set of patterns that lowers operations with encoding types to
/// operations without encodings.
void populateMaterializeEncodingPatterns(
RewritePatternSet &patterns, MaterializeEncodingConversionTarget &target,
MaterializeEncodingTypeConverter &typeConverter,
MaterializeEncodingValueFn materializeEncodingValueFn);

/// Returns true when `padLayout` adds non-zero padding to at least one
/// dimension.
bool isNonZeroPadding(IREE::Encoding::PadEncodingLayoutAttr padLayout);

} // namespace mlir::iree_compiler

#endif // IREE_COMPILER_SRC_IREE_COMPILER_CODEGEN_COMMON_ENCODINGUTILS_H_
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,8 @@ emitLinkedTuningSpec(ModuleOp module, ArrayRef<NamedSequenceOp> specsToLink) {
operand = builder
.create<transform::IncludeOp>(
loc, anyOpType, symbol,
transform::FailurePropagationMode::Suppress, operand)
transform::FailurePropagationMode::Suppress, operand,
/*arg_attrs=*/nullptr, /*res_attrs=*/nullptr)
.getResults()
.front();
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,266 @@
// Copyright 2025 The IREE Authors
//
// Licensed under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

#include <cassert>
#include "iree/compiler/Codegen/Common/EncodingUtils.h"
#include "iree/compiler/Codegen/Common/PassUtils.h"
#include "iree/compiler/Codegen/Common/Passes.h"
#include "iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenAttrs.h"
#include "iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenDialect.h"
#include "iree/compiler/Dialect/Encoding/IR/EncodingTypes.h"
#include "iree/compiler/Dialect/Flow/IR/FlowOps.h"
#include "iree/compiler/Dialect/Flow/IR/FlowTypes.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallVector.h"
#include "mlir/Dialect/MemRef/Transforms/Transforms.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/Pass/PassManager.h"
#include "mlir/Transforms/DialectConversion.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "mlir/Transforms/Passes.h"

namespace mlir::iree_compiler {

#define GEN_PASS_DEF_MATERIALIZEENCODINGINTOPADDINGPASS
#include "iree/compiler/Codegen/Common/Passes.h.inc"

using namespace IREE::Encoding;

namespace {

// Returns the pad encoding layout, or nullptr if this is not the only layout or
// if there's no encoding at all.
static PadEncodingLayoutAttr getPadLayout(RankedTensorType type) {
auto encoding =
dyn_cast_or_null<IREE::Encoding::EncodingAttr>(type.getEncoding());
if (!encoding) {
return nullptr;
}
ArrayAttr layouts = encoding.getLayouts();
if (!layouts || layouts.size() != 1) {
return nullptr;
}

return dyn_cast<PadEncodingLayoutAttr>(*layouts.begin());
}

// Returns a padded tensor type (without encoding) for tensor types with the pad
// encoding layout, or the same type for all other tensors.
static RankedTensorType getPaddedType(RankedTensorType type) {
PadEncodingLayoutAttr layout = getPadLayout(type);
if (!isNonZeroPadding(layout)) {
return dropEncoding(type);
}

ArrayRef<int32_t> padding = layout.getPadding().asArrayRef();
auto newShape = llvm::to_vector_of<int64_t>(type.getShape());
for (auto [newDim, padValue] : llvm::zip_equal(newShape, padding)) {
assert(padValue == 0 || !ShapedType::isDynamic(newDim) &&
"Padding dynamic dims not supported");
newDim += padValue;
}

return RankedTensorType::get(newShape, type.getElementType());
}

static bool hasNonZeroPadding(RankedTensorType type) {
return isNonZeroPadding(getPadLayout(type));
}

struct MaterializePadEncodingTypeConverter final
: MaterializeEncodingTypeConverter {
MaterializePadEncodingTypeConverter(MLIRContext *ctx)
: MaterializeEncodingTypeConverter(
IREE::Codegen::EncodingNopLayoutAttr::get(ctx)) {
addConversion([](RankedTensorType type) -> std::optional<RankedTensorType> {
if (!getPadLayout(type)) {
// Return `nullopt` so that other conversion functions have a chance to
// handle this type.
return std::nullopt;
}
return getPaddedType(type);
});
}
};

/// Pattern to convert `flow.dispatch.tensor.store` operation when
/// materializing the encoding. We extract a smaller tensor for the padded
/// source. This way we do not create partial loads prematurely, which would be
/// difficult to undo later on.
struct MaterializeFlowDispatchTensorLoadOp final
: OpMaterializeEncodingPattern<IREE::Flow::DispatchTensorLoadOp> {
using OpMaterializeEncodingPattern::OpMaterializeEncodingPattern;

LogicalResult
matchAndRewrite(IREE::Flow::DispatchTensorLoadOp loadOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
// Only handle operations where the load covers the entire
// `!flow.dispatch.tensor` type.
if (!loadOp.isLoadOfWholeSource()) {
return rewriter.notifyMatchFailure(loadOp, "unhandled partial loads");
}

IREE::Flow::DispatchTensorType sourceType = loadOp.getSourceType();
auto boundTensorType = cast<RankedTensorType>(sourceType.getBoundType());
if (!hasNonZeroPadding(boundTensorType)) {
// Let the Nop pattern handle this.
return rewriter.notifyMatchFailure(loadOp, "no padding applied");
}

auto &typeConverter =
*getTypeConverter<MaterializePadEncodingTypeConverter>();
auto paddedType =
typeConverter.convertType<RankedTensorType>(boundTensorType);
assert(paddedType != boundTensorType && "Expected conversion with padding");

SmallVector<OpFoldResult> newMixedSizes =
getMixedValues(paddedType.getShape(), loadOp.getSourceDims(), rewriter);

SmallVector<OpFoldResult> newOffsets(newMixedSizes.size(),
rewriter.getIndexAttr(0));
SmallVector<OpFoldResult> newStrides(newMixedSizes.size(),
rewriter.getIndexAttr(1));
SmallVector<int64_t> newStaticDims;
SmallVector<Value> newDynamicDims;
dispatchIndexOpFoldResults(newMixedSizes, newDynamicDims, newStaticDims);

Location loc = loadOp.getLoc();
Value newLoad = rewriter.create<IREE::Flow::DispatchTensorLoadOp>(
loc, adaptor.getSource(), newDynamicDims, newOffsets, newMixedSizes,
newStrides);
auto extractType = RankedTensorType::get(boundTensorType.getShape(),
boundTensorType.getElementType());
SmallVector<OpFoldResult> extractSizes = getMixedValues(
boundTensorType.getShape(), loadOp.getSourceDims(), rewriter);
rewriter.replaceOpWithNewOp<tensor::ExtractSliceOp>(
loadOp, extractType, newLoad, newOffsets, extractSizes, newStrides);
return success();
}
};

/// Pattern to convert `flow.dispatch.tensor.store` operation when
/// materializing the encoding. We create a larger empty tensor for the
/// destination and insert the value into it. This way we do not create partial
/// stores prematurely, which would be difficult to undo later on.
struct MaterializeFlowDispatchTensorStoreOp final
: OpMaterializeEncodingPattern<IREE::Flow::DispatchTensorStoreOp> {
using OpMaterializeEncodingPattern::OpMaterializeEncodingPattern;

LogicalResult
matchAndRewrite(IREE::Flow::DispatchTensorStoreOp storeOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
// Only handle operations where the store covers the entire
// `!flow.dispatch.tensor` type.
if (!storeOp.isStoreToWholeTarget()) {
return rewriter.notifyMatchFailure(storeOp, "unhandled partial stores");
}

IREE::Flow::DispatchTensorType targetType = storeOp.getTargetType();
auto boundTensorType = cast<RankedTensorType>(targetType.getBoundType());
if (!hasNonZeroPadding(boundTensorType)) {
// Let the Nop pattern handle this.
return rewriter.notifyMatchFailure(storeOp, "no padding applied");
}

auto &typeConverter =
*getTypeConverter<MaterializePadEncodingTypeConverter>();
auto paddedType =
typeConverter.convertType<RankedTensorType>(boundTensorType);
assert(paddedType != boundTensorType && "Expected conversion with padding");

Location loc = storeOp.getLoc();
SmallVector<Value> dynamicResultSizes{storeOp->getOperands()};
Value empty =
rewriter.create<tensor::EmptyOp>(loc, paddedType, dynamicResultSizes);

SmallVector<OpFoldResult> offsets(paddedType.getRank(),
rewriter.getIndexAttr(0));
SmallVector<OpFoldResult> strides(paddedType.getRank(),
rewriter.getIndexAttr(1));
SmallVector<OpFoldResult> sizes =
tensor::getMixedSizes(rewriter, loc, adaptor.getValue());
Value insertOp = rewriter.create<tensor::InsertSliceOp>(
loc, adaptor.getValue(), empty, offsets, sizes, strides);

SmallVector<OpFoldResult> newMixedSizes = getMixedValues(
paddedType.getShape(), storeOp.getTargetDims(), rewriter);
SmallVector<int64_t> newStaticDims;
SmallVector<Value> newDynamicDims;
dispatchIndexOpFoldResults(newMixedSizes, newDynamicDims, newStaticDims);

rewriter.replaceOpWithNewOp<IREE::Flow::DispatchTensorStoreOp>(
storeOp, insertOp, adaptor.getTarget(), newDynamicDims, offsets,
newMixedSizes, strides);
return success();
}
};

struct MaterializeEncodingIntoPaddingPass final
: impl::MaterializeEncodingIntoPaddingPassBase<
MaterializeEncodingIntoPaddingPass> {
void getDependentDialects(DialectRegistry &registry) const override {
registry.insert<linalg::LinalgDialect, tensor::TensorDialect,
IREE::Codegen::IREECodegenDialect>();
}

void runOnOperation() override {
MLIRContext *context = &getContext();
FunctionOpInterface operation = getOperation();

auto materializeEncodingValueFn =
[](RankedTensorType, OpBuilder &,
Location) -> FailureOr<MaterializeEncodingValueInfo> {
return failure();
};

RewritePatternSet materializeEncodingPattern(context);
MaterializePadEncodingTypeConverter typeConverter(context);
MaterializeEncodingConversionTarget target(*context);
populateMaterializeEncodingPatterns(materializeEncodingPattern, target,
typeConverter,
materializeEncodingValueFn);

// The majority of this conversion is based on the 'Nop' materialization,
// with the exception of a few ops that have to account for padding.
// We add custom patterns with much higher priority to run before the
// equivalent 'Nop' patterns.
materializeEncodingPattern.add<MaterializeFlowDispatchTensorLoadOp,
MaterializeFlowDispatchTensorStoreOp>(
context, typeConverter, materializeEncodingValueFn,
PatternBenefit{100});

if (failed(applyPartialConversion(operation, target,
std::move(materializeEncodingPattern)))) {
operation.emitOpError("materialization failed");
return signalPassFailure();
}

// Add patterns to resolve dims ops and cleanups.
{
RewritePatternSet patterns(context);
memref::populateResolveRankedShapedTypeResultDimsPatterns(patterns);
context->getOrLoadDialect<tensor::TensorDialect>()
->getCanonicalizationPatterns(patterns);
// TODO: Drop these when we deprecate partial loads/stores.
IREE::Flow::populateTensorSliceOpWithDispatchTensorOpFoldingPatterns(
patterns, context);
if (failed(applyPatternsGreedily(operation, std::move(patterns)))) {
operation.emitOpError("folding patterns failed");
return signalPassFailure();
}
}
}
};
} // namespace

void addEncodingToPaddingPasses(FunctionLikeNest &passManager) {
passManager.addPass(createMaterializeEncodingIntoPaddingPass)
.addPass(createBufferizeCopyOnlyDispatchesPass)
.addPass(createCanonicalizerPass);
}

} // namespace mlir::iree_compiler
5 changes: 4 additions & 1 deletion compiler/src/iree/compiler/Codegen/Common/Passes.h
Original file line number Diff line number Diff line change
Expand Up @@ -50,9 +50,12 @@ void addIREEComprehensiveBufferizePasses(

void addConstantBufferizePasses(OpPassManager &funcPassManager);

/// Populate Encoding to Nop pass and canonicalizer pass to the pipeline
/// Populate Encoding to Nop pass and canonicalizer pass to the pipeline.
void addEncodingToNopPasses(FunctionLikeNest &passManager);

/// Populate Encoding to padding pass and canonicalizer pass to the pipeline.
void addEncodingToPaddingPasses(FunctionLikeNest &passManager);

/// Links nested transform dialect tuning specs named sequences into a single
/// entry point. Returns the new named sequence op (inserted into the `module`)
/// that includes the nested tuning specs, or a null op when no nested named
Expand Down
Loading
Loading