Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor: [sparse_weights] get for predict #4651

Merged
merged 25 commits into from
Nov 17, 2023
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions vowpalwabbit/core/include/vw/core/array_parameters_dense.h
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,9 @@ class dense_parameters
inline const VW::weight& operator[](size_t i) const { return _begin.get()[i & _weight_mask]; }
inline VW::weight& operator[](size_t i) { return _begin.get()[i & _weight_mask]; }

inline const VW::weight& get(size_t i) const { return _begin.get()[i & _weight_mask]; }
inline VW::weight& get(size_t i) { return _begin.get()[i & _weight_mask]; }

VW_ATTR(nodiscard) static dense_parameters shallow_copy(const dense_parameters& input);
VW_ATTR(nodiscard) static dense_parameters deep_copy(const dense_parameters& input);

Expand Down
9 changes: 6 additions & 3 deletions vowpalwabbit/core/include/vw/core/array_parameters_sparse.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@

#include <cstddef>
#include <functional>
#include <map>
#include <memory>
#include <unordered_map>

namespace VW
{
Expand All @@ -20,7 +20,7 @@ class sparse_parameters;
namespace details
{

using weight_map = std::unordered_map<uint64_t, std::shared_ptr<VW::weight>>;
using weight_map = std::map<uint64_t, std::shared_ptr<VW::weight>>;
bassmang marked this conversation as resolved.
Show resolved Hide resolved

template <typename T>
class sparse_iterator
Expand Down Expand Up @@ -82,9 +82,11 @@ class sparse_parameters
const_iterator cend() const { return const_iterator(_map.end()); }

inline VW::weight& operator[](size_t i) { return *(get_or_default_and_get(i)); }

inline const VW::weight& operator[](size_t i) const { return *(get_or_default_and_get(i)); }

inline VW::weight& get(size_t i) { return *(get_impl(i)); };
inline const VW::weight& get(size_t i) const { return *(get_impl(i)); };

inline VW::weight& strided_index(size_t index) { return operator[](index << _stride_shift); }
inline const VW::weight& strided_index(size_t index) const { return operator[](index << _stride_shift); }

Expand Down Expand Up @@ -119,6 +121,7 @@ class sparse_parameters
// It is marked const so it can be used from both const and non const operator[]
// The map itself is mutable to facilitate this
VW::weight* get_or_default_and_get(size_t i) const;
VW::weight* get_impl(size_t i) const;
};
} // namespace VW
using sparse_parameters VW_DEPRECATED("sparse_parameters moved into VW namespace") = VW::sparse_parameters;
4 changes: 2 additions & 2 deletions vowpalwabbit/core/include/vw/core/gd_predict.h
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ inline void foreach_feature(WeightsT& weights, const VW::features& fs, DataT& da
{
for (const auto& f : fs)
{
VW::weight& w = weights[(f.index() + offset)];
VW::weight& w = weights.get(f.index() + offset);
FuncT(dat, mult * f.value(), w);
}
}
Expand All @@ -46,7 +46,7 @@ template <class DataT, void (*FuncT)(DataT&, float, float), class WeightsT>
inline void foreach_feature(
const WeightsT& weights, const VW::features& fs, DataT& dat, uint64_t offset = 0, float mult = 1.)
{
for (const auto& f : fs) { FuncT(dat, mult * f.value(), weights[static_cast<size_t>(f.index() + offset)]); }
for (const auto& f : fs) { FuncT(dat, mult * f.value(), weights.get(static_cast<size_t>(f.index() + offset))); }
}

template <class DataT, class WeightOrIndexT, void (*FuncT)(DataT&, float, WeightOrIndexT),
Expand Down
4 changes: 2 additions & 2 deletions vowpalwabbit/core/include/vw/core/interactions_predict.h
Original file line number Diff line number Diff line change
Expand Up @@ -37,13 +37,13 @@ const static VW::audit_strings EMPTY_AUDIT_STRINGS;
template <class DataT, void (*FuncT)(DataT&, const float, float&), class WeightsT>
inline void call_func_t(DataT& dat, WeightsT& weights, const float ft_value, const uint64_t ft_idx)
{
FuncT(dat, ft_value, weights[ft_idx]);
FuncT(dat, ft_value, weights.get(ft_idx));
}

template <class DataT, void (*FuncT)(DataT&, const float, float), class WeightsT>
inline void call_func_t(DataT& dat, const WeightsT& weights, const float ft_value, const uint64_t ft_idx)
{
FuncT(dat, ft_value, weights[static_cast<size_t>(ft_idx)]);
FuncT(dat, ft_value, weights.get(static_cast<size_t>(ft_idx)));
}

template <class DataT, void (*FuncT)(DataT&, float, uint64_t), class WeightsT>
Expand Down
15 changes: 15 additions & 0 deletions vowpalwabbit/core/src/array_parameters_sparse.cc
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,21 @@ VW::weight* VW::sparse_parameters::get_or_default_and_get(size_t i) const
return iter->second.get();
}

VW::weight* VW::sparse_parameters::get_impl(size_t i) const
{
static VW::weight default_value = 0.0f;

uint64_t index = i & _weight_mask;
auto iter = _map.find(index);
rajan-chari marked this conversation as resolved.
Show resolved Hide resolved
if (iter == _map.end())
{
if (_default_func != nullptr) { _default_func(&default_value, index); }
return &default_value;
}

return iter->second.get();
}

VW::sparse_parameters::sparse_parameters(size_t length, uint32_t stride_shift)
: _weight_mask((length << stride_shift) - 1), _stride_shift(stride_shift), _default_func(nullptr)
{
Expand Down
1 change: 1 addition & 0 deletions vowpalwabbit/core/src/reductions/cb/cb_explore_adf_rnd.cc
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,7 @@ class lazy_gaussian
{
public:
inline float operator[](uint64_t index) const { return VW::details::merand48_boxmuller(index); }
inline float get(uint64_t index) const { return VW::details::merand48_boxmuller(index); }
bassmang marked this conversation as resolved.
Show resolved Hide resolved
};

inline void vec_add_with_norm(std::pair<float, float>& p, float fx, float fw)
Expand Down
Loading