Skip to content

Commit

Permalink
Fix issues in lif-array-alloc
Browse files Browse the repository at this point in the history
  • Loading branch information
annagrin committed Jan 31, 2025
1 parent d893046 commit c1592b8
Show file tree
Hide file tree
Showing 2 changed files with 67 additions and 4 deletions.
24 changes: 20 additions & 4 deletions lib/Optimizer/Transforms/LiftArrayAlloc.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,10 @@ class AllocaPattern : public OpRewritePattern<cudaq::cc::AllocaOp> {
return failure();

LLVM_DEBUG(llvm::dbgs() << "Candidate was found\n");
auto eleTy = alloc.getElementType();
auto arrTy = cast<cudaq::cc::ArrayType>(eleTy);
auto allocTy = alloc.getElementType();
auto arrTy = cast<cudaq::cc::ArrayType>(allocTy);
auto eleTy = arrTy.getElementType();

SmallVector<Attribute> values;

// Every element of `stores` must be a cc::StoreOp with a ConstantOp as the
Expand Down Expand Up @@ -89,12 +91,16 @@ class AllocaPattern : public OpRewritePattern<cudaq::cc::AllocaOp> {
cannotEraseAlloc = isLive = true;
} else {
for (auto *useuser : user->getUsers()) {
if (!useuser)
continue;
if (auto load = dyn_cast<cudaq::cc::LoadOp>(useuser)) {
rewriter.setInsertionPointAfter(useuser);
LLVM_DEBUG(llvm::dbgs() << "replaced load\n");
rewriter.replaceOpWithNewOp<cudaq::cc::ExtractValueOp>(
load, eleTy, conArr,
auto extractValue = rewriter.create<cudaq::cc::ExtractValueOp>(
loc, eleTy, conArr,
ArrayRef<cudaq::cc::ExtractValueArg>{offset});
rewriter.replaceAllUsesWith(load, extractValue);
insertOpToErase(load);
continue;
}
if (isa<cudaq::cc::StoreOp>(useuser)) {
Expand Down Expand Up @@ -152,6 +158,7 @@ class AllocaPattern : public OpRewritePattern<cudaq::cc::AllocaOp> {

SmallVector<Operation *> toGlobalUses;
SmallVector<SmallPtrSet<Operation *, 2>> loadSets(size);
SmallVector<SmallPtrSet<Operation *, 2>> storeSets(size);

auto getWriteOp = [&](auto op, std::int32_t index) -> Operation * {
Operation *theStore = nullptr;
Expand All @@ -160,6 +167,7 @@ class AllocaPattern : public OpRewritePattern<cudaq::cc::AllocaOp> {
if (!u)
return nullptr;
if (auto store = dyn_cast<cudaq::cc::StoreOp>(u)) {
storeSets[index].insert(u);
if (op.getOperation() == store.getPtrvalue().getDefiningOp() &&
isa_and_present<arith::ConstantOp, complex::ConstantOp>(
store.getValue().getDefiningOp())) {
Expand Down Expand Up @@ -255,6 +263,14 @@ class AllocaPattern : public OpRewritePattern<cudaq::cc::AllocaOp> {
<< " doesn't dominate load: " << *load << '\n');
return false;
}

for (auto *store : storeSets[i])
if (scoreboard[i] != store && dom.dominates(scoreboard[i], store)) {
LLVM_DEBUG(llvm::dbgs()
<< "store " << scoreboard[i]
<< " dominates another store: " << *store << '\n');
return false;
}
}

// For all global uses, all of the stores must dominate every use.
Expand Down
47 changes: 47 additions & 0 deletions test/Quake/lift_array.qke
Original file line number Diff line number Diff line change
Expand Up @@ -125,3 +125,50 @@ func.func @test2() -> !quake.veq<2> {
// GLOBAL-DAG: cc.global constant private @__nvqpp__mlirgen__function_test_complex_constant_array._Z27test_complex_constant_arrayv.rodata_{{[0-9]+}} (dense<[(0.707106769,0.000000e+00), (0.707106769,0.000000e+00), (0.000000e+00,0.000000e+00), (0.000000e+00,0.000000e+00)]> : tensor<4xcomplex<f32>>) : !cc.array<complex<f32> x 4>
// GLOBAL-DAG: cc.global constant private @__nvqpp__mlirgen__function_custom_h_generator_1._Z20custom_h_generator_1v.rodata_{{[0-9]+}} (dense<[(0.70710678118654757,0.000000e+00), (0.70710678118654757,0.000000e+00), (0.70710678118654757,0.000000e+00), (-0.70710678118654757,0.000000e+00)]> : tensor<4xcomplex<f64>>) : !cc.array<complex<f64> x 4>
// GLOBAL-DAG: cc.global constant private @test2.rodata_{{[0-9]+}} (dense<[1.000000e+00, 2.000000e+00, 6.000000e+00, 9.000000e+00]>" : tensor<4xf64>) : !cc.array<f64 x 4>

func.func @test3() {
%c0_i64 = arith.constant 0 : i64
%c1_i64 = arith.constant 1 : i64

// qubits = cudaq.qvector(2)
%0 = quake.alloca !quake.veq<2>

// arr1 = [1]
%1 = cc.alloca !cc.array<i64 x 1>
%2 = cc.cast %1 : (!cc.ptr<!cc.array<i64 x 1>>) -> !cc.ptr<i64>
cc.store %c1_i64, %2 : !cc.ptr<i64>

// t = arr1[0]
%3 = cc.load %2 : !cc.ptr<i64>

// arr2 = [0]
%4 = cc.alloca !cc.array<i64 x 1>
%5 = cc.cast %4 : (!cc.ptr<!cc.array<i64 x 1>>) -> !cc.ptr<i64>
cc.store %c0_i64, %5 : !cc.ptr<i64> // Dominates the next store, don't lift

// arr2[0] = t
cc.store %3, %5 : !cc.ptr<i64>

// b = arr2[0]
%6 = cc.load %5 : !cc.ptr<i64>

// x(qubits[b])
%7 = quake.extract_ref %0[%6] : (!quake.veq<2>, i64) -> !quake.ref
quake.x %7 : (!quake.ref) -> ()
return
}

// CHECK-LABEL: func.func @test3() {
// CHECK: %[[VAL_0:.*]] = arith.constant 0 : i64
// CHECK: %[[VAL_1:.*]] = quake.alloca !quake.veq<2>
// CHECK: %[[VAL_2:.*]] = cc.const_array [1] : !cc.array<i64 x 1>
// CHECK: %[[VAL_3:.*]] = cc.extract_value %[[VAL_2]][0] : (!cc.array<i64 x 1>) -> i64
// CHECK: %[[VAL_4:.*]] = cc.alloca !cc.array<i64 x 1>
// CHECK: %[[VAL_5:.*]] = cc.cast %[[VAL_4]] : (!cc.ptr<!cc.array<i64 x 1>>) -> !cc.ptr<i64>
// CHECK: cc.store %[[VAL_0]], %[[VAL_5]] : !cc.ptr<i64>
// CHECK: cc.store %[[VAL_3]], %[[VAL_5]] : !cc.ptr<i64>
// CHECK: %[[VAL_6:.*]] = cc.load %[[VAL_5]] : !cc.ptr<i64>
// CHECK: %[[VAL_7:.*]] = quake.extract_ref %[[VAL_1]][%[[VAL_6]]] : (!quake.veq<2>, i64) -> !quake.ref
// CHECK: quake.x %[[VAL_7]] : (!quake.ref) -> ()
// CHECK: return
// CHECK: }

0 comments on commit c1592b8

Please sign in to comment.