diff --git a/src/form/expression_util.cpp b/src/form/expression_util.cpp index adfda76b..78bf8f04 100644 --- a/src/form/expression_util.cpp +++ b/src/form/expression_util.cpp @@ -1,5 +1,7 @@ #include "form/expression_util.hpp" +#include "lang/semantics.hpp" + Expression ExpressionUtil::newConstant(int64_t value) { return Expression(Expression::Type::CONSTANT, "", Number(value)); } @@ -314,3 +316,55 @@ void ExpressionUtil::collectNames(const Expression& e, Expression::Type type, collectNames(*c, type, target); } } + +void assertNumChildren(const Expression& e, size_t num) { + if (e.children.size() != num) { + throw std::runtime_error("unexpected number of terms in " + e.toString()); + } +} + +Number ExpressionUtil::eval(const Expression& e, + const std::map params) { + switch (e.type) { + case Expression::Type::CONSTANT: { + return e.value; + } + case Expression::Type::PARAMETER: { + return params.at(e.name); + } + case Expression::Type::SUM: { + auto result = Number::ZERO; + for (auto c : e.children) { + result = Semantics::add(result, eval(*c, params)); + } + return result; + } + case Expression::Type::PRODUCT: { + auto result = Number::ONE; + for (auto c : e.children) { + result = Semantics::mul(result, eval(*c, params)); + } + return result; + } + case Expression::Type::FRACTION: { + assertNumChildren(e, 2); + auto a = eval(*e.children[0], params); + auto b = eval(*e.children[1], params); + return Semantics::div(a, b); + } + case Expression::Type::POWER: { + assertNumChildren(e, 2); + auto a = eval(*e.children[0], params); + auto b = eval(*e.children[1], params); + return Semantics::pow(a, b); + } + case Expression::Type::MODULUS: { + assertNumChildren(e, 2); + auto a = eval(*e.children[0], params); + auto b = eval(*e.children[1], params); + return Semantics::mod(a, b); + } + default: + throw std::runtime_error("cannot evaluate " + e.toString()); + } +} diff --git a/src/form/expression_util.hpp b/src/form/expression_util.hpp index f7018c3e..0a1b98a0 100644 --- a/src/form/expression_util.hpp +++ b/src/form/expression_util.hpp @@ -1,5 +1,6 @@ #pragma once +#include #include #include "form/expression.hpp" @@ -20,4 +21,7 @@ class ExpressionUtil { static void collectNames(const Expression& e, Expression::Type type, std::set& target); + + static Number eval(const Expression& e, + const std::map params); };