Skip to content

Commit

Permalink
NTILE window function (#14850)
Browse files Browse the repository at this point in the history
  • Loading branch information
yashmayya authored Jan 22, 2025
1 parent 7cea247 commit 17f94a9
Show file tree
Hide file tree
Showing 7 changed files with 186 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ public class PinotWindowExchangeNodeInsertRule extends RelOptRule {
// OTHER_FUNCTION supported are: BOOL_AND, BOOL_OR
private static final EnumSet<SqlKind> SUPPORTED_WINDOW_FUNCTION_KIND =
EnumSet.of(SqlKind.SUM, SqlKind.SUM0, SqlKind.MIN, SqlKind.MAX, SqlKind.COUNT, SqlKind.ROW_NUMBER, SqlKind.RANK,
SqlKind.DENSE_RANK, SqlKind.LAG, SqlKind.LEAD, SqlKind.FIRST_VALUE, SqlKind.LAST_VALUE,
SqlKind.DENSE_RANK, SqlKind.NTILE, SqlKind.LAG, SqlKind.LEAD, SqlKind.FIRST_VALUE, SqlKind.LAST_VALUE,
SqlKind.OTHER_FUNCTION);

public PinotWindowExchangeNodeInsertRule(RelBuilderFactory factory) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
import org.apache.calcite.sql.SqlSyntax;
import org.apache.calcite.sql.fun.SqlLeadLagAggFunction;
import org.apache.calcite.sql.fun.SqlMonotonicBinaryOperator;
import org.apache.calcite.sql.fun.SqlNtileAggFunction;
import org.apache.calcite.sql.fun.SqlStdOperatorTable;
import org.apache.calcite.sql.type.InferTypes;
import org.apache.calcite.sql.type.OperandTypes;
Expand Down Expand Up @@ -175,6 +176,10 @@ public static PinotOperatorTable instance() {
SqlStdOperatorTable.DENSE_RANK,
SqlStdOperatorTable.RANK,
SqlStdOperatorTable.ROW_NUMBER,
// The Calcite standard NTILE operator doesn't override the allowsFraming method, so we need to define our own.
// The NTILE operator doesn't allow custom window frames in other SQL databases as well, so this is probably a
// mistake in Calcite.
PinotNtileWindowFunction.INSTANCE,

// WINDOW Functions (non-aggregate)
SqlStdOperatorTable.LAST_VALUE,
Expand Down Expand Up @@ -434,4 +439,13 @@ public boolean allowsNullTreatment() {
return false;
}
}

private static final class PinotNtileWindowFunction extends SqlNtileAggFunction {
static final SqlOperator INSTANCE = new PinotNtileWindowFunction();

@Override
public boolean allowsFraming() {
return false;
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -395,7 +395,7 @@ public void testDuplicateWithAlias() {
}

@Test
public void testWindowFunctionsWithCustomWindowFrame() {
public void testWindowFunctions() {
String queryWithDefaultWindow = "SELECT col1, col2, RANK() OVER (PARTITION BY col1 ORDER BY col2) FROM a";
_queryEnvironment.planQuery(queryWithDefaultWindow);

Expand Down Expand Up @@ -441,7 +441,7 @@ public void testWindowFunctionsWithCustomWindowFrame() {
assertTrue(e.getCause().getCause().getMessage()
.contains("RANGE window frame with offset PRECEDING / FOLLOWING is not supported"));

// RANK, DENSE_RANK, ROW_NUMBER, LAG, LEAD with custom window frame are invalid
// RANK, DENSE_RANK, ROW_NUMBER, NTILE, LAG, LEAD with custom window frame are invalid
String rankQuery =
"SELECT col1, col2, RANK() OVER (PARTITION BY col1 ORDER BY col2 ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING) "
+ "FROM a";
Expand All @@ -460,6 +460,12 @@ public void testWindowFunctionsWithCustomWindowFrame() {
e = expectThrows(RuntimeException.class, () -> _queryEnvironment.planQuery(rowNumberQuery));
assertTrue(e.getCause().getMessage().contains("ROW/RANGE not allowed"));

String ntileQuery =
"SELECT col1, col2, NTILE(10) OVER (PARTITION BY col1 ORDER BY col2 RANGE BETWEEN UNBOUNDED PRECEDING AND "
+ "CURRENT ROW) FROM a";
e = expectThrows(RuntimeException.class, () -> _queryEnvironment.planQuery(ntileQuery));
assertTrue(e.getCause().getMessage().contains("ROW/RANGE not allowed"));

String lagQuery =
"SELECT col1, col2, LAG(col2, 1) OVER (PARTITION BY col1 ORDER BY col2 ROWS BETWEEN UNBOUNDED PRECEDING AND "
+ "UNBOUNDED FOLLOWING) FROM a";
Expand All @@ -471,6 +477,12 @@ public void testWindowFunctionsWithCustomWindowFrame() {
+ "UNBOUNDED FOLLOWING) FROM a";
e = expectThrows(RuntimeException.class, () -> _queryEnvironment.planQuery(leadQuery));
assertTrue(e.getCause().getMessage().contains("ROW/RANGE not allowed"));

String ntileQueryWithNoArg =
"SELECT col1, col2, NTILE() OVER (PARTITION BY col1 ORDER BY col2 RANGE BETWEEN UNBOUNDED PRECEDING AND "
+ "CURRENT ROW) FROM a";
e = expectThrows(RuntimeException.class, () -> _queryEnvironment.planQuery(ntileQueryWithNoArg));
assertTrue(e.getCause().getMessage().contains("expecting 1 argument"));
}

// --------------------------------------------------------------------------
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
/**
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.pinot.query.runtime.operator.window.range;

import com.google.common.base.Preconditions;
import java.util.ArrayList;
import java.util.List;
import org.apache.calcite.rel.RelFieldCollation;
import org.apache.pinot.common.utils.DataSchema;
import org.apache.pinot.query.planner.logical.RexExpression;
import org.apache.pinot.query.runtime.operator.window.WindowFrame;


public class NtileWindowFunction extends RankBasedWindowFunction {

private final int _numBuckets;

public NtileWindowFunction(RexExpression.FunctionCall aggCall, DataSchema inputSchema,
List<RelFieldCollation> collations, WindowFrame windowFrame) {
super(aggCall, inputSchema, collations, windowFrame);

List<RexExpression> operands = aggCall.getFunctionOperands();
Preconditions.checkArgument(operands.size() == 1, "NTILE function must have exactly 1 operand");
RexExpression firstOperand = operands.get(0);
Preconditions.checkArgument(firstOperand instanceof RexExpression.Literal,
"The operand for the NTILE function must be a literal");
Object operandValue = ((RexExpression.Literal) firstOperand).getValue();
Preconditions.checkArgument(operandValue instanceof Number, "The operand for the NTILE function must be a number");
_numBuckets = ((Number) operandValue).intValue();
}

@Override
public List<Object> processRows(List<Object[]> rows) {
int numRows = rows.size();
List<Object> result = new ArrayList<>(numRows);
int bucketSize = numRows / _numBuckets;

if (numRows % _numBuckets == 0) {
for (int i = 0; i < numRows; i++) {
result.add(i / bucketSize + 1);
}
} else {
int numLargeBuckets = numRows % _numBuckets;
int largeBucketSize = bucketSize + 1;
for (int i = 0; i < numRows; i++) {
if (i < numLargeBuckets * largeBucketSize) {
result.add(i / largeBucketSize + 1);
} else {
result.add(numLargeBuckets + (i - numLargeBuckets * largeBucketSize) / bucketSize + 1);
}
}
}

return result;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ public abstract class RankBasedWindowFunction extends WindowFunction {
.put("ROW_NUMBER", RowNumberWindowFunction.class)
.put("RANK", RankWindowFunction.class)
.put("DENSE_RANK", DenseRankWindowFunction.class)
.put("NTILE", NtileWindowFunction.class)
.build();
//@formatter:on

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2756,6 +2756,68 @@ public void testLastValueIgnoreNullsWithOffsetFollowingLowerAndOffsetFollowingUp
assertTrue(operator.nextBlock().isSuccessfulEndOfStreamBlock(), "Second block is EOS (done processing)");
}

@Test
public void testNtile() {
// Given:
WindowAggregateOperator operator = prepareDataForWindowFunction(new String[]{"name", "value"},
new ColumnDataType[]{STRING, INT}, INT, List.of(0), 1, ROWS, 0, 0,
new RexExpression.FunctionCall(ColumnDataType.INT, SqlKind.NTILE.name(),
List.of(new RexExpression.Literal(INT, 3)), false, false),
new Object[][]{
new Object[]{"A", 1},
new Object[]{"A", 2},
new Object[]{"A", 3},
new Object[]{"A", 4},
new Object[]{"A", 5},
new Object[]{"A", 6},
new Object[]{"A", 7},
new Object[]{"A", 8},
new Object[]{"A", 9},
new Object[]{"A", 10},
new Object[]{"A", 11},
new Object[]{"B", 1},
new Object[]{"B", 2},
new Object[]{"B", 3},
new Object[]{"B", 4},
new Object[]{"B", 5},
new Object[]{"B", 6},
new Object[]{"C", 1},
new Object[]{"C", 2}
});

// When:
List<Object[]> resultRows = operator.nextBlock().getContainer();

// Then:
verifyResultRows(resultRows, List.of(0), Map.of(
"A", List.of(
new Object[]{"A", 1, 1},
new Object[]{"A", 2, 1},
new Object[]{"A", 3, 1},
new Object[]{"A", 4, 1},
new Object[]{"A", 5, 2},
new Object[]{"A", 6, 2},
new Object[]{"A", 7, 2},
new Object[]{"A", 8, 2},
new Object[]{"A", 9, 3},
new Object[]{"A", 10, 3},
new Object[]{"A", 11, 3}
),
"B", List.of(
new Object[]{"B", 1, 1},
new Object[]{"B", 2, 1},
new Object[]{"B", 3, 2},
new Object[]{"B", 4, 2},
new Object[]{"B", 5, 3},
new Object[]{"B", 6, 3}
),
"C", List.of(
new Object[]{"C", 1, 1},
new Object[]{"C", 2, 2}
)));
assertTrue(operator.nextBlock().isSuccessfulEndOfStreamBlock(), "Second block is EOS (done processing)");
}

private WindowAggregateOperator prepareDataForWindowFunction(String[] inputSchemaCols,
ColumnDataType[] inputSchemaColTypes, ColumnDataType outputType, List<Integer> partitionKeys,
int collationFieldIndex, WindowNode.WindowFrameType frameType, int windowFrameLowerBound,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5460,6 +5460,28 @@
["g", 100.0, 10, null],
["h", -1.53, null, null]
]
},
{
"description": "NTILE with 2 buckets",
"sql": "SELECT string_col, int_col, NTILE(2) OVER(PARTITION BY string_col ORDER BY int_col) FROM {tbl} ORDER BY string_col, int_col",
"outputs": [
["a", 2, 1],
["a", 2, 1],
["a", 42, 1],
["a", 42, 2],
["a", 42, 2],
["b", 3, 1],
["b", 100, 2],
["c", -101, 1],
["c", 2, 1],
["c", 3, 2],
["c", 150, 2],
["d", 42, 1],
["e", 42, 1],
["e", 42, 2],
["g", 3, 1],
["h", 150, 1]
]
}
]
}
Expand Down

0 comments on commit 17f94a9

Please sign in to comment.