Skip to content

Commit

Permalink
Merge pull request #93 from paulsengroup/improve/balancing
Browse files Browse the repository at this point in the history
Improve ICE balancing [ci full]
  • Loading branch information
robomics authored Dec 22, 2023
2 parents f9e3e50 + 9164a59 commit b3aaf1e
Show file tree
Hide file tree
Showing 5 changed files with 66 additions and 105 deletions.
2 changes: 0 additions & 2 deletions conanfile.txt
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@ parallel-hashmap/1.3.11#1e67f4855a3f7cdeb977cc472113baf7
readerwriterqueue/1.0.6#aaa5ff6fac60c2aee591e9e51b063b83
span-lite/0.10.3#1967d71abb32b314387c2ab9c558dd22
spdlog/1.12.0#0e390a2f5c3e96671d0857bc734e4731
xxhash/0.8.2#03fd1c9a839b3f9cdf5ea9742c312187
zstd/1.5.5#b87dc3b185caa4b122979ac4ae8ef7e8

[generators]
Expand Down Expand Up @@ -73,4 +72,3 @@ highfive*:with_eigen=False
highfive*:with_opencv=False
highfive*:with_xtensor=False
spdlog*:header_only=True
xxhash*:utility=False
2 changes: 0 additions & 2 deletions src/libhictk/balancing/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
find_package(bshoshany-thread-pool REQUIRED)
find_package(phmap REQUIRED)
find_package(span-lite REQUIRED)
find_package(xxHash REQUIRED)
find_package(zstd REQUIRED)

add_library(balancing INTERFACE)
Expand All @@ -28,7 +27,6 @@ target_link_system_libraries(
bshoshany-thread-pool::bshoshany-thread-pool
nonstd::span-lite
phmap
xxHash::xxhash
"zstd::libzstd_$<IF:$<BOOL:${BUILD_SHARED_LIBS}>,shared,static>")

target_compile_definitions(balancing INTERFACE span_FEATURE_MAKE_SPAN=1)
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@

#pragma once

#include <xxhash.h>
#include <zstd.h>

#include <BS_thread_pool.hpp>
Expand All @@ -27,99 +26,79 @@

namespace hictk::balancing {

inline MargsVector::MargsVector(std::size_t size_)
: _margs(size_, 0), _mtxes(compute_number_of_mutexes(size_)) {}
inline MargsVector::MargsVector(std::size_t size_, std::size_t decimals)
: _margsi(size_), _margsd(size_), _cfx(static_cast<std::uint64_t>(std::pow(10, decimals - 1))) {
fill(0);
}

inline MargsVector::MargsVector(const MargsVector& other)
: _margs(other._margs.begin(), other._margs.end()), _mtxes(other.size()) {}
: _margsi(other.size()), _margsd(other.size()), _cfx(other._cfx) {
for (std::size_t i = 0; i < size(); ++i) {
_margsi[i] = other._margsi[i].load();
}
}

inline MargsVector& MargsVector::operator=(const MargsVector& other) {
if (this == &other) {
return *this;
}

_margs = other._margs;
_mtxes = std::vector<std::mutex>{other.size()};
_margsi = std::vector<N>(other.size());
for (std::size_t i = 0; i < size(); ++i) {
_margsi[i] = other._margsi[i].load();
}
_margsd = other._margsd;
_cfx = other._cfx;

return *this;
}

inline double MargsVector::operator[](std::size_t i) const noexcept {
assert(i < size());
return _margs[i];
return decode(_margsi[i].load());
}

inline double& MargsVector::operator[](std::size_t i) noexcept {
inline void MargsVector::add(std::size_t i, double n) noexcept {
assert(i < size());
return _margs[i];
_margsi[i] += encode(n);
}

inline void MargsVector::add(std::size_t i, double n) noexcept {
assert(i < size());
[[maybe_unused]] const std::scoped_lock lck(_mtxes[get_mutex_idx(i)]);
_margs[i] += n;
inline const std::vector<double>& MargsVector::operator()() const noexcept {
assert(_margsi.size() == _margsd.size());
for (std::size_t i = 0; i < size(); ++i) {
_margsd[i] = (*this)[i];
}
return _margsd;
}

inline const std::vector<double>& MargsVector::operator()() const noexcept { return _margs; }
inline std::vector<double>& MargsVector::operator()() noexcept { return _margs; }
inline std::vector<double>& MargsVector::operator()() noexcept {
assert(_margsi.size() == _margsd.size());
for (std::size_t i = 0; i < size(); ++i) {
_margsd[i] = (*this)[i];
}
return _margsd;
}

inline void MargsVector::fill(double value) noexcept {
for (auto& n : _margsi) {
n = encode(value);
}
}

inline void MargsVector::fill(double n) noexcept { std::fill(_margs.begin(), _margs.end(), n); }
inline void MargsVector::resize(std::size_t size_) {
if (size_ != size()) {
_margs.resize(size_);
std::vector<std::mutex> v(size_);
std::swap(v, _mtxes);
_margsi = std::vector<N>(size_);
}
}

inline std::size_t MargsVector::size() const noexcept { return _margs.size(); }
inline std::size_t MargsVector::size() const noexcept { return _margsi.size(); }
inline bool MargsVector::empty() const noexcept { return size() == 0; }

inline std::size_t MargsVector::compute_number_of_mutexes(std::size_t size) noexcept {
if (size == 0) {
return 0;
}
const auto nthreads = static_cast<std::size_t>(std::thread::hardware_concurrency());
// Clamping to 2-n is needed because get_pixel_mutex_idx expects the number of
// mutexes to be a multiple of 2
return next_pow2(std::clamp(size, std::size_t(2), 5000 * nthreads));
}

template <typename I, typename>
inline I MargsVector::next_pow2(I n) noexcept {
using ull = unsigned long long;
if constexpr (std::is_signed_v<I>) {
assert(n >= 0);
return conditional_static_cast<I>(next_pow2(static_cast<ull>(n)));
} else {
auto m = conditional_static_cast<ull>(n);
#ifndef __GNUC__
// https://graphics.stanford.edu/~seander/bithacks.html#RoundUpPowerOf2
--m;
m |= m >> 1;
m |= m >> 2;
m |= m >> 4;
m |= m >> 8;
m |= m >> 16;
m |= m >> 32;
return conditional_static_cast<I>(m + 1);
#else
// https://jameshfisher.com/2018/03/30/round-up-power-2/
// https://gcc.gnu.org/onlinedocs/gcc/Other-Builtins.html

return conditional_static_cast<I>(
m <= 1 ? m
: std::uint64_t(1) << (std::uint64_t(64) - std::uint64_t(__builtin_clzll(m - 1))));
#endif
}
inline auto MargsVector::encode(double n) const noexcept -> I {
return static_cast<I>(n * static_cast<double>(_cfx));
}

inline std::size_t MargsVector::get_mutex_idx(std::size_t i) const noexcept {
assert(!_mtxes.empty());
assert(_mtxes.size() % 2 == 0);
i = XXH3_64bits(&i, sizeof(std::size_t));
// equivalent to i % _mtxes.size() when _mtxes.size() % 2 == 0
return i & (_mtxes.size() - 1);
inline double MargsVector::decode(I n) const noexcept {
return static_cast<double>(n) / static_cast<double>(_cfx);
}

inline bool SparseMatrix::empty() const noexcept { return size() == 0; }
Expand Down Expand Up @@ -256,14 +235,9 @@ inline void SparseMatrix::marginalize(MargsVector& marg, BS::thread_pool* tpool,
const auto i1 = _bin1_ids[i];
const auto i2 = _bin2_ids[i];

if (tpool) {
if (_counts[i] != 0) {
marg.add(i1, _counts[i]);
marg.add(i2, _counts[i]);
}
} else {
marg[i1] += _counts[i];
marg[i2] += _counts[i];
if (_counts[i] != 0) {
marg.add(i1, _counts[i]);
marg.add(i2, _counts[i]);
}
}
};
Expand All @@ -288,14 +262,9 @@ inline void SparseMatrix::marginalize_nnz(MargsVector& marg, BS::thread_pool* tp
const auto i1 = _bin1_ids[i];
const auto i2 = _bin2_ids[i];

if (tpool) {
if (_counts[i] != 0) {
marg.add(i1, _counts[i] != 0);
marg.add(i2, _counts[i] != 0);
}
} else {
marg[i1] += _counts[i] != 0;
marg[i2] += _counts[i] != 0;
if (_counts[i] != 0) {
marg.add(i1, _counts[i] != 0);
marg.add(i2, _counts[i] != 0);
}
}
};
Expand Down Expand Up @@ -328,14 +297,9 @@ inline void SparseMatrix::times_outer_product_marg(MargsVector& marg,
const auto w2 = weights.empty() ? 1 : weights[i2];
const auto count = _counts[i] * (w1 * biases[i1]) * (w2 * biases[i2]);

if (tpool) {
if (count != 0) {
marg.add(i1, count);
marg.add(i2, count);
}
} else {
marg[i1] += count;
marg[i2] += count;
if (count != 0) {
marg.add(i1, count);
marg.add(i2, count);
}
}
};
Expand Down
23 changes: 12 additions & 11 deletions src/libhictk/balancing/include/hictk/balancing/sparse_matrix.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,13 @@
#include <zstd.h>

#include <BS_thread_pool.hpp>
#include <atomic>
#include <cstddef>
#include <cstdint>
#include <filesystem>
#include <fstream>
#include <ios>
#include <memory>
#include <mutex>
#include <nonstd/span.hpp>
#include <string>
#include <type_traits>
Expand All @@ -36,12 +36,16 @@ struct default_delete<ZSTD_DCtx_s> {
namespace hictk::balancing {

class MargsVector {
std::vector<double> _margs{};
mutable std::vector<std::mutex> _mtxes;
using I = std::uint64_t;
using N = std::atomic<I>;
std::vector<N> _margsi{};
mutable std::vector<double> _margsd{};
std::uint64_t _cfx{};
const static auto DEFAULT_DECIMAL_DIGITS = 9ULL;

public:
MargsVector() = default;
explicit MargsVector(std::size_t size_);
MargsVector() = delete;
explicit MargsVector(std::size_t size_ = 0, std::size_t decimals = DEFAULT_DECIMAL_DIGITS);

MargsVector(const MargsVector& other);
MargsVector(MargsVector&& other) noexcept = default;
Expand All @@ -52,23 +56,20 @@ class MargsVector {
MargsVector& operator=(MargsVector&& other) noexcept = default;

[[nodiscard]] double operator[](std::size_t i) const noexcept;
[[nodiscard]] double& operator[](std::size_t i) noexcept;
void add(std::size_t i, double n) noexcept;

[[nodiscard]] const std::vector<double>& operator()() const noexcept;
[[nodiscard]] std::vector<double>& operator()() noexcept;

void fill(double n = 0) noexcept;
void fill(double value = 0) noexcept;
void resize(std::size_t size_);

[[nodiscard]] std::size_t size() const noexcept;
[[nodiscard]] bool empty() const noexcept;

private:
static std::size_t compute_number_of_mutexes(std::size_t size) noexcept;
template <typename I, typename = std::enable_if_t<std::is_integral_v<I>>>
[[nodiscard]] static I next_pow2(I n) noexcept;
[[nodiscard]] std::size_t get_mutex_idx(std::size_t i) const noexcept;
auto encode(double n) const noexcept -> I;
double decode(I n) const noexcept;
};

class SparseMatrix {
Expand Down
4 changes: 2 additions & 2 deletions test/units/balancing/balancing_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -50,14 +50,14 @@ namespace hictk::test::balancing {
}

static void compare_weights(const std::vector<double>& weights, const std::vector<double>& expected,
double tol = 1.0e-6) {
double tol = 1.0e-5) {
REQUIRE(weights.size() == expected.size());

for (std::size_t i = 0; i < weights.size(); ++i) {
if (std::isnan(weights[i])) {
CHECK(std::isnan(expected[i]));
} else {
CHECK_THAT(weights[i], Catch::Matchers::WithinAbs(expected[i], tol));
CHECK_THAT(weights[i], Catch::Matchers::WithinRel(expected[i], tol));
}
}
}
Expand Down

0 comments on commit b3aaf1e

Please sign in to comment.