Skip to content

Commit

Permalink
Address review comment : Value -> OpFoldResult for basis
Browse files Browse the repository at this point in the history
  • Loading branch information
Abhishek-Varma committed Jan 17, 2025
1 parent bbfb245 commit c9dc955
Showing 1 changed file with 26 additions and 23 deletions.
49 changes: 26 additions & 23 deletions compiler/src/iree/compiler/Codegen/Common/LinearizeMemRefs.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -102,23 +102,25 @@ static Value linearizeOperand(Location loc, PatternRewriter &rewriter,
}
}

static SmallVector<Value> getDimValues(Location loc, PatternRewriter &rewriter,
MemRefType type,
ValueRange dynamicDims) {
SmallVector<Value> dims;
static SmallVector<OpFoldResult> getDimValues(Location loc,
PatternRewriter &rewriter,
MemRefType type,
ValueRange dynamicDims) {
SmallVector<OpFoldResult> dims;
auto shape = type.getShape();
int dynamicDimIndex = 0;
for (int i = 0; i < shape.size(); ++i) {
if (ShapedType::isDynamic(shape[i])) {
dims.push_back(dynamicDims[dynamicDimIndex++]);
} else {
dims.push_back(rewriter.create<arith::ConstantIndexOp>(loc, shape[i]));
dims.push_back(
rewriter.create<arith::ConstantIndexOp>(loc, shape[i]).getResult());
}
}
return dims;
}

static FailureOr<SmallVector<Value>>
static FailureOr<SmallVector<OpFoldResult>>
getMixedOrigSize(Location loc, PatternRewriter &rewriter, Value sourceValue) {
MemRefType sourceType = llvm::cast<MemRefType>(sourceValue.getType());
Operation *sourceOp = sourceValue.getDefiningOp();
Expand All @@ -128,10 +130,11 @@ getMixedOrigSize(Location loc, PatternRewriter &rewriter, Value sourceValue) {
return getDimValues(loc, rewriter, sourceType, allocaOp.getDynamicSizes());
} else {
if (sourceType.hasStaticShape()) {
SmallVector<Value> dims;
SmallVector<OpFoldResult> dims;
dims.reserve(sourceType.getRank());
for (int64_t dim : sourceType.getShape()) {
dims.push_back(rewriter.create<arith::ConstantIndexOp>(loc, dim));
dims.push_back(
rewriter.create<arith::ConstantIndexOp>(loc, dim).getResult());
}
return dims;
} else {
Expand Down Expand Up @@ -163,11 +166,11 @@ struct LinearizeMemrefAlloc : public OpRewritePattern<OpTy> {

SmallVector<Value> dynamicLinearizedSize;
if (!newTypeOfSourceMemref.hasStaticShape()) {
SmallVector<Value> basis = getDimValues(
SmallVector<OpFoldResult> basis = getDimValues(
loc, rewriter, currentTypeOfSourceMemref, allocOp.getDynamicSizes());
SmallVector<Value> multiIndices(
basis.size(), rewriter.create<arith::ConstantIndexOp>(loc, 0));
multiIndices[0] = basis[0];
multiIndices[0] = llvm::dyn_cast_if_present<Value>(basis[0]);
Value linearizedSizes = rewriter.create<affine::AffineLinearizeIndexOp>(
loc, multiIndices, basis, true);
dynamicLinearizedSize.push_back(linearizedSizes);
Expand Down Expand Up @@ -202,7 +205,7 @@ struct LinearizeMemrefLoad : public OpRewritePattern<memref::LoadOp> {
return failure();
MemRefType newTypeOfSourceMemref = *maybeNewTypeOfSourceMemref;

FailureOr<SmallVector<Value>> basis =
FailureOr<SmallVector<OpFoldResult>> basis =
getMixedOrigSize(loc, rewriter, loadOp.getMemref());
if (failed(basis))
return failure();
Expand All @@ -211,7 +214,7 @@ struct LinearizeMemrefLoad : public OpRewritePattern<memref::LoadOp> {

SmallVector<Value> multiIndices(
(*basis).size(), rewriter.create<arith::ConstantIndexOp>(loc, 0));
multiIndices[0] = (*basis)[0];
multiIndices[0] = llvm::dyn_cast_if_present<Value>((*basis)[0]);
Value linearizedSizes = rewriter.create<affine::AffineLinearizeIndexOp>(
loc, multiIndices, *basis, true);
Value linearizedOperand =
Expand Down Expand Up @@ -241,15 +244,15 @@ struct LinearizeMemrefStore : public OpRewritePattern<memref::StoreOp> {
return failure();
MemRefType newTypeOfSourceMemref = *maybeNewTypeOfSourceMemref;

FailureOr<SmallVector<Value>> basis =
FailureOr<SmallVector<OpFoldResult>> basis =
getMixedOrigSize(loc, rewriter, storeOp.getMemref());
if (failed(basis))
return failure();
Value linearizedIndices = rewriter.create<affine::AffineLinearizeIndexOp>(
loc, storeOp.getIndices(), *basis, true);
SmallVector<Value> multiIndices(
(*basis).size(), rewriter.create<arith::ConstantIndexOp>(loc, 0));
multiIndices[0] = (*basis)[0];
multiIndices[0] = llvm::dyn_cast_if_present<Value>((*basis)[0]);
Value linearizedSizes = rewriter.create<affine::AffineLinearizeIndexOp>(
loc, multiIndices, *basis, true);
Value linearizedOperand =
Expand Down Expand Up @@ -280,13 +283,13 @@ struct LinearizeMemrefDealloc : public OpRewritePattern<memref::DeallocOp> {
return failure();
MemRefType newTypeOfSourceMemref = *maybeNewTypeOfSourceMemref;

FailureOr<SmallVector<Value>> basis =
FailureOr<SmallVector<OpFoldResult>> basis =
getMixedOrigSize(loc, rewriter, deallocOp.getMemref());
if (failed(basis))
return failure();
SmallVector<Value> multiIndices(
(*basis).size(), rewriter.create<arith::ConstantIndexOp>(loc, 0));
multiIndices[0] = (*basis)[0];
multiIndices[0] = llvm::dyn_cast_if_present<Value>((*basis)[0]);
Value linearizedSizes = rewriter.create<affine::AffineLinearizeIndexOp>(
loc, multiIndices, *basis, true);
Value linearizedOperand =
Expand Down Expand Up @@ -318,13 +321,13 @@ struct LinearizeMemrefCopy : public OpRewritePattern<memref::CopyOp> {
return failure();
MemRefType newTypeOfSourceMemref = *maybeNewTypeOfSourceMemref;

FailureOr<SmallVector<Value>> basis =
FailureOr<SmallVector<OpFoldResult>> basis =
getMixedOrigSize(loc, rewriter, copyOp.getSource());
if (failed(basis))
return failure();
SmallVector<Value> multiIndices(
(*basis).size(), rewriter.create<arith::ConstantIndexOp>(loc, 0));
multiIndices[0] = (*basis)[0];
multiIndices[0] = llvm::dyn_cast_if_present<Value>((*basis)[0]);
Value linearizedSizes = rewriter.create<affine::AffineLinearizeIndexOp>(
loc, multiIndices, *basis, true);
Value linearizedSource =
Expand All @@ -333,7 +336,7 @@ struct LinearizeMemrefCopy : public OpRewritePattern<memref::CopyOp> {
basis = getMixedOrigSize(loc, rewriter, copyOp.getTarget());
if (failed(basis))
return failure();
multiIndices[0] = (*basis)[0];
multiIndices[0] = llvm::dyn_cast_if_present<Value>(((*basis)[0]));
linearizedSizes = rewriter.create<affine::AffineLinearizeIndexOp>(
loc, multiIndices, *basis, true);
Value linearizedTarget =
Expand Down Expand Up @@ -362,15 +365,15 @@ struct LinearizeVectorLoad : public OpRewritePattern<vector::LoadOp> {
return failure();
MemRefType newTypeOfSourceMemref = *maybeNewTypeOfSourceMemref;

FailureOr<SmallVector<Value>> basis =
FailureOr<SmallVector<OpFoldResult>> basis =
getMixedOrigSize(loc, rewriter, loadOp.getBase());
if (failed(basis))
return failure();
Value linearizedIndices = rewriter.create<affine::AffineLinearizeIndexOp>(
loc, loadOp.getIndices(), *basis, true);
SmallVector<Value> multiIndices(
(*basis).size(), rewriter.create<arith::ConstantIndexOp>(loc, 0));
multiIndices[0] = (*basis)[0];
multiIndices[0] = llvm::dyn_cast_if_present<Value>((*basis)[0]);
Value linearizedSizes = rewriter.create<affine::AffineLinearizeIndexOp>(
loc, multiIndices, *basis, true);
Value linearizedOperand =
Expand Down Expand Up @@ -401,15 +404,15 @@ struct LinearizeVectorStore : public OpRewritePattern<vector::StoreOp> {
return failure();
MemRefType newTypeOfSourceMemref = *maybeNewTypeOfSourceMemref;

FailureOr<SmallVector<Value>> basis =
FailureOr<SmallVector<OpFoldResult>> basis =
getMixedOrigSize(loc, rewriter, storeOp.getBase());
if (failed(basis))
return failure();
Value linearizedIndices = rewriter.create<affine::AffineLinearizeIndexOp>(
loc, storeOp.getIndices(), *basis, true);
SmallVector<Value> multiIndices(
(*basis).size(), rewriter.create<arith::ConstantIndexOp>(loc, 0));
multiIndices[0] = (*basis)[0];
multiIndices[0] = llvm::dyn_cast_if_present<Value>((*basis)[0]);
Value linearizedSizes = rewriter.create<affine::AffineLinearizeIndexOp>(
loc, multiIndices, *basis, true);
Value linearizedOperand =
Expand Down

0 comments on commit c9dc955

Please sign in to comment.