Skip to content

Commit

Permalink
refactor postloop preparation (#310)
Browse files Browse the repository at this point in the history
  • Loading branch information
ckrause authored Jan 3, 2024
1 parent c52c4af commit 0349bdd
Show file tree
Hide file tree
Showing 2 changed files with 44 additions and 35 deletions.
76 changes: 41 additions & 35 deletions src/form/formula_gen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -212,7 +212,7 @@ bool FormulaGenerator::generateSingle(const Program& p) {

// initialize expressions for memory cells
initFormula(numCells, false);
std::map<int64_t, Expression> preloop_exprs;
std::map<int64_t, Expression> preloopExprs;
if (useIncEval) {
// TODO: remove this limitation
if (incEval.getInputDependentCells().size() > 1 &&
Expand All @@ -227,7 +227,7 @@ bool FormulaGenerator::generateSingle(const Program& p) {
for (auto cell : incEval.getInputDependentCells()) {
auto op = Operand(Operand::Type::DIRECT, Number(cell));
auto param = operandToExpression(op);
preloop_exprs[cell] = formula.entries[param];
preloopExprs[cell] = formula.entries[param];
}
initFormula(numCells, true);
}
Expand Down Expand Up @@ -287,37 +287,7 @@ bool FormulaGenerator::generateSingle(const Program& p) {
}

// prepare post-loop processing
auto preloop_counter = preloop_exprs.at(incEval.getSimpleLoop().counter);
for (int64_t cell = 0; cell < numCells; cell++) {
auto name = newName();
auto left = ExpressionUtil::newFunction(name);
Expression right;
if (cell == incEval.getSimpleLoop().counter) {
auto last = ExpressionUtil::newConstant(0);
if (incEval.getLoopCounterDecrement() > 1) {
auto loop_dec =
ExpressionUtil::newConstant(incEval.getLoopCounterDecrement());
last = Expression(Expression::Type::MODULUS, "",
{preloop_counter, loop_dec});
}
right = Expression(Expression::Type::FUNCTION, "min",
{preloop_counter, last});
} else if (incEval.getInputDependentCells().find(cell) !=
incEval.getInputDependentCells().end()) {
right = preloop_exprs.at(cell);
} else {
auto safe_param = preloop_counter;
if (ExpressionUtil::canBeNegative(safe_param)) {
auto tmp = safe_param;
safe_param = Expression(Expression::Type::FUNCTION, "max",
{tmp, ExpressionUtil::newConstant(0)});
}
right = Expression(Expression::Type::FUNCTION, getCellName(cell),
{safe_param});
}
formula.entries[left] = right;
cellNames[cell] = name;
}
prepareForPostLoop(numCells, preloopExprs);
Log::get().debug("Prepared post-loop: " + formula.toString());

// handle post-loop code
Expand All @@ -327,11 +297,11 @@ bool FormulaGenerator::generateSingle(const Program& p) {
Log::get().debug("Processed post-loop: " + formula.toString());
}

// resolve linear functions
// resolve simple recursions
FormulaUtil::resolveSimpleRecursions(formula);
Log::get().debug("Resolved simple recursions: " + formula.toString());

// resolve linear functions
// resolve simple functions
FormulaUtil::resolveSimpleFunctions(formula);
Log::get().debug("Resolved simple functions: " + formula.toString());

Expand All @@ -349,6 +319,42 @@ bool FormulaGenerator::generateSingle(const Program& p) {
return true;
}

void FormulaGenerator::prepareForPostLoop(
int64_t numCells, const std::map<int64_t, Expression> preloopExprs) {
// prepare post-loop processing
auto preloopCounter = preloopExprs.at(incEval.getSimpleLoop().counter);
for (int64_t cell = 0; cell < numCells; cell++) {
auto name = newName();
auto left = ExpressionUtil::newFunction(name);
Expression right;
if (cell == incEval.getSimpleLoop().counter) {
auto last = ExpressionUtil::newConstant(0);
if (incEval.getLoopCounterDecrement() > 1) {
auto loop_dec =
ExpressionUtil::newConstant(incEval.getLoopCounterDecrement());
last = Expression(Expression::Type::MODULUS, "",
{preloopCounter, loop_dec});
}
right =
Expression(Expression::Type::FUNCTION, "min", {preloopCounter, last});
} else if (incEval.getInputDependentCells().find(cell) !=
incEval.getInputDependentCells().end()) {
right = preloopExprs.at(cell);
} else {
auto safe_param = preloopCounter;
if (ExpressionUtil::canBeNegative(safe_param)) {
auto tmp = safe_param;
safe_param = Expression(Expression::Type::FUNCTION, "max",
{tmp, ExpressionUtil::newConstant(0)});
}
right = Expression(Expression::Type::FUNCTION, getCellName(cell),
{safe_param});
}
formula.entries[left] = right;
cellNames[cell] = name;
}
}

void FormulaGenerator::simplifyFunctionNames() {
std::set<std::string> names;
for (auto& e : formula.entries) {
Expand Down
3 changes: 3 additions & 0 deletions src/form/formula_gen.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,9 @@ class FormulaGenerator {

bool update(const Program& p);

void prepareForPostLoop(int64_t numCells,
const std::map<int64_t, Expression> preloopExprs);

std::string newName();

std::string getCellName(int64_t cell) const;
Expand Down

0 comments on commit 0349bdd

Please sign in to comment.