Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[MLIR][LLVM] Fix memory explosion when converting global variable bodies in ModuleTranslation (#82708) #12

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
71 changes: 69 additions & 2 deletions mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -51,11 +51,15 @@
#include "llvm/IR/MDBuilder.h"
#include "llvm/IR/Module.h"
#include "llvm/IR/Verifier.h"
#include "llvm/Support/Debug.h"
#include "llvm/Support/raw_ostream.h"
#include "llvm/Transforms/Utils/BasicBlockUtils.h"
#include "llvm/Transforms/Utils/Cloning.h"
#include "llvm/Transforms/Utils/ModuleUtils.h"
#include <optional>

#define DEBUG_TYPE "llvm-dialect-to-llvm-ir"

using namespace mlir;
using namespace mlir::LLVM;
using namespace mlir::LLVM::detail;
Expand Down Expand Up @@ -1042,17 +1046,80 @@ LogicalResult ModuleTranslation::convertGlobals() {
for (auto op : getModuleBody(mlirModule).getOps<LLVM::GlobalOp>()) {
if (Block *initializer = op.getInitializerBlock()) {
llvm::IRBuilder<> builder(llvmModule->getContext());

int numConstantsHit = 0;
int numConstantsErased = 0;
DenseMap<llvm::ConstantAggregate *, int> constantAggregateUseMap;

for (auto &op : initializer->without_terminator()) {
if (failed(convertOperation(op, builder)) ||
!isa<llvm::Constant>(lookupValue(op.getResult(0))))
if (failed(convertOperation(op, builder)))
return emitError(op.getLoc(), "fail to convert global initializer");
auto *cst = dyn_cast<llvm::Constant>(lookupValue(op.getResult(0)));
if (!cst)
return emitError(op.getLoc(), "unemittable constant value");

// When emitting an LLVM constant, a new constant is created and the old
// constant may become dangling and take space. We should remove the
// dangling constants to avoid memory explosion especially for constant
// arrays whose number of elements is large.
// Because multiple operations may refer to the same constant, we need
// to count the number of uses of each constant array and remove it only
// when the count becomes zero.
if (auto *agg = dyn_cast<llvm::ConstantAggregate>(cst)) {
numConstantsHit++;
Value result = op.getResult(0);
int numUsers = std::distance(result.use_begin(), result.use_end());
auto [iterator, inserted] =
constantAggregateUseMap.try_emplace(agg, numUsers);
if (!inserted) {
// Key already exists, update the value
iterator->second += numUsers;
}
}
// Scan the operands of the operation to decrement the use count of
// constants. Erase the constant if the use count becomes zero.
for (Value v : op.getOperands()) {
auto cst = dyn_cast<llvm::ConstantAggregate>(lookupValue(v));
if (!cst)
continue;
auto iter = constantAggregateUseMap.find(cst);
assert(iter != constantAggregateUseMap.end() && "constant not found");
iter->second--;
if (iter->second == 0) {
// NOTE: cannot call removeDeadConstantUsers() here because it
// may remove the constant which has uses not be converted yet.
if (cst->user_empty()) {
cst->destroyConstant();
numConstantsErased++;
}
constantAggregateUseMap.erase(iter);
}
}
}

ReturnOp ret = cast<ReturnOp>(initializer->getTerminator());
llvm::Constant *cst =
cast<llvm::Constant>(lookupValue(ret.getOperand(0)));
auto *global = cast<llvm::GlobalVariable>(lookupGlobal(op));
if (!shouldDropGlobalInitializer(global->getLinkage(), cst))
global->setInitializer(cst);

// Try to remove the dangling constants again after all operations are
// converted.
for (auto it : constantAggregateUseMap) {
auto cst = it.first;
cst->removeDeadConstantUsers();
if (cst->user_empty()) {
cst->destroyConstant();
numConstantsErased++;
}
}

LLVM_DEBUG(llvm::dbgs()
<< "Convert initializer for " << op.getName() << "\n";
llvm::dbgs() << numConstantsHit << " new constants hit\n";
llvm::dbgs()
<< numConstantsErased << " dangling constants erased\n";);
}
}

Expand Down
72 changes: 72 additions & 0 deletions mlir/test/Target/LLVMIR/erase-dangling-constants.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
// RUN: mlir-translate -mlir-to-llvmir %s -debug-only=llvm-dialect-to-llvm-ir 2>&1 | FileCheck %s

// CHECK: Convert initializer for dup_const
// CHECK: 6 new constants hit
// CHECK: 3 dangling constants erased
// CHECK: Convert initializer for unique_const
// CHECK: 6 new constants hit
// CHECK: 5 dangling constants erased


// CHECK:@dup_const = global { [2 x double], [2 x double], [2 x double] } { [2 x double] [double 3.612250e-02, double 5.119230e-02], [2 x double] [double 3.612250e-02, double 5.119230e-02], [2 x double] [double 3.612250e-02, double 5.119230e-02] }

llvm.mlir.global @dup_const() : !llvm.struct<(array<2 x f64>, array<2 x f64>, array<2 x f64>)> {
%c0 = llvm.mlir.constant(3.612250e-02 : f64) : f64
%c1 = llvm.mlir.constant(5.119230e-02 : f64) : f64

%empty0 = llvm.mlir.undef : !llvm.array<2 x f64>
%a00 = llvm.insertvalue %c0, %empty0[0] : !llvm.array<2 x f64>

%empty1 = llvm.mlir.undef : !llvm.array<2 x f64>
%a10 = llvm.insertvalue %c0, %empty1[0] : !llvm.array<2 x f64>

%empty2 = llvm.mlir.undef : !llvm.array<2 x f64>
%a20 = llvm.insertvalue %c0, %empty2[0] : !llvm.array<2 x f64>

// NOTE: a00, a10, a20 are all same ConstantAggregate which not used at this point.
// should not delete it before all of the uses of the ConstantAggregate finished.

%a01 = llvm.insertvalue %c1, %a00[1] : !llvm.array<2 x f64>
%a11 = llvm.insertvalue %c1, %a10[1] : !llvm.array<2 x f64>
%a21 = llvm.insertvalue %c1, %a20[1] : !llvm.array<2 x f64>
%empty_r = llvm.mlir.undef : !llvm.struct<(array<2 x f64>, array<2 x f64>, array<2 x f64>)>
%r0 = llvm.insertvalue %a01, %empty_r[0] : !llvm.struct<(array<2 x f64>, array<2 x f64>, array<2 x f64>)>
%r1 = llvm.insertvalue %a11, %r0[1] : !llvm.struct<(array<2 x f64>, array<2 x f64>, array<2 x f64>)>
%r2 = llvm.insertvalue %a21, %r1[2] : !llvm.struct<(array<2 x f64>, array<2 x f64>, array<2 x f64>)>

llvm.return %r2 : !llvm.struct<(array<2 x f64>, array<2 x f64>, array<2 x f64>)>
}

// CHECK:@unique_const = global { [2 x double], [2 x double], [2 x double] } { [2 x double] [double 3.612250e-02, double 5.119230e-02], [2 x double] [double 3.312250e-02, double 5.219230e-02], [2 x double] [double 3.412250e-02, double 5.419230e-02] }

llvm.mlir.global @unique_const() : !llvm.struct<(array<2 x f64>, array<2 x f64>, array<2 x f64>)> {
%c0 = llvm.mlir.constant(3.612250e-02 : f64) : f64
%c1 = llvm.mlir.constant(5.119230e-02 : f64) : f64

%c2 = llvm.mlir.constant(3.312250e-02 : f64) : f64
%c3 = llvm.mlir.constant(5.219230e-02 : f64) : f64

%c4 = llvm.mlir.constant(3.412250e-02 : f64) : f64
%c5 = llvm.mlir.constant(5.419230e-02 : f64) : f64

%2 = llvm.mlir.undef : !llvm.struct<(array<2 x f64>, array<2 x f64>, array<2 x f64>)>

%3 = llvm.mlir.undef : !llvm.array<2 x f64>

%4 = llvm.insertvalue %c0, %3[0] : !llvm.array<2 x f64>
%5 = llvm.insertvalue %c1, %4[1] : !llvm.array<2 x f64>

%6 = llvm.insertvalue %5, %2[0] : !llvm.struct<(array<2 x f64>, array<2 x f64>, array<2 x f64>)>

%7 = llvm.insertvalue %c2, %3[0] : !llvm.array<2 x f64>
%8 = llvm.insertvalue %c3, %7[1] : !llvm.array<2 x f64>

%9 = llvm.insertvalue %8, %6[1] : !llvm.struct<(array<2 x f64>, array<2 x f64>, array<2 x f64>)>

%10 = llvm.insertvalue %c4, %3[0] : !llvm.array<2 x f64>
%11 = llvm.insertvalue %c5, %10[1] : !llvm.array<2 x f64>

%12 = llvm.insertvalue %11, %9[2] : !llvm.struct<(array<2 x f64>, array<2 x f64>, array<2 x f64>)>

llvm.return %12 : !llvm.struct<(array<2 x f64>, array<2 x f64>, array<2 x f64>)>
}