diff --git a/include/cudaq/Optimizer/Transforms/Passes.td b/include/cudaq/Optimizer/Transforms/Passes.td index 38189117f2..616d63e376 100644 --- a/include/cudaq/Optimizer/Transforms/Passes.td +++ b/include/cudaq/Optimizer/Transforms/Passes.td @@ -975,6 +975,29 @@ def RegToMem : Pass<"regtomem", "mlir::func::FuncOp"> { let dependentDialects = ["cudaq::cc::CCDialect", "quake::QuakeDialect"]; } +def RemoveUselessStores : Pass<"remove-useless-stores", "mlir::func::FuncOp"> { + let summary = "Remove stores that are overriden by other stores."; + let description = [{ + + Example: + ```mlir + %1 = cc.alloca !cc.array + %2 = cc.cast %1 : (!cc.ptr>) -> !cc.ptr + cc.store %c0_i64, %2 : !cc.ptr + // nothing using %2 until the next instruction + cc.store %c1_i64, %2 : !cc.ptr + ``` + + would be converted to + + ```mlir + %1 = cc.alloca !cc.array + %2 = cc.cast %1 : (!cc.ptr>) -> !cc.ptr + cc.store %c1_i64, %2 : !cc.ptr + ``` + }]; +} + // UnitarySynthesis is a module pass because it may modify the `ModuleOp` by // adding new `FuncOp`(s). def UnitarySynthesis : Pass<"unitary-synthesis", "mlir::ModuleOp"> { diff --git a/lib/Optimizer/Transforms/CMakeLists.txt b/lib/Optimizer/Transforms/CMakeLists.txt index 0c7dad74c7..f159ce471a 100644 --- a/lib/Optimizer/Transforms/CMakeLists.txt +++ b/lib/Optimizer/Transforms/CMakeLists.txt @@ -52,6 +52,7 @@ add_cudaq_library(OptTransforms QuakeSynthesizer.cpp RefToVeqAlloc.cpp RegToMem.cpp + RemoveUselessStores.cpp StatePreparation.cpp UnitarySynthesis.cpp WiresToWiresets.cpp diff --git a/lib/Optimizer/Transforms/RemoveUselessStores.cpp b/lib/Optimizer/Transforms/RemoveUselessStores.cpp new file mode 100644 index 0000000000..fdacc9d4f6 --- /dev/null +++ b/lib/Optimizer/Transforms/RemoveUselessStores.cpp @@ -0,0 +1,174 @@ +/******************************************************************************* + * Copyright (c) 2022 - 2025 NVIDIA Corporation & Affiliates. * + * All rights reserved. * + * * + * This source code and the accompanying materials are made available under * + * the terms of the Apache License 2.0 which accompanies this distribution. * + ******************************************************************************/ + +#include "PassDetails.h" +#include "cudaq/Optimizer/Builder/Intrinsics.h" +#include "cudaq/Optimizer/Dialect/CC/CCOps.h" +#include "cudaq/Optimizer/Dialect/Quake/QuakeOps.h" +#include "cudaq/Optimizer/Transforms/Passes.h" +#include "mlir/Dialect/Complex/IR/Complex.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/Dominance.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "mlir/Transforms/Passes.h" + +namespace cudaq::opt { +#define GEN_PASS_DEF_REMOVEUSELESSSTORES +#include "cudaq/Optimizer/Transforms/Passes.h.inc" +} // namespace cudaq::opt + +#define DEBUG_TYPE "remove-useless-stores" + +using namespace mlir; + +namespace { +/// Remove stores followed by a store to the same pointer +/// if the pointer is not used in between. +/// ``` +/// cc.store %c0_i64, %1 : !cc.ptr +/// // no use of %1 until next line +/// cc.store %0, %1 : !cc.ptr +/// ─────────────────────────────────────────── +/// cc.store %0, %1 : !cc.ptr +/// ``` +class RemoveUselessStorePattern : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + + explicit RemoveUselessStorePattern(MLIRContext *ctx, DominanceInfo &di) + : OpRewritePattern(ctx), dom(di) {} + + LogicalResult matchAndRewrite(cudaq::cc::StoreOp store, + PatternRewriter &rewriter) const override { + if (isUselessStore(store)) { + rewriter.eraseOp(store); + return success(); + } + return failure(); + } + +private: + /// Detect if the current store is overriden by another store in the same + /// block. + bool isUselessStore(cudaq::cc::StoreOp store) const { + Value currentPtr; + + if (!isStoreToStack(store)) + return false; + + auto block = store.getOperation()->getBlock(); + for (auto &op : *block) { + if (auto currentStore = dyn_cast(&op)) { + auto nextPtr = currentStore.getPtrvalue(); + if (store == currentStore) { + // Start searching from the current store + currentPtr = nextPtr; + } else { + // Found an overriding store, the current store is useless + if (currentPtr == nextPtr) + return isReplacement(currentPtr, store, currentStore); + + // // Found a use for a current ptr before the overriding store + // if (currentPtr && isUsed(currentPtr, &op)) + // return false; + } + } + // } else { + // // Found a use for a current ptr before the overriding store + // if (currentPtr && isUsed(currentPtr, &op)) + // return false; + // } + } + // No multiple stores to the same location found + return false; + } + + /// Detect stores to stack locations, for example: + /// ``` + /// %1 = cc.alloca !cc.array + /// + /// %2 = cc.cast %1 : (!cc.ptr>) -> !cc.ptr + /// cc.store %c0_i64, %2 : !cc.ptr + /// + /// %3 = cc.compute_ptr %1[1] : (!cc.ptr>) -> !cc.ptr + /// cc.store %c0_i64, %3 : !cc.ptr + /// ``` + static bool isStoreToStack(cudaq::cc::StoreOp store) { + auto ptrOp = store.getPtrvalue(); + if (auto cast = ptrOp.getDefiningOp()) + ptrOp = cast.getOperand(); + + if (auto computePtr = ptrOp.getDefiningOp()) + ptrOp = computePtr.getBase(); + + if (auto alloca = ptrOp.getDefiningOp()) + return true; + + return false; + } + + /// Detect if value is used in the op or its nested blocks. + bool isReplacement(Value ptr, cudaq::cc::StoreOp store, + cudaq::cc::StoreOp replacement) const { + // Check that there are no stores dominated by the store and not dominated + // by the replacement (i.e. used in between the store and the replacement) + for (auto *user : ptr.getUsers()) { + if (user != store && user != replacement) { + if (dom.dominates(store, user) && !dom.dominates(replacement, user)) { + LLVM_DEBUG(llvm::dbgs() << "store " << replacement + << " is used before: " << store << '\n'); + return false; + } + } + } + return true; + } + + // /// Detect if value is used in the op or its nested blocks. + // static bool isUsed(Value v, Operation *op) { + // for (auto opnd : op->getOperands()) + // if (opnd == v) + // return true; + + // for (auto ®ion : op->getRegions()) + // for (auto &b : region) + // for (auto &innerOp : b) + // if (isUsed(v, &innerOp)) + // return true; + + // return false; + // } + + DominanceInfo &dom; +}; + +class RemoveUselessStoresPass + : public cudaq::opt::impl::RemoveUselessStoresBase< + RemoveUselessStoresPass> { +public: + using RemoveUselessStoresBase::RemoveUselessStoresBase; + + void runOnOperation() override { + auto *ctx = &getContext(); + auto func = getOperation(); + DominanceInfo domInfo(func); + RewritePatternSet patterns(ctx); + patterns.insert(ctx, domInfo); + + LLVM_DEBUG(llvm::dbgs() + << "Before removing useless stores: " << func << '\n'); + + if (failed(applyPatternsAndFoldGreedily(func, std::move(patterns)))) + signalPassFailure(); + + LLVM_DEBUG(llvm::dbgs() + << "After removing useless stores: " << func << '\n'); + } +}; +} // namespace \ No newline at end of file diff --git a/test/Quake/remove_useless_stores.qke b/test/Quake/remove_useless_stores.qke new file mode 100644 index 0000000000..481bdafbd7 --- /dev/null +++ b/test/Quake/remove_useless_stores.qke @@ -0,0 +1,84 @@ +// ========================================================================== // +// Copyright (c) 2022 - 2025 NVIDIA Corporation & Affiliates. // +// All rights reserved. // +// // +// This source code and the accompanying materials are made available under // +// the terms of the Apache License 2.0 which accompanies this distribution. // +// ========================================================================== // + +// RUN: cudaq-opt -remove-useless-stores %s | FileCheck %s + +func.func @test_two_stores_same_pointer() { + %c0_i64 = arith.constant 0 : i64 + %0 = quake.alloca !quake.veq<2> + %1 = cc.const_array [1] : !cc.array + %2 = cc.extract_value %1[0] : (!cc.array) -> i64 + %3 = cc.alloca !cc.array + %4 = cc.cast %3 : (!cc.ptr>) -> !cc.ptr + cc.store %c0_i64, %4 : !cc.ptr + cc.store %2, %4 : !cc.ptr + %5 = cc.load %4 : !cc.ptr + %6 = quake.extract_ref %0[%5] : (!quake.veq<2>, i64) -> !quake.ref + quake.x %6 : (!quake.ref) -> () + return +} + +// CHECK-LABEL: func.func @test_two_stores_same_pointer() { +// CHECK: %[[VAL_1:.*]] = quake.alloca !quake.veq<2> +// CHECK: %[[VAL_2:.*]] = cc.const_array [1] : !cc.array +// CHECK: %[[VAL_3:.*]] = cc.extract_value %[[VAL_2]][0] : (!cc.array) -> i64 +// CHECK: %[[VAL_4:.*]] = cc.alloca !cc.array +// CHECK: %[[VAL_5:.*]] = cc.cast %[[VAL_4]] : (!cc.ptr>) -> !cc.ptr +// CHECK: cc.store %[[VAL_3]], %[[VAL_5]] : !cc.ptr +// CHECK: %[[VAL_6:.*]] = cc.load %[[VAL_5]] : !cc.ptr +// CHECK: %[[VAL_7:.*]] = quake.extract_ref %[[VAL_1]][%[[VAL_6]]] : (!quake.veq<2>, i64) -> !quake.ref +// CHECK: quake.x %[[VAL_7]] : (!quake.ref) -> () +// CHECK: return +// CHECK: } + +func.func @test_two_stores_different_pointers() { + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %0 = quake.alloca !quake.veq<2> + %1 = cc.alloca !cc.array + %2 = cc.alloca i64 + cc.store %c0_i64, %2 : !cc.ptr + %3 = cc.alloca i64 + cc.store %c1_i64, %3 : !cc.ptr + return +} + +// CHECK-LABEL: func.func @test_two_stores_different_pointers() { +// CHECK: %[[VAL_0:.*]] = arith.constant 0 : i64 +// CHECK: %[[VAL_1:.*]] = arith.constant 1 : i64 +// CHECK: %[[VAL_2:.*]] = quake.alloca !quake.veq<2> +// CHECK: %[[VAL_3:.*]] = cc.alloca !cc.array +// CHECK: %[[VAL_4:.*]] = cc.alloca i64 +// CHECK: cc.store %[[VAL_0]], %[[VAL_4]] : !cc.ptr +// CHECK: %[[VAL_5:.*]] = cc.alloca i64 +// CHECK: cc.store %[[VAL_1]], %[[VAL_5]] : !cc.ptr +// CHECK: return +// CHECK: } + +func.func @test_two_stores_same_pointer_interleaving() { + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %1 = cc.alloca !cc.array + %2 = cc.cast %1 : (!cc.ptr>) -> !cc.ptr + cc.store %c0_i64, %2 : !cc.ptr + %3 = cc.compute_ptr %1[1] : (!cc.ptr>) -> !cc.ptr + cc.store %c0_i64, %3 : !cc.ptr + cc.store %c1_i64, %2 : !cc.ptr + cc.store %c1_i64, %3 : !cc.ptr + return +} + +// CHECK-LABEL: func.func @test_two_stores_same_pointer_interleaving() { +// CHECK: %[[VAL_0:.*]] = arith.constant 1 : i64 +// CHECK: %[[VAL_1:.*]] = cc.alloca !cc.array +// CHECK: %[[VAL_2:.*]] = cc.cast %[[VAL_1]] : (!cc.ptr>) -> !cc.ptr +// CHECK: %[[VAL_3:.*]] = cc.compute_ptr %[[VAL_1]][1] : (!cc.ptr>) -> !cc.ptr +// CHECK: cc.store %[[VAL_0]], %[[VAL_2]] : !cc.ptr +// CHECK: cc.store %[[VAL_0]], %[[VAL_3]] : !cc.ptr +// CHECK: return +// CHECK: }