Skip to content

Commit

Permalink
Reserve space in VariableMatrix constructors (#689)
Browse files Browse the repository at this point in the history
  • Loading branch information
calcmogul authored Jan 17, 2025
1 parent 1786e64 commit 59d55f1
Showing 1 changed file with 9 additions and 3 deletions.
12 changes: 9 additions & 3 deletions include/sleipnir/autodiff/VariableMatrix.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,8 @@ class SLEIPNIR_DLLEXPORT VariableMatrix {
* @param rows The number of matrix rows.
*/
explicit VariableMatrix(int rows) : m_rows{rows}, m_cols{1} {
for (int row = 0; row < rows; ++row) {
m_storage.reserve(Rows());
for (int row = 0; row < Rows(); ++row) {
m_storage.emplace_back();
}
}
Expand All @@ -50,8 +51,9 @@ class SLEIPNIR_DLLEXPORT VariableMatrix {
* @param cols The number of matrix columns.
*/
VariableMatrix(int rows, int cols) : m_rows{rows}, m_cols{cols} {
for (int row = 0; row < rows; ++row) {
for (int col = 0; col < cols; ++col) {
m_storage.reserve(Rows() * Cols());
for (int row = 0; row < Rows(); ++row) {
for (int col = 0; col < Cols(); ++col) {
m_storage.emplace_back();
}
}
Expand Down Expand Up @@ -238,6 +240,7 @@ class SLEIPNIR_DLLEXPORT VariableMatrix {
*/
VariableMatrix(const VariableBlock<VariableMatrix>& values) // NOLINT
: m_rows{values.Rows()}, m_cols{values.Cols()} {
m_storage.reserve(Rows() * Cols());
for (int row = 0; row < Rows(); ++row) {
for (int col = 0; col < Cols(); ++col) {
m_storage.emplace_back(values(row, col));
Expand All @@ -252,6 +255,7 @@ class SLEIPNIR_DLLEXPORT VariableMatrix {
*/
VariableMatrix(const VariableBlock<const VariableMatrix>& values) // NOLINT
: m_rows{values.Rows()}, m_cols{values.Cols()} {
m_storage.reserve(Rows() * Cols());
for (int row = 0; row < Rows(); ++row) {
for (int col = 0; col < Cols(); ++col) {
m_storage.emplace_back(values(row, col));
Expand All @@ -266,6 +270,7 @@ class SLEIPNIR_DLLEXPORT VariableMatrix {
*/
explicit VariableMatrix(std::span<const Variable> values)
: m_rows{static_cast<int>(values.size())}, m_cols{1} {
m_storage.reserve(Rows() * Cols());
for (int row = 0; row < Rows(); ++row) {
for (int col = 0; col < Cols(); ++col) {
m_storage.emplace_back(values[row * Cols() + col]);
Expand All @@ -283,6 +288,7 @@ class SLEIPNIR_DLLEXPORT VariableMatrix {
VariableMatrix(std::span<const Variable> values, int rows, int cols)
: m_rows{rows}, m_cols{cols} {
Assert(static_cast<int>(values.size()) == Rows() * Cols());
m_storage.reserve(Rows() * Cols());
for (int row = 0; row < Rows(); ++row) {
for (int col = 0; col < Cols(); ++col) {
m_storage.emplace_back(values[row * Cols() + col]);
Expand Down

0 comments on commit 59d55f1

Please sign in to comment.