From 9f0f6c111384f2116613ea95d816fa43dc3fc1f3 Mon Sep 17 00:00:00 2001 From: Tyler Veness Date: Sat, 11 Jan 2025 14:48:54 -0800 Subject: [PATCH] Replace iterator loops with ranges --- include/sleipnir/autodiff/ExpressionGraph.hpp | 28 ++++++++----------- src/util/PrintIterationDiagnostics.hpp | 2 +- 2 files changed, 13 insertions(+), 17 deletions(-) diff --git a/include/sleipnir/autodiff/ExpressionGraph.hpp b/include/sleipnir/autodiff/ExpressionGraph.hpp index 5e9be8de..038cceef 100644 --- a/include/sleipnir/autodiff/ExpressionGraph.hpp +++ b/include/sleipnir/autodiff/ExpressionGraph.hpp @@ -2,6 +2,7 @@ #pragma once +#include #include #include "sleipnir/autodiff/Expression.hpp" @@ -42,10 +43,10 @@ class ExpressionGraph { // Initialize the number of instances of each node in the tree // (Expression::duplications) while (!stack.empty()) { - auto currentNode = stack.back(); + auto node = stack.back(); stack.pop_back(); - for (auto&& arg : currentNode->args) { + for (auto&& arg : node->args) { // Only continue if the node is not a constant and hasn't already been // explored. if (arg != nullptr && arg->Type() != ExpressionType::kConstant) { @@ -62,19 +63,19 @@ class ExpressionGraph { stack.emplace_back(root.Get()); while (!stack.empty()) { - auto currentNode = stack.back(); + auto node = stack.back(); stack.pop_back(); // BFS lists sorted from parent to child. - m_rowList.emplace_back(currentNode->row); - m_adjointList.emplace_back(currentNode); - if (currentNode->args[0] != nullptr) { + m_rowList.emplace_back(node->row); + m_adjointList.emplace_back(node); + if (node->args[0] != nullptr) { // Constants (expressions with no arguments) are skipped because they // don't need to be updated - m_valueList.emplace_back(currentNode); + m_valueList.emplace_back(node); } - for (auto&& arg : currentNode->args) { + for (auto&& arg : node->args) { // Only add node if it's not a constant and doesn't already exist in the // tape. if (arg != nullptr && arg->Type() != ExpressionType::kConstant) { @@ -97,9 +98,7 @@ class ExpressionGraph { void Update() { // Traverse the BFS list backward from child to parent and update the value // of each node. - for (auto it = m_valueList.rbegin(); it != m_valueList.rend(); ++it) { - auto& node = *it; - + for (auto& node : m_valueList | std::views::reverse) { auto& lhs = node->args[0]; auto& rhs = node->args[1]; @@ -136,9 +135,7 @@ class ExpressionGraph { // Zero adjoints. The root node's adjoint is 1.0 as df/df is always 1. if (m_adjointList.size() > 0) { m_adjointList[0]->adjointExpr = MakeExpressionPtr(1.0); - for (auto it = m_adjointList.begin() + 1; it != m_adjointList.end(); - ++it) { - auto& node = *it; + for (auto& node : m_adjointList | std::views::drop(1)) { node->adjointExpr = MakeExpressionPtr(); } } @@ -194,8 +191,7 @@ class ExpressionGraph { void ComputeAdjoints(function_ref func) { // Zero adjoints. The root node's adjoint is 1.0 as df/df is always 1. m_adjointList[0]->adjoint = 1.0; - for (auto it = m_adjointList.begin() + 1; it != m_adjointList.end(); ++it) { - auto& node = *it; + for (auto& node : m_adjointList | std::views::drop(1)) { node->adjoint = 0.0; } diff --git a/src/util/PrintIterationDiagnostics.hpp b/src/util/PrintIterationDiagnostics.hpp index 3234915f..d02c9791 100644 --- a/src/util/PrintIterationDiagnostics.hpp +++ b/src/util/PrintIterationDiagnostics.hpp @@ -69,7 +69,7 @@ void PrintIterationDiagnostics(int iterations, bool feasibilityRestoration, } constexpr const char* strs[] = {"⁰", "¹", "²", "³", "⁴", "⁵", "⁶", "⁷", "⁸", "⁹"}; - for (const auto& digit : std::ranges::views::reverse(digits)) { + for (const auto& digit : digits | std::views::reverse) { reg += strs[digit]; }