Skip to content

Commit

Permalink
Support BROADCAST join strategy
Browse files Browse the repository at this point in the history
  • Loading branch information
Jackie-Jiang committed Jan 11, 2025
1 parent 863e05f commit 996e7a5
Show file tree
Hide file tree
Showing 12 changed files with 139 additions and 24 deletions.
1 change: 1 addition & 0 deletions pinot-common/src/main/proto/plan.proto
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,7 @@ enum JoinType {
enum JoinStrategy {
HASH = 0;
LOOKUP = 1;
BROADCAST = 2;
}

message JoinNode {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
*/
package org.apache.pinot.calcite.rel.hint;

import javax.annotation.Nullable;
import org.apache.calcite.rel.RelNode;
import org.apache.calcite.rel.core.Join;
import org.apache.calcite.rel.hint.RelHint;
Expand Down Expand Up @@ -70,6 +71,8 @@ public static class JoinHintOptions {
public static final String DYNAMIC_BROADCAST_JOIN_STRATEGY = "dynamic_broadcast";
// "lookup" can be used when the right table is a dimension table replicated to all workers
public static final String LOOKUP_JOIN_STRATEGY = "lookup";
// "broadcast" can be used when the right table is small enough to be broadcasted to all workers
public static final String BROADCAST_JOIN_STRATEGY = "broadcast";

/**
* Max rows allowed to build the right table hash collection.
Expand All @@ -93,11 +96,27 @@ public static class JoinHintOptions {
*/
public static final String APPEND_DISTINCT_TO_SEMI_JOIN_PROJECT = "append_distinct_to_semi_join_project";

@Nullable
public static String getJoinStrategyHint(Join join) {
return PinotHintStrategyTable.getHintOption(join.getHints(), PinotHintOptions.JOIN_HINT_OPTIONS,
PinotHintOptions.JoinHintOptions.JOIN_STRATEGY);
}

public static boolean useLookupJoinStrategy(@Nullable String joinStrategyHint) {
return LOOKUP_JOIN_STRATEGY.equalsIgnoreCase(joinStrategyHint);
}

// TODO: Consider adding a Join implementation with join strategy.
public static boolean useLookupJoinStrategy(Join join) {
return LOOKUP_JOIN_STRATEGY.equalsIgnoreCase(
PinotHintStrategyTable.getHintOption(join.getHints(), PinotHintOptions.JOIN_HINT_OPTIONS,
PinotHintOptions.JoinHintOptions.JOIN_STRATEGY));
return useLookupJoinStrategy(getJoinStrategyHint(join));
}

public static boolean useBroadcastJoinStrategy(@Nullable String joinStrategyHint) {
return BROADCAST_JOIN_STRATEGY.equalsIgnoreCase(joinStrategyHint);
}

public static boolean useBroadcastJoinStrategy(Join join) {
return useBroadcastJoinStrategy(getJoinStrategyHint(join));
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,10 +54,15 @@ public void onMatch(RelOptRuleCall call) {
JoinInfo joinInfo = join.analyzeCondition();
RelNode newLeft;
RelNode newRight;
if (PinotHintOptions.JoinHintOptions.useLookupJoinStrategy(join)) {
String joinStrategyHint = PinotHintOptions.JoinHintOptions.getJoinStrategyHint(join);
if (PinotHintOptions.JoinHintOptions.useLookupJoinStrategy(joinStrategyHint)) {
// Lookup join - add local exchange on the left side
newLeft = PinotLogicalExchange.create(left, RelDistributions.SINGLETON);
newRight = right;
} else if (PinotHintOptions.JoinHintOptions.useBroadcastJoinStrategy(joinStrategyHint)) {
// Broadcast join - add local exchange on the left side, broadcast exchange on the right side
newLeft = PinotLogicalExchange.create(left, RelDistributions.SINGLETON);
newRight = PinotLogicalExchange.create(right, RelDistributions.BROADCAST_DISTRIBUTED);
} else {
// Regular join - add exchange on both sides
if (joinInfo.leftKeys.isEmpty()) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -313,7 +313,8 @@ private JoinNode convertLogicalJoin(LogicalJoin join) {

// Check if the join hint specifies the join strategy
JoinNode.JoinStrategy joinStrategy;
if (PinotHintOptions.JoinHintOptions.useLookupJoinStrategy(join)) {
String joinStrategyHint = PinotHintOptions.JoinHintOptions.getJoinStrategyHint(join);
if (PinotHintOptions.JoinHintOptions.useLookupJoinStrategy(joinStrategyHint)) {
joinStrategy = JoinNode.JoinStrategy.LOOKUP;

// Run some validations for lookup join
Expand All @@ -333,6 +334,8 @@ private JoinNode convertLogicalJoin(LogicalJoin join) {
Preconditions.checkState(projectInput instanceof TableScan,
"Right input for lookup join must be a Project over TableScan, got Project over: %s",
projectInput.getClass().getSimpleName());
} else if (PinotHintOptions.JoinHintOptions.useBroadcastJoinStrategy(joinStrategyHint)) {
joinStrategy = JoinNode.JoinStrategy.BROADCAST;
} else {
// TODO: Consider adding DYNAMIC_BROADCAST as a separate join strategy
joinStrategy = JoinNode.JoinStrategy.HASH;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,16 @@ public int hashCode() {
}

public enum JoinStrategy {
HASH, LOOKUP
// HASH is the default equi-join strategy, where both left and right tables are hash partitioned on join keys, then
// shuffled to the same worker to perform the join.
HASH,

// LOOKUP join strategy can be used for equi-join when the right table is a dimension table replicated to all
// workers. It looks up the in-memory pre-materialized right table to perform the join.
LOOKUP,

// BROADCAST join strategy can be used when the right table is small enough to be broadcasted to all workers of the
// left table.
BROADCAST
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -295,6 +295,8 @@ private static JoinNode.JoinStrategy convertJoinStrategy(Plan.JoinStrategy joinS
return JoinNode.JoinStrategy.HASH;
case LOOKUP:
return JoinNode.JoinStrategy.LOOKUP;
case BROADCAST:
return JoinNode.JoinStrategy.BROADCAST;
default:
throw new IllegalStateException("Unsupported JoinStrategy: " + joinStrategy);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -300,6 +300,8 @@ private static Plan.JoinStrategy convertJoinStrategy(JoinNode.JoinStrategy joinS
return Plan.JoinStrategy.HASH;
case LOOKUP:
return Plan.JoinStrategy.LOOKUP;
case BROADCAST:
return Plan.JoinStrategy.BROADCAST;
default:
throw new IllegalStateException("Unsupported JoinStrategy: " + joinStrategy);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -97,10 +97,13 @@ private void assignWorkersToNonRootFragment(PlanFragment fragment, DispatchableP
Map<Integer, DispatchablePlanMetadata> metadataMap = context.getDispatchablePlanMetadataMap();
DispatchablePlanMetadata metadata = metadataMap.get(fragment.getFragmentId());
boolean leafPlan = isLeafPlan(metadata);
if (isLocalExchange(children)) {
// If it is a local exchange (single child with SINGLETON distribution), use the same worker assignment to avoid
Integer childIdWithLocalExchange = findLocalExchange(children);
if (childIdWithLocalExchange != null) {
// If there is a local exchange (child with SINGLETON distribution), use the same worker assignment to avoid
// shuffling data.
// TODO: Support partition parallelism
// TODO:
// 1. Support partition parallelism
// 2. Check if there are conflicts (multiple children with different local exchange)
DispatchablePlanMetadata childMetadata = metadataMap.get(children.get(0).getFragmentId());
metadata.setWorkerIdToServerInstanceMap(childMetadata.getWorkerIdToServerInstanceMap());
metadata.setPartitionFunction(childMetadata.getPartitionFunction());
Expand All @@ -121,13 +124,21 @@ private void assignWorkersToNonRootFragment(PlanFragment fragment, DispatchableP
}
}

private boolean isLocalExchange(List<PlanFragment> children) {
if (children.size() != 1) {
return false;
/**
* Returns the index of the child fragment that has a local exchange (SINGLETON distribution), or {@code null} if none
* exists.
*/
@Nullable
private Integer findLocalExchange(List<PlanFragment> children) {
int numChildren = children.size();
for (int i = 0; i < numChildren; i++) {
PlanNode childPlanNode = children.get(i).getFragmentRoot();
if (childPlanNode instanceof MailboxSendNode
&& ((MailboxSendNode) childPlanNode).getDistributionType() == RelDistribution.Type.SINGLETON) {
return i;
}
}
PlanNode childPlanNode = children.get(0).getFragmentRoot();
return childPlanNode instanceof MailboxSendNode
&& ((MailboxSendNode) childPlanNode).getDistributionType() == RelDistribution.Type.SINGLETON;
return null;
}

private static boolean isLeafPlan(DispatchablePlanMetadata metadata) {
Expand Down
55 changes: 55 additions & 0 deletions pinot-query-planner/src/test/resources/queries/JoinPlans.json
Original file line number Diff line number Diff line change
Expand Up @@ -758,6 +758,61 @@
}
]
},
"broadcast_join_planning_tests": {
"queries": [
{
"description": "Simple broadcast join",
"sql": "EXPLAIN PLAN FOR SELECT /*+ joinOptions(join_strategy = 'broadcast') */ a.col1, b.col2 FROM a JOIN b ON a.col1 = b.col1",
"output": [
"Execution Plan",
"\nLogicalProject(col1=[$0], col2=[$2])",
"\n LogicalJoin(condition=[=($0, $1)], joinType=[inner])",
"\n PinotLogicalExchange(distribution=[single])",
"\n LogicalProject(col1=[$0])",
"\n LogicalTableScan(table=[[default, a]])",
"\n PinotLogicalExchange(distribution=[broadcast])",
"\n LogicalProject(col1=[$0], col2=[$1])",
"\n LogicalTableScan(table=[[default, b]])",
"\n"
]
},
{
"description": "Broadcast join with filter on both left and right table",
"sql": "EXPLAIN PLAN FOR SELECT /*+ joinOptions(join_strategy = 'broadcast') */ a.col1, b.col2 FROM a JOIN b ON a.col1 = b.col1 WHERE a.col2 = 'foo' AND b.col2 = 'bar'",
"output": [
"Execution Plan",
"\nLogicalProject(col1=[$0], col2=[$2])",
"\n LogicalJoin(condition=[=($0, $1)], joinType=[inner])",
"\n PinotLogicalExchange(distribution=[single])",
"\n LogicalProject(col1=[$0])",
"\n LogicalFilter(condition=[=($1, _UTF-8'foo')])",
"\n LogicalTableScan(table=[[default, a]])",
"\n PinotLogicalExchange(distribution=[broadcast])",
"\n LogicalProject(col1=[$0], col2=[$1])",
"\n LogicalFilter(condition=[=($1, _UTF-8'bar')])",
"\n LogicalTableScan(table=[[default, b]])",
"\n"
]
},
{
"description": "Broadcast join with transformation on both left and right table joined key",
"sql": "EXPLAIN PLAN FOR SELECT /*+ joinOptions(join_strategy = 'broadcast') */ a.col1, b.col2 FROM a JOIN b ON upper(a.col1) = upper(b.col1)",
"output": [
"Execution Plan",
"\nLogicalProject(col1=[$0], col2=[$2])",
"\n LogicalJoin(condition=[=($1, $3)], joinType=[inner])",
"\n PinotLogicalExchange(distribution=[single])",
"\n LogicalProject(col1=[$0], $f8=[UPPER($0)])",
"\n LogicalTableScan(table=[[default, a]])",
"\n PinotLogicalExchange(distribution=[broadcast])",
"\n LogicalProject(col2=[$1], $f8=[UPPER($0)])",
"\n LogicalTableScan(table=[[default, b]])",
"\n"
]
}
]

},
"exception_throwing_join_planning_tests": {
"queries": [
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -138,11 +138,11 @@ public ObjectNode visitFilter(FilterNode node, Void context) {

@Override
public ObjectNode visitJoin(JoinNode node, Void context) {
if (node.getJoinStrategy() == JoinNode.JoinStrategy.HASH) {
return recursiveCase(node, MultiStageOperator.Type.HASH_JOIN);
} else {
assert node.getJoinStrategy() == JoinNode.JoinStrategy.LOOKUP;
if (node.getJoinStrategy() == JoinNode.JoinStrategy.LOOKUP) {
return recursiveCase(node, MultiStageOperator.Type.LOOKUP_JOIN);
} else {
// TODO: Consider renaming this operator type. It handles multiple join strategies.
return recursiveCase(node, MultiStageOperator.Type.HASH_JOIN);
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -178,12 +178,11 @@ public MultiStageOperator visitJoin(JoinNode node, OpChainExecutionContext conte
MultiStageOperator leftOperator = visit(left, context);
PlanNode right = inputs.get(1);
MultiStageOperator rightOperator = visit(right, context);
JoinNode.JoinStrategy joinStrategy = node.getJoinStrategy();
if (joinStrategy == JoinNode.JoinStrategy.HASH) {
return new HashJoinOperator(context, leftOperator, left.getDataSchema(), rightOperator, node);
} else {
assert joinStrategy == JoinNode.JoinStrategy.LOOKUP;
if (node.getJoinStrategy() == JoinNode.JoinStrategy.LOOKUP) {
return new LookupJoinOperator(context, leftOperator, rightOperator, node);
} else {
// TODO: Consider renaming this operator. It handles multiple join strategies.
return new HashJoinOperator(context, leftOperator, left.getDataSchema(), rightOperator, node);
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,14 @@
"description": "Colocated JOIN with partition column and group by non-partitioned column with stage parallelism",
"sql": "SET stageParallelism=2; SELECT {tbl1}.name, SUM({tbl2}.num) FROM {tbl1} /*+ tableOptions(partition_function='hashcode', partition_key='num', partition_size='4') */ JOIN {tbl2} /*+ tableOptions(partition_function='hashcode', partition_key='num', partition_size='4') */ ON {tbl1}.num = {tbl2}.num GROUP BY {tbl1}.name"
},
{
"description": "Broadcast JOIN without partition hint",
"sql": "SELECT /*+ joinOptions(join_strategy='broadcast') */ {tbl1}.num, {tbl1}.name, {tbl2}.num, {tbl2}.val FROM {tbl1} JOIN {tbl2} ON {tbl1}.num = {tbl2}.num"
},
{
"description": "Broadcast JOIN with partition hint",
"sql": "SELECT /*+ joinOptions(join_strategy='broadcast') */ {tbl1}.num, {tbl1}.name, {tbl2}.num, {tbl2}.val FROM {tbl1} /*+ tableOptions(partition_function='hashcode', partition_key='num', partition_size='4') */ JOIN {tbl2} ON {tbl1}.num = {tbl2}.num"
},
{
"description": "Colocated, Dynamic broadcast SEMI-JOIN with partition column",
"sql": "SELECT /*+ joinOptions(join_strategy='dynamic_broadcast') */ {tbl1}.num, {tbl1}.name FROM {tbl1} /*+ tableOptions(partition_function='hashcode', partition_key='num', partition_size='4') */ WHERE {tbl1}.num IN (SELECT {tbl2}.num FROM {tbl2} /*+ tableOptions(partition_function='hashcode', partition_key='num', partition_size='4') */ WHERE {tbl2}.val IN ('xxx', 'yyy'))"
Expand Down

0 comments on commit 996e7a5

Please sign in to comment.