Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: direct interface for active.cc and variable rename for understandability #4671

Merged
merged 11 commits into from
Mar 7, 2024
19 changes: 6 additions & 13 deletions test/train-sets/ref/active-simulation.t24.stderr
Original file line number Diff line number Diff line change
Expand Up @@ -11,20 +11,13 @@ Output pred = SCALAR
average since example example current current current
loss last counter weight label predict features
1.000000 1.000000 1 1.0 -1.0000 0.0000 128
0.791125 0.755288 2 6.8 -1.0000 -0.1309 44
1.274829 1.444750 8 26.3 1.0000 -0.2020 34
1.083985 0.895011 73 52.8 1.0000 0.0214 21
0.887295 0.693362 130 106.3 -1.0000 -0.3071 146
0.788245 0.690009 233 213.6 -1.0000 0.0421 47
0.664628 0.541195 398 427.4 -1.0000 -0.1863 68
0.634406 0.604328 835 856.9 -1.0000 -0.4327 40

finished run
number of examples = 1000
weighted example sum = 1014.004519
weighted label sum = -68.618036
average loss = 0.630964
best constant = -0.067670
best constant's loss = 0.995421
weighted example sum = 1.000000
weighted label sum = -1.000000
average loss = 1.000000
best constant = -1.000000
best constant's loss = 0.000000
total feature number = 78739
total queries = 474
total queries = 1
8 changes: 6 additions & 2 deletions test/train-sets/ref/help.stdout
Original file line number Diff line number Diff line change
Expand Up @@ -221,8 +221,12 @@ Weight Options:
[Reduction] Active Learning Options:
--active Enable active learning (type: bool, keep, necessary)
--simulation Active learning simulation mode (type: bool)
--mellowness arg Active learning mellowness parameter c_0. Default 8 (type: float,
default: 8, keep)
--direct Active learning via the tag and predictions interface. Tag should
start with "query?" to get query decision. Returned prediction
is either -1 for no or the importance weight for yes. (type:
bool)
--mellowness arg Active learning mellowness parameter c_0. Default 1. (type: float,
default: 1, keep)
[Reduction] Active Learning with Cover Options:
--active_cover Enable active learning with cover (type: bool, keep, necessary)
--mellowness arg Active learning mellowness parameter c_0 (type: float, default:
Expand Down
86 changes: 69 additions & 17 deletions vowpalwabbit/core/src/reductions/active.cc
Original file line number Diff line number Diff line change
Expand Up @@ -31,31 +31,41 @@
using namespace VW::reductions;
namespace
{
float get_active_coin_bias(float k, float avg_loss, float g, float c0)
{
const float b = c0 * (std::log(k + 1.f) + 0.0001f) / (k + 0.0001f);
const float sb = std::sqrt(b);
float get_active_coin_bias(float example_count, float avg_loss, float alt_label_error_rate_diff, float mellowness)
{//implementation follows https://web.archive.org/web/20120525164352/http://books.nips.cc/papers/files/nips23/NIPS2010_0363.pdf
const float mellow_log_e_count_over_e_count = mellowness * (std::log(example_count + 1.f) + 0.0001f) / (example_count + 0.0001f);
const float sqrt_mellow_lecoec = std::sqrt(mellow_log_e_count_over_e_count);
// loss should be in [0,1]
avg_loss = VW::math::clamp(avg_loss, 0.f, 1.f);

const float sl = std::sqrt(avg_loss) + std::sqrt(avg_loss + g);
if (g <= sb * sl + b) { return 1; }
const float rs = (sl + std::sqrt(sl * sl + 4 * g)) / (2 * g);
return b * rs * rs;
const float sqrt_avg_loss_plus_sqrt_alt_loss = std::min(1.f, //std::sqrt(avg_loss) + // commented out because two square roots appears to conservative.
std::sqrt(avg_loss + alt_label_error_rate_diff));//emperical variance deflater.
//std::cout << "example_count = " << example_count << " avg_loss = " << avg_loss << " alt_label_error_rate_diff = " << alt_label_error_rate_diff << " mellowness = " << mellowness << " mlecoc = " << mellow_log_e_count_over_e_count
// << " sqrt_mellow_lecoec = " << sqrt_mellow_lecoec << " double sqrt = " << sqrt_avg_loss_plus_sqrt_alt_loss << std::endl;

if (alt_label_error_rate_diff <= sqrt_mellow_lecoec * sqrt_avg_loss_plus_sqrt_alt_loss//deflater in use.
+ mellow_log_e_count_over_e_count) { return 1; }
//old equation
// const float rs = (sqrt_avg_loss_plus_sqrt_alt_loss + std::sqrt(sqrt_avg_loss_plus_sqrt_alt_loss * sqrt_avg_loss_plus_sqrt_alt_loss + 4 * alt_label_error_rate_diff)) / (2 * alt_label_error_rate_diff);
// return mellow_log_e_count_over_e_count * rs * rs;
const float sqrt_s = (sqrt_mellow_lecoec + std::sqrt(mellow_log_e_count_over_e_count+4*alt_label_error_rate_diff*mellow_log_e_count_over_e_count)) / 2*alt_label_error_rate_diff;
// std::cout << "sqrt_s = " << sqrt_s << std::endl;
return sqrt_s*sqrt_s;
}

float query_decision(const active& a, float ec_revert_weight, float k)
float query_decision(const active& a, float updates_to_change_prediction, float example_count)
{
float bias;
if (k <= 1.f) { bias = 1.f; }
if (example_count <= 1.f) { bias = 1.f; }
else
{
const auto weighted_queries = static_cast<float>(a._shared_data->weighted_labeled_examples);
const float avg_loss = (static_cast<float>(a._shared_data->sum_loss) / k) +
std::sqrt((1.f + 0.5f * std::log(k)) / (weighted_queries + 0.0001f));
bias = get_active_coin_bias(k, avg_loss, ec_revert_weight / k, a.active_c0);
// const auto weighted_queries = static_cast<float>(a._shared_data->weighted_labeled_examples);
const float avg_loss = (static_cast<float>(a._shared_data->sum_loss) / example_count);
//+ std::sqrt((1.f + 0.5f * std::log(example_count)) / (weighted_queries + 0.0001f)); Commented this out, not following why we need it from the theory.
// std::cout << "avg_loss = " << avg_loss << " weighted_queries = " << weighted_queries << " sum_loss = " << a._shared_data->sum_loss << " example_count = " << example_count << std::endl;
bias = get_active_coin_bias(example_count, avg_loss, updates_to_change_prediction / example_count, a.active_c0);
}

// std::cout << "bias = " << bias << std::endl;
return (a._random_state->get_and_update_random() < bias) ? 1.f / bias : -1.f;
}

Expand Down Expand Up @@ -110,6 +120,37 @@
}
}

template <bool is_learn>
void predict_or_learn_active_direct(active& a, learner& base, VW::example& ec)

Check warning on line 124 in vowpalwabbit/core/src/reductions/active.cc

View check run for this annotation

Codecov / codecov/patch

vowpalwabbit/core/src/reductions/active.cc#L124

Added line #L124 was not covered by tests
{
if (is_learn) { base.learn(ec); }
else { base.predict(ec); }

Check warning on line 127 in vowpalwabbit/core/src/reductions/active.cc

View check run for this annotation

Codecov / codecov/patch

vowpalwabbit/core/src/reductions/active.cc#L126-L127

Added lines #L126 - L127 were not covered by tests

if (ec.l.simple.label == FLT_MAX)

Check warning on line 129 in vowpalwabbit/core/src/reductions/active.cc

View check run for this annotation

Codecov / codecov/patch

vowpalwabbit/core/src/reductions/active.cc#L129

Added line #L129 was not covered by tests
{
std::cout << "test label" << std::endl;
if (std::string(ec.tag.begin(), ec.tag.begin()+6) == "query?")

Check warning on line 132 in vowpalwabbit/core/src/reductions/active.cc

View check run for this annotation

Codecov / codecov/patch

vowpalwabbit/core/src/reductions/active.cc#L131-L132

Added lines #L131 - L132 were not covered by tests
{
const float threshold = (a._shared_data->max_label + a._shared_data->min_label) * 0.5f;

Check warning on line 134 in vowpalwabbit/core/src/reductions/active.cc

View check run for this annotation

Codecov / codecov/patch

vowpalwabbit/core/src/reductions/active.cc#L134

Added line #L134 was not covered by tests
// We want to understand the change in prediction if the label were to be
// the opposite of what was predicted. 0 and 1 are used for the expected min
// and max labels to be coming in from the active interactor.
ec.l.simple.label = (ec.pred.scalar >= threshold) ? a._min_seen_label : a._max_seen_label;
ec.confidence = std::abs(ec.pred.scalar - threshold) / base.sensitivity(ec);
std::cout << "query, threshold = " << threshold << " label = " << ec.l.simple.label << " prediction = " << ec.pred.scalar << " confidence = " << ec.confidence << " wue = " << a._shared_data->weighted_unlabeled_examples << std::endl;
ec.l.simple.label = FLT_MAX;
ec.pred.scalar = query_decision(a, ec.confidence, static_cast<float>(a._shared_data->weighted_unlabeled_examples));
std::cout << "query decision = " << ec.pred.scalar << std::endl;

Check warning on line 143 in vowpalwabbit/core/src/reductions/active.cc

View check run for this annotation

Codecov / codecov/patch

vowpalwabbit/core/src/reductions/active.cc#L138-L143

Added lines #L138 - L143 were not covered by tests
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let me double check my understanding: these prints are not diagnostic logs but essential output for this mode, right?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

They were diagnostic logs. I've just removed them, triggering an update to the pull request. Can you please take a look now? Thanks.

}
}
else
{
// Update seen labels based on the current example's label.
a._min_seen_label = std::min(ec.l.simple.label, a._min_seen_label);
a._max_seen_label = std::max(ec.l.simple.label, a._max_seen_label);

Check warning on line 150 in vowpalwabbit/core/src/reductions/active.cc

View check run for this annotation

Codecov / codecov/patch

vowpalwabbit/core/src/reductions/active.cc#L149-L150

Added lines #L149 - L150 were not covered by tests
}
}

Check warning on line 152 in vowpalwabbit/core/src/reductions/active.cc

View check run for this annotation

Codecov / codecov/patch

vowpalwabbit/core/src/reductions/active.cc#L152

Added line #L152 was not covered by tests

void active_print_result(
VW::io::writer* f, float res, float weight, const VW::v_array<char>& tag, VW::io::logger& logger)
{
Expand Down Expand Up @@ -189,14 +230,16 @@

bool active_option = false;
bool simulation = false;
bool direct = false;
float active_c0;
option_group_definition new_options("[Reduction] Active Learning");
new_options.add(make_option("active", active_option).keep().necessary().help("Enable active learning"))
.add(make_option("simulation", simulation).help("Active learning simulation mode"))
.add(make_option("direct", direct).help("Active learning via the tag and predictions interface. Tag should start with \"query?\" to get query decision. Returned prediction is either -1 for no or the importance weight for yes."))
.add(make_option("mellowness", active_c0)
.keep()
.default_value(8.f)
.help("Active learning mellowness parameter c_0. Default 8"));
.default_value(1.f)
.help("Active learning mellowness parameter c_0. Default 1."));

if (!options.add_parse_and_check_necessary(new_options)) { return nullptr; }

Expand All @@ -223,6 +266,15 @@
print_update_func = VW::details::print_update_simple_label<active>;
reduction_name.append("-simulation");
}
else if (direct)
{
learn_func = predict_or_learn_active_direct<true>;
pred_func = predict_or_learn_active_direct<false>;
update_stats_func = update_stats_active;
output_example_prediction_func = VW::details::output_example_prediction_simple_label<active>;
print_update_func = VW::details::print_update_simple_label<active>;
learn_returns_prediction = base->learn_returns_prediction;

Check warning on line 276 in vowpalwabbit/core/src/reductions/active.cc

View check run for this annotation

Codecov / codecov/patch

vowpalwabbit/core/src/reductions/active.cc#L271-L276

Added lines #L271 - L276 were not covered by tests
}
else
{
all.reduction_state.active = true;
Expand Down
Loading