Skip to content

Commit

Permalink
feat: Support cancellation in HashProbe::getOutput [3/n] (facebookinc…
Browse files Browse the repository at this point in the history
…ubator#12237)

Summary:

Hashprobe::getOutput can take a long time to run for degenerate query shapes.

Reviewed By: Yuhta

Differential Revision: D68997164
  • Loading branch information
yuandagits authored and facebook-github-bot committed Feb 4, 2025
1 parent 48d1bbf commit d560915
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 0 deletions.
7 changes: 7 additions & 0 deletions velox/exec/HashProbe.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1078,6 +1078,13 @@ RowVectorPtr HashProbe::getOutputInternal(bool toSpillOutput) {
initBuffer<char*>(outputTableRows_, outputTableRowsCapacity_, pool());

for (;;) {
// If the task owning this operator has been cancelled, there is no point
// to continue executing this procedure, which may be long in degenerate
// cases. Exit the working loop and let the Driver handle exiting gracefully
// in its own loop.
if (operatorCtx_->task()->isCancelled()) {
return nullptr;
}
int numOut = 0;

if (emptyBuildSide) {
Expand Down
27 changes: 27 additions & 0 deletions velox/exec/tests/HashJoinTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -401,6 +401,11 @@ class HashJoinBuilder {
return *this;
}

HashJoinBuilder& injectTaskCancellation(bool injectTaskCancellation) {
injectTaskCancellation_ = injectTaskCancellation;
return *this;
}

HashJoinBuilder& maxSpillLevel(int32_t maxSpillLevel) {
maxSpillLevel_ = maxSpillLevel;
return *this;
Expand Down Expand Up @@ -665,6 +670,11 @@ class HashJoinBuilder {
memory::spillMemoryPool()->stats().peakBytes;
TestScopedSpillInjection scopedSpillInjection(spillPct);
auto task = builder.assertResults(referenceQuery_);

if (injectTaskCancellation_) {
task->requestCancel();
}

// Wait up to 5 seconds for all the task background activities to complete.
// Then we can collect the stats from all the operators.
//
Expand Down Expand Up @@ -758,6 +768,7 @@ class HashJoinBuilder {
std::optional<bool> runParallelProbe_;
std::optional<bool> runParallelBuild_;

bool injectTaskCancellation_{false};
bool injectSpill_{true};
// If not set, then the test will run the test with different settings:
// 0, 2.
Expand Down Expand Up @@ -1020,6 +1031,22 @@ TEST_P(MultiThreadedHashJoinTest, outOfJoinKeyColumnOrder) {
.run();
}

TEST_P(MultiThreadedHashJoinTest, joinWithCancellation) {
HashJoinBuilder(*pool_, duckDbQueryRunner_, driverExecutor_.get())
.numDrivers(numDrivers_)
.keyTypes({BIGINT()})
.probeVectors(1600, 5)
.buildVectors(1500, 5)
.injectTaskCancellation(true)
.referenceQuery(
"SELECT t_k0, t_data, u_k0, u_data FROM t, u WHERE t.t_k0 = u.u_k0")
.verifier([&](const std::shared_ptr<Task>& task, bool /*unused*/) {
auto stats = task->taskStats();
EXPECT_GT(stats.terminationTimeMs, 0);
})
.run();
}

TEST_P(MultiThreadedHashJoinTest, emptyBuild) {
const std::vector<bool> finishOnEmptys = {false, true};
for (const auto finishOnEmpty : finishOnEmptys) {
Expand Down

0 comments on commit d560915

Please sign in to comment.