Skip to content

Commit

Permalink
refactor: [sparse_weights] use std::map and get for predict
Browse files Browse the repository at this point in the history
  • Loading branch information
bassmang committed Oct 13, 2023
1 parent d0f6bc3 commit 261d47e
Show file tree
Hide file tree
Showing 6 changed files with 30 additions and 8 deletions.
3 changes: 3 additions & 0 deletions vowpalwabbit/core/include/vw/core/array_parameters_dense.h
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,9 @@ class dense_parameters
inline const VW::weight& operator[](size_t i) const { return _begin.get()[i & _weight_mask]; }
inline VW::weight& operator[](size_t i) { return _begin.get()[i & _weight_mask]; }

inline const VW::weight& get(size_t i) const { return _begin.get()[i & _weight_mask]; }
inline VW::weight& get(size_t i) { return _begin.get()[i & _weight_mask]; }

VW_ATTR(nodiscard) static dense_parameters shallow_copy(const dense_parameters& input);
VW_ATTR(nodiscard) static dense_parameters deep_copy(const dense_parameters& input);

Expand Down
11 changes: 7 additions & 4 deletions vowpalwabbit/core/include/vw/core/array_parameters_sparse.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
#include <cstddef>
#include <functional>
#include <memory>
#include <unordered_map>
#include <map>

namespace VW
{
Expand All @@ -20,7 +20,7 @@ class sparse_parameters;
namespace details
{

using weight_map = std::unordered_map<uint64_t, std::shared_ptr<VW::weight>>;
using weight_map = std::map<uint64_t, std::shared_ptr<VW::weight>>;

template <typename T>
class sparse_iterator
Expand Down Expand Up @@ -82,9 +82,11 @@ class sparse_parameters
const_iterator cend() const { return const_iterator(_map.end()); }

inline VW::weight& operator[](size_t i) { return *(get_or_default_and_get(i)); }

inline const VW::weight& operator[](size_t i) const { return *(get_or_default_and_get(i)); }

inline VW::weight& get(size_t i) { return *(get_impl(i)); };
inline const VW::weight& get(size_t i) const { return *(get_impl(i)); };

inline VW::weight& strided_index(size_t index) { return operator[](index << _stride_shift); }
inline const VW::weight& strided_index(size_t index) const { return operator[](index << _stride_shift); }

Expand All @@ -109,7 +111,7 @@ class sparse_parameters
void share(size_t /* length */);
#endif

private:
public:
// This must be mutable because the const operator[] must be able to intialize default weights to return.
mutable details::weight_map _map;
uint64_t _weight_mask; // (stride*(1 << num_bits) -1)
Expand All @@ -119,6 +121,7 @@ class sparse_parameters
// It is marked const so it can be used from both const and non const operator[]
// The map itself is mutable to facilitate this
VW::weight* get_or_default_and_get(size_t i) const;
VW::weight* get_impl(size_t i) const;
};
} // namespace VW
using sparse_parameters VW_DEPRECATED("sparse_parameters moved into VW namespace") = VW::sparse_parameters;
4 changes: 2 additions & 2 deletions vowpalwabbit/core/include/vw/core/gd_predict.h
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ inline void foreach_feature(WeightsT& weights, const VW::features& fs, DataT& da
{
for (const auto& f : fs)
{
VW::weight& w = weights[(f.index() + offset)];
VW::weight& w = weights.get(f.index() + offset);
FuncT(dat, mult * f.value(), w);
}
}
Expand All @@ -46,7 +46,7 @@ template <class DataT, void (*FuncT)(DataT&, float, float), class WeightsT>
inline void foreach_feature(
const WeightsT& weights, const VW::features& fs, DataT& dat, uint64_t offset = 0, float mult = 1.)
{
for (const auto& f : fs) { FuncT(dat, mult * f.value(), weights[static_cast<size_t>(f.index() + offset)]); }
for (const auto& f : fs) { FuncT(dat, mult * f.value(), weights.get(static_cast<size_t>(f.index() + offset))); }
}

template <class DataT, class WeightOrIndexT, void (*FuncT)(DataT&, float, WeightOrIndexT),
Expand Down
4 changes: 2 additions & 2 deletions vowpalwabbit/core/include/vw/core/interactions_predict.h
Original file line number Diff line number Diff line change
Expand Up @@ -37,13 +37,13 @@ const static VW::audit_strings EMPTY_AUDIT_STRINGS;
template <class DataT, void (*FuncT)(DataT&, const float, float&), class WeightsT>
inline void call_func_t(DataT& dat, WeightsT& weights, const float ft_value, const uint64_t ft_idx)
{
FuncT(dat, ft_value, weights[ft_idx]);
FuncT(dat, ft_value, weights.get(ft_idx));
}

template <class DataT, void (*FuncT)(DataT&, const float, float), class WeightsT>
inline void call_func_t(DataT& dat, const WeightsT& weights, const float ft_value, const uint64_t ft_idx)
{
FuncT(dat, ft_value, weights[static_cast<size_t>(ft_idx)]);
FuncT(dat, ft_value, weights.get(static_cast<size_t>(ft_idx)));
}

template <class DataT, void (*FuncT)(DataT&, float, uint64_t), class WeightsT>
Expand Down
15 changes: 15 additions & 0 deletions vowpalwabbit/core/src/array_parameters_sparse.cc
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,21 @@ VW::weight* VW::sparse_parameters::get_or_default_and_get(size_t i) const
return iter->second.get();
}

VW::weight* sparse_parameters::get_impl(size_t i) const {
static VW::weight default_value = 0.0f;

uint64_t index = i & _weight_mask;
auto iter = _map.find(index);
if (iter == _map.end()) {
if (_default_func != nullptr) {
_default_func(&default_value, index);
}
return &default_value;
}

return iter->second.get();
}

VW::sparse_parameters::sparse_parameters(size_t length, uint32_t stride_shift)
: _weight_mask((length << stride_shift) - 1), _stride_shift(stride_shift), _default_func(nullptr)
{
Expand Down
1 change: 1 addition & 0 deletions vowpalwabbit/core/src/reductions/cb/cb_explore_adf_rnd.cc
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,7 @@ class lazy_gaussian
{
public:
inline float operator[](uint64_t index) const { return VW::details::merand48_boxmuller(index); }
inline float get(uint64_t index) const { return VW::details::merand48_boxmuller(index); }
};

inline void vec_add_with_norm(std::pair<float, float>& p, float fx, float fw)
Expand Down

0 comments on commit 261d47e

Please sign in to comment.