diff --git a/lib/Optimizer/Transforms/LiftArrayAlloc.cpp b/lib/Optimizer/Transforms/LiftArrayAlloc.cpp index 5a1d003214..0d24253d5f 100644 --- a/lib/Optimizer/Transforms/LiftArrayAlloc.cpp +++ b/lib/Optimizer/Transforms/LiftArrayAlloc.cpp @@ -40,8 +40,10 @@ class AllocaPattern : public OpRewritePattern { return failure(); LLVM_DEBUG(llvm::dbgs() << "Candidate was found\n"); - auto eleTy = alloc.getElementType(); - auto arrTy = cast(eleTy); + auto allocTy = alloc.getElementType(); + auto arrTy = cast(allocTy); + auto eleTy = arrTy.getElementType(); + SmallVector values; // Every element of `stores` must be a cc::StoreOp with a ConstantOp as the @@ -89,12 +91,16 @@ class AllocaPattern : public OpRewritePattern { cannotEraseAlloc = isLive = true; } else { for (auto *useuser : user->getUsers()) { + if (!useuser) + continue; if (auto load = dyn_cast(useuser)) { rewriter.setInsertionPointAfter(useuser); LLVM_DEBUG(llvm::dbgs() << "replaced load\n"); - rewriter.replaceOpWithNewOp( - load, eleTy, conArr, + auto extractValue = rewriter.create( + loc, eleTy, conArr, ArrayRef{offset}); + rewriter.replaceAllUsesWith(load, extractValue); + insertOpToErase(load); continue; } if (isa(useuser)) { @@ -152,6 +158,7 @@ class AllocaPattern : public OpRewritePattern { SmallVector toGlobalUses; SmallVector> loadSets(size); + SmallVector> storeSets(size); auto getWriteOp = [&](auto op, std::int32_t index) -> Operation * { Operation *theStore = nullptr; @@ -160,6 +167,7 @@ class AllocaPattern : public OpRewritePattern { if (!u) return nullptr; if (auto store = dyn_cast(u)) { + storeSets[index].insert(u); if (op.getOperation() == store.getPtrvalue().getDefiningOp() && isa_and_present( store.getValue().getDefiningOp())) { @@ -255,6 +263,14 @@ class AllocaPattern : public OpRewritePattern { << " doesn't dominate load: " << *load << '\n'); return false; } + + for (auto *store : storeSets[i]) + if (scoreboard[i] != store && dom.dominates(scoreboard[i], store)) { + LLVM_DEBUG(llvm::dbgs() + << "store " << scoreboard[i] + << " dominates another store: " << *store << '\n'); + return false; + } } // For all global uses, all of the stores must dominate every use. diff --git a/test/Quake/lift_array.qke b/test/Quake/lift_array.qke index 00574f8e45..9faafef7f5 100644 --- a/test/Quake/lift_array.qke +++ b/test/Quake/lift_array.qke @@ -125,3 +125,50 @@ func.func @test2() -> !quake.veq<2> { // GLOBAL-DAG: cc.global constant private @__nvqpp__mlirgen__function_test_complex_constant_array._Z27test_complex_constant_arrayv.rodata_{{[0-9]+}} (dense<[(0.707106769,0.000000e+00), (0.707106769,0.000000e+00), (0.000000e+00,0.000000e+00), (0.000000e+00,0.000000e+00)]> : tensor<4xcomplex>) : !cc.array x 4> // GLOBAL-DAG: cc.global constant private @__nvqpp__mlirgen__function_custom_h_generator_1._Z20custom_h_generator_1v.rodata_{{[0-9]+}} (dense<[(0.70710678118654757,0.000000e+00), (0.70710678118654757,0.000000e+00), (0.70710678118654757,0.000000e+00), (-0.70710678118654757,0.000000e+00)]> : tensor<4xcomplex>) : !cc.array x 4> // GLOBAL-DAG: cc.global constant private @test2.rodata_{{[0-9]+}} (dense<[1.000000e+00, 2.000000e+00, 6.000000e+00, 9.000000e+00]>" : tensor<4xf64>) : !cc.array + +func.func @test3() { + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + + // qubits = cudaq.qvector(2) + %0 = quake.alloca !quake.veq<2> + + // arr1 = [1] + %1 = cc.alloca !cc.array + %2 = cc.cast %1 : (!cc.ptr>) -> !cc.ptr + cc.store %c1_i64, %2 : !cc.ptr + + // t = arr1[0] + %3 = cc.load %2 : !cc.ptr + + // arr2 = [0] + %4 = cc.alloca !cc.array + %5 = cc.cast %4 : (!cc.ptr>) -> !cc.ptr + cc.store %c0_i64, %5 : !cc.ptr // Dominates the next store, don't lift + + // arr2[0] = t + cc.store %3, %5 : !cc.ptr + + // b = arr2[0] + %6 = cc.load %5 : !cc.ptr + + // x(qubits[b]) + %7 = quake.extract_ref %0[%6] : (!quake.veq<2>, i64) -> !quake.ref + quake.x %7 : (!quake.ref) -> () + return +} + +// CHECK-LABEL: func.func @test3() { +// CHECK: %[[VAL_0:.*]] = arith.constant 0 : i64 +// CHECK: %[[VAL_1:.*]] = quake.alloca !quake.veq<2> +// CHECK: %[[VAL_2:.*]] = cc.const_array [1] : !cc.array +// CHECK: %[[VAL_3:.*]] = cc.extract_value %[[VAL_2]][0] : (!cc.array) -> i64 +// CHECK: %[[VAL_4:.*]] = cc.alloca !cc.array +// CHECK: %[[VAL_5:.*]] = cc.cast %[[VAL_4]] : (!cc.ptr>) -> !cc.ptr +// CHECK: cc.store %[[VAL_0]], %[[VAL_5]] : !cc.ptr +// CHECK: cc.store %[[VAL_3]], %[[VAL_5]] : !cc.ptr +// CHECK: %[[VAL_6:.*]] = cc.load %[[VAL_5]] : !cc.ptr +// CHECK: %[[VAL_7:.*]] = quake.extract_ref %[[VAL_1]][%[[VAL_6]]] : (!quake.veq<2>, i64) -> !quake.ref +// CHECK: quake.x %[[VAL_7]] : (!quake.ref) -> () +// CHECK: return +// CHECK: } \ No newline at end of file