Skip to content

Commit

Permalink
Merge remote-tracking branch 'origin/master' into group_by_trim_v2
Browse files Browse the repository at this point in the history
# Conflicts:
#	pinot-query-planner/src/main/java/org/apache/pinot/calcite/rel/logical/PinotLogicalAggregate.java
#	pinot-query-planner/src/main/java/org/apache/pinot/calcite/rel/rules/PinotAggregateExchangeNodeInsertRule.java
#	pinot-query-planner/src/main/java/org/apache/pinot/query/planner/plannode/AggregateNode.java
#	pinot-query-planner/src/test/resources/queries/GroupByPlans.json
#	pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/plan/server/ServerPlanRequestVisitor.java
#	pinot-query-runtime/src/test/resources/queries/QueryHints.json
  • Loading branch information
bziobrowski committed Dec 31, 2024
2 parents b603ad5 + 9b96068 commit a1e3675
Show file tree
Hide file tree
Showing 13 changed files with 379 additions and 133 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -411,7 +411,8 @@ private void replicaGroupBasedMinimumMovement(Map<Integer, List<InstanceConfig>>
for (int replicaGroupId = 0; replicaGroupId < numReplicaGroups; replicaGroupId++) {
List<String> instancesInReplicaGroup = replicaGroupIdToInstancesMap.get(replicaGroupId);
if (replicaGroupId < existingNumReplicaGroups) {
int maxNumPartitionsPerInstance = (numInstancesPerReplicaGroup + numPartitions - 1) / numPartitions;
int maxNumPartitionsPerInstance =
(numPartitions + numInstancesPerReplicaGroup - 1) / numInstancesPerReplicaGroup;
Map<String, Integer> instanceToNumPartitionsMap =
Maps.newHashMapWithExpectedSize(numInstancesPerReplicaGroup);
for (String instance : instancesInReplicaGroup) {
Expand Down

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
Expand Up @@ -18,16 +18,17 @@
*/
package org.apache.pinot.core.query.optimizer.filter;

import com.google.common.collect.Maps;
import java.util.ArrayList;
import java.util.Collection;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import javax.annotation.Nullable;
import org.apache.pinot.common.request.Expression;
import org.apache.pinot.common.request.ExpressionType;
import org.apache.pinot.common.request.Function;
import org.apache.pinot.common.request.context.RequestContextUtils;
import org.apache.pinot.common.utils.request.RequestUtils;
import org.apache.pinot.spi.data.Schema;
import org.apache.pinot.sql.FilterKind;
Expand Down Expand Up @@ -61,9 +62,10 @@ private Expression optimize(Expression filterExpression) {
String operator = function.getOperator();
if (operator.equals(FilterKind.OR.name())) {
List<Expression> children = function.getOperands();
Map<Expression, Set<Expression>> valuesMap = new HashMap<>();
List<Expression> newChildren = new ArrayList<>();
boolean recreateFilter = false;
// Key is the lhs of the EQ/IN predicate, value is the map from string representation of the value to the value
Map<Expression, Map<String, Expression>> valuesMap = new HashMap<>();
List<Expression> newChildren = new ArrayList<>(children.size());
boolean[] recreateFilter = new boolean[1];

// Iterate over all the child filters to merge EQ and IN predicates
for (Expression child : children) {
Expand All @@ -80,52 +82,62 @@ private Expression optimize(Expression filterExpression) {
List<Expression> operands = childFunction.getOperands();
Expression lhs = operands.get(0);
Expression value = operands.get(1);
Set<Expression> values = valuesMap.get(lhs);
if (values == null) {
values = new HashSet<>();
values.add(value);
valuesMap.put(lhs, values);
} else {
values.add(value);
// Recreate filter when multiple predicates can be merged
recreateFilter = true;
}
// Use string value to de-duplicate the values to prevent the overhead of Expression.hashCode(). This is
// consistent with how server handles predicates.
String stringValue = RequestContextUtils.getStringValue(value);
valuesMap.compute(lhs, (k, v) -> {
if (v == null) {
Map<String, Expression> values = new HashMap<>();
values.put(stringValue, value);
return values;
} else {
v.put(stringValue, value);
// Recreate filter when multiple predicates can be merged
recreateFilter[0] = true;
return v;
}
});
} else if (childOperator.equals(FilterKind.IN.name())) {
List<Expression> operands = childFunction.getOperands();
Expression lhs = operands.get(0);
Set<Expression> inPredicateValuesSet = new HashSet<>();
int numOperands = operands.size();
for (int i = 1; i < numOperands; i++) {
inPredicateValuesSet.add(operands.get(i));
}
int numUniqueValues = inPredicateValuesSet.size();
if (numUniqueValues == 1 || numUniqueValues != numOperands - 1) {
// Recreate filter when the IN predicate contains only 1 value (can be rewritten to EQ predicate),
// or values can be de-duplicated
recreateFilter = true;
}
Set<Expression> values = valuesMap.get(lhs);
if (values == null) {
valuesMap.put(lhs, inPredicateValuesSet);
} else {
values.addAll(inPredicateValuesSet);
// Recreate filter when multiple predicates can be merged
recreateFilter = true;
}
valuesMap.compute(lhs, (k, v) -> {
if (v == null) {
Map<String, Expression> values = getInValues(operands);
int numUniqueValues = values.size();
if (numUniqueValues == 1 || numUniqueValues != operands.size() - 1) {
// Recreate filter when the IN predicate contains only 1 value (can be rewritten to EQ predicate), or
// values can be de-duplicated
recreateFilter[0] = true;
}
return values;
} else {
int numOperands = operands.size();
for (int i = 1; i < numOperands; i++) {
Expression value = operands.get(i);
// Use string value to de-duplicate the values to prevent the overhead of Expression.hashCode(). This
// is consistent with how server handles predicates.
String stringValue = RequestContextUtils.getStringValue(value);
v.put(stringValue, value);
}
// Recreate filter when multiple predicates can be merged
recreateFilter[0] = true;
return v;
}
});
} else {
newChildren.add(child);
}
}
}

if (recreateFilter) {
if (recreateFilter[0]) {
if (newChildren.isEmpty() && valuesMap.size() == 1) {
// Single range without other filters
Map.Entry<Expression, Set<Expression>> entry = valuesMap.entrySet().iterator().next();
return getFilterExpression(entry.getKey(), entry.getValue());
Map.Entry<Expression, Map<String, Expression>> entry = valuesMap.entrySet().iterator().next();
return getFilterExpression(entry.getKey(), entry.getValue().values());
} else {
for (Map.Entry<Expression, Set<Expression>> entry : valuesMap.entrySet()) {
newChildren.add(getFilterExpression(entry.getKey(), entry.getValue()));
for (Map.Entry<Expression, Map<String, Expression>> entry : valuesMap.entrySet()) {
newChildren.add(getFilterExpression(entry.getKey(), entry.getValue().values()));
}
function.setOperands(newChildren);
return filterExpression;
Expand All @@ -138,17 +150,12 @@ private Expression optimize(Expression filterExpression) {
return filterExpression;
} else if (operator.equals(FilterKind.IN.name())) {
List<Expression> operands = function.getOperands();
Expression lhs = operands.get(0);
Set<Expression> values = new HashSet<>();
int numOperands = operands.size();
for (int i = 1; i < numOperands; i++) {
values.add(operands.get(i));
}
Map<String, Expression> values = getInValues(operands);
int numUniqueValues = values.size();
if (numUniqueValues == 1 || numUniqueValues != numOperands - 1) {
// Recreate filter when the IN predicate contains only 1 value (can be rewritten to EQ predicate), or values
// can be de-duplicated
return getFilterExpression(lhs, values);
if (numUniqueValues == 1 || numUniqueValues != operands.size() - 1) {
// Recreate filter when the IN predicate contains only 1 value (can be rewritten to EQ predicate), or values can
// be de-duplicated
return getFilterExpression(operands.get(0), values.values());
} else {
return filterExpression;
}
Expand All @@ -157,10 +164,27 @@ private Expression optimize(Expression filterExpression) {
}
}

/**
* Helper method to get the values from the IN predicate. Returns a map from string representation of the value to the
* value.
*/
private Map<String, Expression> getInValues(List<Expression> operands) {
int numOperands = operands.size();
Map<String, Expression> values = Maps.newHashMapWithExpectedSize(numOperands - 1);
for (int i = 1; i < numOperands; i++) {
Expression value = operands.get(i);
// Use string value to de-duplicate the values to prevent the overhead of Expression.hashCode(). This is
// consistent with how server handles predicates.
String stringValue = RequestContextUtils.getStringValue(value);
values.put(stringValue, value);
}
return values;
}

/**
* Helper method to construct a EQ or IN predicate filter Expression from the given lhs and values.
*/
private static Expression getFilterExpression(Expression lhs, Set<Expression> values) {
private static Expression getFilterExpression(Expression lhs, Collection<Expression> values) {
int numValues = values.size();
if (numValues == 1) {
return RequestUtils.getFunctionExpression(FilterKind.EQUALS.name(), lhs, values.iterator().next());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,8 @@ private PinotHintOptions() {
public static class AggregateOptions {
public static final String IS_PARTITIONED_BY_GROUP_BY_KEYS = "is_partitioned_by_group_by_keys";
public static final String IS_LEAF_RETURN_FINAL_RESULT = "is_leaf_return_final_result";
public static final String SKIP_LEAF_STAGE_GROUP_BY_AGGREGATION = "is_skip_leaf_stage_group_by";
public static final String IS_SKIP_LEAF_STAGE_GROUP_BY = "is_skip_leaf_stage_group_by";
public static final String IS_ENABLE_GROUP_TRIM = "is_enable_group_trim";

/** Enables trimming of aggregation intermediate results by pushing down order by and limit to leaf stage. */
public static final String ENABLE_GROUP_TRIM = "is_enable_group_trim";
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ public class PinotLogicalAggregate extends Aggregate {
private final AggType _aggType;
private final boolean _leafReturnFinalResult;

// The following fields are for group trimming purpose, and are extracted from the Sort on top of this Aggregate.
// The following fields are set when group trim is enabled, and are extracted from the Sort on top of this Aggregate.
private final List<RelFieldCollation> _collations;
private final int _limit;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,28 @@
* - COUNT(*) with a GROUP_BY_KEY transforms into: COUNT(*)__LEAF --> COUNT(*)__FINAL, where
* - COUNT(*)__LEAF produces TUPLE[ SUM(1), GROUP_BY_KEY ]
* - COUNT(*)__FINAL produces TUPLE[ SUM(COUNT(*)__LEAF), GROUP_BY_KEY ]
*
* There are 3 sub-rules:
* 1. {@link SortProjectAggregate}:
* Matches the case when there's a Sort on top of Project on top of Aggregate, and enable group trim hint is present.
* E.g.
* SELECT /*+ aggOptions(is_enable_group_trim='true') * /
* COUNT(*) AS cnt, col1 FROM myTable GROUP BY col1 ORDER BY cnt DESC LIMIT 10
* It will extract the collations and limit from the Sort node, and set them into the Aggregate node. It works only
* when the sort key is a direct reference to the input, i.e. no transform on the input columns.
* 2. {@link SortAggregate}:
* Matches the case when there's a Sort on top of Aggregate, and enable group trim hint is present.
* E.g.
* SELECT /*+ aggOptions(is_enable_group_trim='true') * /
* col1, COUNT(*) AS cnt FROM myTable GROUP BY col1 ORDER BY cnt DESC LIMIT 10
* It will extract the collations and limit from the Sort node, and set them into the Aggregate node.
* 3. {@link WithoutSort}:
* Matches Aggregate node if there is no match of {@link SortProjectAggregate} or {@link SortAggregate}.
*
* TODO:
* 1. Always enable group trim when the result is guaranteed to be accurate
* 2. Add intermediate stage group trim
* 3. Allow tuning group trim parameters with query hint
*/
public class PinotAggregateExchangeNodeInsertRule {

Expand All @@ -98,26 +120,21 @@ private SortProjectAggregate(RelBuilderFactory factory) {

@Override
public void onMatch(RelOptRuleCall call) {
// Apply this rule for group-by queries with enable group trim hint.
LogicalAggregate aggRel = call.rel(2);
if (aggRel.getGroupSet().isEmpty()) {
return;
}
Map<String, String> hintOptions =
PinotHintStrategyTable.getHintOptions(aggRel.getHints(), PinotHintOptions.AGGREGATE_HINT_OPTIONS);
if (hintOptions == null || !Boolean.parseBoolean(
hintOptions.get(PinotHintOptions.AggregateOptions.ENABLE_GROUP_TRIM))) {
hintOptions.get(PinotHintOptions.AggregateOptions.IS_ENABLE_GROUP_TRIM))) {
return;
}

Sort sortRel = call.rel(0);
Project projectRel = call.rel(1);
List<RexNode> projects = projectRel.getProjects();
List<RelFieldCollation> collations = sortRel.getCollation().getFieldCollations();
if (collations.isEmpty()) {
// Cannot enable group trim without sort key.
return;
}
List<RelFieldCollation> newCollations = new ArrayList<>(collations.size());
for (RelFieldCollation fieldCollation : collations) {
RexNode project = projects.get(fieldCollation.getFieldIndex());
Expand All @@ -133,7 +150,7 @@ public void onMatch(RelOptRuleCall call) {
limit = RexLiteral.intValue(sortRel.fetch);
}
if (limit <= 0) {
// Cannot enable group trim without limit.
// Cannot enable group trim when there is no limit.
return;
}

Expand All @@ -154,30 +171,25 @@ private SortAggregate(RelBuilderFactory factory) {

@Override
public void onMatch(RelOptRuleCall call) {
// Apply this rule for group-by queries with enable group trim hint.
LogicalAggregate aggRel = call.rel(1);
if (aggRel.getGroupSet().isEmpty()) {
return;
}
Map<String, String> hintOptions =
PinotHintStrategyTable.getHintOptions(aggRel.getHints(), PinotHintOptions.AGGREGATE_HINT_OPTIONS);
if (hintOptions == null || !Boolean.parseBoolean(
hintOptions.get(PinotHintOptions.AggregateOptions.ENABLE_GROUP_TRIM))) {
hintOptions.get(PinotHintOptions.AggregateOptions.IS_ENABLE_GROUP_TRIM))) {
return;
}

Sort sortRel = call.rel(0);
List<RelFieldCollation> collations = sortRel.getCollation().getFieldCollations();
if (collations.isEmpty()) {
// Cannot enable group trim without sort key.
return;
}
int limit = 0;
if (sortRel.fetch != null) {
limit = RexLiteral.intValue(sortRel.fetch);
}
if (limit <= 0) {
// Cannot enable group trim without limit.
// Cannot enable group trim when there is no limit.
return;
}

Expand Down Expand Up @@ -211,7 +223,7 @@ private static PinotLogicalAggregate createPlan(RelOptRuleCall call, Aggregate a
// WITHIN GROUP collation is not supported in leaf stage aggregation.
RelCollation withinGroupCollation = extractWithinGroupCollation(aggRel);
if (withinGroupCollation != null || (hasGroupBy && Boolean.parseBoolean(
hintOptions.get(PinotHintOptions.AggregateOptions.SKIP_LEAF_STAGE_GROUP_BY_AGGREGATION)))) {
hintOptions.get(PinotHintOptions.AggregateOptions.IS_SKIP_LEAF_STAGE_GROUP_BY)))) {
return createPlanWithExchangeDirectAggregation(call, aggRel, withinGroupCollation, collations, limit);
} else if (hasGroupBy && Boolean.parseBoolean(
hintOptions.get(PinotHintOptions.AggregateOptions.IS_PARTITIONED_BY_GROUP_BY_KEYS))) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,8 @@ public class AggregateNode extends BasePlanNode {
private final AggType _aggType;
private final boolean _leafReturnFinalResult;

// The following fields are for group trimming purpose, and are extracted from the Sort on top of this Aggregate.
// The following fields are set when group trim is enabled, and are extracted from the Sort on top of this Aggregate.
// The group trim behavior at leaf stage is shared with single-stage engine.
private final List<RelFieldCollation> _collations;
private final int _limit;

Expand Down
18 changes: 17 additions & 1 deletion pinot-query-planner/src/test/resources/queries/GroupByPlans.json
Original file line number Diff line number Diff line change
Expand Up @@ -252,7 +252,7 @@
},
{
"description": "SQL hint based group by optimization with partitioned aggregated values and group trim enabled",
"sql": "EXPLAIN PLAN FOR SELECT /*+ aggOptions(is_leaf_return_final_result='true', is_enable_group_trim='true') */ col1, COUNT(DISTINCT col2) AS cnt FROM a WHERE a.col3 >= 0 GROUP BY col1 ORDER BY cnt DESC LIMIT 10",
"sql": "EXPLAIN PLAN FOR SELECT /*+ aggOptions(is_leaf_return_final_result='true', is_enable_group_trim='true') */ col1, COUNT(DISTINCT col2) AS cnt FROM a WHERE col3 >= 0 GROUP BY col1 ORDER BY cnt DESC LIMIT 10",
"output": [
"Execution Plan",
"\nLogicalSort(sort0=[$1], dir0=[DESC], offset=[0], fetch=[10])",
Expand Down Expand Up @@ -282,6 +282,22 @@
"\n LogicalTableScan(table=[[default, a]])",
"\n"
]
},
{
"description": "SQL hint based distinct optimization with group trim enabled",
"sql": "EXPLAIN PLAN FOR SELECT /*+ aggOptions(is_enable_group_trim='true') */ DISTINCT col1, col2 FROM a WHERE col3 >= 0 LIMIT 10",
"output": [
"Execution Plan",
"\nLogicalSort(offset=[0], fetch=[10])",
"\n PinotLogicalSortExchange(distribution=[hash], collation=[[]], isSortOnSender=[false], isSortOnReceiver=[false])",
"\n LogicalSort(fetch=[10])",
"\n PinotLogicalAggregate(group=[{0, 1}], aggType=[FINAL], collations=[[]], limit=[10])",
"\n PinotLogicalExchange(distribution=[hash[0, 1]])",
"\n PinotLogicalAggregate(group=[{0, 1}], aggType=[LEAF], collations=[[]], limit=[10])",
"\n LogicalFilter(condition=[>=($2, 0)])",
"\n LogicalTableScan(table=[[default, a]])",
"\n"
]
}
]
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -85,10 +85,12 @@ public Void visitAggregate(AggregateNode node, ServerPlanRequestContext context)
pinotQuery.putToQueryOptions(
CommonConstants.Broker.Request.QueryOptionKey.SERVER_RETURN_FINAL_RESULT_KEY_UNPARTITIONED, "true");
}
List<RelFieldCollation> collations = node.getCollations();
int limit = node.getLimit();
if (!collations.isEmpty() && limit > 0) {
pinotQuery.setOrderByList(CalciteRexExpressionParser.convertOrderByList(collations, pinotQuery));
if (limit > 0) {
List<RelFieldCollation> collations = node.getCollations();
if (!collations.isEmpty()) {
pinotQuery.setOrderByList(CalciteRexExpressionParser.convertOrderByList(collations, pinotQuery));
}
pinotQuery.setLimit(limit);
}
// There cannot be any more modification of PinotQuery post agg, thus this is the last one possible.
Expand Down
Loading

0 comments on commit a1e3675

Please sign in to comment.