diff --git a/compiler/src/iree/compiler/Codegen/Common/TypePropagationPass.cpp b/compiler/src/iree/compiler/Codegen/Common/TypePropagationPass.cpp index 676a7c3ae0058..2d9c4a98864f5 100644 --- a/compiler/src/iree/compiler/Codegen/Common/TypePropagationPass.cpp +++ b/compiler/src/iree/compiler/Codegen/Common/TypePropagationPass.cpp @@ -67,7 +67,7 @@ static Value convertElementType(OpBuilder &b, Location loc, Type targetType, static std::optional getLegalizedType(Type t) { if (auto shapedType = llvm::dyn_cast(t)) { std::optional legalizedElementType = - legalizeTensorStorageElementType(shapedType); + legalizeStorageElementType(shapedType); if (!legalizedElementType) return std::nullopt; return RankedTensorType::get(shapedType.getShape(), @@ -121,7 +121,7 @@ struct ConstantOpTypeConversion constantOp, "expected attribute type to be shaped type"); } std::optional legalizedElementType = - legalizeTensorStorageElementType(attrType); + legalizeStorageElementType(attrType); if (!legalizedElementType) { return rewriter.notifyMatchFailure(constantOp, "cannot legalize elementType"); @@ -230,7 +230,7 @@ struct GenericOpTypePropagation auto inputOperandType = llvm::cast(genericOp->getOperandTypes()[index]); std::optional legalizedArgType = - legalizeTensorStorageElementType(inputOperandType); + legalizeStorageElementType(inputOperandType); if (!legalizedArgType) { return genericOp.emitOpError("failed to get legalized type for arg ") << index; @@ -260,8 +260,8 @@ struct GenericOpTypePropagation modifyYield = true; OpOperand *yieldOperand = modifiedOp.getMatchingYieldValue(modifiedOpOperand); - std::optional legalizedType = legalizeTensorStorageElementType( - modifiedOpOperand->get().getType()); + std::optional legalizedType = + legalizeStorageElementType(modifiedOpOperand->get().getType()); if (!legalizedType) { return genericOp.emitOpError( "failed to get legalized type for yield value"); @@ -291,7 +291,7 @@ struct LinalgFillTypePropagation ConversionPatternRewriter &rewriter) const final { Value value = adaptor.getInputs().front(); std::optional legalizedElementType = - legalizeTensorStorageElementType(adaptor.getOutputs()[0].getType()); + legalizeStorageElementType(adaptor.getOutputs()[0].getType()); if (!legalizedElementType) { return fillOp.emitOpError("failed to get legalized type for value"); } @@ -358,7 +358,7 @@ struct IREELinalgExtScatterTypePropagation TypeConverter::SignatureConversion signatureConverter( modifiedOpRegion.getNumArguments()); std::optional legalizedArgType = - legalizeTensorStorageElementType(inputType); + legalizeStorageElementType(inputType); if (!legalizedArgType) { return scatterOp.emitOpError("failed to get legalized type for argument"); } @@ -425,7 +425,7 @@ struct IREELinalgExtSortTypePropagation auto convertType = index % 2 == 0 ? sortOp->getOperandTypes()[index / 2] : sortOp->getResultTypes()[index / 2]; std::optional legalizedArgType = - legalizeTensorStorageElementType(convertType); + legalizeStorageElementType(convertType); if (!legalizedArgType) { return sortOp.emitOpError("failed to get legalized type for argument"); } diff --git a/compiler/src/iree/compiler/Dialect/Stream/Transforms/EncodeTensors.cpp b/compiler/src/iree/compiler/Dialect/Stream/Transforms/EncodeTensors.cpp index 7d3d0f43562d4..1d3c72c344b10 100644 --- a/compiler/src/iree/compiler/Dialect/Stream/Transforms/EncodeTensors.cpp +++ b/compiler/src/iree/compiler/Dialect/Stream/Transforms/EncodeTensors.cpp @@ -58,7 +58,7 @@ static LogicalResult checkEncoding(Operation *op, RankedTensorType encodingType, // Aligns the element type of a tensor<> to a byte-aligned power of 2 bit width. static RankedTensorType alignTensorType(RankedTensorType originalType) { Type elementType = originalType.getElementType(); - Type alignedType = legalizeTensorStorageElementType(originalType); + Type alignedType = legalizeStorageElementType(originalType); if (alignedType == elementType) return originalType; return RankedTensorType::get(originalType.getShape(), alignedType, @@ -621,7 +621,7 @@ static IREE::Flow::DispatchTensorType alignDispatchTensorType(IREE::Flow::DispatchTensorType originalType) { Type elementType = originalType.getBoundElementType(); Type alignedType = - legalizeTensorStorageElementType(originalType.asRankedTensorType()); + legalizeStorageElementType(originalType.asRankedTensorType()); if (alignedType == elementType) return originalType; return IREE::Flow::DispatchTensorType::get( diff --git a/compiler/src/iree/compiler/Utils/ElementPackingUtils.cpp b/compiler/src/iree/compiler/Utils/ElementPackingUtils.cpp index 00ba6fc4be781..425ddeac82647 100644 --- a/compiler/src/iree/compiler/Utils/ElementPackingUtils.cpp +++ b/compiler/src/iree/compiler/Utils/ElementPackingUtils.cpp @@ -37,7 +37,7 @@ bool needToPackSubByteElements(Type type) { return bitWidth < 8 && llvm::isPowerOf2_32(bitWidth) && bitWidth != 1; } -Type legalizeTensorStorageElementType(Type type) { +Type legalizeStorageElementType(Type type) { auto tensorType = llvm::cast(type); auto elementType = tensorType.getElementType(); @@ -72,7 +72,7 @@ Value calculateStorageElementCountInBytes(Location loc, loc, builder, shapedType, dynamicDims); } - Type alignedElementType = legalizeTensorStorageElementType(shapedType); + Type alignedElementType = legalizeStorageElementType(shapedType); unsigned elementBits = IREE::Util::getTypeBitWidth(alignedElementType); // Calculate all static dims first, if any. @@ -113,7 +113,7 @@ Value calculateStorageElementOffsetInBytes(Location loc, RankedTensorType originalType, Value linearizedIndex, OpBuilder &builder) { - Type alignedElementType = legalizeTensorStorageElementType(originalType); + Type alignedElementType = legalizeStorageElementType(originalType); unsigned elementBits = IREE::Util::getTypeBitWidth(alignedElementType); // Sub-byte packing requires putting multiple elements in the same byte. diff --git a/compiler/src/iree/compiler/Utils/ElementPackingUtils.h b/compiler/src/iree/compiler/Utils/ElementPackingUtils.h index 4078643d832d1..8318e2328d9ec 100644 --- a/compiler/src/iree/compiler/Utils/ElementPackingUtils.h +++ b/compiler/src/iree/compiler/Utils/ElementPackingUtils.h @@ -18,15 +18,13 @@ namespace mlir::iree_compiler { /// together. bool needToPackSubByteElements(Type type); -/// Legalizes the given |elementType| for storage. +/// Legalizes the given |tensorType|'s element type for storage. /// /// In IREE, if compiling from the same source model, we control both the /// runtime and kernel. For such cases, we perform tight packing for supported /// sub-byte elements, and expand to the next power-of-two bit width for other /// cases. -Type legalizeStorageElementType(Type elementType); - -Type legalizeTensorStorageElementType(Type tensorType); +Type legalizeStorageElementType(Type tensorType); /// Emits IR with the given |builder| to calculate the total number of bytes /// required for the given |shapedType| in storage. Returns the value for the