Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

pari vector evaluation #317

Merged
merged 3 commits into from
Jun 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 10 additions & 8 deletions src/cmd/commands.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -211,6 +211,7 @@ void Commands::export_(const std::string& path) {
Program program = OeisProgram::getProgramAndSeqId(path).first;
const auto& format = settings.export_format;
Formula formula;
PariFormula pari_formula;
FormulaGenerator generator;
if (format.empty() || format == "formula") {
if (!generator.generate(program, -1, formula, settings.with_deps)) {
Expand All @@ -219,16 +220,16 @@ void Commands::export_(const std::string& path) {
std::cout << formula.toString() << std::endl;
} else if (format == "pari") {
if (!generator.generate(program, -1, formula, settings.with_deps) ||
!Pari::convertToPari(formula, false)) {
!PariFormula::convert(formula, false, pari_formula)) {
throwConversionError(format);
}
std::cout << formula.toString("; ", true) << std::endl;
std::cout << pari_formula.toString() << std::endl;
} else if (format == "pari-vector") {
if (!generator.generate(program, -1, formula, settings.with_deps) ||
!Pari::convertToPari(formula, true)) {
!PariFormula::convert(formula, true, pari_formula)) {
throwConversionError(format);
}
std::cout << formula.toString("; ", true) << std::endl;
std::cout << pari_formula.toString() << std::endl;
} else if (format == "loda") {
ProgramUtil::print(program, std::cout);
} else {
Expand Down Expand Up @@ -462,10 +463,12 @@ void Commands::testPari(const std::string& test_id) {
// generate PARI code
FormulaGenerator generator;
Formula formula;
PariFormula pari_formula;
const bool as_vector = false; // TODO: switch to true
Sequence expSeq;
try {
if (!generator.generate(program, id, formula, true) ||
!Pari::convertToPari(formula)) {
!PariFormula::convert(formula, as_vector, pari_formula)) {
continue;
}
} catch (const std::exception& e) {
Expand All @@ -483,7 +486,6 @@ void Commands::testPari(const std::string& test_id) {
true);
}
}
auto pariCode = formula.toString("; ", true);

// determine number of terms for testing
size_t numTerms = seq.existingNumTerms();
Expand Down Expand Up @@ -513,7 +515,7 @@ void Commands::testPari(const std::string& test_id) {
}
}
Log::get().info("Checking " + std::to_string(numTerms) + " terms of " +
seq.id_str() + ": " + pariCode);
seq.id_str() + ": " + pari_formula.toString());

if (numTerms == 0) {
Log::get().warn("Skipping " + seq.id_str());
Expand All @@ -532,7 +534,7 @@ void Commands::testPari(const std::string& test_id) {
}

// evaluate PARI program
auto genSeq = Pari::eval(formula, 0, numTerms - 1);
auto genSeq = pari_formula.eval(numTerms);

// compare results
if (genSeq != expSeq) {
Expand Down
67 changes: 67 additions & 0 deletions src/cmd/test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
#include <stdexcept>

#include "form/formula_gen.hpp"
#include "form/pari.hpp"
#include "lang/big_number.hpp"
#include "lang/comments.hpp"
#include "lang/evaluator.hpp"
Expand Down Expand Up @@ -60,6 +61,7 @@ void Test::all() {
checkpoint();
knownPrograms();
formula();
pariEval();

// slow tests
number();
Expand Down Expand Up @@ -907,6 +909,71 @@ void Test::formula() {
}
}

void Test::pariEval() {
auto base_path =
std::string("tests") + FILE_SEP + std::string("formula") + FILE_SEP;
testPariEval(base_path + "program-pari-recursive.txt", false);
testPariEval(base_path + "program-pari-vector.txt", true);
}

void testPariEvalCode(const std::string& seq_id,
const std::string& expected_eval_code, bool asVector) {
Log::get().info("Testing PARI/GP " +
std::string((asVector ? "vector" : "recursive")) +
" code for " + seq_id);
Parser parser;
FormulaGenerator generator;
OeisSequence seq(seq_id);
auto program = parser.parse(seq.getProgramPath());
Formula f;
PariFormula pari;
if (!generator.generate(program, seq.id, f, true)) {
Log::get().error("Cannot generate formula from program", true);
}
if (!PariFormula::convert(f, asVector, pari)) {
Log::get().error("Cannot convert formula to PARI/GP", true);
}
std::stringstream eval_code;
pari.printEvalCode(10, eval_code);
if (eval_code.str() != expected_eval_code) {
Log::get().error("Unexpected PARI/GP code: " + eval_code.str(), true);
}
}

void Test::testPariEval(const std::string& testFile, bool asVector) {
std::ifstream file(testFile);
if (!file.is_open()) {
Log::get().error("Cannot open test file: " + testFile, true);
}
std::string line, seq_id, expected_eval_code;
Parser parser;
FormulaGenerator generator;
size_t num_tests = 0;
while (std::getline(file, line)) {
if (line.empty()) {
if (!seq_id.empty()) {
testPariEvalCode(seq_id, expected_eval_code, asVector);
seq_id.clear();
expected_eval_code.clear();
num_tests++;
}
} else if (line.substr(0, 2) == "\\\\") {
seq_id = line.substr(2);
trimString(seq_id);
expected_eval_code.clear();
} else {
expected_eval_code += line + "\n";
}
}
if (!seq_id.empty()) {
testPariEvalCode(seq_id, expected_eval_code, asVector);
num_tests++;
}
if (num_tests == 0) {
Log::get().error("No tests found in file: " + testFile, true);
}
}

void Test::stats() {
Log::get().info("Testing stats loading and saving");

Expand Down
4 changes: 4 additions & 0 deletions src/cmd/test.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,8 @@ class Test {

void formula();

void pariEval();

private:
std::vector<std::pair<Program, Program>> loadInOutTests(
const std::string &prefix);
Expand All @@ -81,6 +83,8 @@ class Test {

void testMatcherPair(Matcher &matcher, size_t id1, size_t id2);

void testPariEval(const std::string &testFile, bool asVector);

OeisManager &getManager();

Settings settings;
Expand Down
14 changes: 0 additions & 14 deletions src/form/formula_util.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -197,17 +197,3 @@ void FormulaUtil::convertInitialTermsToIf(Formula& formula) {
}
}
}

Formula FormulaUtil::extractInitialTerms(Formula& formula) {
Formula initial_terms;
auto it = formula.entries.begin();
while (it != formula.entries.end()) {
if (ExpressionUtil::isInitialTerm(it->first)) {
initial_terms.entries[it->first] = it->second;
it = formula.entries.erase(it);
} else {
it++;
}
}
return initial_terms;
}
2 changes: 0 additions & 2 deletions src/form/formula_util.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,4 @@ class FormulaUtil {
static void resolveSimpleRecursions(Formula& formula);

static void convertInitialTermsToIf(Formula& formula);

static Formula extractInitialTerms(Formula& formula);
};
83 changes: 65 additions & 18 deletions src/form/pari.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -70,54 +70,101 @@ bool addLocalVars(Formula& f) {
return changed;
}

bool Pari::convertToPari(Formula& f, bool as_vector) {
Formula tmp;
for (auto& entry : f.entries) {
void PariFormula::extractInitialTerms() {
initial_terms.clear();
auto it = main_formula.entries.begin();
while (it != main_formula.entries.end()) {
if (ExpressionUtil::isInitialTerm(it->first)) {
auto key = it->first;
key.children.front().value += Number::ONE;
initial_terms.entries[key] = it->second;
max_initial_terms[key.name] = std::max(
max_initial_terms[key.name], key.children.front().value.asInt());
it = main_formula.entries.erase(it);
} else {
it++;
}
}
}

bool PariFormula::convert(const Formula& formula, bool as_vector,
PariFormula& pari_formula) {
pari_formula = PariFormula();
pari_formula.as_vector = as_vector;
for (const auto& entry : formula.entries) {
auto left = entry.first;
auto right = entry.second;
if (as_vector && left.type == Expression::Type::FUNCTION) {
left.type = Expression::Type::VECTOR;
}
if (!convertExprToPari(right, f, as_vector)) {
if (!convertExprToPari(right, formula, as_vector)) {
return false;
}
tmp.entries[left] = right;
pari_formula.main_formula.entries[left] = right;
}
f = tmp;
if (!as_vector) {
addLocalVars(f);
FormulaUtil::convertInitialTermsToIf(f);
if (as_vector) {
pari_formula.extractInitialTerms();
} else {
addLocalVars(pari_formula.main_formula);
FormulaUtil::convertInitialTermsToIf(pari_formula.main_formula);
}
return true;
}

void Pari::printEvalCode(const Formula& f, int64_t start, int64_t end,
std::ostream& out, bool as_vector) {
std::string PariFormula::toString() const {
if (as_vector) {
// TODO: print initial terms
return main_formula.toString("; ", false) + "; " +
initial_terms.toString("; ", false);
} else {
out << f.toString("; ", true) << std::endl;
return main_formula.toString("; ", true);
}
out << "for (n = " << start << ", " << end << ", ";
}

void PariFormula::printEvalCode(int64_t numTerms, std::ostream& out) const {
if (as_vector) {
// TODO: print function definition
// declare vectors
auto functions = main_formula.getDefinitions(Expression::Type::VECTOR);
for (const auto& f : functions) {
out << f << " = vector(" << numTerms << ")" << std::endl;
}
// initial terms only
out << initial_terms.toString("\n", false) << std::endl;
} else {
// main function
out << main_formula.toString("; ", true) << std::endl;
}
const int64_t start = as_vector ? 1 : 0;
const int64_t end = numTerms + start - 1;
out << "for(n=" << start << "," << end << ",";
if (as_vector) {
auto sorted = main_formula.getDefinitions(Expression::Type::VECTOR, true);
for (const auto& f : sorted) {
auto key = ExpressionUtil::newFunction(f);
key.type = Expression::Type::VECTOR;
if (max_initial_terms.find(f) != max_initial_terms.end()) {
out << "if(n>" << max_initial_terms.at(f) << ", ";
out << f << "[n] = " << main_formula.entries.at(key).toString()
<< "); ";
} else {
out << f << "[n] = " << main_formula.entries.at(key).toString() << "; ";
}
}
out << "print(a[n])";
} else {
out << "print(a(n))";
}
out << ")" << std::endl << "quit" << std::endl;
}

Sequence Pari::eval(const Formula& f, int64_t start, int64_t end,
bool as_vector) {
Sequence PariFormula::eval(int64_t numTerms) const {
const std::string gpPath("pari-loda.gp");
const std::string gpResult("pari-result.txt");
const int64_t maxparisize = 256; // in MB
std::ofstream gp(gpPath);
if (!gp) {
throw std::runtime_error("error generating gp file");
}
printEvalCode(f, start, end, gp, as_vector);
printEvalCode(numTerms, gp);
gp.close();
std::string cmd = "gp -s " + std::to_string(maxparisize) + "M -q " + gpPath +
" > " + gpResult;
Expand Down
23 changes: 17 additions & 6 deletions src/form/pari.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,24 @@
* Generated PARI/GP code:
* a(n) = if(n==0,1,n*a(n-1))
*/
class Pari {
class PariFormula {
public:
static bool convertToPari(Formula& f, bool as_vector = false);
PariFormula() : as_vector(false){};

static void printEvalCode(const Formula& f, int64_t start, int64_t end,
std::ostream& out, bool as_vector = false);
static bool convert(const Formula& formula, bool as_vector,
PariFormula& pari_formula);

static Sequence eval(const Formula& f, int64_t start, int64_t end,
bool as_vector = false);
void printEvalCode(int64_t numTerms, std::ostream& out) const;

std::string toString() const;

Sequence eval(int64_t numTerms) const;

private:
Formula main_formula;
Formula initial_terms;
bool as_vector;
std::map<std::string, int64_t> max_initial_terms;

void extractInitialTerms();
};
4 changes: 4 additions & 0 deletions tests/formula/program-pari-recursive.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
\\ A000058
(a(n) = b(n)+1); (b(n) = if(n==0,1,local(l1=b(n-1)); l1*(l1+1)))
for(n=0,9,print(a(n)))
quit
6 changes: 6 additions & 0 deletions tests/formula/program-pari-vector.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
\\ A000058
a = vector(10)
b = vector(10)
b[1] = 1
for(n=1,10,if(n>1, b[n] = b[n-1]*(b[n-1]+1)); a[n] = b[n]+1; print(a[n]))
quit
Loading