Skip to content

Commit

Permalink
[Stream] Enable batch affinity queries in SpecializeEncoding pass. (i…
Browse files Browse the repository at this point in the history
…ree-org#19975)

The returned function (i.e., `ResolveLayoutAttrFn`) can be very
inefficient because there could be other data-flow analysis in a run.
The revision updates the `ResolveLayoutAttrFn` API. Now it accepts a
list of query, and it stores the results to the map of
`SetVector<Attribute>`.

In the encoding specialization pass, it introduces
`StreamTensorOpUpdater` class. There are two phases in the updater. The
class caches all the queries in `init()`, and updates all the encodings
in `run()`. The `init` method is introduced because there could be a
failure in the initialization. In this context, we do not put them to
the constructor because we can not signal the error in constructors. See
https://google.github.io/styleguide/cppguide.html#Doing_Work_in_Constructors

The pass gets 440x speed-up for one of SDXL compilation.

The lit test configuration change (i.e.,
`--pass-pipeline='builtin.module(iree-stream-specialize-encodings)'`) is
needed because we want to validate failures for unsupported encodings.

---------

Signed-off-by: hanhanW <hanhan0912@gmail.com>
  • Loading branch information
hanhanW committed Feb 13, 2025
1 parent c4a2c4d commit d472796
Show file tree
Hide file tree
Showing 4 changed files with 238 additions and 112 deletions.
45 changes: 29 additions & 16 deletions compiler/src/iree/compiler/Dialect/HAL/IR/HALDialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -123,30 +123,43 @@ class HALAffinityAnalysisDialectInterface
: public IREE::Stream::AffinityAnalysisDialectInterface {
public:
using AffinityAnalysisDialectInterface::AffinityAnalysisDialectInterface;

// Returns a function that gathers the corresponding
// EncodingLayoutAttrInterface attributes for each
// (IREE::Stream::Affinity, Operation) query. The attribute is extracted from
// the `encoding` field in the HAL::ExecutableTargetAttr configuration. If the
// `encoding` is not present, the target attribute is returned.
IREE::Stream::ResolveLayoutAttrFn
makeLayoutAttrResolver(ModuleOp moduleOp) const {
return [=](IREE::Stream::AffinityAttr affinityAttr, Operation *op,
SetVector<Attribute> &layoutAttrs) -> LogicalResult {
// This needs to be in the lambda because the moduleOp could be modified..
return [=](ArrayRef<IREE::Stream::AffinityAndOpPair> batchQueries,
llvm::DenseMap<IREE::Stream::AffinityAndOpPair,
SetVector<Attribute>> &layoutAttrs)
-> LogicalResult {
// This needs to be in the lambda because the moduleOp could be modified.
IREE::HAL::DeviceAnalysis deviceAnalysis(moduleOp);
if (failed(deviceAnalysis.run())) {
return op->emitError("failed to run DeviceAnalysis");
return moduleOp->emitError("failed to run DeviceAnalysis");
}
SetVector<IREE::HAL::ExecutableTargetAttr> resultSet;
deviceAnalysis.gatherRequiredExecutableTargets(affinityAttr, op,
resultSet);
for (auto targetAttr : resultSet) {
Attribute result = targetAttr;
if (auto attr = targetAttr.getConfiguration().getNamed("encoding")) {
if (auto encodingLayoutAttr =
dyn_cast<IREE::Encoding::EncodingLayoutAttrInterface>(
attr->getValue())) {
result = encodingLayoutAttr.cloneWithSimplifiedConfig(
targetAttr.getConfiguration());

for (IREE::Stream::AffinityAndOpPair key : batchQueries) {
auto [affinityAttr, op] = key;
SetVector<IREE::HAL::ExecutableTargetAttr> resultSet;
deviceAnalysis.gatherRequiredExecutableTargets(affinityAttr, op,
resultSet);
for (auto targetAttr : resultSet) {
Attribute result = targetAttr;
if (auto attr = targetAttr.getConfiguration().getNamed("encoding")) {
if (auto encodingLayoutAttr =
dyn_cast<IREE::Encoding::EncodingLayoutAttrInterface>(
attr->getValue())) {
result = encodingLayoutAttr.cloneWithSimplifiedConfig(
targetAttr.getConfiguration());
}
}
layoutAttrs[key].insert(result);
}
layoutAttrs.insert(result);
}

return success();
};
};
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,13 @@

namespace mlir::iree_compiler::IREE::Stream {

using AffinityAndOpPair = std::pair<AffinityAttr, Operation *>;

// The function could be slow, if any data flow analysis is involved. Thus, the
// API provides the batch mode.
using ResolveLayoutAttrFn = std::function<LogicalResult(
AffinityAttr, Operation *, SetVector<Attribute> &)>;
ArrayRef<AffinityAndOpPair> batchQueries,
llvm::DenseMap<AffinityAndOpPair, SetVector<Attribute>> &layoutAttrs)>;

class AffinityAnalysisDialectInterface
: public DialectInterface::Base<AffinityAnalysisDialectInterface> {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
#include "llvm/Support/Debug.h"
#include "llvm/Support/LogicalResult.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/SymbolTable.h"
Expand Down Expand Up @@ -57,6 +58,8 @@ SmallVector<const T *> gatherUsedDialectInterfaces(mlir::ModuleOp moduleOp) {
return results;
}

} // namespace

// Returns an updated encoding attribute if the type is a RankedTensorType
// and an EncodingAttr is present. Otherwise, returns std::nullopt. The
// method uses the EncodingLayoutAttrInterface from the EncodingAttr to
Expand Down Expand Up @@ -319,14 +322,140 @@ static RankedTensorType cloneWithEncoding(RankedTensorType type,
encodingAttr);
}

/// Returns all the stream tensor ops that implement AffinityOpInterface, where
/// a stream affinity indicates the kind of enviroment the ops are expected run
/// in.
static SmallVector<IREE::Stream::AffinityOpInterface>
collectStreamTensorOps(FunctionOpInterface funcOp) {
SmallVector<IREE::Stream::AffinityOpInterface> result;
funcOp.walk([&](IREE::Stream::AffinityOpInterface affinityOp) {
// Only need to update encoding types for ops that have TensorPhaseOp trait.
if (!affinityOp->hasTrait<OpTrait::IREE::Stream::TensorPhaseOp>()) {
return;
}

// Bail out if the operation does not have an affinity attribute.
auto affinityAttr = affinityOp.getAffinityAttr();
if (!affinityAttr) {
return;
}
result.push_back(affinityOp);
});
return result;
}

namespace {

// Adds the resolved layouts to all tensor types on stream tensor ops, if
// encodings are present. Most of stream tensor ops implement
// AffinityOpInterface, where a stream affinity indicates the kind of
// enviroment the ops are expected run in. When an encoding is present in the
// tensor type, the method resolves the layouts, strips outdated information,
// and adds the resolved layouts to the encodings. The updated encodings should
// have enough information for other lowering transformations.
// TODO(hanchung): Add support for stream.tensor.load ops and
// stream.tensor.store ops. They are not affinity ops, so additional analysis
// will be needed in the work.
class StreamTensorOpUpdater {
public:
explicit StreamTensorOpUpdater(ModuleOp moduleOp) : moduleOp(moduleOp){};
~StreamTensorOpUpdater() {}

// Collects the stream tensor op candidates, and prepares all the needed
// information for the update. This must be called once before calling `run`.
// Note that all the ops are unmodified after the execution.
LogicalResult init();

// Adds the resolved layouts to all tensor types of `streamOps`, if encodings
// are present.
LogicalResult run();

private:
// Appends the query from the `affinityOp` to `queries`. Note that most of
// operations only care the execution affinity. There are outliers (e.g.,
// tensor dispatch op, etc.) that need to resolve affinities for
// operand resources.
LogicalResult addQuery(IREE::Stream::AffinityAnalysis &affinityAnalysis,
IREE::Stream::AffinityOpInterface affinityOp);

// The list of the queries that can be used for batch affinity queries. The
// analysis could be very expensive because it could apply the whole program
// data flow analysis.
SmallVector<IREE::Stream::AffinityAndOpPair> queries;

// The layout resolvers for each query.
llvm::DenseMap<IREE::Stream::AffinityAndOpPair, SetVector<Attribute>>
cachedLayoutAttrs;

// Input moduleOp. The op is not expected to be updated during the query.
// Because data flow analaysis can be involved. Modifying the IR invalidates
// the state and may lead to crashes as pointer references into the IR
// structure are retained.
ModuleOp moduleOp;

// The ops that need to be updated.
SmallVector<IREE::Stream::AffinityOpInterface> streamOps;

// The layout resolver function, which is used to resolve layouts for
// encodings. See StreamInterfaces.h for more details.
IREE::Stream::ResolveLayoutAttrFn resolveLayoutAttr;
};

} // namespace

LogicalResult StreamTensorOpUpdater::init() {
auto usedDialects = gatherUsedDialectInterfaces<
IREE::Stream::AffinityAnalysisDialectInterface>(moduleOp);
if (usedDialects.size() != 1) {
return moduleOp.emitError("expected only one dialect implementing "
"AffinityAnalysisDialectInterface");
}
resolveLayoutAttr = usedDialects[0]->makeLayoutAttrResolver(moduleOp);

for (auto funcOp : moduleOp.getOps<FunctionOpInterface>()) {
streamOps.append(collectStreamTensorOps(funcOp));
}

return success();
}

LogicalResult StreamTensorOpUpdater::addQuery(
IREE::Stream::AffinityAnalysis &affinityAnalysis,
IREE::Stream::AffinityOpInterface affinityOp) {
queries.emplace_back(affinityOp.getAffinityAttr(), affinityOp);

if (auto dispatchOp =
dyn_cast<IREE::Stream::TensorDispatchOp>(affinityOp.getOperation())) {
for (auto [operand, typeAttr] :
llvm::zip_equal(dispatchOp.getMixedOperands(),
dispatchOp.getOperandEncodings().getValue())) {
auto type = cast<TypeAttr>(typeAttr).getValue();
// Skip if the operand type is not AffinityType.
if (!isa<IREE::Stream::AffinityTypeInterface>(type)) {
continue;
}
SmallVector<IREE::Stream::AffinityAttr> affinityAttrs;
if (!affinityAnalysis.tryLookupResourceAffinity(operand, affinityAttrs)) {
return failure();
}
for (auto affinity : affinityAttrs) {
queries.emplace_back(affinity, affinityOp);
}
}
}

return success();
}

/// Updates the operand encondings and result encodings for the `dispatchOp`
/// with resolved layouts.
static LogicalResult
updateTensorDispatchOp(RewriterBase &rewriter, ModuleOp moduleOp,
IREE::Stream::AffinityAnalysis &affinityAnalysis,
IREE::Stream::TensorDispatchOp dispatchOp,
const SetVector<Attribute> &resLayoutResolvers,
IREE::Stream::ResolveLayoutAttrFn resolveLayoutAttr) {
static LogicalResult updateTensorDispatchOp(
RewriterBase &rewriter, ModuleOp moduleOp,
IREE::Stream::AffinityAnalysis &affinityAnalysis,
IREE::Stream::TensorDispatchOp dispatchOp,
const SetVector<Attribute> &resLayoutResolvers,
llvm::DenseMap<IREE::Stream::AffinityAndOpPair, SetVector<Attribute>>
&cachedLayoutAttrs) {
SmallVector<Type> newOperandEncodings;
for (auto [operand, typeAttr] :
llvm::zip_equal(dispatchOp.getMixedOperands(),
Expand All @@ -344,11 +473,11 @@ updateTensorDispatchOp(RewriterBase &rewriter, ModuleOp moduleOp,
if (affinityAttrs.size() != 1) {
return failure();
}
SetVector<Attribute> layoutResolvers;
if (failed(
resolveLayoutAttr(affinityAttrs[0], moduleOp, layoutResolvers))) {
return dispatchOp.emitError("failed on making layout resolvers");
}

IREE::Stream::AffinityAndOpPair key(affinityAttrs[0], dispatchOp);
assert(cachedLayoutAttrs.contains(key) &&
"the (affinity, dispatchOp) query is invalid");
const SetVector<Attribute> &layoutResolvers = cachedLayoutAttrs[key];

std::optional<IREE::Encoding::EncodingAttr> encodingAttr =
getEncodingWithNewLayouts(type, layoutResolvers);
Expand All @@ -370,7 +499,6 @@ updateTensorDispatchOp(RewriterBase &rewriter, ModuleOp moduleOp,
newResultEncodings.push_back(type);
continue;
}

std::optional<IREE::Encoding::EncodingAttr> encodingAttr =
getEncodingWithNewLayouts(type, resLayoutResolvers);
if (!encodingAttr) {
Expand Down Expand Up @@ -517,53 +645,34 @@ updateResultEncoding(RewriterBase &rewriter, OpTy op,
return success();
}

/// Adds the resolved layouts to all tensor types on stream tensor ops, if
/// encodings are present. Most of stream tensor ops implement
/// AffinityOpInterface, where a stream affinity indicates the kind of
/// enviroment the ops are expected run in. When an encoding is present in the
/// tensor type, the method resolves the layouts, strips outdated information,
/// and adds the resolved layouts to the encodings. The updated encodings should
/// have enough information for other lowering transformations.
/// TODO(hanchung): Add support for stream.tensor.load ops and
/// stream.tensor.store ops. They are not affinity ops, so additional analysis
/// will be needed in the work.
static LogicalResult addLayoutsToTensorPhaseOps(
ModuleOp moduleOp, IREE::Stream::AffinityAnalysis &affinityAnalysis,
FunctionOpInterface funcOp,
IREE::Stream::ResolveLayoutAttrFn resolveLayoutAttr) {
SmallVector<IREE::Stream::AffinityOpInterface> candidates;
funcOp.walk([&](IREE::Stream::AffinityOpInterface affinityOp) {
// Only need to update encoding types for ops that have TensorPhaseOp trait.
if (!affinityOp->hasTrait<OpTrait::IREE::Stream::TensorPhaseOp>()) {
return;
}
LogicalResult StreamTensorOpUpdater::run() {
IREE::Stream::AffinityAnalysis affinityAnalysis(moduleOp);
if (failed(affinityAnalysis.run())) {
return moduleOp.emitError("failed on running affinity analysis");
}

// Bail out if the operation does not have an affinity attribute.
auto affinityAttr = affinityOp.getAffinityAttr();
if (!affinityAttr) {
return;
for (auto op : streamOps) {
if (failed(addQuery(affinityAnalysis, op))) {
return failure();
}
candidates.push_back(affinityOp);
});
}

if (candidates.empty()) {
return success();
if (failed(resolveLayoutAttr(queries, cachedLayoutAttrs))) {
return failure();
}

IRRewriter rewriter(funcOp.getContext());
for (auto affinityOp : candidates) {
auto affinityAttr = affinityOp.getAffinityAttr();
SetVector<Attribute> layoutResolvers;
if (failed(resolveLayoutAttr(affinityAttr, moduleOp, layoutResolvers))) {
return affinityOp.emitError("failed on making layout resolvers");
}
IRRewriter rewriter(moduleOp.getContext());
for (auto affinityOp : streamOps) {
const SetVector<Attribute> &layoutResolvers =
cachedLayoutAttrs[IREE::Stream::AffinityAndOpPair(
affinityOp.getAffinityAttr(), affinityOp)];

LogicalResult result =
TypeSwitch<Operation *, LogicalResult>(affinityOp)
.Case<IREE::Stream::TensorDispatchOp>([&](auto op) {
return updateTensorDispatchOp(rewriter, moduleOp,
affinityAnalysis, op,
layoutResolvers, resolveLayoutAttr);
layoutResolvers, cachedLayoutAttrs);
})
.Case<IREE::Stream::TensorSizeOfOp>([&](auto op) {
return updateTensorSizeOfOp(rewriter, op, layoutResolvers);
Expand Down Expand Up @@ -594,36 +703,26 @@ static LogicalResult addLayoutsToTensorPhaseOps(
}
return success();
}
} // namespace

namespace {
struct SpecializeEncodingsPass
: public impl::SpecializeEncodingsPassBase<SpecializeEncodingsPass> {
void runOnOperation() override {
ModuleOp moduleOp = getOperation();
auto usedDialects = gatherUsedDialectInterfaces<
IREE::Stream::AffinityAnalysisDialectInterface>(moduleOp);
if (usedDialects.size() != 1) {
moduleOp.emitError("expected only one dialect implementing "
"AffinityAnalysisDialectInterface");

StreamTensorOpUpdater streamTensorOpUpdater(moduleOp);
if (failed(streamTensorOpUpdater.init())) {
moduleOp.emitError("failed to initialize StreamTensorOpUpdater");
return signalPassFailure();
}

IREE::Stream::AffinityAnalysis affinityAnalysis(moduleOp);
if (failed(affinityAnalysis.run())) {
moduleOp.emitError("failed on running affinity analysis");
if (failed(streamTensorOpUpdater.run())) {
moduleOp.emitError(
"failed to add layouts to Stream::TensorPhaseOp with encodings");
return signalPassFailure();
}

SymbolTable symbolTable(moduleOp);
IREE::Stream::ResolveLayoutAttrFn resolveLayoutAttr =
usedDialects[0]->makeLayoutAttrResolver(moduleOp);
for (auto funcOp : moduleOp.getOps<FunctionOpInterface>()) {
if (failed(addLayoutsToTensorPhaseOps(moduleOp, affinityAnalysis, funcOp,
resolveLayoutAttr))) {
funcOp.emitError(
"failed on adding layouts to Stream::TensorPhaseOp with encodings");
return signalPassFailure();
}
if (failed(duplicateExecutablesPerLayoutVariant(moduleOp, symbolTable,
funcOp))) {
funcOp.emitError("failed on executable duplication");
Expand All @@ -632,5 +731,6 @@ struct SpecializeEncodingsPass
}
}
};
} // namespace

} // namespace mlir::iree_compiler::IREE::Stream
Loading

0 comments on commit d472796

Please sign in to comment.