diff --git a/pinot-query-planner/src/main/java/org/apache/pinot/calcite/rel/rules/PinotWindowExchangeNodeInsertRule.java b/pinot-query-planner/src/main/java/org/apache/pinot/calcite/rel/rules/PinotWindowExchangeNodeInsertRule.java index 78c8da3d8e7b..67857745900f 100644 --- a/pinot-query-planner/src/main/java/org/apache/pinot/calcite/rel/rules/PinotWindowExchangeNodeInsertRule.java +++ b/pinot-query-planner/src/main/java/org/apache/pinot/calcite/rel/rules/PinotWindowExchangeNodeInsertRule.java @@ -70,7 +70,7 @@ public class PinotWindowExchangeNodeInsertRule extends RelOptRule { // OTHER_FUNCTION supported are: BOOL_AND, BOOL_OR private static final EnumSet SUPPORTED_WINDOW_FUNCTION_KIND = EnumSet.of(SqlKind.SUM, SqlKind.SUM0, SqlKind.MIN, SqlKind.MAX, SqlKind.COUNT, SqlKind.ROW_NUMBER, SqlKind.RANK, - SqlKind.DENSE_RANK, SqlKind.LAG, SqlKind.LEAD, SqlKind.FIRST_VALUE, SqlKind.LAST_VALUE, + SqlKind.DENSE_RANK, SqlKind.NTILE, SqlKind.LAG, SqlKind.LEAD, SqlKind.FIRST_VALUE, SqlKind.LAST_VALUE, SqlKind.OTHER_FUNCTION); public PinotWindowExchangeNodeInsertRule(RelBuilderFactory factory) { diff --git a/pinot-query-planner/src/main/java/org/apache/pinot/calcite/sql/fun/PinotOperatorTable.java b/pinot-query-planner/src/main/java/org/apache/pinot/calcite/sql/fun/PinotOperatorTable.java index fc861d5d2e7e..aadeeb6127cf 100644 --- a/pinot-query-planner/src/main/java/org/apache/pinot/calcite/sql/fun/PinotOperatorTable.java +++ b/pinot-query-planner/src/main/java/org/apache/pinot/calcite/sql/fun/PinotOperatorTable.java @@ -35,6 +35,7 @@ import org.apache.calcite.sql.SqlSyntax; import org.apache.calcite.sql.fun.SqlLeadLagAggFunction; import org.apache.calcite.sql.fun.SqlMonotonicBinaryOperator; +import org.apache.calcite.sql.fun.SqlNtileAggFunction; import org.apache.calcite.sql.fun.SqlStdOperatorTable; import org.apache.calcite.sql.type.InferTypes; import org.apache.calcite.sql.type.OperandTypes; @@ -175,6 +176,10 @@ public static PinotOperatorTable instance() { SqlStdOperatorTable.DENSE_RANK, SqlStdOperatorTable.RANK, SqlStdOperatorTable.ROW_NUMBER, + // The Calcite standard NTILE operator doesn't override the allowsFraming method, so we need to define our own. + // The NTILE operator doesn't allow custom window frames in other SQL databases as well, so this is probably a + // mistake in Calcite. + PinotNtileWindowFunction.INSTANCE, // WINDOW Functions (non-aggregate) SqlStdOperatorTable.LAST_VALUE, @@ -434,4 +439,13 @@ public boolean allowsNullTreatment() { return false; } } + + private static final class PinotNtileWindowFunction extends SqlNtileAggFunction { + static final SqlOperator INSTANCE = new PinotNtileWindowFunction(); + + @Override + public boolean allowsFraming() { + return false; + } + } } diff --git a/pinot-query-planner/src/test/java/org/apache/pinot/query/QueryCompilationTest.java b/pinot-query-planner/src/test/java/org/apache/pinot/query/QueryCompilationTest.java index 0be64ad20f12..e490fb6b12db 100644 --- a/pinot-query-planner/src/test/java/org/apache/pinot/query/QueryCompilationTest.java +++ b/pinot-query-planner/src/test/java/org/apache/pinot/query/QueryCompilationTest.java @@ -395,7 +395,7 @@ public void testDuplicateWithAlias() { } @Test - public void testWindowFunctionsWithCustomWindowFrame() { + public void testWindowFunctions() { String queryWithDefaultWindow = "SELECT col1, col2, RANK() OVER (PARTITION BY col1 ORDER BY col2) FROM a"; _queryEnvironment.planQuery(queryWithDefaultWindow); @@ -441,7 +441,7 @@ public void testWindowFunctionsWithCustomWindowFrame() { assertTrue(e.getCause().getCause().getMessage() .contains("RANGE window frame with offset PRECEDING / FOLLOWING is not supported")); - // RANK, DENSE_RANK, ROW_NUMBER, LAG, LEAD with custom window frame are invalid + // RANK, DENSE_RANK, ROW_NUMBER, NTILE, LAG, LEAD with custom window frame are invalid String rankQuery = "SELECT col1, col2, RANK() OVER (PARTITION BY col1 ORDER BY col2 ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING) " + "FROM a"; @@ -460,6 +460,12 @@ public void testWindowFunctionsWithCustomWindowFrame() { e = expectThrows(RuntimeException.class, () -> _queryEnvironment.planQuery(rowNumberQuery)); assertTrue(e.getCause().getMessage().contains("ROW/RANGE not allowed")); + String ntileQuery = + "SELECT col1, col2, NTILE(10) OVER (PARTITION BY col1 ORDER BY col2 RANGE BETWEEN UNBOUNDED PRECEDING AND " + + "CURRENT ROW) FROM a"; + e = expectThrows(RuntimeException.class, () -> _queryEnvironment.planQuery(ntileQuery)); + assertTrue(e.getCause().getMessage().contains("ROW/RANGE not allowed")); + String lagQuery = "SELECT col1, col2, LAG(col2, 1) OVER (PARTITION BY col1 ORDER BY col2 ROWS BETWEEN UNBOUNDED PRECEDING AND " + "UNBOUNDED FOLLOWING) FROM a"; @@ -471,6 +477,12 @@ public void testWindowFunctionsWithCustomWindowFrame() { + "UNBOUNDED FOLLOWING) FROM a"; e = expectThrows(RuntimeException.class, () -> _queryEnvironment.planQuery(leadQuery)); assertTrue(e.getCause().getMessage().contains("ROW/RANGE not allowed")); + + String ntileQueryWithNoArg = + "SELECT col1, col2, NTILE() OVER (PARTITION BY col1 ORDER BY col2 RANGE BETWEEN UNBOUNDED PRECEDING AND " + + "CURRENT ROW) FROM a"; + e = expectThrows(RuntimeException.class, () -> _queryEnvironment.planQuery(ntileQueryWithNoArg)); + assertTrue(e.getCause().getMessage().contains("expecting 1 argument")); } // -------------------------------------------------------------------------- diff --git a/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/window/range/NtileWindowFunction.java b/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/window/range/NtileWindowFunction.java new file mode 100644 index 000000000000..e411e04abdb1 --- /dev/null +++ b/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/window/range/NtileWindowFunction.java @@ -0,0 +1,72 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.pinot.query.runtime.operator.window.range; + +import com.google.common.base.Preconditions; +import java.util.ArrayList; +import java.util.List; +import org.apache.calcite.rel.RelFieldCollation; +import org.apache.pinot.common.utils.DataSchema; +import org.apache.pinot.query.planner.logical.RexExpression; +import org.apache.pinot.query.runtime.operator.window.WindowFrame; + + +public class NtileWindowFunction extends RankBasedWindowFunction { + + private final int _numBuckets; + + public NtileWindowFunction(RexExpression.FunctionCall aggCall, DataSchema inputSchema, + List collations, WindowFrame windowFrame) { + super(aggCall, inputSchema, collations, windowFrame); + + List operands = aggCall.getFunctionOperands(); + Preconditions.checkArgument(operands.size() == 1, "NTILE function must have exactly 1 operand"); + RexExpression firstOperand = operands.get(0); + Preconditions.checkArgument(firstOperand instanceof RexExpression.Literal, + "The operand for the NTILE function must be a literal"); + Object operandValue = ((RexExpression.Literal) firstOperand).getValue(); + Preconditions.checkArgument(operandValue instanceof Number, "The operand for the NTILE function must be a number"); + _numBuckets = ((Number) operandValue).intValue(); + } + + @Override + public List processRows(List rows) { + int numRows = rows.size(); + List result = new ArrayList<>(numRows); + int bucketSize = numRows / _numBuckets; + + if (numRows % _numBuckets == 0) { + for (int i = 0; i < numRows; i++) { + result.add(i / bucketSize + 1); + } + } else { + int numLargeBuckets = numRows % _numBuckets; + int largeBucketSize = bucketSize + 1; + for (int i = 0; i < numRows; i++) { + if (i < numLargeBuckets * largeBucketSize) { + result.add(i / largeBucketSize + 1); + } else { + result.add(numLargeBuckets + (i - numLargeBuckets * largeBucketSize) / bucketSize + 1); + } + } + } + + return result; + } +} diff --git a/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/window/range/RankBasedWindowFunction.java b/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/window/range/RankBasedWindowFunction.java index 455328206274..e45221e49325 100644 --- a/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/window/range/RankBasedWindowFunction.java +++ b/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/window/range/RankBasedWindowFunction.java @@ -40,6 +40,7 @@ public abstract class RankBasedWindowFunction extends WindowFunction { .put("ROW_NUMBER", RowNumberWindowFunction.class) .put("RANK", RankWindowFunction.class) .put("DENSE_RANK", DenseRankWindowFunction.class) + .put("NTILE", NtileWindowFunction.class) .build(); //@formatter:on diff --git a/pinot-query-runtime/src/test/java/org/apache/pinot/query/runtime/operator/WindowAggregateOperatorTest.java b/pinot-query-runtime/src/test/java/org/apache/pinot/query/runtime/operator/WindowAggregateOperatorTest.java index abb38b8d7d55..9e87a1283ff3 100644 --- a/pinot-query-runtime/src/test/java/org/apache/pinot/query/runtime/operator/WindowAggregateOperatorTest.java +++ b/pinot-query-runtime/src/test/java/org/apache/pinot/query/runtime/operator/WindowAggregateOperatorTest.java @@ -2756,6 +2756,68 @@ public void testLastValueIgnoreNullsWithOffsetFollowingLowerAndOffsetFollowingUp assertTrue(operator.nextBlock().isSuccessfulEndOfStreamBlock(), "Second block is EOS (done processing)"); } + @Test + public void testNtile() { + // Given: + WindowAggregateOperator operator = prepareDataForWindowFunction(new String[]{"name", "value"}, + new ColumnDataType[]{STRING, INT}, INT, List.of(0), 1, ROWS, 0, 0, + new RexExpression.FunctionCall(ColumnDataType.INT, SqlKind.NTILE.name(), + List.of(new RexExpression.Literal(INT, 3)), false, false), + new Object[][]{ + new Object[]{"A", 1}, + new Object[]{"A", 2}, + new Object[]{"A", 3}, + new Object[]{"A", 4}, + new Object[]{"A", 5}, + new Object[]{"A", 6}, + new Object[]{"A", 7}, + new Object[]{"A", 8}, + new Object[]{"A", 9}, + new Object[]{"A", 10}, + new Object[]{"A", 11}, + new Object[]{"B", 1}, + new Object[]{"B", 2}, + new Object[]{"B", 3}, + new Object[]{"B", 4}, + new Object[]{"B", 5}, + new Object[]{"B", 6}, + new Object[]{"C", 1}, + new Object[]{"C", 2} + }); + + // When: + List resultRows = operator.nextBlock().getContainer(); + + // Then: + verifyResultRows(resultRows, List.of(0), Map.of( + "A", List.of( + new Object[]{"A", 1, 1}, + new Object[]{"A", 2, 1}, + new Object[]{"A", 3, 1}, + new Object[]{"A", 4, 1}, + new Object[]{"A", 5, 2}, + new Object[]{"A", 6, 2}, + new Object[]{"A", 7, 2}, + new Object[]{"A", 8, 2}, + new Object[]{"A", 9, 3}, + new Object[]{"A", 10, 3}, + new Object[]{"A", 11, 3} + ), + "B", List.of( + new Object[]{"B", 1, 1}, + new Object[]{"B", 2, 1}, + new Object[]{"B", 3, 2}, + new Object[]{"B", 4, 2}, + new Object[]{"B", 5, 3}, + new Object[]{"B", 6, 3} + ), + "C", List.of( + new Object[]{"C", 1, 1}, + new Object[]{"C", 2, 2} + ))); + assertTrue(operator.nextBlock().isSuccessfulEndOfStreamBlock(), "Second block is EOS (done processing)"); + } + private WindowAggregateOperator prepareDataForWindowFunction(String[] inputSchemaCols, ColumnDataType[] inputSchemaColTypes, ColumnDataType outputType, List partitionKeys, int collationFieldIndex, WindowNode.WindowFrameType frameType, int windowFrameLowerBound, diff --git a/pinot-query-runtime/src/test/resources/queries/WindowFunctions.json b/pinot-query-runtime/src/test/resources/queries/WindowFunctions.json index c1cbe9cb0da8..c87aeece6530 100644 --- a/pinot-query-runtime/src/test/resources/queries/WindowFunctions.json +++ b/pinot-query-runtime/src/test/resources/queries/WindowFunctions.json @@ -5460,6 +5460,28 @@ ["g", 100.0, 10, null], ["h", -1.53, null, null] ] + }, + { + "description": "NTILE with 2 buckets", + "sql": "SELECT string_col, int_col, NTILE(2) OVER(PARTITION BY string_col ORDER BY int_col) FROM {tbl} ORDER BY string_col, int_col", + "outputs": [ + ["a", 2, 1], + ["a", 2, 1], + ["a", 42, 1], + ["a", 42, 2], + ["a", 42, 2], + ["b", 3, 1], + ["b", 100, 2], + ["c", -101, 1], + ["c", 2, 1], + ["c", 3, 2], + ["c", 150, 2], + ["d", 42, 1], + ["e", 42, 1], + ["e", 42, 2], + ["g", 3, 1], + ["h", 150, 1] + ] } ] }