diff --git a/src/form/expression_util.cpp b/src/form/expression_util.cpp index 3af605ff..5d4a221a 100644 --- a/src/form/expression_util.cpp +++ b/src/form/expression_util.cpp @@ -290,6 +290,44 @@ bool ExpressionUtil::isInitialTerm(const Expression& e) { return arg.type == Expression::Type::CONSTANT; } +bool ExpressionUtil::isRecursionArgument(const Expression& e, + int64_t max_offset) { + Number offset; + if (e.type == Expression::Type::PARAMETER) { + offset = Number::ZERO; + } else if (e.type == Expression::Type::SUM && e.children.size() == 2 && + e.children[0].type == Expression::Type::PARAMETER && + e.children[1].type == Expression::Type::CONSTANT) { + offset = e.children[1].value; + } else { + return false; + } + return offset < Number(max_offset + 1); +} + +bool ExpressionUtil::isNonRecursiveFunctionReference( + const Expression& e, const std::vector& names, + int64_t max_offset) { + if (e.type != Expression::Type::FUNCTION || e.children.size() != 1 || + std::find(names.begin(), names.end(), e.name) == names.end()) { + return false; + } + return !isRecursionArgument(e.children.front(), max_offset); +} + +bool ExpressionUtil::hasNonRecursiveFunctionReference( + const Expression& e, const std::vector& names, + int64_t max_offset) { + if (isNonRecursiveFunctionReference(e, names, max_offset)) { + return true; + } else { + return std::any_of( + e.children.begin(), e.children.end(), [&](const Expression& c) { + return hasNonRecursiveFunctionReference(c, names, max_offset); + }); + } +} + bool ExpressionUtil::canBeNegative(const Expression& e) { switch (e.type) { case Expression::Type::CONSTANT: diff --git a/src/form/expression_util.hpp b/src/form/expression_util.hpp index 66eeb81f..7bef5ae9 100644 --- a/src/form/expression_util.hpp +++ b/src/form/expression_util.hpp @@ -22,6 +22,16 @@ class ExpressionUtil { static bool isInitialTerm(const Expression& e); + static bool isRecursionArgument(const Expression& e, int64_t max_offset); + + static bool isNonRecursiveFunctionReference( + const Expression& e, const std::vector& names, + int64_t max_offset); + + static bool hasNonRecursiveFunctionReference( + const Expression& e, const std::vector& names, + int64_t max_offset); + static bool canBeNegative(const Expression& e); static void collectNames(const Expression& e, Expression::Type type, diff --git a/src/form/pari.cpp b/src/form/pari.cpp index 68ec9525..5191bd22 100644 --- a/src/form/pari.cpp +++ b/src/form/pari.cpp @@ -74,12 +74,18 @@ bool PariFormula::convert(const Formula& formula, bool as_vector, PariFormula& pari_formula) { pari_formula = PariFormula(); pari_formula.as_vector = as_vector; + auto defs = FormulaUtil::getDefinitions(formula, Expression::Type::FUNCTION); for (const auto& entry : formula.entries) { auto left = entry.first; auto right = entry.second; if (as_vector && left.type == Expression::Type::FUNCTION) { left.type = Expression::Type::VECTOR; } + // TODO: remove this limitation + if (as_vector && + ExpressionUtil::hasNonRecursiveFunctionReference(right, defs, 0)) { + return false; + } if (!convertExprToPari(right, formula, as_vector)) { return false; }