Skip to content

Commit

Permalink
Make Variable::expr private
Browse files Browse the repository at this point in the history
  • Loading branch information
calcmogul committed Dec 14, 2023
1 parent 737b730 commit 33075cb
Show file tree
Hide file tree
Showing 3 changed files with 41 additions and 6 deletions.
41 changes: 38 additions & 3 deletions include/sleipnir/autodiff/Variable.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,17 @@

namespace sleipnir {

// Forward declarations for friend declarations in Variable
class SLEIPNIR_DLLEXPORT Jacobian;
namespace detail {
class SLEIPNIR_DLLEXPORT ExpressionGraph;
} // namespace detail

/**
* An autodiff variable pointing to an expression node.
*/
class SLEIPNIR_DLLEXPORT Variable {
public:
/// The expression node.
detail::ExpressionPtr expr = detail::Zero();

/**
* Constructs a Variable initialized to zero.
*/
Expand Down Expand Up @@ -168,6 +171,38 @@ class SLEIPNIR_DLLEXPORT Variable {
* variables.
*/
void Update();

private:
/// The expression node.
detail::ExpressionPtr expr = detail::Zero();

friend SLEIPNIR_DLLEXPORT Variable abs(const Variable& x);
friend SLEIPNIR_DLLEXPORT Variable acos(const Variable& x);
friend SLEIPNIR_DLLEXPORT Variable asin(const Variable& x);
friend SLEIPNIR_DLLEXPORT Variable atan(const Variable& x);
friend SLEIPNIR_DLLEXPORT Variable atan2(const Variable& y,
const Variable& x);
friend SLEIPNIR_DLLEXPORT Variable cos(const Variable& x);
friend SLEIPNIR_DLLEXPORT Variable cosh(const Variable& x);
friend SLEIPNIR_DLLEXPORT Variable erf(const Variable& x);
friend SLEIPNIR_DLLEXPORT Variable exp(const Variable& x);
friend SLEIPNIR_DLLEXPORT Variable hypot(const Variable& x,
const Variable& y);
friend SLEIPNIR_DLLEXPORT Variable hypot(const Variable& x, const Variable& y,
const Variable& z);
friend SLEIPNIR_DLLEXPORT Variable log(const Variable& x);
friend SLEIPNIR_DLLEXPORT Variable log10(const Variable& x);
friend SLEIPNIR_DLLEXPORT Variable pow(const Variable& base,
const Variable& power);
friend SLEIPNIR_DLLEXPORT Variable sign(const Variable& x);
friend SLEIPNIR_DLLEXPORT Variable sin(const Variable& x);
friend SLEIPNIR_DLLEXPORT Variable sinh(const Variable& x);
friend SLEIPNIR_DLLEXPORT Variable sqrt(const Variable& x);
friend SLEIPNIR_DLLEXPORT Variable tan(const Variable& x);
friend SLEIPNIR_DLLEXPORT Variable tanh(const Variable& x);

friend class SLEIPNIR_DLLEXPORT Jacobian;
friend class SLEIPNIR_DLLEXPORT detail::ExpressionGraph;
};

using VectorXvar = Eigen::Vector<Variable, Eigen::Dynamic>;
Expand Down
4 changes: 2 additions & 2 deletions src/autodiff/Jacobian.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,14 +20,14 @@ Jacobian::Jacobian(VectorXvar variables, VectorXvar wrt) noexcept
m_cachedTriplets.reserve(m_variables.rows() * m_wrt.rows() * 0.01);

for (int row = 0; row < m_variables.rows(); ++row) {
if (m_variables(row).expr->type == ExpressionType::kLinear) {
if (m_variables(row).Type() == ExpressionType::kLinear) {
// If the row is linear, compute its gradient once here and cache its
// triplets. Constant rows are ignored because their gradients have no
// nonzero triplets.
m_graphs[row].ComputeAdjoints([&](int col, double adjoint) {
m_cachedTriplets.emplace_back(row, col, adjoint);
});
} else if (m_variables(row).expr->type > ExpressionType::kLinear) {
} else if (m_variables(row).Type() > ExpressionType::kLinear) {
// If the row is quadratic or nonlinear, add it to the list of nonlinear
// rows to be recomputed in Calculate().
m_nonlinearRows.emplace_back(row);
Expand Down
2 changes: 1 addition & 1 deletion test/src/autodiff/VariableTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,5 +7,5 @@ TEST(VariableTest, DefaultConstructor) {
sleipnir::Variable a;

EXPECT_EQ(0.0, a.Value());
EXPECT_EQ(sleipnir::ExpressionType::kConstant, a.expr->type);
EXPECT_EQ(sleipnir::ExpressionType::kConstant, a.Type());
}

0 comments on commit 33075cb

Please sign in to comment.