Skip to content

Commit

Permalink
Remove option and refactor
Browse files Browse the repository at this point in the history
  • Loading branch information
lialan committed Dec 29, 2024
1 parent 419ee6b commit e27a766
Show file tree
Hide file tree
Showing 14 changed files with 119 additions and 65 deletions.
6 changes: 6 additions & 0 deletions compiler/src/iree/compiler/Codegen/Common/EncodingUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -83,4 +83,10 @@ RankedTensorType dropEncoding(RankedTensorType type) {
return RankedTensorType::get(type.getShape(), type.getElementType());
}

RankedTensorType dropPackedStorageEncodingIfAny(RankedTensorType type) {
if (!IREE::Encoding::hasPackedStorageAttr(type))
return type;
return RankedTensorType::get(type.getShape(), type.getElementType());
}

} // namespace mlir::iree_compiler
4 changes: 4 additions & 0 deletions compiler/src/iree/compiler/Codegen/Common/EncodingUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
#include "iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenInterfaces.h"
#include "iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenTypes.h"
#include "iree/compiler/Dialect/Encoding/IR/EncodingOps.h"
#include "iree/compiler/Dialect/Encoding/IR/EncodingTypes.h"
#include "iree/compiler/Dialect/HAL/IR/HALTypes.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Transforms/DialectConversion.h"
Expand Down Expand Up @@ -79,6 +80,9 @@ class OpMaterializeEncodingPattern : public OpConversionPattern<OpTy> {
/// Returns the RankedTensorType without encodings.
RankedTensorType dropEncoding(RankedTensorType type);

/// Returns the RankedTensorType without packed storage encoding (if any).
RankedTensorType dropPackedStorageEncodingIfAny(RankedTensorType type);

/// Utility method to convert from `set_encoding` op to `pack` operation.
/// NOTE: `source` could be returned when packing is not needed.
FailureOr<Value> lowerSetEncodingOpToPackOp(
Expand Down
24 changes: 14 additions & 10 deletions compiler/src/iree/compiler/Codegen/Common/TypePropagationPass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
//===---------------------------------------------------------------------===//

#include "iree/compiler/Codegen/Common/Passes.h"
#include "iree/compiler/Dialect/Encoding/IR/EncodingTypes.h"
#include "iree/compiler/Dialect/LinalgExt/IR/LinalgExtDialect.h"
#include "iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.h"
#include "iree/compiler/Dialect/Util/IR/UtilTypes.h"
Expand Down Expand Up @@ -65,9 +66,8 @@ static Value convertElementType(OpBuilder &b, Location loc, Type targetType,
/// std::nullopt.
static std::optional<Type> getLegalizedType(Type t) {
if (auto shapedType = llvm::dyn_cast<RankedTensorType>(t)) {
Type elementType = shapedType.getElementType();
std::optional<Type> legalizedElementType =
legalizeStorageElementType(elementType);
legalizeTensorStorageElementType(shapedType);
if (!legalizedElementType)
return std::nullopt;
return RankedTensorType::get(shapedType.getShape(),
Expand Down Expand Up @@ -114,7 +114,7 @@ struct ConstantOpTypeConversion
constantOp, "expected attribute type to be shaped type");
}
std::optional<Type> legalizedElementType =
legalizeStorageElementType(attrType.getElementType());
legalizeTensorStorageElementType(attrType);
if (!legalizedElementType) {
return rewriter.notifyMatchFailure(constantOp,
"cannot legalize elementType");
Expand Down Expand Up @@ -220,8 +220,10 @@ struct GenericOpTypePropagation
signatureConverter.addInputs(index, argType);
continue;
}
auto inputOperandType =
llvm::cast<RankedTensorType>(genericOp->getOperandTypes()[index]);
std::optional<Type> legalizedArgType =
legalizeStorageElementType(argType);
legalizeTensorStorageElementType(inputOperandType);
if (!legalizedArgType) {
return genericOp.emitOpError("failed to get legalized type for arg ")
<< index;
Expand Down Expand Up @@ -251,8 +253,8 @@ struct GenericOpTypePropagation
modifyYield = true;
OpOperand *yieldOperand =
modifiedOp.getMatchingYieldValue(modifiedOpOperand);
std::optional<Type> legalizedType =
legalizeStorageElementType(yieldOperand->get().getType());
std::optional<Type> legalizedType = legalizeTensorStorageElementType(
modifiedOpOperand->get().getType());
if (!legalizedType) {
return genericOp.emitOpError(
"failed to get legalized type for yield value");
Expand Down Expand Up @@ -282,7 +284,7 @@ struct LinalgFillTypePropagation
ConversionPatternRewriter &rewriter) const final {
Value value = adaptor.getInputs().front();
std::optional<Type> legalizedElementType =
legalizeStorageElementType(value.getType());
legalizeTensorStorageElementType(adaptor.getOutputs()[0].getType());
if (!legalizedElementType) {
return fillOp.emitOpError("failed to get legalized type for value");
}
Expand Down Expand Up @@ -348,8 +350,8 @@ struct IREELinalgExtScatterTypePropagation
// type.
TypeConverter::SignatureConversion signatureConverter(
modifiedOpRegion.getNumArguments());
Type argType = modifiedOpRegion.getArguments()[0].getType();
std::optional<Type> legalizedArgType = legalizeStorageElementType(argType);
std::optional<Type> legalizedArgType =
legalizeTensorStorageElementType(inputType);
if (!legalizedArgType) {
return scatterOp.emitOpError("failed to get legalized type for argument");
}
Expand Down Expand Up @@ -411,8 +413,10 @@ struct IREELinalgExtSortTypePropagation
TypeConverter::SignatureConversion signatureConverter(
modifiedOpRegion.getNumArguments());
for (auto [index, arg] : llvm::enumerate(modifiedOpRegion.getArguments())) {
// Refer to input types of the original operation to determine the
// corresponding legal arg type.
std::optional<Type> legalizedArgType =
legalizeStorageElementType(arg.getType());
legalizeTensorStorageElementType(sortOp->getOperandTypes()[index]);
if (!legalizedArgType) {
return sortOp.emitOpError("failed to get legalized type for argument");
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -243,8 +243,12 @@ EncodingAttr getEncodingAttr(RankedTensorType type) {
return dyn_cast_or_null<EncodingAttr>(type.getEncoding());
}

bool hasPackedStorageAttr(RankedTensorType type) {
return dyn_cast_or_null<PackedStorageAttr>(type.getEncoding()) != nullptr;
bool hasPackedStorageAttr(Type type) {
if (auto tensorType = dyn_cast<RankedTensorType>(type)) {
return dyn_cast_or_null<PackedStorageAttr>(tensorType.getEncoding()) !=
nullptr;
}
return false;
}

FailureOr<linalg::ContractionDimensions>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ namespace mlir::iree_compiler::IREE::Encoding {
EncodingAttr getEncodingAttr(RankedTensorType type);

/// Returns true if the type contains packed_storage attribute.
bool hasPackedStorageAttr(RankedTensorType type);
bool hasPackedStorageAttr(Type type);

/// Returns the ContractionDimensions for the encoding user_indexing_maps.
FailureOr<linalg::ContractionDimensions>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@

#include "iree/compiler/Dialect/HAL/Conversion/StreamToHAL/Patterns.h"

#include "iree/compiler/Codegen/Common/EncodingUtils.h"
#include "iree/compiler/Dialect/Encoding/IR/EncodingTypes.h"
#include "iree/compiler/Dialect/HAL/Analysis/Captures.h"
#include "iree/compiler/Dialect/HAL/Conversion/StreamToHAL/Utils.h"
#include "iree/compiler/Dialect/HAL/IR/HALDialect.h"
Expand Down Expand Up @@ -478,7 +480,8 @@ struct TensorExportBufferViewOpPattern
}

auto loc = exportOp.getLoc();
auto tensorType = llvm::cast<RankedTensorType>(adaptor.getSourceEncoding());
auto tensorType = dropPackedStorageEncodingIfAny(
llvm::cast<RankedTensorType>(adaptor.getSourceEncoding()));
auto dynamicDims = adaptor.getSourceEncodingDims();

// NOTE: we should have verified supported encodings/types at entry into the
Expand Down
7 changes: 6 additions & 1 deletion compiler/src/iree/compiler/Dialect/Stream/IR/StreamOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

#include "iree/compiler/Dialect/Stream/IR/StreamOps.h"

#include "iree/compiler/Dialect/Encoding/IR/EncodingTypes.h"
#include "iree/compiler/Dialect/Util/IR/ClosureOpUtils.h"
#include "iree/compiler/Dialect/Util/IR/UtilOps.h"
#include "iree/compiler/Dialect/Util/IR/UtilTypes.h"
Expand All @@ -27,6 +28,10 @@
#include "mlir/Support/LogicalResult.h"
#include "mlir/Transforms/RegionUtils.h"

namespace mlir::iree_compiler {
using IREE::Encoding::getEncodingAttr;
}

namespace mlir::iree_compiler::IREE::Stream {

//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -1512,7 +1517,7 @@ LogicalResult TensorCloneOp::verify() {
// information.
auto sourceEncoding = llvm::cast<RankedTensorType>(op.getSourceEncoding());
auto resultEncoding = llvm::cast<RankedTensorType>(op.getResultEncoding());
if (sourceEncoding.getEncoding() != resultEncoding.getEncoding()) {
if (getEncodingAttr(sourceEncoding) != getEncodingAttr(resultEncoding)) {
return op.emitOpError() << "clones changing tensor encoding from "
<< sourceEncoding.getEncoding() << " to "
<< resultEncoding.getEncoding() << "; not allowed";
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

#include "iree/compiler/Dialect/Encoding/IR/EncodingTypes.h"
#include "iree/compiler/Dialect/Flow/IR/FlowDialect.h"
#include "iree/compiler/Dialect/Flow/IR/FlowTypes.h"
#include "iree/compiler/Dialect/Stream/Analysis/Affinity.h"
Expand All @@ -22,6 +23,7 @@
#include "iree/compiler/Dialect/Util/IR/UtilTypes.h"
#include "iree/compiler/Dialect/Util/Transforms/Passes.h"
#include "iree/compiler/Dialect/Util/Transforms/Patterns.h"
#include "llvm/Support/Casting.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
Expand Down Expand Up @@ -247,6 +249,12 @@ struct ConvertToStreamPass final
if (llvm::isa<IREE::Flow::ChannelType>(type)) {
return IREE::Stream::ChannelType::get(context);
}
if (auto rankedType = llvm::dyn_cast_or_null<RankedTensorType>(type)) {
if (IREE::Encoding::hasPackedStorageAttr(rankedType)) {
return RankedTensorType::get(rankedType.getShape(),
rankedType.getElementType());
}
}
return !llvm::isa<TensorType>(type) ? type : Type{};
});

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ static LogicalResult checkEncoding(Operation *op, RankedTensorType encodingType,
// Aligns the element type of a tensor<> to a byte-aligned power of 2 bit width.
static RankedTensorType alignTensorType(RankedTensorType originalType) {
Type elementType = originalType.getElementType();
Type alignedType = legalizeStorageElementType(elementType);
Type alignedType = legalizeTensorStorageElementType(originalType);
if (alignedType == elementType)
return originalType;
return RankedTensorType::get(originalType.getShape(), alignedType,
Expand Down Expand Up @@ -620,7 +620,8 @@ struct EncodeHostTensorsPass
static IREE::Flow::DispatchTensorType
alignDispatchTensorType(IREE::Flow::DispatchTensorType originalType) {
Type elementType = originalType.getBoundElementType();
Type alignedType = legalizeStorageElementType(elementType);
Type alignedType =
legalizeTensorStorageElementType(originalType.asRankedTensorType());
if (alignedType == elementType)
return originalType;
return IREE::Flow::DispatchTensorType::get(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,3 +20,10 @@ func.func @aligned_i1_size() -> index {
// CHECK: func @aligned_i1_size() -> index {
// CHECK-DAG: %[[C3:.+]] = arith.constant 3 : index
// CHECK: return %[[C3]] : index

// -----

#packed = #iree_encoding.packed_storage
func.func @packed_i1_input_output(%input : tensor<16xi1, #packed>) -> tensor<16xi1, #packed> {
return %input : tensor<16xi1, #packed>
}
76 changes: 28 additions & 48 deletions compiler/src/iree/compiler/Utils/ElementPackingUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,46 +15,23 @@
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/IR/BuiltinTypes.h"

llvm::cl::opt<bool> clEnableI1Support(
"iree-experimental-packed-i1-storage",
llvm::cl::desc(
"Experimental feature: force to use packed storage for i1 tensors."
"Turning on this option will see i1 tensors as if it has "
"#iree_encoding.packed_storage attribute."
"This is to allow an alternative way to test the packed storage "
"feature before frontend can emit packed i1 tensors."
"This option can be dropped once the frontend can emit packed i1 "
"tensors."),
llvm::cl::init(false));

namespace mlir::iree_compiler {

static bool needToPackSubByteElementBitWidthImpl(unsigned bitWidth,
bool isPackedStorage) {
// Enable i1 support if requested.
if (isPackedStorage && bitWidth == 1) {
return true;
}
bool needToPackSubByteElementBitWidth(unsigned bitWidth) {
// Require the original bit width to be some power of two for now to avoid
// trickiness and weirdness of packing and cross-byte access.
// Also disallow boolean values for now--they may require separate interface
// choices.
return bitWidth < 8 && llvm::isPowerOf2_32(bitWidth) && bitWidth != 1;
}

bool needToPackSubByteElementBitWidth(unsigned bitWidth) {
return needToPackSubByteElementBitWidthImpl(
bitWidth, /*isPackedStorage=*/clEnableI1Support);
}

bool needToPackSubByteElements(RankedTensorType shapedType) {
unsigned bitWidth = IREE::Util::getTypeBitWidth(shapedType.getElementType());
// Two paths to enable packed storage for i1 tensors: the attribute or cl
// option. The cl option will be dropped once frontend supports emitting
// tensors with attributes.
bool isPackedStorage =
IREE::Encoding::hasPackedStorageAttr(shapedType) || clEnableI1Support;
return needToPackSubByteElementBitWidthImpl(bitWidth, isPackedStorage);
// i1 with packed memory layout does not need to be extended.
if (bitWidth == 1 && !IREE::Encoding::hasPackedStorageAttr(shapedType)) {
return true;
}
return needToPackSubByteElementBitWidth(bitWidth);
}

static Type legalizeStorageElementTypeImpl(Type elementType,
Expand All @@ -64,9 +41,13 @@ static Type legalizeStorageElementTypeImpl(Type elementType,
if (!intType)
return elementType;

// For sub-byte elements, default to pack them into bytes.
unsigned bitWidth = intType.getWidth();
if (needToPackSubByteElementBitWidthImpl(bitWidth, isPackedStorage))
if (bitWidth == 1 && !isPackedStorage) {
return elementType;
}

// For sub-byte elements, default to pack them into bytes.
if (needToPackSubByteElementBitWidth(bitWidth))
return elementType;

// Otherwise, extend them to the next power-of-two bit width.
Expand All @@ -78,10 +59,10 @@ static Type legalizeStorageElementTypeImpl(Type elementType,
intType.getSignedness());
}

Type legalizeStorageElementType(Type elementType) {
// Consider packed storage for i1 tensors if cl opt is set.
return legalizeStorageElementTypeImpl(elementType,
/*isPackedStorage=*/clEnableI1Support);
Type legalizeTensorStorageElementType(Type type) {
auto tensorType = llvm::dyn_cast<RankedTensorType>(type);
return legalizeStorageElementTypeImpl(
type, tensorType && IREE::Encoding::hasPackedStorageAttr(type));
}

Value calculateStorageElementCountInBytes(Location loc,
Expand All @@ -96,16 +77,16 @@ Value calculateStorageElementCountInBytes(Location loc,
loc, builder, shapedType, dynamicDims);
}

// TODO(lialan): remove cl options once frontend can emit packed i1 tensors.
bool isPackedStorage =
IREE::Encoding::hasPackedStorageAttr(shapedType) || clEnableI1Support;
Type alignedElementType = legalizeStorageElementTypeImpl(
shapedType.getElementType(), isPackedStorage);
Type alignedElementType = legalizeTensorStorageElementType(shapedType);
unsigned elementBits = IREE::Util::getTypeBitWidth(alignedElementType);

bool isPackedStorage = IREE::Encoding::hasPackedStorageAttr(shapedType);
bool isI1WithPackedStorage = elementBits == 1 && isPackedStorage;

// Calculate all static dims first, if any.
int64_t staticCount = 1;
if (!needToPackSubByteElementBitWidthImpl(elementBits, isPackedStorage)) {
if (!isI1WithPackedStorage &&
!needToPackSubByteElementBitWidth(elementBits)) {
staticCount *= IREE::Util::getRoundedElementByteWidth(alignedElementType);
}

Expand All @@ -120,7 +101,7 @@ Value calculateStorageElementCountInBytes(Location loc,
value = builder.createOrFold<arith::MulIOp>(loc, value, dim);
}
// Sub-byte packing requires putting multiple elements in the same byte.
if (needToPackSubByteElementBitWidthImpl(elementBits, isPackedStorage)) {
if (isI1WithPackedStorage || needToPackSubByteElementBitWidth(elementBits)) {
assert(8 % elementBits == 0);
unsigned byteElements = 8 / elementBits;
// TODO(antiagainst): We may want to emit runtime check to make sure this is
Expand All @@ -140,15 +121,14 @@ Value calculateStorageElementOffsetInBytes(Location loc,
RankedTensorType originalType,
Value linearizedIndex,
OpBuilder &builder) {
// TODO: remove cl options once frontend can emit packed i1 tensors.
bool isPackedStorage =
IREE::Encoding::hasPackedStorageAttr(originalType) || clEnableI1Support;
Type alignedElementType = legalizeStorageElementTypeImpl(
originalType.getElementType(), isPackedStorage);
Type alignedElementType = legalizeTensorStorageElementType(originalType);
unsigned elementBits = IREE::Util::getTypeBitWidth(alignedElementType);

bool isPackedStorage = IREE::Encoding::hasPackedStorageAttr(originalType);
bool isI1WithPackedStorage = elementBits == 1 && isPackedStorage;

// Sub-byte packing requires putting multiple elements in the same byte.
if (needToPackSubByteElementBitWidthImpl(elementBits, isPackedStorage)) {
if (isI1WithPackedStorage || needToPackSubByteElementBitWidth(elementBits)) {
Value byteElements =
builder.create<arith::ConstantIndexOp>(loc, 8 / elementBits);
// TODO(antiagainst): We may want to emit runtime check to make sure this is
Expand Down
2 changes: 2 additions & 0 deletions compiler/src/iree/compiler/Utils/ElementPackingUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@ bool needToPackSubByteElements(RankedTensorType shapedType);
/// cases.
Type legalizeStorageElementType(Type elementType);

Type legalizeTensorStorageElementType(Type tensorType);

/// Emits IR with the given |builder| to calculate the total number of bytes
/// required for the given |shapedType| in storage. Returns the value for the
/// final count on success; returns nullptr on failure. Dynamic dimensions in
Expand Down
Loading

0 comments on commit e27a766

Please sign in to comment.