Skip to content

Commit

Permalink
Document ExpressionGraph algorithm as topological sort (SleipnirGroup…
Browse files Browse the repository at this point in the history
  • Loading branch information
calcmogul authored Jan 21, 2025
1 parent 14076dc commit 1af6098
Show file tree
Hide file tree
Showing 4 changed files with 42 additions and 51 deletions.
37 changes: 17 additions & 20 deletions include/sleipnir/autodiff/AdjointExpressionGraph.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,12 @@ namespace sleipnir::detail {

/**
* This class is an adaptor type that performs value updates of an expression's
* adjoint graph in a way that skips duplicates.
* adjoint graph.
*/
class AdjointExpressionGraph {
public:
/**
* Generates the deduplicated adjoint graph for the given expression.
* Generates the adjoint graph for the given expression.
*
* @param root The root node of the expression.
*/
Expand All @@ -31,34 +31,33 @@ class AdjointExpressionGraph {
return;
}

// Breadth-first search (BFS) is used as opposed to a depth-first search
// (DFS) to avoid counting duplicate nodes multiple times. A list of nodes
// ordered from parent to child with no duplicates is generated.
//
// https://en.wikipedia.org/wiki/Breadth-first_search

// Stack of nodes to explore
small_vector<Expression*> stack;

// Assign each node's number of instances in the tree to
// Expression::duplications
// Enumerate incoming edges for each node via depth-first search
stack.emplace_back(root.expr.Get());
while (!stack.empty()) {
auto node = stack.back();
stack.pop_back();

for (auto& arg : node->args) {
if (arg != nullptr) {
// If this is the first instance of the node encountered (it hasn't
// been explored yet), add it to stack so it's recursed upon
if (arg->duplications == 0) {
// If the node hasn't been explored yet, add it to the stack
if (arg->incomingEdges == 0) {
stack.push_back(arg.Get());
}
++arg->duplications;
++arg->incomingEdges;
}
}
}

// Generate BFS lists sorted from parent to child
// Generate topological sort of graph from parent to child.
//
// A node is only added to the stack after all its incoming edges have been
// traversed. Expression::incomingEdges is a decrementing counter for
// tracking this.
//
// https://en.wikipedia.org/wiki/Topological_sorting
stack.emplace_back(root.expr.Get());
while (!stack.empty()) {
auto node = stack.back();
Expand All @@ -74,11 +73,9 @@ class AdjointExpressionGraph {

for (auto& arg : node->args) {
if (arg != nullptr) {
// Once the number of node visitations equals the number of
// duplications (the counter hits zero), add it to the stack. Note
// that this means the node is only enqueued once.
--arg->duplications;
if (arg->duplications == 0) {
// If we traversed all this node's incoming edges, add it to the stack
--arg->incomingEdges;
if (arg->incomingEdges == 0) {
stack.push_back(arg.Get());
}
}
Expand Down
5 changes: 2 additions & 3 deletions include/sleipnir/autodiff/Expression.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -79,9 +79,8 @@ struct Expression {
/// The adjoint of the expression node used during autodiff.
double adjoint = 0.0;

/// Tracks the number of instances of this expression yet to be encountered in
/// an expression tree.
uint32_t duplications = 0;
/// Counts incoming edges for this node.
uint32_t incomingEdges = 0;

/// This expression's column in a Jacobian, or -1 otherwise.
int32_t col = -1;
Expand Down
37 changes: 17 additions & 20 deletions include/sleipnir/autodiff/ValueExpressionGraph.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,12 @@ namespace sleipnir::detail {

/**
* This class is an adaptor type that performs value updates of an expression's
* value graph in a way that skips duplicates.
* value graph.
*/
class ValueExpressionGraph {
public:
/**
* Generates the deduplicated value graph for the given expression.
* Generates the value graph for the given expression.
*
* @param root The root node of the expression.
*/
Expand All @@ -27,34 +27,33 @@ class ValueExpressionGraph {
return;
}

// Breadth-first search (BFS) is used as opposed to a depth-first search
// (DFS) to avoid counting duplicate nodes multiple times. A list of nodes
// ordered from parent to child with no duplicates is generated.
//
// https://en.wikipedia.org/wiki/Breadth-first_search

// Stack of nodes to explore
small_vector<Expression*> stack;

// Assign each node's number of instances in the tree to
// Expression::duplications
// Enumerate incoming edges for each node via depth-first search
stack.emplace_back(root.Get());
while (!stack.empty()) {
auto node = stack.back();
stack.pop_back();

for (auto& arg : node->args) {
if (arg != nullptr) {
// If this is the first instance of the node encountered (it hasn't
// been explored yet), add it to stack so it's recursed upon
if (arg->duplications == 0) {
// If the node hasn't been explored yet, add it to the stack
if (arg->incomingEdges == 0) {
stack.push_back(arg.Get());
}
++arg->duplications;
++arg->incomingEdges;
}
}
}

// Generate BFS lists sorted from parent to child
// Generate topological sort of graph from parent to child.
//
// A node is only added to the stack after all its incoming edges have been
// traversed. Expression::incomingEdges is a decrementing counter for
// tracking this.
//
// https://en.wikipedia.org/wiki/Topological_sorting
stack.emplace_back(root.Get());
while (!stack.empty()) {
auto node = stack.back();
Expand All @@ -68,11 +67,9 @@ class ValueExpressionGraph {

for (auto& arg : node->args) {
if (arg != nullptr) {
// Once the number of node visitations equals the number of
// duplications (the counter hits zero), add it to the stack. Note
// that this means the node is only enqueued once.
--arg->duplications;
if (arg->duplications == 0) {
// If we traversed all this node's incoming edges, add it to the stack
--arg->incomingEdges;
if (arg->incomingEdges == 0) {
stack.push_back(arg.Get());
}
}
Expand Down
14 changes: 6 additions & 8 deletions jormungandr/cpp/Docstrings.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -1952,12 +1952,12 @@ static const char *__doc_sleipnir_detail_AcosExpression_Value = R"doc()doc";

static const char *__doc_sleipnir_detail_AdjointExpressionGraph =
R"doc(This class is an adaptor type that performs value updates of an
expression's adjoint graph in a way that skips duplicates.)doc";
expression's adjoint graph.)doc";

static const char *__doc_sleipnir_detail_AdjointExpressionGraph_2 = R"doc()doc";

static const char *__doc_sleipnir_detail_AdjointExpressionGraph_AdjointExpressionGraph =
R"doc(Generates the deduplicated adjoint graph for the given expression.
R"doc(Generates the adjoint graph for the given expression.
Parameter ``root``:
The root node of the expression.)doc";
Expand Down Expand Up @@ -2318,9 +2318,7 @@ static const char *__doc_sleipnir_detail_Expression_args = R"doc(Expression argu

static const char *__doc_sleipnir_detail_Expression_col = R"doc(This expression's column in a Jacobian, or -1 otherwise.)doc";

static const char *__doc_sleipnir_detail_Expression_duplications =
R"doc(Tracks the number of instances of this expression yet to be
encountered in an expression tree.)doc";
static const char *__doc_sleipnir_detail_Expression_incomingEdges = R"doc(Counts incoming edges for this node.)doc";

static const char *__doc_sleipnir_detail_Expression_refCount = R"doc(Reference count for intrusive shared pointer.)doc";

Expand Down Expand Up @@ -2579,18 +2577,18 @@ static const char *__doc_sleipnir_detail_UnaryMinusExpression_Value = R"doc()doc

static const char *__doc_sleipnir_detail_ValueExpressionGraph =
R"doc(This class is an adaptor type that performs value updates of an
expression's value graph in a way that skips duplicates.)doc";
expression's value graph.)doc";

static const char *__doc_sleipnir_detail_ValueExpressionGraph_2 =
R"doc(This class is an adaptor type that performs value updates of an
expression's value graph in a way that skips duplicates.)doc";
expression's value graph.)doc";

static const char *__doc_sleipnir_detail_ValueExpressionGraph_Update =
R"doc(Update the values of all nodes in this value graph based on the values
of their dependent nodes.)doc";

static const char *__doc_sleipnir_detail_ValueExpressionGraph_ValueExpressionGraph =
R"doc(Generates the deduplicated value graph for the given expression.
R"doc(Generates the value graph for the given expression.
Parameter ``root``:
The root node of the expression.)doc";
Expand Down

0 comments on commit 1af6098

Please sign in to comment.