Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Replace iterator loops with ranges #682

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 12 additions & 16 deletions include/sleipnir/autodiff/ExpressionGraph.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

#pragma once

#include <ranges>
#include <span>

#include "sleipnir/autodiff/Expression.hpp"
Expand Down Expand Up @@ -42,10 +43,10 @@ class ExpressionGraph {
// Initialize the number of instances of each node in the tree
// (Expression::duplications)
while (!stack.empty()) {
auto currentNode = stack.back();
auto node = stack.back();
stack.pop_back();

for (auto&& arg : currentNode->args) {
for (auto&& arg : node->args) {
// Only continue if the node is not a constant and hasn't already been
// explored.
if (arg != nullptr && arg->Type() != ExpressionType::kConstant) {
Expand All @@ -62,19 +63,19 @@ class ExpressionGraph {
stack.emplace_back(root.Get());

while (!stack.empty()) {
auto currentNode = stack.back();
auto node = stack.back();
stack.pop_back();

// BFS lists sorted from parent to child.
m_rowList.emplace_back(currentNode->row);
m_adjointList.emplace_back(currentNode);
if (currentNode->args[0] != nullptr) {
m_rowList.emplace_back(node->row);
m_adjointList.emplace_back(node);
if (node->args[0] != nullptr) {
// Constants (expressions with no arguments) are skipped because they
// don't need to be updated
m_valueList.emplace_back(currentNode);
m_valueList.emplace_back(node);
}

for (auto&& arg : currentNode->args) {
for (auto&& arg : node->args) {
// Only add node if it's not a constant and doesn't already exist in the
// tape.
if (arg != nullptr && arg->Type() != ExpressionType::kConstant) {
Expand All @@ -97,9 +98,7 @@ class ExpressionGraph {
void Update() {
// Traverse the BFS list backward from child to parent and update the value
// of each node.
for (auto it = m_valueList.rbegin(); it != m_valueList.rend(); ++it) {
auto& node = *it;

for (auto& node : m_valueList | std::views::reverse) {
auto& lhs = node->args[0];
auto& rhs = node->args[1];

Expand Down Expand Up @@ -136,9 +135,7 @@ class ExpressionGraph {
// Zero adjoints. The root node's adjoint is 1.0 as df/df is always 1.
if (m_adjointList.size() > 0) {
m_adjointList[0]->adjointExpr = MakeExpressionPtr<ConstExpression>(1.0);
for (auto it = m_adjointList.begin() + 1; it != m_adjointList.end();
++it) {
auto& node = *it;
for (auto& node : m_adjointList | std::views::drop(1)) {
node->adjointExpr = MakeExpressionPtr<ConstExpression>();
}
}
Expand Down Expand Up @@ -194,8 +191,7 @@ class ExpressionGraph {
void ComputeAdjoints(function_ref<void(int row, double adjoint)> func) {
// Zero adjoints. The root node's adjoint is 1.0 as df/df is always 1.
m_adjointList[0]->adjoint = 1.0;
for (auto it = m_adjointList.begin() + 1; it != m_adjointList.end(); ++it) {
auto& node = *it;
for (auto& node : m_adjointList | std::views::drop(1)) {
node->adjoint = 0.0;
}

Expand Down
2 changes: 1 addition & 1 deletion src/util/PrintIterationDiagnostics.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ void PrintIterationDiagnostics(int iterations, bool feasibilityRestoration,
}
constexpr const char* strs[] = {"⁰", "¹", "²", "³", "⁴",
"⁵", "⁶", "⁷", "⁸", "⁹"};
for (const auto& digit : std::ranges::views::reverse(digits)) {
for (const auto& digit : digits | std::views::reverse) {
reg += strs[digit];
}

Expand Down