Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

formula variants #301

Closed
wants to merge 13 commits into from
197 changes: 125 additions & 72 deletions src/form/formula_alt.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,101 +3,154 @@
#include "form/expression_util.hpp"
#include "sys/log.hpp"

bool resolve(const Alternatives& alt, const Expression& left,
Expression& right) {
if (right.type == Expression::Type::FUNCTION) {
auto lookup = ExpressionUtil::newFunction(right.name);
if (lookup != left) {
auto range = alt.equal_range(lookup);
for (auto it = range.first; it != range.second; it++) {
auto replacement = it->second;
replacement.replaceAll(ExpressionUtil::newParameter(),
*right.children[0]);
ExpressionUtil::normalize(replacement);
auto range2 = alt.equal_range(left);
bool exists = false;
for (auto it2 = range2.first; it2 != range2.second; it2++) {
if (it2->second == replacement) {
exists = true;
break;
}
}
if (!exists) {
right = replacement;
return true; // must stop here
}
VariantsManager::VariantsManager(const Formula& formula) {
// step 1: collect function names
for (auto& entry : formula.entries) {
if (ExpressionUtil::isSimpleFunction(entry.first, true)) {
variants[entry.first.name] = {};
}
}
// step 2: initialize function variants
for (auto& entry : formula.entries) {
if (ExpressionUtil::isSimpleFunction(entry.first, true)) {
Variant variant;
variant.definition = entry.second;
collectUsedFuncs(variant.definition, variant.used_funcs);
variants[entry.first.name].push_back(variant);
}
}
}

bool VariantsManager::update(const std::string& func, const Expression& expr) {
Variant new_variant;
new_variant.definition = expr;
collectUsedFuncs(expr, new_variant.used_funcs);
if (new_variant.used_funcs.size() > 3) { // magic number
return false;
}
auto& vs = variants[func];
for (size_t i = 0; i < vs.size(); i++) {
if (vs[i].used_funcs == new_variant.used_funcs) {
if (expr.numTerms() < vs[i].definition.numTerms()) {
// update existing variant but don't report as new
vs[i].definition = expr;
Log::get().debug("Updated variant to " +
ExpressionUtil::newFunction(func).toString() + " = " +
expr.toString());
}
return false;
}
}
// add new variant
Log::get().debug("Found variant " +
ExpressionUtil::newFunction(func).toString() + " = " +
expr.toString());
vs.push_back(new_variant);
return true;
}

void VariantsManager::collectUsedFuncs(
const Expression& expr, std::set<std::string>& used_funcs) const {
if (expr.type == Expression::Type::FUNCTION &&
variants.find(expr.name) != variants.end()) {
used_funcs.insert(expr.name);
}
for (auto c : expr.children) {
collectUsedFuncs(*c, used_funcs);
}
}

size_t VariantsManager::numVariants() const {
size_t num = 0;
for (auto& vs : variants) {
num += vs.second.size();
}
return num;
}

bool resolve(const std::string& lookup_name, const Expression& lookup_def,
const std::string& target_name, Expression& target_def) {
if (target_def.type == Expression::Type::FUNCTION) {
if (target_def.name != target_name && target_def.name == lookup_name) {
auto replacement = lookup_def; // copy
replacement.replaceAll(ExpressionUtil::newParameter(),
*target_def.children[0]);
ExpressionUtil::normalize(replacement);
target_def = replacement;
return true; // must stop here
}
}
bool resolved = false;
for (auto c : right.children) {
if (resolve(alt, left, *c)) {
for (auto c : target_def.children) {
if (resolve(lookup_name, lookup_def, target_name, *c)) {
resolved = true;
}
}
ExpressionUtil::normalize(right);
ExpressionUtil::normalize(target_def);
return resolved;
}

bool findAlternativesByResolve(Alternatives& alt) {
auto newAlt = alt; // copy
bool findVariants(VariantsManager& manager) {
auto variants = manager.variants; // copy
bool updated = false;
for (auto& target : variants) {
for (auto& target_variant : target.second) {
for (auto& lookup : variants) {
for (auto& lookup_variant : lookup.second) {
auto def = target_variant.definition; // copy
if (resolve(lookup.first, lookup_variant.definition, target.first,
def)) {
if (manager.update(target.first, def)) {
updated = true;
}
}
}
}
}
}
return updated;
}

bool simplifyFormulaUsingAlternatives(Formula& formula) {
VariantsManager manager(formula);
bool found = false;
for (auto& e : alt) {
auto right = e.second; // copy
if (resolve(newAlt, e.first, right)) {
std::pair<Expression, Expression> p(e.first, right);
Log::get().debug("Found alternative " + p.first.toString() + " = " +
p.second.toString());
newAlt.insert(p);
for (size_t it = 1; it <= 10; it++) { // magic number
Log::get().debug("Finding variants in iteration " + std::to_string(it));
if (findVariants(manager)) {
found = true;
} else {
break;
}
}
if (found) {
alt = newAlt;
if (!found) {
return false;
}
return found;
}

bool applyAlternatives(const Alternatives& alt, Formula& f) {
Log::get().debug("Found " + std::to_string(manager.numVariants()) +
" variants");
bool applied = false;
for (auto& e : f.entries) {
auto range = alt.equal_range(e.first);
for (auto it = range.first; it != range.second; it++) {
if (it->second == e.second) {
for (auto& entry : formula.entries) {
if (!ExpressionUtil::isSimpleFunction(entry.first, true)) {
continue;
}
for (auto& variant : manager.variants[entry.first.name]) {
if (variant.definition == entry.second) {
continue;
}
Formula g = f; // copy
g.entries[e.first] = it->second;
auto depsOld = f.getFunctionDeps(true, true);
auto depsNew = g.getFunctionDeps(true, true);
std::string debugMsg =
" alternative " + e.first.toString() + " = " + it->second.toString();
if (depsNew.size() < depsOld.size()) {
e.second = it->second;
Formula copy = formula;
copy.entries[entry.first] = variant.definition;
auto deps_old = formula.getFunctionDeps(true, true);
auto deps_new = copy.getFunctionDeps(true, true);
std::string debugMsg = " variant " + entry.first.toString() + " = " +
variant.definition.toString();
if (deps_new.size() < deps_old.size()) {
entry.second = variant.definition;
applied = true;
Log::get().debug("Applied" + debugMsg);
} else {
Log::get().debug("Skipped" + debugMsg);
// Log::get().debug("Skipped" + debugMsg);
}
}
}
Log::get().debug("Updated formula: " + formula.toString());
return applied;
}

bool simplifyFormulaUsingAlternatives(Formula& formula) {
// find and choose alternative function definitions
Alternatives alt;
bool updated = false;
alt.insert(formula.entries.begin(), formula.entries.end());
while (true) {
if (!findAlternativesByResolve(alt)) {
break;
}
if (!applyAlternatives(alt, formula)) {
break;
}
updated = true;
Log::get().debug("Updated formula: " + formula.toString());
}
return updated;
}
23 changes: 20 additions & 3 deletions src/form/formula_alt.hpp
Original file line number Diff line number Diff line change
@@ -1,13 +1,30 @@
#pragma once

#include <map>
#include <set>
#include <vector>

#include "form/formula.hpp"

typedef std::multimap<Expression, Expression> Alternatives;
class Variant {
public:
Expression definition;
std::set<std::string> used_funcs;
};

bool findAlternativesByResolve(Alternatives& alt);
class VariantsManager {
public:
VariantsManager(const Formula& formula);

bool applyAlternatives(const Alternatives& alt, Formula& formula);
bool update(const std::string& func, const Expression& expr);

std::map<std::string, std::vector<Variant>> variants;

size_t numVariants() const;

private:
void collectUsedFuncs(const Expression& expr,
std::set<std::string>& used_funcs) const;
};

bool simplifyFormulaUsingAlternatives(Formula& formula);
9 changes: 4 additions & 5 deletions src/form/formula_util.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -178,10 +178,9 @@ void FormulaUtil::resolveSimpleRecursions(Formula& formula) {
}
}

int64_t getRecursionDepthInExpr(const Expression& expr, const std::string& fname) {
int64_t getRecursionDepthInExpr(const Expression& expr) {
int64_t depth = 0;
if (expr.type == Expression::Type::FUNCTION && expr.name == fname &&
expr.children.size() == 1) {
if (expr.type == Expression::Type::FUNCTION && expr.children.size() == 1) {
const auto& arg = *expr.children[0];
if (arg.type == Expression::Type::SUM && arg.children.size() == 2 &&
arg.children[0]->type == Expression::Type::PARAMETER &&
Expand All @@ -190,7 +189,7 @@ int64_t getRecursionDepthInExpr(const Expression& expr, const std::string& fname
}
}
for (auto c : expr.children) {
depth = std::max<int64_t>(depth, getRecursionDepthInExpr(*c, fname));
depth = std::max<int64_t>(depth, getRecursionDepthInExpr(*c));
}
return depth;
}
Expand All @@ -202,7 +201,7 @@ int64_t FormulaUtil::getRecursionDepth(const Formula& formula,
if (left.type == Expression::Type::FUNCTION && left.name == fname &&
left.children.size() == 1 &&
left.children[0]->type == Expression::Type::PARAMETER) {
return getRecursionDepthInExpr(e.second, fname);
return getRecursionDepthInExpr(e.second);
}
}
return -1;
Expand Down
4 changes: 3 additions & 1 deletion tests/formula/program-formula.txt
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ A001478: a(n) = -n-1
A001489: a(n) = -n
A001542: a(n) = 3*a(n-1)+2*b(n-1), a(1) = 2, a(0) = 0, b(n) = 4*a(n-1)+3*b(n-1), b(1) = 3, b(0) = 1
A001611: a(n) = A000045(n)+1, A000045(n) = A000045(n-1)+A000045(n-2), A000045(1) = 1, A000045(0) = 0
A001687: a(n) = a(n-2)+a(n-5), a(4) = 0, a(3) = 1, a(2) = 0, a(1) = 1, a(0) = 0
A001715: a(n) = b(n+3)/6, b(n) = n*b(n-1), b(0) = 1
A001911: a(n) = b(n)-2, b(n) = b(n-1)+b(n-2), b(1) = 3, b(0) = 2
A001923: a(n) = n^n+a(n-1), a(0) = 0
Expand All @@ -50,14 +51,15 @@ A005408: a(n) = 2*n+1
A005843: a(n) = 2*n
A007583: a(n) = 2*((4^n)/3)+1
A008785: a(n) = (n+4)^n
A008999: a(n) = 2*a(n-1)+a(n-4), a(3) = 8, a(2) = 4, a(1) = 2, a(0) = 1
A014731: a(n) = 4*b(n)^2, b(n) = 4*b(n-1)+b(n-2), b(1) = -2, b(0) = -1
A021019: a(n) = 6*min(n,1)
A022322: a(n) = b(n+2), b(n) = c(n-2)+1, b(2) = 1, b(1) = 6, b(0) = 0, c(n) = c(n-1)+c(n-2)+2, c(2) = 9, c(1) = 7, c(0) = 0
A022958: a(n) = -n+2
A037536: a(n) = b(n+1), b(n) = (3*c(n-1))/13, b(1) = 1, b(0) = 0, c(n) = 3*c(n-1), c(1) = 24, c(0) = 8
A048745: a(n) = 2*a(n-1)+a(n-2)+3, a(1) = 5, a(0) = 1
A062815: a(n) = b(n+1), b(n) = n*n^n+b(n-1), b(0) = 0
A078013: a(n) = -a(n-4)-a(n-5), a(4) = -1, a(3) = -1, a(2) = 0, a(1) = 0, a(0) = 1
A078013: a(n) = b(n-2), a(2) = 0, a(1) = 0, a(0) = 1, b(n) = -b(n-3)+b(n-1), b(2) = -1, b(1) = -1, b(0) = 0
A110158: a(n) = b(max(n-2,0)), b(n) = b(n-2)+A077847(max(n-2,0)), b(1) = 0, b(0) = 0, c(n) = c(n-1)+d(n-1)+e(n-1), c(3) = 9, c(2) = 3, c(1) = 1, c(0) = 0, d(n) = c(n-1)+d(n-1)+e(n-1)+1, d(3) = 10, d(2) = 4, d(1) = 2, d(0) = 0, e(n) = 2*d(n-2)+2*e(n-2), e(3) = 4, e(2) = 2, e(1) = 0, e(0) = 1, A077847(n) = c(n+1)
A112032: a(n) = 2*a(n-2), a(1) = 1, a(0) = 4
A123111: a(n) = ((n+1)^5+(n+1)^3+n+2)*(n+1)^2+1
Expand Down
15 changes: 15 additions & 0 deletions tests/programs/oeis/001/A001687.asm
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
; A001687: a(n) = a(n-2) + a(n-5).
; 0,1,0,1,0,1,1,1,2,1,3,2,4,4,5,7,7,11,11,16,18,23,29,34,45,52,68,81,102,126,154,194,235,296,361,450,555,685,851,1046,1301,1601,1986,2452,3032,3753,4633,5739,7085,8771,10838,13404,16577,20489,25348,31327,38752,47904,59241,73252,90568,112004,138472,171245,211724,261813,323728,400285,494973,612009,756786,935737,1157071,1430710,1769080,2187496,2704817,3344567,4135527,5113647

mov $2,1
lpb $0
sub $0,1
mov $5,$4
mov $4,$2
add $4,$6
mov $2,$1
mov $1,$3
mov $3,$6
mov $6,$5
lpe
mov $0,$4
14 changes: 14 additions & 0 deletions tests/programs/oeis/008/A008999.asm
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
; A008999: a(n) = 2*a(n-1) + a(n-4).
; 1,2,4,8,17,36,76,160,337,710,1496,3152,6641,13992,29480,62112,130865,275722,580924,1223960,2578785,5433292,11447508,24118976,50816737,107066766,225581040,475281056,1001378849,2109824464,4445229968,9365740992,19732860833,41575546130,87596322228,184558385448,388849631729,819274809588,1726145941404,3636850268256,7662550168241,16144375146070,34014896233544,71666642735344,150995835638929,318136046423928,670286989081400,1412240620898144,2975477077435217,6269090201294362,13208467391670124

mov $5,1
lpb $0
sub $0,1
mov $4,$3
mov $3,$2
mov $2,$1
mov $1,$5
mul $5,2
add $5,$4
lpe
mov $0,$5