diff --git a/include/sleipnir/autodiff/VariableMatrix.hpp b/include/sleipnir/autodiff/VariableMatrix.hpp index 329ce60d..6f8c30d4 100644 --- a/include/sleipnir/autodiff/VariableMatrix.hpp +++ b/include/sleipnir/autodiff/VariableMatrix.hpp @@ -38,7 +38,8 @@ class SLEIPNIR_DLLEXPORT VariableMatrix { * @param rows The number of matrix rows. */ explicit VariableMatrix(int rows) : m_rows{rows}, m_cols{1} { - for (int row = 0; row < rows; ++row) { + m_storage.reserve(Rows()); + for (int row = 0; row < Rows(); ++row) { m_storage.emplace_back(); } } @@ -50,8 +51,9 @@ class SLEIPNIR_DLLEXPORT VariableMatrix { * @param cols The number of matrix columns. */ VariableMatrix(int rows, int cols) : m_rows{rows}, m_cols{cols} { - for (int row = 0; row < rows; ++row) { - for (int col = 0; col < cols; ++col) { + m_storage.reserve(Rows() * Cols()); + for (int row = 0; row < Rows(); ++row) { + for (int col = 0; col < Cols(); ++col) { m_storage.emplace_back(); } } @@ -238,6 +240,7 @@ class SLEIPNIR_DLLEXPORT VariableMatrix { */ VariableMatrix(const VariableBlock& values) // NOLINT : m_rows{values.Rows()}, m_cols{values.Cols()} { + m_storage.reserve(Rows() * Cols()); for (int row = 0; row < Rows(); ++row) { for (int col = 0; col < Cols(); ++col) { m_storage.emplace_back(values(row, col)); @@ -252,6 +255,7 @@ class SLEIPNIR_DLLEXPORT VariableMatrix { */ VariableMatrix(const VariableBlock& values) // NOLINT : m_rows{values.Rows()}, m_cols{values.Cols()} { + m_storage.reserve(Rows() * Cols()); for (int row = 0; row < Rows(); ++row) { for (int col = 0; col < Cols(); ++col) { m_storage.emplace_back(values(row, col)); @@ -266,6 +270,7 @@ class SLEIPNIR_DLLEXPORT VariableMatrix { */ explicit VariableMatrix(std::span values) : m_rows{static_cast(values.size())}, m_cols{1} { + m_storage.reserve(Rows() * Cols()); for (int row = 0; row < Rows(); ++row) { for (int col = 0; col < Cols(); ++col) { m_storage.emplace_back(values[row * Cols() + col]); @@ -283,6 +288,7 @@ class SLEIPNIR_DLLEXPORT VariableMatrix { VariableMatrix(std::span values, int rows, int cols) : m_rows{rows}, m_cols{cols} { Assert(static_cast(values.size()) == Rows() * Cols()); + m_storage.reserve(Rows() * Cols()); for (int row = 0; row < Rows(); ++row) { for (int col = 0; col < Cols(); ++col) { m_storage.emplace_back(values[row * Cols() + col]);