From f7e8efe01d9ce7d8e1d2b0db33b86e84146cbb05 Mon Sep 17 00:00:00 2001 From: Peter Chang Date: Thu, 12 Oct 2023 12:50:22 -0400 Subject: [PATCH 1/3] Fixing model version check for Active reduction --- .../core/include/vw/core/reductions/active.h | 12 ++++++++++-- vowpalwabbit/core/include/vw/core/vw_versions.h | 2 ++ vowpalwabbit/core/src/reductions/active.cc | 16 ++++++++-------- 3 files changed, 20 insertions(+), 10 deletions(-) diff --git a/vowpalwabbit/core/include/vw/core/reductions/active.h b/vowpalwabbit/core/include/vw/core/reductions/active.h index dfca91718c3..0bff31db2d0 100644 --- a/vowpalwabbit/core/include/vw/core/reductions/active.h +++ b/vowpalwabbit/core/include/vw/core/reductions/active.h @@ -16,10 +16,18 @@ namespace reductions class active { public: - active(float active_c0, VW::workspace* all) : active_c0(active_c0), _all(all) {} + active(float active_c0, std::shared_ptr shared_data, std::shared_ptr random_state, + VW::version_struct model_version) + : active_c0(active_c0) + , _shared_data(shared_data) + , _random_state(std::move(random_state)) + , _model_version{std::move(model_version)} + { + } float active_c0; - VW::workspace* _all = nullptr; + std::shared_ptr _shared_data; // statistics, loss + std::shared_ptr _random_state; float _min_seen_label = 0.f; float _max_seen_label = 1.f; diff --git a/vowpalwabbit/core/include/vw/core/vw_versions.h b/vowpalwabbit/core/include/vw/core/vw_versions.h index 05b9a2dc9aa..3ede0811ed9 100644 --- a/vowpalwabbit/core/include/vw/core/vw_versions.h +++ b/vowpalwabbit/core/include/vw/core/vw_versions.h @@ -59,6 +59,8 @@ constexpr VW::version_struct VERSION_PASS_UINT64{8, 3, 3}; /// Added serialized seen min and max labels in the --active reduction constexpr VW::version_struct VERSION_FILE_WITH_ACTIVE_SEEN_LABELS{9, 0, 0}; +/// Active seen labels was accidentally reverted out in 9.4.0 +constexpr VW::version_struct VERSION_FILE_WITH_ACTIVE_SEEN_LABELS_FIXED{9, 10, 0}; /// Moved option values from command line to model data constexpr VW::version_struct VERSION_FILE_WITH_L1_AND_L2_STATE_IN_MODEL_DATA{9, 0, 0}; diff --git a/vowpalwabbit/core/src/reductions/active.cc b/vowpalwabbit/core/src/reductions/active.cc index 1f55a34ef87..86e52a939b5 100644 --- a/vowpalwabbit/core/src/reductions/active.cc +++ b/vowpalwabbit/core/src/reductions/active.cc @@ -50,13 +50,13 @@ float query_decision(const active& a, float ec_revert_weight, float k) if (k <= 1.f) { bias = 1.f; } else { - const auto weighted_queries = static_cast(a._all->sd->weighted_labeled_examples); - const float avg_loss = (static_cast(a._all->sd->sum_loss) / k) + + const auto weighted_queries = static_cast(a._shared_data->weighted_labeled_examples); + const float avg_loss = (static_cast(a._shared_data->sum_loss) / k) + std::sqrt((1.f + 0.5f * std::log(k)) / (weighted_queries + 0.0001f)); bias = get_active_coin_bias(k, avg_loss, ec_revert_weight / k, a.active_c0); } - return (a._all->get_random_state()->get_and_update_random() < bias) ? 1.f / bias : -1.f; + return (a._random_state->get_and_update_random() < bias) ? 1.f / bias : -1.f; } template @@ -66,7 +66,7 @@ void predict_or_learn_simulation(active& a, learner& base, VW::example& ec) if (is_learn) { - const auto k = static_cast(a._all->sd->t); + const auto k = static_cast(a._shared_data->t); constexpr float threshold = 0.f; ec.confidence = fabsf(ec.pred.scalar - threshold) / base.sensitivity(ec); @@ -74,7 +74,7 @@ void predict_or_learn_simulation(active& a, learner& base, VW::example& ec) if (importance > 0.f) { - a._all->sd->queries += 1; + a._shared_data->queries += 1; ec.weight *= importance; base.learn(ec); } @@ -94,7 +94,7 @@ void predict_or_learn_active(active& a, learner& base, VW::example& ec) if (ec.l.simple.label == FLT_MAX) { - const float threshold = (a._all->sd->max_label + a._all->sd->min_label) * 0.5f; + const float threshold = (a._shared_data->max_label + a._shared_data->min_label) * 0.5f; // We want to understand the change in prediction if the label were to be // the opposite of what was predicted. 0 and 1 are used for the expected min // and max labels to be coming in from the active interactor. @@ -130,7 +130,7 @@ void active_print_result( void save_load(active& a, VW::io_buf& io, bool read, bool text) { if (io.num_files() == 0) { return; } - if (a._model_version >= VW::version_definitions::VERSION_FILE_WITH_ACTIVE_SEEN_LABELS) + if (a._model_version >= VW::version_definitions::VERSION_FILE_WITH_ACTIVE_SEEN_LABELS_FIXED) { if (read) { @@ -195,7 +195,7 @@ std::shared_ptr VW::reductions::active_setup(VW::setup_bas if (!options.add_parse_and_check_necessary(new_options)) { return nullptr; } if (options.was_supplied("lda")) { THROW("lda cannot be combined with active learning") } - auto data = VW::make_unique(active_c0, &all); + auto data = VW::make_unique(active_c0, all.sd, all.get_random_state(), all.runtime_state.model_file_ver); auto base = require_singleline(stack_builder.setup_base_learner()); using learn_pred_func_t = void (*)(active&, VW::LEARNER::learner&, VW::example&); From 527ad32b25a3f46190619247742e71d6440d62ba Mon Sep 17 00:00:00 2001 From: Peter Chang Date: Thu, 12 Oct 2023 13:27:57 -0400 Subject: [PATCH 2/3] adding additional check to account for models generated between 9.0 and 9.4 --- vowpalwabbit/core/include/vw/core/vw_versions.h | 1 + vowpalwabbit/core/src/reductions/active.cc | 11 ++++++++++- 2 files changed, 11 insertions(+), 1 deletion(-) diff --git a/vowpalwabbit/core/include/vw/core/vw_versions.h b/vowpalwabbit/core/include/vw/core/vw_versions.h index 3ede0811ed9..77b2429970e 100644 --- a/vowpalwabbit/core/include/vw/core/vw_versions.h +++ b/vowpalwabbit/core/include/vw/core/vw_versions.h @@ -60,6 +60,7 @@ constexpr VW::version_struct VERSION_PASS_UINT64{8, 3, 3}; /// Added serialized seen min and max labels in the --active reduction constexpr VW::version_struct VERSION_FILE_WITH_ACTIVE_SEEN_LABELS{9, 0, 0}; /// Active seen labels was accidentally reverted out in 9.4.0 +constexpr VW::version_struct VERSION_FILE_WITH_ACTIVE_SEEN_LABELS_REVERTED{9, 4, 0}; constexpr VW::version_struct VERSION_FILE_WITH_ACTIVE_SEEN_LABELS_FIXED{9, 10, 0}; /// Moved option values from command line to model data diff --git a/vowpalwabbit/core/src/reductions/active.cc b/vowpalwabbit/core/src/reductions/active.cc index 86e52a939b5..64c4e8fb2fa 100644 --- a/vowpalwabbit/core/src/reductions/active.cc +++ b/vowpalwabbit/core/src/reductions/active.cc @@ -129,8 +129,17 @@ void active_print_result( void save_load(active& a, VW::io_buf& io, bool read, bool text) { + using namespace VW::version_definitions; if (io.num_files() == 0) { return; } - if (a._model_version >= VW::version_definitions::VERSION_FILE_WITH_ACTIVE_SEEN_LABELS_FIXED) + // This code is valid if version is within + // [VERSION_FILE_WITH_ACTIVE_SEEN_LABELS, VERSION_FILE_WITH_ACTIVE_SEEN_LABELS_REVERTED) + // or >= VERSION_FILE_WITH_ACTIVE_SEEN_LABELS_FIXED + if ( + ( a._model_version >= VERSION_FILE_WITH_ACTIVE_SEEN_LABELS + && + a._model_version < VERSION_FILE_WITH_ACTIVE_SEEN_LABELS_REVERTED) + || + a._model_version >= VERSION_FILE_WITH_ACTIVE_SEEN_LABELS_FIXED) { if (read) { From e038514f122d751f04bed3c2dbee1ecbb0e3190f Mon Sep 17 00:00:00 2001 From: Peter Chang Date: Thu, 12 Oct 2023 13:37:52 -0400 Subject: [PATCH 3/3] lint --- vowpalwabbit/core/src/reductions/active.cc | 9 +++------ 1 file changed, 3 insertions(+), 6 deletions(-) diff --git a/vowpalwabbit/core/src/reductions/active.cc b/vowpalwabbit/core/src/reductions/active.cc index 64c4e8fb2fa..ea8b66c40e4 100644 --- a/vowpalwabbit/core/src/reductions/active.cc +++ b/vowpalwabbit/core/src/reductions/active.cc @@ -134,12 +134,9 @@ void save_load(active& a, VW::io_buf& io, bool read, bool text) // This code is valid if version is within // [VERSION_FILE_WITH_ACTIVE_SEEN_LABELS, VERSION_FILE_WITH_ACTIVE_SEEN_LABELS_REVERTED) // or >= VERSION_FILE_WITH_ACTIVE_SEEN_LABELS_FIXED - if ( - ( a._model_version >= VERSION_FILE_WITH_ACTIVE_SEEN_LABELS - && - a._model_version < VERSION_FILE_WITH_ACTIVE_SEEN_LABELS_REVERTED) - || - a._model_version >= VERSION_FILE_WITH_ACTIVE_SEEN_LABELS_FIXED) + if ((a._model_version >= VERSION_FILE_WITH_ACTIVE_SEEN_LABELS && + a._model_version < VERSION_FILE_WITH_ACTIVE_SEEN_LABELS_REVERTED) || + a._model_version >= VERSION_FILE_WITH_ACTIVE_SEEN_LABELS_FIXED) { if (read) {