From a566bcac4ff705cdc3a18ac1ee5f15bbf27ec9eb Mon Sep 17 00:00:00 2001 From: Christian Krause Date: Fri, 21 Jun 2024 19:02:37 +0200 Subject: [PATCH 1/2] pari vector evaluation --- .vscode/settings.json | 3 +- src/cmd/commands.cpp | 20 +++--- src/cmd/test.cpp | 67 +++++++++++++++++ src/cmd/test.hpp | 4 ++ src/form/formula.cpp | 49 +++++++++---- src/form/formula.hpp | 13 ++-- src/form/formula_util.cpp | 20 ++---- src/form/formula_util.hpp | 2 - src/form/pari.cpp | 92 ++++++++++++++++++------ src/form/pari.hpp | 23 ++++-- src/form/variant.cpp | 6 +- tests/formula/program-pari-recursive.txt | 4 ++ tests/formula/program-pari-vector.txt | 6 ++ tests/programs/oeis/000/A000058.asm | 12 ++++ 14 files changed, 247 insertions(+), 74 deletions(-) create mode 100644 tests/formula/program-pari-recursive.txt create mode 100644 tests/formula/program-pari-vector.txt create mode 100644 tests/programs/oeis/000/A000058.asm diff --git a/.vscode/settings.json b/.vscode/settings.json index b9c9939b..4dc0ded5 100644 --- a/.vscode/settings.json +++ b/.vscode/settings.json @@ -80,7 +80,8 @@ "functional": "cpp", "__verbose_abort": "cpp", "bit": "cpp", - "charconv": "cpp" + "charconv": "cpp", + "execution": "cpp" }, "makefile.makeDirectory": "src", "c-cpp-flylint.clang.includePaths": [ diff --git a/src/cmd/commands.cpp b/src/cmd/commands.cpp index 80256641..a960a85b 100644 --- a/src/cmd/commands.cpp +++ b/src/cmd/commands.cpp @@ -211,6 +211,7 @@ void Commands::export_(const std::string& path) { Program program = OeisProgram::getProgramAndSeqId(path).first; const auto& format = settings.export_format; Formula formula; + PariFormula pari_formula; FormulaGenerator generator; if (format.empty() || format == "formula") { if (!generator.generate(program, -1, formula, settings.with_deps)) { @@ -219,16 +220,18 @@ void Commands::export_(const std::string& path) { std::cout << formula.toString() << std::endl; } else if (format == "pari") { if (!generator.generate(program, -1, formula, settings.with_deps) || - !Pari::convertToPari(formula, false)) { + !PariFormula::convert(formula, false, pari_formula)) { throwConversionError(format); } - std::cout << formula.toString("; ", true) << std::endl; + std::cout << pari_formula.toString() << std::endl; + // pari_formula.printEvalCode(10, std::cout); } else if (format == "pari-vector") { if (!generator.generate(program, -1, formula, settings.with_deps) || - !Pari::convertToPari(formula, true)) { + !PariFormula::convert(formula, true, pari_formula)) { throwConversionError(format); } - std::cout << formula.toString("; ", true) << std::endl; + std::cout << pari_formula.toString() << std::endl; + // pari_formula.printEvalCode(10, std::cout); } else if (format == "loda") { ProgramUtil::print(program, std::cout); } else { @@ -462,10 +465,12 @@ void Commands::testPari(const std::string& test_id) { // generate PARI code FormulaGenerator generator; Formula formula; + PariFormula pari_formula; + const bool as_vector = true; Sequence expSeq; try { if (!generator.generate(program, id, formula, true) || - !Pari::convertToPari(formula)) { + !PariFormula::convert(formula, as_vector, pari_formula)) { continue; } } catch (const std::exception& e) { @@ -483,7 +488,6 @@ void Commands::testPari(const std::string& test_id) { true); } } - auto pariCode = formula.toString("; ", true); // determine number of terms for testing size_t numTerms = seq.existingNumTerms(); @@ -513,7 +517,7 @@ void Commands::testPari(const std::string& test_id) { } } Log::get().info("Checking " + std::to_string(numTerms) + " terms of " + - seq.id_str() + ": " + pariCode); + seq.id_str() + ": " + pari_formula.toString()); if (numTerms == 0) { Log::get().warn("Skipping " + seq.id_str()); @@ -532,7 +536,7 @@ void Commands::testPari(const std::string& test_id) { } // evaluate PARI program - auto genSeq = Pari::eval(formula, 0, numTerms - 1); + auto genSeq = pari_formula.eval(numTerms); // compare results if (genSeq != expSeq) { diff --git a/src/cmd/test.cpp b/src/cmd/test.cpp index 3787ff76..699578da 100644 --- a/src/cmd/test.cpp +++ b/src/cmd/test.cpp @@ -7,6 +7,7 @@ #include #include "form/formula_gen.hpp" +#include "form/pari.hpp" #include "lang/big_number.hpp" #include "lang/comments.hpp" #include "lang/evaluator.hpp" @@ -60,6 +61,7 @@ void Test::all() { checkpoint(); knownPrograms(); formula(); + pariEval(); // slow tests number(); @@ -907,6 +909,71 @@ void Test::formula() { } } +void Test::pariEval() { + auto base_path = + std::string("tests") + FILE_SEP + std::string("formula") + FILE_SEP; + testPariEval(base_path + "program-pari-recursive.txt", false); + testPariEval(base_path + "program-pari-vector.txt", true); +} + +void testPariEvalCode(const std::string& seq_id, + const std::string& expected_eval_code, bool asVector) { + Log::get().info("Testing PARI/GP " + + std::string((asVector ? "vector" : "recursive")) + + " code for " + seq_id); + Parser parser; + FormulaGenerator generator; + OeisSequence seq(seq_id); + auto program = parser.parse(seq.getProgramPath()); + Formula f; + PariFormula pari; + if (!generator.generate(program, seq.id, f, true)) { + Log::get().error("Cannot generate formula from program", true); + } + if (!PariFormula::convert(f, asVector, pari)) { + Log::get().error("Cannot convert formula to PARI/GP", true); + } + std::stringstream eval_code; + pari.printEvalCode(10, eval_code); + if (eval_code.str() != expected_eval_code) { + Log::get().error("Unexpected PARI/GP code: " + eval_code.str(), true); + } +} + +void Test::testPariEval(const std::string& testFile, bool asVector) { + std::ifstream file(testFile); + if (!file.is_open()) { + Log::get().error("Cannot open test file: " + testFile, true); + } + std::string line, seq_id, expected_eval_code; + Parser parser; + FormulaGenerator generator; + size_t num_tests = 0; + while (std::getline(file, line)) { + if (line.empty()) { + if (!seq_id.empty()) { + testPariEvalCode(seq_id, expected_eval_code, asVector); + seq_id.clear(); + expected_eval_code.clear(); + num_tests++; + } + } else if (line.substr(0, 2) == "\\\\") { + seq_id = line.substr(2); + trimString(seq_id); + expected_eval_code.clear(); + } else { + expected_eval_code += line + "\n"; + } + } + if (!seq_id.empty()) { + testPariEvalCode(seq_id, expected_eval_code, asVector); + num_tests++; + } + if (num_tests == 0) { + Log::get().error("No tests found in file: " + testFile, true); + } +} + void Test::stats() { Log::get().info("Testing stats loading and saving"); diff --git a/src/cmd/test.hpp b/src/cmd/test.hpp index 365864d3..6bb01ea6 100644 --- a/src/cmd/test.hpp +++ b/src/cmd/test.hpp @@ -68,6 +68,8 @@ class Test { void formula(); + void pariEval(); + private: std::vector> loadInOutTests( const std::string &prefix); @@ -81,6 +83,8 @@ class Test { void testMatcherPair(Matcher &matcher, size_t id1, size_t id2); + void testPariEval(const std::string &testFile, bool asVector); + OeisManager &getManager(); Settings settings; diff --git a/src/form/formula.cpp b/src/form/formula.cpp index 24bb9d55..954b39a4 100644 --- a/src/form/formula.cpp +++ b/src/form/formula.cpp @@ -33,12 +33,30 @@ bool Formula::contains(const Expression& search) const { }); } -bool Formula::containsFunctionDef(const std::string& fname) const { - return std::any_of(entries.begin(), entries.end(), - [&](const std::pair& e) { - return e.first.type == Expression::Type::FUNCTION && - e.first.name == fname; - }); +std::vector Formula::getDefinitions( + Expression::Type type, bool sortByDependencies) const { + std::vector result; + for (const auto& e : entries) { + if (e.first.type == type && + std::find(result.begin(), result.end(), e.first.name) == result.end()) { + result.push_back(e.first.name); + } + } + if (sortByDependencies) { + const auto deps = getDependencies(type, true, true); + std::sort(result.begin(), result.end(), + [&](const std::string& a, const std::string& b) { + auto depsA = deps.equal_range(a); + return !std::any_of( + depsA.first, depsA.second, + [&](const std::pair& e) { + return e.second == b; + }); + }); + } else { + std::sort(result.begin(), result.end()); + } + return result; } bool containsPair(std::multimap& deps, @@ -53,13 +71,13 @@ bool containsPair(std::multimap& deps, } void collectDeps(const std::string& fname, const Expression& e, + Expression::Type type, std::multimap& deps) { - if (e.type == Expression::Type::FUNCTION && !e.name.empty() && - !containsPair(deps, fname, e.name)) { + if (e.type == type && !e.name.empty() && !containsPair(deps, fname, e.name)) { deps.insert({fname, e.name}); } for (const auto& c : e.children) { - collectDeps(fname, c, deps); + collectDeps(fname, c, type, deps); } } @@ -78,12 +96,12 @@ std::pair findMissingPair( return result; } -std::multimap Formula::getFunctionDeps( - bool transitive, bool ignoreSelf) const { +std::multimap Formula::getDependencies( + Expression::Type type, bool transitive, bool ignoreSelf) const { std::multimap deps; for (auto& e : entries) { - if (e.first.type == Expression::Type::FUNCTION && !e.first.name.empty()) { - collectDeps(e.first.name, e.second, deps); + if (e.first.type == type && !e.first.name.empty()) { + collectDeps(e.first.name, e.second, type, deps); } } if (transitive) { @@ -108,8 +126,9 @@ std::multimap Formula::getFunctionDeps( return deps; } -bool Formula::isRecursive(const std::string& funcName) const { - auto deps = getFunctionDeps(false, false); +bool Formula::isRecursive(const std::string& funcName, + Expression::Type type) const { + auto deps = getDependencies(type, false, false); for (auto it : deps) { if (it.first == funcName && it.second == funcName) { return true; diff --git a/src/form/formula.hpp b/src/form/formula.hpp index 634b2262..0d88ffb8 100644 --- a/src/form/formula.hpp +++ b/src/form/formula.hpp @@ -1,6 +1,7 @@ #pragma once #include +#include #include "form/expression.hpp" @@ -23,12 +24,16 @@ class Formula { bool contains(const Expression& search) const; - bool containsFunctionDef(const std::string& fname) const; + std::vector getDefinitions( + Expression::Type type = Expression::Type::FUNCTION, + bool sortByDependencies = false) const; - std::multimap getFunctionDeps( - bool transitive, bool ignoreSelf) const; + std::multimap getDependencies( + Expression::Type type = Expression::Type::FUNCTION, + bool transitive = false, bool ignoreSelf = false) const; - bool isRecursive(const std::string& funcName) const; + bool isRecursive(const std::string& funcName, + Expression::Type type = Expression::Type::FUNCTION) const; void replaceAll(const Expression& from, const Expression& to); diff --git a/src/form/formula_util.cpp b/src/form/formula_util.cpp index 11f6fee9..faf12b7a 100644 --- a/src/form/formula_util.cpp +++ b/src/form/formula_util.cpp @@ -45,7 +45,7 @@ void FormulaUtil::resolveSimpleFunctions(Formula& formula) { } } // filter out non-simple functions - auto deps = formula.getFunctionDeps(false, false); + auto deps = formula.getDependencies(Expression::Type::FUNCTION, false, false); for (auto& e : formula.entries) { if (e.first.type != Expression::Type::FUNCTION) { continue; // should not happen @@ -56,7 +56,9 @@ void FormulaUtil::resolveSimpleFunctions(Formula& formula) { is_simple = false; } for (auto it : deps) { - if (it.first == f && formula.containsFunctionDef(it.second)) { + auto functions = formula.getDefinitions(); + if (it.first == f && std::find(functions.begin(), functions.end(), + it.second) != functions.end()) { is_simple = false; break; } @@ -195,17 +197,3 @@ void FormulaUtil::convertInitialTermsToIf(Formula& formula) { } } } - -Formula FormulaUtil::extractInitialTerms(Formula& formula) { - Formula initial_terms; - auto it = formula.entries.begin(); - while (it != formula.entries.end()) { - if (ExpressionUtil::isInitialTerm(it->first)) { - initial_terms.entries[it->first] = it->second; - it = formula.entries.erase(it); - } else { - it++; - } - } - return initial_terms; -} diff --git a/src/form/formula_util.hpp b/src/form/formula_util.hpp index 8fd4b7e5..26049ea4 100644 --- a/src/form/formula_util.hpp +++ b/src/form/formula_util.hpp @@ -11,6 +11,4 @@ class FormulaUtil { static void resolveSimpleRecursions(Formula& formula); static void convertInitialTermsToIf(Formula& formula); - - static Formula extractInitialTerms(Formula& formula); }; diff --git a/src/form/pari.cpp b/src/form/pari.cpp index f88cdbeb..48af8ac3 100644 --- a/src/form/pari.cpp +++ b/src/form/pari.cpp @@ -20,8 +20,10 @@ bool convertExprToPari(Expression& expr, const Formula& f, bool as_vector) { return false; } } + auto functions = f.getDefinitions(Expression::Type::FUNCTION); if (expr.type == Expression::Type::FUNCTION && as_vector && - f.containsFunctionDef(expr.name)) { + std::find(functions.begin(), functions.end(), expr.name) != + functions.end()) { expr.type = Expression::Type::VECTOR; } return true; @@ -29,7 +31,10 @@ bool convertExprToPari(Expression& expr, const Formula& f, bool as_vector) { void countFuncs(const Formula& f, const Expression& e, std::map& count) { - if (e.type == Expression::Type::FUNCTION && f.containsFunctionDef(e.name)) { + auto functions = f.getDefinitions(Expression::Type::FUNCTION); + if (e.type == Expression::Type::FUNCTION && + std::find(functions.begin(), functions.end(), e.name) != + functions.end()) { if (count.find(e) == count.end()) { count[e] = 1; } else { @@ -65,37 +70,85 @@ bool addLocalVars(Formula& f) { return changed; } -bool Pari::convertToPari(Formula& f, bool as_vector) { - Formula tmp; - for (auto& entry : f.entries) { +void PariFormula::extractInitialTerms() { + initial_terms.clear(); + auto it = main_formula.entries.begin(); + while (it != main_formula.entries.end()) { + if (ExpressionUtil::isInitialTerm(it->first)) { + auto key = it->first; + key.children.front().value += Number::ONE; + initial_terms.entries[key] = it->second; + max_initial_terms[key.name] = std::max( + max_initial_terms[key.name], key.children.front().value.asInt()); + it = main_formula.entries.erase(it); + } else { + it++; + } + } +} + +bool PariFormula::convert(const Formula& formula, bool as_vector, + PariFormula& pari_formula) { + pari_formula = PariFormula(); + pari_formula.as_vector = as_vector; + for (const auto& entry : formula.entries) { auto left = entry.first; auto right = entry.second; if (as_vector && left.type == Expression::Type::FUNCTION) { left.type = Expression::Type::VECTOR; } - if (!convertExprToPari(right, f, as_vector)) { + if (!convertExprToPari(right, formula, as_vector)) { return false; } - tmp.entries[left] = right; + pari_formula.main_formula.entries[left] = right; } - f = tmp; - if (!as_vector) { - addLocalVars(f); - FormulaUtil::convertInitialTermsToIf(f); + if (as_vector) { + pari_formula.extractInitialTerms(); + } else { + addLocalVars(pari_formula.main_formula); + FormulaUtil::convertInitialTermsToIf(pari_formula.main_formula); } return true; } -void Pari::printEvalCode(const Formula& f, int64_t start, int64_t end, - std::ostream& out, bool as_vector) { +std::string PariFormula::toString() const { if (as_vector) { - // TODO: print initial terms + return main_formula.toString("; ", false) + "; " + + initial_terms.toString("; ", false); } else { - out << f.toString("; ", true) << std::endl; + return main_formula.toString("; ", true); } - out << "for (n = " << start << ", " << end << ", "; +} + +void PariFormula::printEvalCode(int64_t numTerms, std::ostream& out) const { if (as_vector) { - // TODO: print function definition + // declare vectors + auto functions = main_formula.getDefinitions(Expression::Type::VECTOR); + for (const auto& f : functions) { + out << f << " = vector(" << numTerms << ")" << std::endl; + } + // initial terms only + out << initial_terms.toString("\n", false) << std::endl; + } else { + // main function + out << main_formula.toString("; ", true) << std::endl; + } + const int64_t start = as_vector ? 1 : 0; + const int64_t end = numTerms + start - 1; + out << "for(n=" << start << "," << end << ","; + if (as_vector) { + auto sorted = main_formula.getDefinitions(Expression::Type::VECTOR, true); + for (const auto& f : sorted) { + auto key = ExpressionUtil::newFunction(f); + key.type = Expression::Type::VECTOR; + if (max_initial_terms.find(f) != max_initial_terms.end()) { + out << "if(n>" << max_initial_terms.at(f) << ", "; + out << f << "[n] = " << main_formula.entries.at(key).toString() + << "); "; + } else { + out << f << "[n] = " << main_formula.entries.at(key).toString() << "; "; + } + } out << "print(a[n])"; } else { out << "print(a(n))"; @@ -103,8 +156,7 @@ void Pari::printEvalCode(const Formula& f, int64_t start, int64_t end, out << ")" << std::endl << "quit" << std::endl; } -Sequence Pari::eval(const Formula& f, int64_t start, int64_t end, - bool as_vector) { +Sequence PariFormula::eval(int64_t numTerms) const { const std::string gpPath("pari-loda.gp"); const std::string gpResult("pari-result.txt"); const int64_t maxparisize = 256; // in MB @@ -112,7 +164,7 @@ Sequence Pari::eval(const Formula& f, int64_t start, int64_t end, if (!gp) { throw std::runtime_error("error generating gp file"); } - printEvalCode(f, start, end, gp, as_vector); + printEvalCode(numTerms, gp); gp.close(); std::string cmd = "gp -s " + std::to_string(maxparisize) + "M -q " + gpPath + " > " + gpResult; diff --git a/src/form/pari.hpp b/src/form/pari.hpp index bcba6fa7..0b354643 100644 --- a/src/form/pari.hpp +++ b/src/form/pari.hpp @@ -14,13 +14,24 @@ * Generated PARI/GP code: * a(n) = if(n==0,1,n*a(n-1)) */ -class Pari { +class PariFormula { public: - static bool convertToPari(Formula& f, bool as_vector = false); + PariFormula() : as_vector(false){}; - static void printEvalCode(const Formula& f, int64_t start, int64_t end, - std::ostream& out, bool as_vector = false); + static bool convert(const Formula& formula, bool as_vector, + PariFormula& pari_formula); - static Sequence eval(const Formula& f, int64_t start, int64_t end, - bool as_vector = false); + void printEvalCode(int64_t numTerms, std::ostream& out) const; + + std::string toString() const; + + Sequence eval(int64_t numTerms) const; + + private: + Formula main_formula; + Formula initial_terms; + bool as_vector; + std::map max_initial_terms; + + void extractInitialTerms(); }; diff --git a/src/form/variant.cpp b/src/form/variant.cpp index 5103744a..95fb53a3 100644 --- a/src/form/variant.cpp +++ b/src/form/variant.cpp @@ -216,8 +216,10 @@ bool simplifyFormulaUsingVariants( } Formula copy = formula; copy.entries[entry.first] = variant.definition; - auto deps_old = formula.getFunctionDeps(true, true); - auto deps_new = copy.getFunctionDeps(true, true); + auto deps_old = + formula.getDependencies(Expression::Type::FUNCTION, true, true); + auto deps_new = + copy.getDependencies(Expression::Type::FUNCTION, true, true); if (deps_new.size() < deps_old.size()) { entry.second = variant.definition; num_initial_terms[entry.first.name] = variant.num_initial_terms; diff --git a/tests/formula/program-pari-recursive.txt b/tests/formula/program-pari-recursive.txt new file mode 100644 index 00000000..d321ad89 --- /dev/null +++ b/tests/formula/program-pari-recursive.txt @@ -0,0 +1,4 @@ +\\ A000058 +(a(n) = b(n)+1); (b(n) = if(n==0,1,local(l1=b(n-1)); l1*(l1+1))) +for(n=0,9,print(a(n))) +quit diff --git a/tests/formula/program-pari-vector.txt b/tests/formula/program-pari-vector.txt new file mode 100644 index 00000000..52686b0d --- /dev/null +++ b/tests/formula/program-pari-vector.txt @@ -0,0 +1,6 @@ +\\ A000058 +a = vector(10) +b = vector(10) +b[1] = 1 +for(n=1,10,if(n>1, b[n] = b[n-1]*(b[n-1]+1)); a[n] = b[n]+1; print(a[n])) +quit diff --git a/tests/programs/oeis/000/A000058.asm b/tests/programs/oeis/000/A000058.asm new file mode 100644 index 00000000..0c7a18de --- /dev/null +++ b/tests/programs/oeis/000/A000058.asm @@ -0,0 +1,12 @@ +; A000058: Sylvester's sequence: a(n+1) = a(n)^2 - a(n) + 1, with a(0) = 2. +; 2,3,7,43,1807,3263443,10650056950807,113423713055421844361000443,12864938683278671740537145998360961546653259485195807,165506647324519964198468195444439180017513152706377497841851388766535868639572406808911988131737645185443,27392450308603031423410234291674686281194364367580914627947367941608692026226993634332118404582438634929548737283992369758487974306317730580753883429460344956410077034761330476016739454649828385541500213920807 + +mov $1,1 +lpb $0 + sub $0,1 + mov $2,$1 + add $2,1 + mul $1,$2 +lpe +mov $0,$1 +add $0,1 From 70c039775cdecfda6c4c8ada5dd5fcf1792f545c Mon Sep 17 00:00:00 2001 From: Christian Krause Date: Fri, 21 Jun 2024 19:25:56 +0200 Subject: [PATCH 2/2] minor stuff --- src/cmd/commands.cpp | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/src/cmd/commands.cpp b/src/cmd/commands.cpp index a960a85b..79b21fc5 100644 --- a/src/cmd/commands.cpp +++ b/src/cmd/commands.cpp @@ -224,14 +224,12 @@ void Commands::export_(const std::string& path) { throwConversionError(format); } std::cout << pari_formula.toString() << std::endl; - // pari_formula.printEvalCode(10, std::cout); } else if (format == "pari-vector") { if (!generator.generate(program, -1, formula, settings.with_deps) || !PariFormula::convert(formula, true, pari_formula)) { throwConversionError(format); } std::cout << pari_formula.toString() << std::endl; - // pari_formula.printEvalCode(10, std::cout); } else if (format == "loda") { ProgramUtil::print(program, std::cout); } else { @@ -466,7 +464,7 @@ void Commands::testPari(const std::string& test_id) { FormulaGenerator generator; Formula formula; PariFormula pari_formula; - const bool as_vector = true; + const bool as_vector = false; // TODO: switch to true Sequence expSeq; try { if (!generator.generate(program, id, formula, true) ||