Skip to content

Commit

Permalink
Merge pull request #95 from paulsengroup/perf/balance
Browse files Browse the repository at this point in the history
Regression is caused by calling operator()() of MargsVect inside a loop.
This operator now involves a non-negligible amount of computation when processing high-resolution matrices.
Also remove an unnecessary copy of MargsVect when processing chunked matrices.
  • Loading branch information
robomics authored Dec 23, 2023
2 parents bbc175f + c8f2e06 commit 456404b
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 22 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -414,8 +414,9 @@ inline void ICE::min_nnz_filtering(MargsVector& marg, const MatrixT& matrix,
nonstd::span<double> biases, std::size_t min_nnz,
BS::thread_pool* tpool) {
matrix.marginalize_nnz(marg, tpool);
const auto marg_ = marg();
for (std::size_t i = 0; i < biases.size(); ++i) {
if (marg()[i] < static_cast<double>(min_nnz)) {
if (marg_[i] < static_cast<double>(min_nnz)) {
biases[i] = 0;
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -370,17 +370,10 @@ inline void SparseMatrixChunked::marginalize(MargsVector& marg, BS::thread_pool*
fs.exceptions(_fs.exceptions());
fs.open(_path, std::ios::in | std::ios::binary);
auto matrix = _matrix;
MargsVector marg_local(marg.size());
for (const auto offset : nonstd::span(_index).subspan(istart, iend - istart)) {
fs.seekg(offset);
matrix.deserialize(fs, *zstd_dctx);
matrix.marginalize(marg_local, nullptr, false);
}

for (std::size_t i = 0; i < marg_local.size(); ++i) {
if (marg_local[i] != 0) {
marg.add(i, marg_local[i]);
}
matrix.marginalize(marg, nullptr, false);
}
};

Expand Down Expand Up @@ -413,17 +406,11 @@ inline void SparseMatrixChunked::marginalize_nnz(MargsVector& marg, BS::thread_p
fs.exceptions(_fs.exceptions());
fs.open(_path, std::ios::in | std::ios::binary);
auto matrix = _matrix;
MargsVector marg_local(marg.size());
for (const auto offset : nonstd::span(_index).subspan(istart, iend - istart)) {
fs.seekg(offset);
matrix.deserialize(fs, *zstd_dctx);
matrix.marginalize_nnz(marg, nullptr, false);
}
for (std::size_t i = 0; i < marg_local.size(); ++i) {
if (marg_local[i] != 0) {
marg.add(i, marg_local[i]);
}
}
};

assert(!marg.empty());
Expand Down Expand Up @@ -457,16 +444,10 @@ inline void SparseMatrixChunked::times_outer_product_marg(MargsVector& marg,
fs.exceptions(_fs.exceptions());
fs.open(_path, std::ios::in | std::ios::binary);
auto matrix = _matrix;
MargsVector marg_local(marg.size());
for (const auto offset : nonstd::span(_index).subspan(istart, iend - istart)) {
fs.seekg(offset);
matrix.deserialize(fs, *zstd_dctx);
matrix.times_outer_product_marg(marg_local, biases, weights, nullptr, false);
}
for (std::size_t i = 0; i < marg.size(); ++i) {
if (marg_local[i] != 0) {
marg.add(i, marg_local[i]);
}
matrix.times_outer_product_marg(marg, biases, weights, nullptr, false);
}
};

Expand Down

0 comments on commit 456404b

Please sign in to comment.