diff --git a/lib/Optimizer/Transforms/FactorQuantumAlloc.cpp b/lib/Optimizer/Transforms/FactorQuantumAlloc.cpp index 8983af56ad..b413aec5b0 100644 --- a/lib/Optimizer/Transforms/FactorQuantumAlloc.cpp +++ b/lib/Optimizer/Transforms/FactorQuantumAlloc.cpp @@ -7,6 +7,7 @@ ******************************************************************************/ #include "PassDetails.h" +#include "cudaq/Optimizer/Builder/Factory.h" #include "cudaq/Optimizer/Dialect/Quake/QuakeOps.h" #include "cudaq/Optimizer/Transforms/Passes.h" #include "mlir/Transforms/DialectConversion.h" @@ -42,21 +43,47 @@ class AllocaPat : public OpRewritePattern { for (std::size_t i = 0; i < size; ++i) newAllocs.emplace_back(rewriter.create(loc, refTy)); - // 2. Visit all users and replace them accordingly. - for (auto *user : allocOp->getUsers()) { - if (auto dealloc = dyn_cast(user)) { - rewriter.setInsertionPoint(dealloc); - auto deloc = dealloc.getLoc(); - for (std::size_t i = 0; i < size - 1; ++i) - rewriter.create(deloc, newAllocs[i]); - rewriter.replaceOpWithNewOp(dealloc, - newAllocs[size - 1]); - continue; + std::function rewriteOpAndUsers = + [&](Operation *op, std::int64_t start) -> LogicalResult { + // First handle the users. Note that this can recurse. + for (auto *user : op->getUsers()) { + if (auto dealloc = dyn_cast(user)) { + rewriter.setInsertionPoint(dealloc); + auto deloc = dealloc.getLoc(); + for (std::size_t i = 0; i < size - 1; ++i) + rewriter.create(deloc, newAllocs[i]); + rewriter.replaceOpWithNewOp(dealloc, + newAllocs[size - 1]); + continue; + } + if (auto subveq = dyn_cast(user)) { + auto lowInt = cudaq::opt::factory::getIntIfConstant(subveq.getLow()); + if (!lowInt) + return failure(); + for (auto *subUser : subveq->getUsers()) + if (failed(rewriteOpAndUsers(subUser, *lowInt))) + return failure(); + rewriter.eraseOp(subveq); + continue; + } + if (auto ext = dyn_cast(user)) { + auto index = ext.getConstantIndex(); + rewriter.replaceOp(ext, newAllocs[start + index].getResult()); + } } - auto ext = cast(user); - auto index = ext.getConstantIndex(); - rewriter.replaceOp(ext, newAllocs[index].getResult()); - } + // Now handle the base operation. + if (isa(op)) + rewriter.eraseOp(op); + else if (auto ext = dyn_cast(op)) { + auto index = ext.getConstantIndex(); + rewriter.replaceOp(ext, newAllocs[start + index].getResult()); + } + return success(); + }; + + // 2. Visit all users and replace them accordingly. + if (failed(rewriteOpAndUsers(allocOp, 0))) + return failure(); // 3. Remove the original alloca operation. rewriter.eraseOp(allocOp); @@ -165,17 +192,32 @@ class FactorQuantumAllocationsPass LogicalResult runAnalysis(SmallVector &allocations) { auto func = getOperation(); + std::function isUseConvertible = + [&](Operation *op) -> bool { + if (isa(op)) + return true; + if (auto ext = dyn_cast(op)) + if (ext.hasConstantIndex()) + return true; + if (auto sub = dyn_cast(op)) { + if (!cudaq::opt::factory::getIntIfConstant(sub.getLow()) || + !cudaq::opt::factory::getIntIfConstant(sub.getHigh())) + return false; + for (auto *subUser : sub->getUsers()) + if (!isUseConvertible(subUser)) + return false; + return true; + } + return false; + }; func.walk([&](quake::AllocaOp alloc) { if (!allocaOfVeq(alloc) || allocaOfUnspecifiedSize(alloc) || alloc.hasInitializedState()) return; bool usesAreConvertible = [&]() { for (auto *users : alloc->getUsers()) { - if (isa(users)) + if (isUseConvertible(users)) continue; - if (auto ext = dyn_cast(users)) - if (ext.hasConstantIndex()) - continue; return false; } return true; diff --git a/test/Quake/combine.qke b/test/Quake/combine.qke index eef81298f6..7d5139fd13 100644 --- a/test/Quake/combine.qke +++ b/test/Quake/combine.qke @@ -157,6 +157,15 @@ func.func @c() { // CHECK: return // CHECK: } +// FACTOR-LABEL: func.func @c() { +// FACTOR: %0 = quake.alloca !quake.ref +// FACTOR: %1 = quake.alloca !quake.ref +// FACTOR: %2 = quake.alloca !quake.ref +// FACTOR: %3 = quake.alloca !quake.ref +// FACTOR: quake.x %3 : (!quake.ref) -> () +// FACTOR: return +// FACTOR: } + func.func @d(%c2: i64, %c3: i64, %c1: i64, %off: i16) { %1 = quake.alloca !quake.veq<4> %2 = quake.subveq %1, %c2, %c3 : (!quake.veq<4>, i64, i64) -> !quake.veq<2>