Skip to content

Commit

Permalink
additional check when converting to pari-vector
Browse files Browse the repository at this point in the history
  • Loading branch information
ckrause committed Oct 29, 2024
1 parent e350c29 commit 6003fea
Show file tree
Hide file tree
Showing 3 changed files with 54 additions and 0 deletions.
38 changes: 38 additions & 0 deletions src/form/expression_util.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -290,6 +290,44 @@ bool ExpressionUtil::isInitialTerm(const Expression& e) {
return arg.type == Expression::Type::CONSTANT;
}

bool ExpressionUtil::isRecursionArgument(const Expression& e,
int64_t max_offset) {
Number offset;
if (e.type == Expression::Type::PARAMETER) {
offset = Number::ZERO;
} else if (e.type == Expression::Type::SUM && e.children.size() == 2 &&
e.children[0].type == Expression::Type::PARAMETER &&
e.children[1].type == Expression::Type::CONSTANT) {
offset = e.children[1].value;
} else {
return false;
}
return offset < Number(max_offset + 1);
}

bool ExpressionUtil::isNonRecursiveFunctionReference(
const Expression& e, const std::vector<std::string>& names,
int64_t max_offset) {
if (e.type != Expression::Type::FUNCTION || e.children.size() != 1 ||
std::find(names.begin(), names.end(), e.name) == names.end()) {
return false;
}
return !isRecursionArgument(e.children.front(), max_offset);
}

bool ExpressionUtil::hasNonRecursiveFunctionReference(
const Expression& e, const std::vector<std::string>& names,
int64_t max_offset) {
if (isNonRecursiveFunctionReference(e, names, max_offset)) {
return true;
} else {
return std::any_of(
e.children.begin(), e.children.end(), [&](const Expression& c) {
return hasNonRecursiveFunctionReference(c, names, max_offset);
});
}
}

bool ExpressionUtil::canBeNegative(const Expression& e) {
switch (e.type) {
case Expression::Type::CONSTANT:
Expand Down
10 changes: 10 additions & 0 deletions src/form/expression_util.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,16 @@ class ExpressionUtil {

static bool isInitialTerm(const Expression& e);

static bool isRecursionArgument(const Expression& e, int64_t max_offset);

static bool isNonRecursiveFunctionReference(
const Expression& e, const std::vector<std::string>& names,
int64_t max_offset);

static bool hasNonRecursiveFunctionReference(
const Expression& e, const std::vector<std::string>& names,
int64_t max_offset);

static bool canBeNegative(const Expression& e);

static void collectNames(const Expression& e, Expression::Type type,
Expand Down
6 changes: 6 additions & 0 deletions src/form/pari.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -74,12 +74,18 @@ bool PariFormula::convert(const Formula& formula, bool as_vector,
PariFormula& pari_formula) {
pari_formula = PariFormula();
pari_formula.as_vector = as_vector;
auto defs = FormulaUtil::getDefinitions(formula, Expression::Type::FUNCTION);
for (const auto& entry : formula.entries) {
auto left = entry.first;
auto right = entry.second;
if (as_vector && left.type == Expression::Type::FUNCTION) {
left.type = Expression::Type::VECTOR;
}
// TODO: remove this limitation
if (as_vector &&
ExpressionUtil::hasNonRecursiveFunctionReference(right, defs, 0)) {
return false;
}
if (!convertExprToPari(right, formula, as_vector)) {
return false;
}
Expand Down

0 comments on commit 6003fea

Please sign in to comment.