From 80e832fb3ab30bf092fae43ec898fd307b8a50c0 Mon Sep 17 00:00:00 2001 From: beygel Date: Thu, 7 Mar 2024 08:55:02 -0500 Subject: [PATCH] feat: direct interface for active.cc and variable rename for understandability (#4671) * updated active.cc * updates to active.cc * updates to tests * revert accidental help change in diff * removed diagnostic print statements --------- Co-authored-by: Alina Beygelzimer Co-authored-by: Alexey Taymanov <41013086+ataymano@users.noreply.github.com> --- .../ref/active-simulation.t24.stderr | 19 ++--- test/train-sets/ref/help.stdout | 8 +- vowpalwabbit/core/src/reductions/active.cc | 83 +++++++++++++++---- 3 files changed, 78 insertions(+), 32 deletions(-) diff --git a/test/train-sets/ref/active-simulation.t24.stderr b/test/train-sets/ref/active-simulation.t24.stderr index 8394160e69a..29f2786913e 100644 --- a/test/train-sets/ref/active-simulation.t24.stderr +++ b/test/train-sets/ref/active-simulation.t24.stderr @@ -11,20 +11,13 @@ Output pred = SCALAR average since example example current current current loss last counter weight label predict features 1.000000 1.000000 1 1.0 -1.0000 0.0000 128 -0.791125 0.755288 2 6.8 -1.0000 -0.1309 44 -1.274829 1.444750 8 26.3 1.0000 -0.2020 34 -1.083985 0.895011 73 52.8 1.0000 0.0214 21 -0.887295 0.693362 130 106.3 -1.0000 -0.3071 146 -0.788245 0.690009 233 213.6 -1.0000 0.0421 47 -0.664628 0.541195 398 427.4 -1.0000 -0.1863 68 -0.634406 0.604328 835 856.9 -1.0000 -0.4327 40 finished run number of examples = 1000 -weighted example sum = 1014.004519 -weighted label sum = -68.618036 -average loss = 0.630964 -best constant = -0.067670 -best constant's loss = 0.995421 +weighted example sum = 1.000000 +weighted label sum = -1.000000 +average loss = 1.000000 +best constant = -1.000000 +best constant's loss = 0.000000 total feature number = 78739 -total queries = 474 +total queries = 1 diff --git a/test/train-sets/ref/help.stdout b/test/train-sets/ref/help.stdout index b9d4fca2f7b..96601833d2e 100644 --- a/test/train-sets/ref/help.stdout +++ b/test/train-sets/ref/help.stdout @@ -221,8 +221,12 @@ Weight Options: [Reduction] Active Learning Options: --active Enable active learning (type: bool, keep, necessary) --simulation Active learning simulation mode (type: bool) - --mellowness arg Active learning mellowness parameter c_0. Default 8 (type: float, - default: 8, keep) + --direct Active learning via the tag and predictions interface. Tag should + start with "query?" to get query decision. Returned prediction + is either -1 for no or the importance weight for yes. (type: + bool) + --mellowness arg Active learning mellowness parameter c_0. Default 1. (type: float, + default: 1, keep) [Reduction] Active Learning with Cover Options: --active_cover Enable active learning with cover (type: bool, keep, necessary) --mellowness arg Active learning mellowness parameter c_0 (type: float, default: diff --git a/vowpalwabbit/core/src/reductions/active.cc b/vowpalwabbit/core/src/reductions/active.cc index ea8b66c40e4..a7449affde2 100644 --- a/vowpalwabbit/core/src/reductions/active.cc +++ b/vowpalwabbit/core/src/reductions/active.cc @@ -31,31 +31,41 @@ using namespace VW::config; using namespace VW::reductions; namespace { -float get_active_coin_bias(float k, float avg_loss, float g, float c0) -{ - const float b = c0 * (std::log(k + 1.f) + 0.0001f) / (k + 0.0001f); - const float sb = std::sqrt(b); +float get_active_coin_bias(float example_count, float avg_loss, float alt_label_error_rate_diff, float mellowness) +{//implementation follows https://web.archive.org/web/20120525164352/http://books.nips.cc/papers/files/nips23/NIPS2010_0363.pdf + const float mellow_log_e_count_over_e_count = mellowness * (std::log(example_count + 1.f) + 0.0001f) / (example_count + 0.0001f); + const float sqrt_mellow_lecoec = std::sqrt(mellow_log_e_count_over_e_count); // loss should be in [0,1] avg_loss = VW::math::clamp(avg_loss, 0.f, 1.f); - const float sl = std::sqrt(avg_loss) + std::sqrt(avg_loss + g); - if (g <= sb * sl + b) { return 1; } - const float rs = (sl + std::sqrt(sl * sl + 4 * g)) / (2 * g); - return b * rs * rs; + const float sqrt_avg_loss_plus_sqrt_alt_loss = std::min(1.f, //std::sqrt(avg_loss) + // commented out because two square roots appears to conservative. + std::sqrt(avg_loss + alt_label_error_rate_diff));//emperical variance deflater. + //std::cout << "example_count = " << example_count << " avg_loss = " << avg_loss << " alt_label_error_rate_diff = " << alt_label_error_rate_diff << " mellowness = " << mellowness << " mlecoc = " << mellow_log_e_count_over_e_count + // << " sqrt_mellow_lecoec = " << sqrt_mellow_lecoec << " double sqrt = " << sqrt_avg_loss_plus_sqrt_alt_loss << std::endl; + + if (alt_label_error_rate_diff <= sqrt_mellow_lecoec * sqrt_avg_loss_plus_sqrt_alt_loss//deflater in use. + + mellow_log_e_count_over_e_count) { return 1; } + //old equation + // const float rs = (sqrt_avg_loss_plus_sqrt_alt_loss + std::sqrt(sqrt_avg_loss_plus_sqrt_alt_loss * sqrt_avg_loss_plus_sqrt_alt_loss + 4 * alt_label_error_rate_diff)) / (2 * alt_label_error_rate_diff); + // return mellow_log_e_count_over_e_count * rs * rs; + const float sqrt_s = (sqrt_mellow_lecoec + std::sqrt(mellow_log_e_count_over_e_count+4*alt_label_error_rate_diff*mellow_log_e_count_over_e_count)) / 2*alt_label_error_rate_diff; + // std::cout << "sqrt_s = " << sqrt_s << std::endl; + return sqrt_s*sqrt_s; } -float query_decision(const active& a, float ec_revert_weight, float k) +float query_decision(const active& a, float updates_to_change_prediction, float example_count) { float bias; - if (k <= 1.f) { bias = 1.f; } + if (example_count <= 1.f) { bias = 1.f; } else { - const auto weighted_queries = static_cast(a._shared_data->weighted_labeled_examples); - const float avg_loss = (static_cast(a._shared_data->sum_loss) / k) + - std::sqrt((1.f + 0.5f * std::log(k)) / (weighted_queries + 0.0001f)); - bias = get_active_coin_bias(k, avg_loss, ec_revert_weight / k, a.active_c0); + // const auto weighted_queries = static_cast(a._shared_data->weighted_labeled_examples); + const float avg_loss = (static_cast(a._shared_data->sum_loss) / example_count); + //+ std::sqrt((1.f + 0.5f * std::log(example_count)) / (weighted_queries + 0.0001f)); Commented this out, not following why we need it from the theory. + // std::cout << "avg_loss = " << avg_loss << " weighted_queries = " << weighted_queries << " sum_loss = " << a._shared_data->sum_loss << " example_count = " << example_count << std::endl; + bias = get_active_coin_bias(example_count, avg_loss, updates_to_change_prediction / example_count, a.active_c0); } - + // std::cout << "bias = " << bias << std::endl; return (a._random_state->get_and_update_random() < bias) ? 1.f / bias : -1.f; } @@ -110,6 +120,34 @@ void predict_or_learn_active(active& a, learner& base, VW::example& ec) } } +template +void predict_or_learn_active_direct(active& a, learner& base, VW::example& ec) +{ + if (is_learn) { base.learn(ec); } + else { base.predict(ec); } + + if (ec.l.simple.label == FLT_MAX) + { + if (std::string(ec.tag.begin(), ec.tag.begin()+6) == "query?") + { + const float threshold = (a._shared_data->max_label + a._shared_data->min_label) * 0.5f; + // We want to understand the change in prediction if the label were to be + // the opposite of what was predicted. 0 and 1 are used for the expected min + // and max labels to be coming in from the active interactor. + ec.l.simple.label = (ec.pred.scalar >= threshold) ? a._min_seen_label : a._max_seen_label; + ec.confidence = std::abs(ec.pred.scalar - threshold) / base.sensitivity(ec); + ec.l.simple.label = FLT_MAX; + ec.pred.scalar = query_decision(a, ec.confidence, static_cast(a._shared_data->weighted_unlabeled_examples)); + } + } + else + { + // Update seen labels based on the current example's label. + a._min_seen_label = std::min(ec.l.simple.label, a._min_seen_label); + a._max_seen_label = std::max(ec.l.simple.label, a._max_seen_label); + } +} + void active_print_result( VW::io::writer* f, float res, float weight, const VW::v_array& tag, VW::io::logger& logger) { @@ -189,14 +227,16 @@ std::shared_ptr VW::reductions::active_setup(VW::setup_bas bool active_option = false; bool simulation = false; + bool direct = false; float active_c0; option_group_definition new_options("[Reduction] Active Learning"); new_options.add(make_option("active", active_option).keep().necessary().help("Enable active learning")) .add(make_option("simulation", simulation).help("Active learning simulation mode")) + .add(make_option("direct", direct).help("Active learning via the tag and predictions interface. Tag should start with \"query?\" to get query decision. Returned prediction is either -1 for no or the importance weight for yes.")) .add(make_option("mellowness", active_c0) .keep() - .default_value(8.f) - .help("Active learning mellowness parameter c_0. Default 8")); + .default_value(1.f) + .help("Active learning mellowness parameter c_0. Default 1.")); if (!options.add_parse_and_check_necessary(new_options)) { return nullptr; } @@ -223,6 +263,15 @@ std::shared_ptr VW::reductions::active_setup(VW::setup_bas print_update_func = VW::details::print_update_simple_label; reduction_name.append("-simulation"); } + else if (direct) + { + learn_func = predict_or_learn_active_direct; + pred_func = predict_or_learn_active_direct; + update_stats_func = update_stats_active; + output_example_prediction_func = VW::details::output_example_prediction_simple_label; + print_update_func = VW::details::print_update_simple_label; + learn_returns_prediction = base->learn_returns_prediction; + } else { all.reduction_state.active = true;