Skip to content

Commit

Permalink
Add RemoveUselessStores pass
Browse files Browse the repository at this point in the history
Signed-off-by: Anna Gringauze <agringauze@nvidia.com>
  • Loading branch information
annagrin committed Feb 1, 2025
1 parent d893046 commit 9b55600
Show file tree
Hide file tree
Showing 4 changed files with 282 additions and 0 deletions.
23 changes: 23 additions & 0 deletions include/cudaq/Optimizer/Transforms/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -975,6 +975,29 @@ def RegToMem : Pass<"regtomem", "mlir::func::FuncOp"> {
let dependentDialects = ["cudaq::cc::CCDialect", "quake::QuakeDialect"];
}

def RemoveUselessStores : Pass<"remove-useless-stores", "mlir::func::FuncOp"> {
let summary = "Remove stores that are overriden by other stores.";
let description = [{

Example:
```mlir
%1 = cc.alloca !cc.array<i64 x 1>
%2 = cc.cast %1 : (!cc.ptr<!cc.array<i64 x 1>>) -> !cc.ptr<i64>
cc.store %c0_i64, %2 : !cc.ptr<i64>
// nothing using %2 until the next instruction
cc.store %c1_i64, %2 : !cc.ptr<i64>
```

would be converted to

```mlir
%1 = cc.alloca !cc.array<i64 x 1>
%2 = cc.cast %1 : (!cc.ptr<!cc.array<i64 x 1>>) -> !cc.ptr<i64>
cc.store %c1_i64, %2 : !cc.ptr<i64>
```
}];
}

// UnitarySynthesis is a module pass because it may modify the `ModuleOp` by
// adding new `FuncOp`(s).
def UnitarySynthesis : Pass<"unitary-synthesis", "mlir::ModuleOp"> {
Expand Down
1 change: 1 addition & 0 deletions lib/Optimizer/Transforms/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ add_cudaq_library(OptTransforms
QuakeSynthesizer.cpp
RefToVeqAlloc.cpp
RegToMem.cpp
RemoveUselessStores.cpp
StatePreparation.cpp
UnitarySynthesis.cpp
WiresToWiresets.cpp
Expand Down
174 changes: 174 additions & 0 deletions lib/Optimizer/Transforms/RemoveUselessStores.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,174 @@
/*******************************************************************************
* Copyright (c) 2022 - 2025 NVIDIA Corporation & Affiliates. *
* All rights reserved. *
* *
* This source code and the accompanying materials are made available under *
* the terms of the Apache License 2.0 which accompanies this distribution. *
******************************************************************************/

#include "PassDetails.h"
#include "cudaq/Optimizer/Builder/Intrinsics.h"
#include "cudaq/Optimizer/Dialect/CC/CCOps.h"
#include "cudaq/Optimizer/Dialect/Quake/QuakeOps.h"
#include "cudaq/Optimizer/Transforms/Passes.h"
#include "mlir/Dialect/Complex/IR/Complex.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/Dominance.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "mlir/Transforms/Passes.h"

namespace cudaq::opt {
#define GEN_PASS_DEF_REMOVEUSELESSSTORES
#include "cudaq/Optimizer/Transforms/Passes.h.inc"
} // namespace cudaq::opt

#define DEBUG_TYPE "remove-useless-stores"

using namespace mlir;

namespace {
/// Remove stores followed by a store to the same pointer
/// if the pointer is not used in between.
/// ```
/// cc.store %c0_i64, %1 : !cc.ptr<i64>
/// // no use of %1 until next line
/// cc.store %0, %1 : !cc.ptr<i64>
/// ───────────────────────────────────────────
/// cc.store %0, %1 : !cc.ptr<i64>
/// ```
class RemoveUselessStorePattern : public OpRewritePattern<cudaq::cc::StoreOp> {
public:
using OpRewritePattern::OpRewritePattern;

explicit RemoveUselessStorePattern(MLIRContext *ctx, DominanceInfo &di)
: OpRewritePattern(ctx), dom(di) {}

LogicalResult matchAndRewrite(cudaq::cc::StoreOp store,
PatternRewriter &rewriter) const override {
if (isUselessStore(store)) {
rewriter.eraseOp(store);
return success();
}
return failure();
}

private:
/// Detect if the current store is overriden by another store in the same
/// block.
bool isUselessStore(cudaq::cc::StoreOp store) const {
Value currentPtr;

if (!isStoreToStack(store))
return false;

auto block = store.getOperation()->getBlock();
for (auto &op : *block) {
if (auto currentStore = dyn_cast<cudaq::cc::StoreOp>(&op)) {
auto nextPtr = currentStore.getPtrvalue();
if (store == currentStore) {
// Start searching from the current store
currentPtr = nextPtr;
} else {
// Found an overriding store, the current store is useless
if (currentPtr == nextPtr)
return isReplacement(currentPtr, store, currentStore);

// // Found a use for a current ptr before the overriding store
// if (currentPtr && isUsed(currentPtr, &op))
// return false;
}
}
// } else {
// // Found a use for a current ptr before the overriding store
// if (currentPtr && isUsed(currentPtr, &op))
// return false;
// }
}
// No multiple stores to the same location found
return false;
}

/// Detect stores to stack locations, for example:
/// ```
/// %1 = cc.alloca !cc.array<i64 x 2>
///
/// %2 = cc.cast %1 : (!cc.ptr<!cc.array<i64 x 2>>) -> !cc.ptr<i64>
/// cc.store %c0_i64, %2 : !cc.ptr<i64>
///
/// %3 = cc.compute_ptr %1[1] : (!cc.ptr<!cc.array<i64 x 2>>) -> !cc.ptr<i64>
/// cc.store %c0_i64, %3 : !cc.ptr<i64>
/// ```
static bool isStoreToStack(cudaq::cc::StoreOp store) {
auto ptrOp = store.getPtrvalue();
if (auto cast = ptrOp.getDefiningOp<cudaq::cc::CastOp>())
ptrOp = cast.getOperand();

if (auto computePtr = ptrOp.getDefiningOp<cudaq::cc::ComputePtrOp>())
ptrOp = computePtr.getBase();

if (auto alloca = ptrOp.getDefiningOp<cudaq::cc::AllocaOp>())
return true;

return false;
}

/// Detect if value is used in the op or its nested blocks.
bool isReplacement(Value ptr, cudaq::cc::StoreOp store,
cudaq::cc::StoreOp replacement) const {
// Check that there are no stores dominated by the store and not dominated
// by the replacement (i.e. used in between the store and the replacement)
for (auto *user : ptr.getUsers()) {
if (user != store && user != replacement) {
if (dom.dominates(store, user) && !dom.dominates(replacement, user)) {
LLVM_DEBUG(llvm::dbgs() << "store " << replacement
<< " is used before: " << store << '\n');
return false;
}
}
}
return true;
}

// /// Detect if value is used in the op or its nested blocks.
// static bool isUsed(Value v, Operation *op) {
// for (auto opnd : op->getOperands())
// if (opnd == v)
// return true;

// for (auto &region : op->getRegions())
// for (auto &b : region)
// for (auto &innerOp : b)
// if (isUsed(v, &innerOp))
// return true;

// return false;
// }

DominanceInfo &dom;
};

class RemoveUselessStoresPass
: public cudaq::opt::impl::RemoveUselessStoresBase<
RemoveUselessStoresPass> {
public:
using RemoveUselessStoresBase::RemoveUselessStoresBase;

void runOnOperation() override {
auto *ctx = &getContext();
auto func = getOperation();
DominanceInfo domInfo(func);
RewritePatternSet patterns(ctx);
patterns.insert<RemoveUselessStorePattern>(ctx, domInfo);

LLVM_DEBUG(llvm::dbgs()
<< "Before removing useless stores: " << func << '\n');

if (failed(applyPatternsAndFoldGreedily(func, std::move(patterns))))
signalPassFailure();

LLVM_DEBUG(llvm::dbgs()
<< "After removing useless stores: " << func << '\n');
}
};
} // namespace
84 changes: 84 additions & 0 deletions test/Quake/remove_useless_stores.qke
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
// ========================================================================== //
// Copyright (c) 2022 - 2025 NVIDIA Corporation & Affiliates. //
// All rights reserved. //
// //
// This source code and the accompanying materials are made available under //
// the terms of the Apache License 2.0 which accompanies this distribution. //
// ========================================================================== //

// RUN: cudaq-opt -remove-useless-stores %s | FileCheck %s

func.func @test_two_stores_same_pointer() {
%c0_i64 = arith.constant 0 : i64
%0 = quake.alloca !quake.veq<2>
%1 = cc.const_array [1] : !cc.array<i64 x 1>
%2 = cc.extract_value %1[0] : (!cc.array<i64 x 1>) -> i64
%3 = cc.alloca !cc.array<i64 x 1>
%4 = cc.cast %3 : (!cc.ptr<!cc.array<i64 x 1>>) -> !cc.ptr<i64>
cc.store %c0_i64, %4 : !cc.ptr<i64>
cc.store %2, %4 : !cc.ptr<i64>
%5 = cc.load %4 : !cc.ptr<i64>
%6 = quake.extract_ref %0[%5] : (!quake.veq<2>, i64) -> !quake.ref
quake.x %6 : (!quake.ref) -> ()
return
}

// CHECK-LABEL: func.func @test_two_stores_same_pointer() {
// CHECK: %[[VAL_1:.*]] = quake.alloca !quake.veq<2>
// CHECK: %[[VAL_2:.*]] = cc.const_array [1] : !cc.array<i64 x 1>
// CHECK: %[[VAL_3:.*]] = cc.extract_value %[[VAL_2]][0] : (!cc.array<i64 x 1>) -> i64
// CHECK: %[[VAL_4:.*]] = cc.alloca !cc.array<i64 x 1>
// CHECK: %[[VAL_5:.*]] = cc.cast %[[VAL_4]] : (!cc.ptr<!cc.array<i64 x 1>>) -> !cc.ptr<i64>
// CHECK: cc.store %[[VAL_3]], %[[VAL_5]] : !cc.ptr<i64>
// CHECK: %[[VAL_6:.*]] = cc.load %[[VAL_5]] : !cc.ptr<i64>
// CHECK: %[[VAL_7:.*]] = quake.extract_ref %[[VAL_1]][%[[VAL_6]]] : (!quake.veq<2>, i64) -> !quake.ref
// CHECK: quake.x %[[VAL_7]] : (!quake.ref) -> ()
// CHECK: return
// CHECK: }

func.func @test_two_stores_different_pointers() {
%c0_i64 = arith.constant 0 : i64
%c1_i64 = arith.constant 1 : i64
%0 = quake.alloca !quake.veq<2>
%1 = cc.alloca !cc.array<i64 x 1>
%2 = cc.alloca i64
cc.store %c0_i64, %2 : !cc.ptr<i64>
%3 = cc.alloca i64
cc.store %c1_i64, %3 : !cc.ptr<i64>
return
}

// CHECK-LABEL: func.func @test_two_stores_different_pointers() {
// CHECK: %[[VAL_0:.*]] = arith.constant 0 : i64
// CHECK: %[[VAL_1:.*]] = arith.constant 1 : i64
// CHECK: %[[VAL_2:.*]] = quake.alloca !quake.veq<2>
// CHECK: %[[VAL_3:.*]] = cc.alloca !cc.array<i64 x 1>
// CHECK: %[[VAL_4:.*]] = cc.alloca i64
// CHECK: cc.store %[[VAL_0]], %[[VAL_4]] : !cc.ptr<i64>
// CHECK: %[[VAL_5:.*]] = cc.alloca i64
// CHECK: cc.store %[[VAL_1]], %[[VAL_5]] : !cc.ptr<i64>
// CHECK: return
// CHECK: }

func.func @test_two_stores_same_pointer_interleaving() {
%c0_i64 = arith.constant 0 : i64
%c1_i64 = arith.constant 1 : i64
%1 = cc.alloca !cc.array<i64 x 2>
%2 = cc.cast %1 : (!cc.ptr<!cc.array<i64 x 2>>) -> !cc.ptr<i64>
cc.store %c0_i64, %2 : !cc.ptr<i64>
%3 = cc.compute_ptr %1[1] : (!cc.ptr<!cc.array<i64 x 2>>) -> !cc.ptr<i64>
cc.store %c0_i64, %3 : !cc.ptr<i64>
cc.store %c1_i64, %2 : !cc.ptr<i64>
cc.store %c1_i64, %3 : !cc.ptr<i64>
return
}

// CHECK-LABEL: func.func @test_two_stores_same_pointer_interleaving() {
// CHECK: %[[VAL_0:.*]] = arith.constant 1 : i64
// CHECK: %[[VAL_1:.*]] = cc.alloca !cc.array<i64 x 2>
// CHECK: %[[VAL_2:.*]] = cc.cast %[[VAL_1]] : (!cc.ptr<!cc.array<i64 x 2>>) -> !cc.ptr<i64>
// CHECK: %[[VAL_3:.*]] = cc.compute_ptr %[[VAL_1]][1] : (!cc.ptr<!cc.array<i64 x 2>>) -> !cc.ptr<i64>
// CHECK: cc.store %[[VAL_0]], %[[VAL_2]] : !cc.ptr<i64>
// CHECK: cc.store %[[VAL_0]], %[[VAL_3]] : !cc.ptr<i64>
// CHECK: return
// CHECK: }

0 comments on commit 9b55600

Please sign in to comment.