diff --git a/include/sleipnir/autodiff/Variable.hpp b/include/sleipnir/autodiff/Variable.hpp index 0e3980915..32b423b0f 100644 --- a/include/sleipnir/autodiff/Variable.hpp +++ b/include/sleipnir/autodiff/Variable.hpp @@ -12,14 +12,17 @@ namespace sleipnir { +// Forward declarations for friend declarations in Variable +class SLEIPNIR_DLLEXPORT Jacobian; +namespace detail { +class SLEIPNIR_DLLEXPORT ExpressionGraph; +} // namespace detail + /** * An autodiff variable pointing to an expression node. */ class SLEIPNIR_DLLEXPORT Variable { public: - /// The expression node. - detail::ExpressionPtr expr = detail::Zero(); - /** * Constructs a Variable initialized to zero. */ @@ -168,6 +171,38 @@ class SLEIPNIR_DLLEXPORT Variable { * variables. */ void Update(); + + private: + /// The expression node. + detail::ExpressionPtr expr = detail::Zero(); + + friend SLEIPNIR_DLLEXPORT Variable abs(const Variable& x); + friend SLEIPNIR_DLLEXPORT Variable acos(const Variable& x); + friend SLEIPNIR_DLLEXPORT Variable asin(const Variable& x); + friend SLEIPNIR_DLLEXPORT Variable atan(const Variable& x); + friend SLEIPNIR_DLLEXPORT Variable atan2(const Variable& y, + const Variable& x); + friend SLEIPNIR_DLLEXPORT Variable cos(const Variable& x); + friend SLEIPNIR_DLLEXPORT Variable cosh(const Variable& x); + friend SLEIPNIR_DLLEXPORT Variable erf(const Variable& x); + friend SLEIPNIR_DLLEXPORT Variable exp(const Variable& x); + friend SLEIPNIR_DLLEXPORT Variable hypot(const Variable& x, + const Variable& y); + friend SLEIPNIR_DLLEXPORT Variable hypot(const Variable& x, const Variable& y, + const Variable& z); + friend SLEIPNIR_DLLEXPORT Variable log(const Variable& x); + friend SLEIPNIR_DLLEXPORT Variable log10(const Variable& x); + friend SLEIPNIR_DLLEXPORT Variable pow(const Variable& base, + const Variable& power); + friend SLEIPNIR_DLLEXPORT Variable sign(const Variable& x); + friend SLEIPNIR_DLLEXPORT Variable sin(const Variable& x); + friend SLEIPNIR_DLLEXPORT Variable sinh(const Variable& x); + friend SLEIPNIR_DLLEXPORT Variable sqrt(const Variable& x); + friend SLEIPNIR_DLLEXPORT Variable tan(const Variable& x); + friend SLEIPNIR_DLLEXPORT Variable tanh(const Variable& x); + + friend class SLEIPNIR_DLLEXPORT Jacobian; + friend class SLEIPNIR_DLLEXPORT detail::ExpressionGraph; }; using VectorXvar = Eigen::Vector; diff --git a/src/autodiff/Jacobian.cpp b/src/autodiff/Jacobian.cpp index 424a06573..eb40b87e4 100644 --- a/src/autodiff/Jacobian.cpp +++ b/src/autodiff/Jacobian.cpp @@ -20,14 +20,14 @@ Jacobian::Jacobian(VectorXvar variables, VectorXvar wrt) noexcept m_cachedTriplets.reserve(m_variables.rows() * m_wrt.rows() * 0.01); for (int row = 0; row < m_variables.rows(); ++row) { - if (m_variables(row).expr->type == ExpressionType::kLinear) { + if (m_variables(row).Type() == ExpressionType::kLinear) { // If the row is linear, compute its gradient once here and cache its // triplets. Constant rows are ignored because their gradients have no // nonzero triplets. m_graphs[row].ComputeAdjoints([&](int col, double adjoint) { m_cachedTriplets.emplace_back(row, col, adjoint); }); - } else if (m_variables(row).expr->type > ExpressionType::kLinear) { + } else if (m_variables(row).Type() > ExpressionType::kLinear) { // If the row is quadratic or nonlinear, add it to the list of nonlinear // rows to be recomputed in Calculate(). m_nonlinearRows.emplace_back(row); diff --git a/test/src/autodiff/VariableTest.cpp b/test/src/autodiff/VariableTest.cpp index e6f9b4353..8cc240f73 100644 --- a/test/src/autodiff/VariableTest.cpp +++ b/test/src/autodiff/VariableTest.cpp @@ -7,5 +7,5 @@ TEST(VariableTest, DefaultConstructor) { sleipnir::Variable a; EXPECT_EQ(0.0, a.Value()); - EXPECT_EQ(sleipnir::ExpressionType::kConstant, a.expr->type); + EXPECT_EQ(sleipnir::ExpressionType::kConstant, a.Type()); }