Skip to content

Commit

Permalink
Merge branch 'main' into ch-use.llvm.adts
Browse files Browse the repository at this point in the history
  • Loading branch information
schweitzpgi authored Aug 12, 2024
2 parents 666aae0 + 73a11a4 commit 22737c1
Show file tree
Hide file tree
Showing 2 changed files with 69 additions and 18 deletions.
78 changes: 60 additions & 18 deletions lib/Optimizer/Transforms/FactorQuantumAlloc.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
******************************************************************************/

#include "PassDetails.h"
#include "cudaq/Optimizer/Builder/Factory.h"
#include "cudaq/Optimizer/Dialect/Quake/QuakeOps.h"
#include "cudaq/Optimizer/Transforms/Passes.h"
#include "mlir/Transforms/DialectConversion.h"
Expand Down Expand Up @@ -42,21 +43,47 @@ class AllocaPat : public OpRewritePattern<quake::AllocaOp> {
for (std::size_t i = 0; i < size; ++i)
newAllocs.emplace_back(rewriter.create<quake::AllocaOp>(loc, refTy));

// 2. Visit all users and replace them accordingly.
for (auto *user : allocOp->getUsers()) {
if (auto dealloc = dyn_cast<quake::DeallocOp>(user)) {
rewriter.setInsertionPoint(dealloc);
auto deloc = dealloc.getLoc();
for (std::size_t i = 0; i < size - 1; ++i)
rewriter.create<quake::DeallocOp>(deloc, newAllocs[i]);
rewriter.replaceOpWithNewOp<quake::DeallocOp>(dealloc,
newAllocs[size - 1]);
continue;
std::function<LogicalResult(Operation *, std::int64_t)> rewriteOpAndUsers =
[&](Operation *op, std::int64_t start) -> LogicalResult {
// First handle the users. Note that this can recurse.
for (auto *user : op->getUsers()) {
if (auto dealloc = dyn_cast<quake::DeallocOp>(user)) {
rewriter.setInsertionPoint(dealloc);
auto deloc = dealloc.getLoc();
for (std::size_t i = 0; i < size - 1; ++i)
rewriter.create<quake::DeallocOp>(deloc, newAllocs[i]);
rewriter.replaceOpWithNewOp<quake::DeallocOp>(dealloc,
newAllocs[size - 1]);
continue;
}
if (auto subveq = dyn_cast<quake::SubVeqOp>(user)) {
auto lowInt = cudaq::opt::factory::getIntIfConstant(subveq.getLow());
if (!lowInt)
return failure();
for (auto *subUser : subveq->getUsers())
if (failed(rewriteOpAndUsers(subUser, *lowInt)))
return failure();
rewriter.eraseOp(subveq);
continue;
}
if (auto ext = dyn_cast<quake::ExtractRefOp>(user)) {
auto index = ext.getConstantIndex();
rewriter.replaceOp(ext, newAllocs[start + index].getResult());
}
}
auto ext = cast<quake::ExtractRefOp>(user);
auto index = ext.getConstantIndex();
rewriter.replaceOp(ext, newAllocs[index].getResult());
}
// Now handle the base operation.
if (isa<quake::SubVeqOp>(op))
rewriter.eraseOp(op);
else if (auto ext = dyn_cast<quake::ExtractRefOp>(op)) {
auto index = ext.getConstantIndex();
rewriter.replaceOp(ext, newAllocs[start + index].getResult());
}
return success();
};

// 2. Visit all users and replace them accordingly.
if (failed(rewriteOpAndUsers(allocOp, 0)))
return failure();

// 3. Remove the original alloca operation.
rewriter.eraseOp(allocOp);
Expand Down Expand Up @@ -165,17 +192,32 @@ class FactorQuantumAllocationsPass

LogicalResult runAnalysis(SmallVector<quake::AllocaOp> &allocations) {
auto func = getOperation();
std::function<bool(Operation *)> isUseConvertible =
[&](Operation *op) -> bool {
if (isa<quake::DeallocOp>(op))
return true;
if (auto ext = dyn_cast<quake::ExtractRefOp>(op))
if (ext.hasConstantIndex())
return true;
if (auto sub = dyn_cast<quake::SubVeqOp>(op)) {
if (!cudaq::opt::factory::getIntIfConstant(sub.getLow()) ||
!cudaq::opt::factory::getIntIfConstant(sub.getHigh()))
return false;
for (auto *subUser : sub->getUsers())
if (!isUseConvertible(subUser))
return false;
return true;
}
return false;
};
func.walk([&](quake::AllocaOp alloc) {
if (!allocaOfVeq(alloc) || allocaOfUnspecifiedSize(alloc) ||
alloc.hasInitializedState())
return;
bool usesAreConvertible = [&]() {
for (auto *users : alloc->getUsers()) {
if (isa<quake::DeallocOp>(users))
if (isUseConvertible(users))
continue;
if (auto ext = dyn_cast<quake::ExtractRefOp>(users))
if (ext.hasConstantIndex())
continue;
return false;
}
return true;
Expand Down
9 changes: 9 additions & 0 deletions test/Quake/combine.qke
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,15 @@ func.func @c() {
// CHECK: return
// CHECK: }

// FACTOR-LABEL: func.func @c() {
// FACTOR: %0 = quake.alloca !quake.ref
// FACTOR: %1 = quake.alloca !quake.ref
// FACTOR: %2 = quake.alloca !quake.ref
// FACTOR: %3 = quake.alloca !quake.ref
// FACTOR: quake.x %3 : (!quake.ref) -> ()
// FACTOR: return
// FACTOR: }

func.func @d(%c2: i64, %c3: i64, %c1: i64, %off: i16) {
%1 = quake.alloca !quake.veq<4>
%2 = quake.subveq %1, %c2, %c3 : (!quake.veq<4>, i64, i64) -> !quake.veq<2>
Expand Down

0 comments on commit 22737c1

Please sign in to comment.