Skip to content

Commit

Permalink
address comments 2
Browse files Browse the repository at this point in the history
  • Loading branch information
lialan committed Jan 14, 2025
1 parent d4463b3 commit c1711b1
Show file tree
Hide file tree
Showing 4 changed files with 15 additions and 17 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ static Value convertElementType(OpBuilder &b, Location loc, Type targetType,
static std::optional<Type> getLegalizedType(Type t) {
if (auto shapedType = llvm::dyn_cast<RankedTensorType>(t)) {
std::optional<Type> legalizedElementType =
legalizeTensorStorageElementType(shapedType);
legalizeStorageElementType(shapedType);
if (!legalizedElementType)
return std::nullopt;
return RankedTensorType::get(shapedType.getShape(),
Expand Down Expand Up @@ -121,7 +121,7 @@ struct ConstantOpTypeConversion
constantOp, "expected attribute type to be shaped type");
}
std::optional<Type> legalizedElementType =
legalizeTensorStorageElementType(attrType);
legalizeStorageElementType(attrType);
if (!legalizedElementType) {
return rewriter.notifyMatchFailure(constantOp,
"cannot legalize elementType");
Expand Down Expand Up @@ -230,7 +230,7 @@ struct GenericOpTypePropagation
auto inputOperandType =
llvm::cast<RankedTensorType>(genericOp->getOperandTypes()[index]);
std::optional<Type> legalizedArgType =
legalizeTensorStorageElementType(inputOperandType);
legalizeStorageElementType(inputOperandType);
if (!legalizedArgType) {
return genericOp.emitOpError("failed to get legalized type for arg ")
<< index;
Expand Down Expand Up @@ -260,8 +260,8 @@ struct GenericOpTypePropagation
modifyYield = true;
OpOperand *yieldOperand =
modifiedOp.getMatchingYieldValue(modifiedOpOperand);
std::optional<Type> legalizedType = legalizeTensorStorageElementType(
modifiedOpOperand->get().getType());
std::optional<Type> legalizedType =
legalizeStorageElementType(modifiedOpOperand->get().getType());
if (!legalizedType) {
return genericOp.emitOpError(
"failed to get legalized type for yield value");
Expand Down Expand Up @@ -291,7 +291,7 @@ struct LinalgFillTypePropagation
ConversionPatternRewriter &rewriter) const final {
Value value = adaptor.getInputs().front();
std::optional<Type> legalizedElementType =
legalizeTensorStorageElementType(adaptor.getOutputs()[0].getType());
legalizeStorageElementType(adaptor.getOutputs()[0].getType());
if (!legalizedElementType) {
return fillOp.emitOpError("failed to get legalized type for value");
}
Expand Down Expand Up @@ -358,7 +358,7 @@ struct IREELinalgExtScatterTypePropagation
TypeConverter::SignatureConversion signatureConverter(
modifiedOpRegion.getNumArguments());
std::optional<Type> legalizedArgType =
legalizeTensorStorageElementType(inputType);
legalizeStorageElementType(inputType);
if (!legalizedArgType) {
return scatterOp.emitOpError("failed to get legalized type for argument");
}
Expand Down Expand Up @@ -425,7 +425,7 @@ struct IREELinalgExtSortTypePropagation
auto convertType = index % 2 == 0 ? sortOp->getOperandTypes()[index / 2]
: sortOp->getResultTypes()[index / 2];
std::optional<Type> legalizedArgType =
legalizeTensorStorageElementType(convertType);
legalizeStorageElementType(convertType);
if (!legalizedArgType) {
return sortOp.emitOpError("failed to get legalized type for argument");
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ static LogicalResult checkEncoding(Operation *op, RankedTensorType encodingType,
// Aligns the element type of a tensor<> to a byte-aligned power of 2 bit width.
static RankedTensorType alignTensorType(RankedTensorType originalType) {
Type elementType = originalType.getElementType();
Type alignedType = legalizeTensorStorageElementType(originalType);
Type alignedType = legalizeStorageElementType(originalType);
if (alignedType == elementType)
return originalType;
return RankedTensorType::get(originalType.getShape(), alignedType,
Expand Down Expand Up @@ -621,7 +621,7 @@ static IREE::Flow::DispatchTensorType
alignDispatchTensorType(IREE::Flow::DispatchTensorType originalType) {
Type elementType = originalType.getBoundElementType();
Type alignedType =
legalizeTensorStorageElementType(originalType.asRankedTensorType());
legalizeStorageElementType(originalType.asRankedTensorType());
if (alignedType == elementType)
return originalType;
return IREE::Flow::DispatchTensorType::get(
Expand Down
6 changes: 3 additions & 3 deletions compiler/src/iree/compiler/Utils/ElementPackingUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ bool needToPackSubByteElements(Type type) {
return bitWidth < 8 && llvm::isPowerOf2_32(bitWidth) && bitWidth != 1;
}

Type legalizeTensorStorageElementType(Type type) {
Type legalizeStorageElementType(Type type) {
auto tensorType = llvm::cast<TensorType>(type);
auto elementType = tensorType.getElementType();

Expand Down Expand Up @@ -72,7 +72,7 @@ Value calculateStorageElementCountInBytes(Location loc,
loc, builder, shapedType, dynamicDims);
}

Type alignedElementType = legalizeTensorStorageElementType(shapedType);
Type alignedElementType = legalizeStorageElementType(shapedType);
unsigned elementBits = IREE::Util::getTypeBitWidth(alignedElementType);

// Calculate all static dims first, if any.
Expand Down Expand Up @@ -113,7 +113,7 @@ Value calculateStorageElementOffsetInBytes(Location loc,
RankedTensorType originalType,
Value linearizedIndex,
OpBuilder &builder) {
Type alignedElementType = legalizeTensorStorageElementType(originalType);
Type alignedElementType = legalizeStorageElementType(originalType);
unsigned elementBits = IREE::Util::getTypeBitWidth(alignedElementType);

// Sub-byte packing requires putting multiple elements in the same byte.
Expand Down
6 changes: 2 additions & 4 deletions compiler/src/iree/compiler/Utils/ElementPackingUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,15 +18,13 @@ namespace mlir::iree_compiler {
/// together.
bool needToPackSubByteElements(Type type);

/// Legalizes the given |elementType| for storage.
/// Legalizes the given |tensorType|'s element type for storage.
///
/// In IREE, if compiling from the same source model, we control both the
/// runtime and kernel. For such cases, we perform tight packing for supported
/// sub-byte elements, and expand to the next power-of-two bit width for other
/// cases.
Type legalizeStorageElementType(Type elementType);

Type legalizeTensorStorageElementType(Type tensorType);
Type legalizeStorageElementType(Type tensorType);

/// Emits IR with the given |builder| to calculate the total number of bytes
/// required for the given |shapedType| in storage. Returns the value for the
Expand Down

0 comments on commit c1711b1

Please sign in to comment.