Skip to content

Commit

Permalink
Merge branch 'master' into feat_not_self_consistent_speed_up
Browse files Browse the repository at this point in the history
  • Loading branch information
ataymano authored Jan 23, 2024
2 parents b5037d0 + b4612f8 commit 22b8e70
Show file tree
Hide file tree
Showing 2 changed files with 45 additions and 7 deletions.
11 changes: 5 additions & 6 deletions vowpalwabbit/core/src/reductions/eigen_memory_tree.cc
Original file line number Diff line number Diff line change
Expand Up @@ -736,13 +736,14 @@ void node_split(emt_tree& b, emt_node& cn)
cn.examples.clear();
}

void node_insert(emt_node& cn, std::unique_ptr<emt_example> ex)
void node_insert(emt_tree& b, emt_node& cn, std::unique_ptr<emt_example> ex)
{
for (auto& cn_ex : cn.examples)
{
if (cn_ex->full == ex->full) { return; }
}
cn.examples.push_back(std::move(ex));
tree_bound(b, cn.examples.back().get());
}

emt_example* node_pick(emt_tree& b, learner& base, emt_node& cn, const emt_example& ex)
Expand Down Expand Up @@ -774,16 +775,15 @@ void node_predict(emt_tree& b, learner& base, emt_node& cn, emt_example& ex, VW:
auto* closest_ex = node_pick(b, base, cn, ex);
ec.pred.multiclass = (closest_ex != nullptr) ? closest_ex->label : 0;
ec.loss = (ec.l.multi.label != ec.pred.multiclass) ? ec.weight : 0;
if (closest_ex != nullptr) { tree_bound(b, closest_ex); }
}

void emt_predict(emt_tree& b, learner& base, VW::example& ec)
{
b.all->feature_tweaks_config.ignore_some_linear = false;
emt_example ex(*b.all, &ec);

emt_node& cn = *tree_route(b, ex);
node_predict(b, base, cn, ex, ec);
tree_bound(b, &ex);
}

void emt_learn(emt_tree& b, learner& base, VW::example& ec)
Expand All @@ -792,10 +792,9 @@ void emt_learn(emt_tree& b, learner& base, VW::example& ec)
auto ex = VW::make_unique<emt_example>(*b.all, &ec);

emt_node& cn = *tree_route(b, *ex);
scorer_learn(b, base, cn, *ex, ec.weight);
node_predict(b, base, cn, *ex, ec); // vw learners predict and emt_learn
tree_bound(b, ex.get());
node_insert(cn, std::move(ex));
scorer_learn(b, base, cn, *ex, ec.weight);
node_insert(b, cn, std::move(ex));
node_split(b, cn);
}

Expand Down
41 changes: 40 additions & 1 deletion vowpalwabbit/core/tests/eigen_memory_tree_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ TEST(EigenMemoryTree, ExactMatchWithRouterTest)
}
}

TEST(EigenMemoryTree, Bounding)
TEST(EigenMemoryTree, BoundingDrop)
{
auto vw = VW::initialize(vwtest::make_args("--quiet", "--emt", "--emt_tree", "5"));
auto* tree = get_emt_tree(*vw);
Expand All @@ -148,6 +148,45 @@ TEST(EigenMemoryTree, Bounding)
EXPECT_EQ(tree->root->router_weights.size(), 0);
}

TEST(EigenMemoryTree, BoundingPredict)
{
auto vw = VW::initialize(vwtest::make_args("--quiet", "--emt", "--emt_tree", "3"));
auto* tree = get_emt_tree(*vw);

auto* ex = VW::read_example(*vw, "1 | 1");
vw->predict(*ex);
vw->finish_example(*ex);

EXPECT_EQ(tree->bounder->list.size(), 0);
}

TEST(EigenMemoryTree, BoundingRecency)
{
auto vw = VW::initialize(vwtest::make_args("--quiet", "--emt", "--emt_tree", "3"));
auto* tree = get_emt_tree(*vw);

for (int i = 0; i < 3; i++)
{
auto* ex = VW::read_example(*vw, std::to_string(i) + " | " + std::to_string(i));
vw->learn(*ex);
vw->finish_example(*ex);
}

EXPECT_EQ((*tree->bounder->list.begin())->base[0].first, 2);

auto* ex1 = VW::read_example(*vw, "1 | 1");
vw->predict(*ex1);
vw->finish_example(*ex1);

EXPECT_EQ((*tree->bounder->list.begin())->base[0].first, 1);

auto* ex2 = VW::read_example(*vw, "1 | 0");
vw->predict(*ex2);
vw->finish_example(*ex2);

EXPECT_EQ((*tree->bounder->list.begin())->base[0].first, 0);
}

TEST(EigenMemoryTree, Split)
{
auto args = vwtest::make_args("--quiet", "--emt", "--emt_tree", "10", "--emt_leaf", "3");
Expand Down

0 comments on commit 22b8e70

Please sign in to comment.