From a1d81cec6146dadf94a1a4238e4743300dfae6b2 Mon Sep 17 00:00:00 2001
From: Yash Mayya <yash.mayya@gmail.com>
Date: Tue, 21 Jan 2025 17:22:26 +0530
Subject: [PATCH] Add NTILE window function

---
 .../PinotWindowExchangeNodeInsertRule.java    |  2 +-
 .../calcite/sql/fun/PinotOperatorTable.java   | 14 ++++
 .../pinot/query/QueryCompilationTest.java     | 16 ++++-
 .../window/range/NtileWindowFunction.java     | 72 +++++++++++++++++++
 .../window/range/RankBasedWindowFunction.java |  1 +
 .../operator/WindowAggregateOperatorTest.java | 62 ++++++++++++++++
 .../resources/queries/WindowFunctions.json    | 22 ++++++
 7 files changed, 186 insertions(+), 3 deletions(-)
 create mode 100644 pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/window/range/NtileWindowFunction.java

diff --git a/pinot-query-planner/src/main/java/org/apache/pinot/calcite/rel/rules/PinotWindowExchangeNodeInsertRule.java b/pinot-query-planner/src/main/java/org/apache/pinot/calcite/rel/rules/PinotWindowExchangeNodeInsertRule.java
index 78c8da3d8e7b..67857745900f 100644
--- a/pinot-query-planner/src/main/java/org/apache/pinot/calcite/rel/rules/PinotWindowExchangeNodeInsertRule.java
+++ b/pinot-query-planner/src/main/java/org/apache/pinot/calcite/rel/rules/PinotWindowExchangeNodeInsertRule.java
@@ -70,7 +70,7 @@ public class PinotWindowExchangeNodeInsertRule extends RelOptRule {
   // OTHER_FUNCTION supported are: BOOL_AND, BOOL_OR
   private static final EnumSet<SqlKind> SUPPORTED_WINDOW_FUNCTION_KIND =
       EnumSet.of(SqlKind.SUM, SqlKind.SUM0, SqlKind.MIN, SqlKind.MAX, SqlKind.COUNT, SqlKind.ROW_NUMBER, SqlKind.RANK,
-          SqlKind.DENSE_RANK, SqlKind.LAG, SqlKind.LEAD, SqlKind.FIRST_VALUE, SqlKind.LAST_VALUE,
+          SqlKind.DENSE_RANK, SqlKind.NTILE, SqlKind.LAG, SqlKind.LEAD, SqlKind.FIRST_VALUE, SqlKind.LAST_VALUE,
           SqlKind.OTHER_FUNCTION);
 
   public PinotWindowExchangeNodeInsertRule(RelBuilderFactory factory) {
diff --git a/pinot-query-planner/src/main/java/org/apache/pinot/calcite/sql/fun/PinotOperatorTable.java b/pinot-query-planner/src/main/java/org/apache/pinot/calcite/sql/fun/PinotOperatorTable.java
index fc861d5d2e7e..aadeeb6127cf 100644
--- a/pinot-query-planner/src/main/java/org/apache/pinot/calcite/sql/fun/PinotOperatorTable.java
+++ b/pinot-query-planner/src/main/java/org/apache/pinot/calcite/sql/fun/PinotOperatorTable.java
@@ -35,6 +35,7 @@
 import org.apache.calcite.sql.SqlSyntax;
 import org.apache.calcite.sql.fun.SqlLeadLagAggFunction;
 import org.apache.calcite.sql.fun.SqlMonotonicBinaryOperator;
+import org.apache.calcite.sql.fun.SqlNtileAggFunction;
 import org.apache.calcite.sql.fun.SqlStdOperatorTable;
 import org.apache.calcite.sql.type.InferTypes;
 import org.apache.calcite.sql.type.OperandTypes;
@@ -175,6 +176,10 @@ public static PinotOperatorTable instance() {
       SqlStdOperatorTable.DENSE_RANK,
       SqlStdOperatorTable.RANK,
       SqlStdOperatorTable.ROW_NUMBER,
+      // The Calcite standard NTILE operator doesn't override the allowsFraming method, so we need to define our own.
+      // The NTILE operator doesn't allow custom window frames in other SQL databases as well, so this is probably a
+      // mistake in Calcite.
+      PinotNtileWindowFunction.INSTANCE,
 
       // WINDOW Functions (non-aggregate)
       SqlStdOperatorTable.LAST_VALUE,
@@ -434,4 +439,13 @@ public boolean allowsNullTreatment() {
       return false;
     }
   }
+
+  private static final class PinotNtileWindowFunction extends SqlNtileAggFunction {
+    static final SqlOperator INSTANCE = new PinotNtileWindowFunction();
+
+    @Override
+    public boolean allowsFraming() {
+      return false;
+    }
+  }
 }
diff --git a/pinot-query-planner/src/test/java/org/apache/pinot/query/QueryCompilationTest.java b/pinot-query-planner/src/test/java/org/apache/pinot/query/QueryCompilationTest.java
index 0be64ad20f12..e490fb6b12db 100644
--- a/pinot-query-planner/src/test/java/org/apache/pinot/query/QueryCompilationTest.java
+++ b/pinot-query-planner/src/test/java/org/apache/pinot/query/QueryCompilationTest.java
@@ -395,7 +395,7 @@ public void testDuplicateWithAlias() {
   }
 
   @Test
-  public void testWindowFunctionsWithCustomWindowFrame() {
+  public void testWindowFunctions() {
     String queryWithDefaultWindow = "SELECT col1, col2, RANK() OVER (PARTITION BY col1 ORDER BY col2) FROM a";
     _queryEnvironment.planQuery(queryWithDefaultWindow);
 
@@ -441,7 +441,7 @@ public void testWindowFunctionsWithCustomWindowFrame() {
     assertTrue(e.getCause().getCause().getMessage()
         .contains("RANGE window frame with offset PRECEDING / FOLLOWING is not supported"));
 
-    // RANK, DENSE_RANK, ROW_NUMBER, LAG, LEAD with custom window frame are invalid
+    // RANK, DENSE_RANK, ROW_NUMBER, NTILE, LAG, LEAD with custom window frame are invalid
     String rankQuery =
         "SELECT col1, col2, RANK() OVER (PARTITION BY col1 ORDER BY col2 ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING) "
             + "FROM a";
@@ -460,6 +460,12 @@ public void testWindowFunctionsWithCustomWindowFrame() {
     e = expectThrows(RuntimeException.class, () -> _queryEnvironment.planQuery(rowNumberQuery));
     assertTrue(e.getCause().getMessage().contains("ROW/RANGE not allowed"));
 
+    String ntileQuery =
+        "SELECT col1, col2, NTILE(10) OVER (PARTITION BY col1 ORDER BY col2 RANGE BETWEEN UNBOUNDED PRECEDING AND "
+            + "CURRENT ROW) FROM a";
+    e = expectThrows(RuntimeException.class, () -> _queryEnvironment.planQuery(ntileQuery));
+    assertTrue(e.getCause().getMessage().contains("ROW/RANGE not allowed"));
+
     String lagQuery =
         "SELECT col1, col2, LAG(col2, 1) OVER (PARTITION BY col1 ORDER BY col2 ROWS BETWEEN UNBOUNDED PRECEDING AND "
             + "UNBOUNDED FOLLOWING) FROM a";
@@ -471,6 +477,12 @@ public void testWindowFunctionsWithCustomWindowFrame() {
             + "UNBOUNDED FOLLOWING) FROM a";
     e = expectThrows(RuntimeException.class, () -> _queryEnvironment.planQuery(leadQuery));
     assertTrue(e.getCause().getMessage().contains("ROW/RANGE not allowed"));
+
+    String ntileQueryWithNoArg =
+        "SELECT col1, col2, NTILE() OVER (PARTITION BY col1 ORDER BY col2 RANGE BETWEEN UNBOUNDED PRECEDING AND "
+            + "CURRENT ROW) FROM a";
+    e = expectThrows(RuntimeException.class, () -> _queryEnvironment.planQuery(ntileQueryWithNoArg));
+    assertTrue(e.getCause().getMessage().contains("expecting 1 argument"));
   }
 
   // --------------------------------------------------------------------------
diff --git a/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/window/range/NtileWindowFunction.java b/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/window/range/NtileWindowFunction.java
new file mode 100644
index 000000000000..e411e04abdb1
--- /dev/null
+++ b/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/window/range/NtileWindowFunction.java
@@ -0,0 +1,72 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.pinot.query.runtime.operator.window.range;
+
+import com.google.common.base.Preconditions;
+import java.util.ArrayList;
+import java.util.List;
+import org.apache.calcite.rel.RelFieldCollation;
+import org.apache.pinot.common.utils.DataSchema;
+import org.apache.pinot.query.planner.logical.RexExpression;
+import org.apache.pinot.query.runtime.operator.window.WindowFrame;
+
+
+public class NtileWindowFunction extends RankBasedWindowFunction {
+
+  private final int _numBuckets;
+
+  public NtileWindowFunction(RexExpression.FunctionCall aggCall, DataSchema inputSchema,
+      List<RelFieldCollation> collations, WindowFrame windowFrame) {
+    super(aggCall, inputSchema, collations, windowFrame);
+
+    List<RexExpression> operands = aggCall.getFunctionOperands();
+    Preconditions.checkArgument(operands.size() == 1, "NTILE function must have exactly 1 operand");
+    RexExpression firstOperand = operands.get(0);
+    Preconditions.checkArgument(firstOperand instanceof RexExpression.Literal,
+        "The operand for the NTILE function must be a literal");
+    Object operandValue = ((RexExpression.Literal) firstOperand).getValue();
+    Preconditions.checkArgument(operandValue instanceof Number, "The operand for the NTILE function must be a number");
+    _numBuckets = ((Number) operandValue).intValue();
+  }
+
+  @Override
+  public List<Object> processRows(List<Object[]> rows) {
+    int numRows = rows.size();
+    List<Object> result = new ArrayList<>(numRows);
+    int bucketSize = numRows / _numBuckets;
+
+    if (numRows % _numBuckets == 0) {
+      for (int i = 0; i < numRows; i++) {
+        result.add(i / bucketSize + 1);
+      }
+    } else {
+      int numLargeBuckets = numRows % _numBuckets;
+      int largeBucketSize = bucketSize + 1;
+      for (int i = 0; i < numRows; i++) {
+        if (i < numLargeBuckets * largeBucketSize) {
+          result.add(i / largeBucketSize + 1);
+        } else {
+          result.add(numLargeBuckets + (i - numLargeBuckets * largeBucketSize) / bucketSize + 1);
+        }
+      }
+    }
+
+    return result;
+  }
+}
diff --git a/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/window/range/RankBasedWindowFunction.java b/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/window/range/RankBasedWindowFunction.java
index 455328206274..e45221e49325 100644
--- a/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/window/range/RankBasedWindowFunction.java
+++ b/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/window/range/RankBasedWindowFunction.java
@@ -40,6 +40,7 @@ public abstract class RankBasedWindowFunction extends WindowFunction {
           .put("ROW_NUMBER", RowNumberWindowFunction.class)
           .put("RANK", RankWindowFunction.class)
           .put("DENSE_RANK", DenseRankWindowFunction.class)
+          .put("NTILE", NtileWindowFunction.class)
           .build();
   //@formatter:on
 
diff --git a/pinot-query-runtime/src/test/java/org/apache/pinot/query/runtime/operator/WindowAggregateOperatorTest.java b/pinot-query-runtime/src/test/java/org/apache/pinot/query/runtime/operator/WindowAggregateOperatorTest.java
index abb38b8d7d55..9e87a1283ff3 100644
--- a/pinot-query-runtime/src/test/java/org/apache/pinot/query/runtime/operator/WindowAggregateOperatorTest.java
+++ b/pinot-query-runtime/src/test/java/org/apache/pinot/query/runtime/operator/WindowAggregateOperatorTest.java
@@ -2756,6 +2756,68 @@ public void testLastValueIgnoreNullsWithOffsetFollowingLowerAndOffsetFollowingUp
     assertTrue(operator.nextBlock().isSuccessfulEndOfStreamBlock(), "Second block is EOS (done processing)");
   }
 
+  @Test
+  public void testNtile() {
+    // Given:
+    WindowAggregateOperator operator = prepareDataForWindowFunction(new String[]{"name", "value"},
+        new ColumnDataType[]{STRING, INT}, INT, List.of(0), 1, ROWS, 0, 0,
+        new RexExpression.FunctionCall(ColumnDataType.INT, SqlKind.NTILE.name(),
+            List.of(new RexExpression.Literal(INT, 3)), false, false),
+        new Object[][]{
+            new Object[]{"A", 1},
+            new Object[]{"A", 2},
+            new Object[]{"A", 3},
+            new Object[]{"A", 4},
+            new Object[]{"A", 5},
+            new Object[]{"A", 6},
+            new Object[]{"A", 7},
+            new Object[]{"A", 8},
+            new Object[]{"A", 9},
+            new Object[]{"A", 10},
+            new Object[]{"A", 11},
+            new Object[]{"B", 1},
+            new Object[]{"B", 2},
+            new Object[]{"B", 3},
+            new Object[]{"B", 4},
+            new Object[]{"B", 5},
+            new Object[]{"B", 6},
+            new Object[]{"C", 1},
+            new Object[]{"C", 2}
+        });
+
+    // When:
+    List<Object[]> resultRows = operator.nextBlock().getContainer();
+
+    // Then:
+    verifyResultRows(resultRows, List.of(0), Map.of(
+        "A", List.of(
+            new Object[]{"A", 1, 1},
+            new Object[]{"A", 2, 1},
+            new Object[]{"A", 3, 1},
+            new Object[]{"A", 4, 1},
+            new Object[]{"A", 5, 2},
+            new Object[]{"A", 6, 2},
+            new Object[]{"A", 7, 2},
+            new Object[]{"A", 8, 2},
+            new Object[]{"A", 9, 3},
+            new Object[]{"A", 10, 3},
+            new Object[]{"A", 11, 3}
+        ),
+        "B", List.of(
+            new Object[]{"B", 1, 1},
+            new Object[]{"B", 2, 1},
+            new Object[]{"B", 3, 2},
+            new Object[]{"B", 4, 2},
+            new Object[]{"B", 5, 3},
+            new Object[]{"B", 6, 3}
+        ),
+        "C", List.of(
+            new Object[]{"C", 1, 1},
+            new Object[]{"C", 2, 2}
+        )));
+    assertTrue(operator.nextBlock().isSuccessfulEndOfStreamBlock(), "Second block is EOS (done processing)");
+  }
+
   private WindowAggregateOperator prepareDataForWindowFunction(String[] inputSchemaCols,
       ColumnDataType[] inputSchemaColTypes, ColumnDataType outputType, List<Integer> partitionKeys,
       int collationFieldIndex, WindowNode.WindowFrameType frameType, int windowFrameLowerBound,
diff --git a/pinot-query-runtime/src/test/resources/queries/WindowFunctions.json b/pinot-query-runtime/src/test/resources/queries/WindowFunctions.json
index c1cbe9cb0da8..c87aeece6530 100644
--- a/pinot-query-runtime/src/test/resources/queries/WindowFunctions.json
+++ b/pinot-query-runtime/src/test/resources/queries/WindowFunctions.json
@@ -5460,6 +5460,28 @@
           ["g", 100.0, 10, null],
           ["h", -1.53, null, null]
         ]
+      },
+      {
+        "description": "NTILE with 2 buckets",
+        "sql": "SELECT string_col, int_col, NTILE(2) OVER(PARTITION BY string_col ORDER BY int_col) FROM {tbl} ORDER BY string_col, int_col",
+        "outputs": [
+          ["a", 2, 1],
+          ["a", 2, 1],
+          ["a", 42, 1],
+          ["a", 42, 2],
+          ["a", 42, 2],
+          ["b", 3, 1],
+          ["b", 100, 2],
+          ["c", -101, 1],
+          ["c", 2, 1],
+          ["c", 3, 2],
+          ["c", 150, 2],
+          ["d", 42, 1],
+          ["e", 42, 1],
+          ["e", 42, 2],
+          ["g", 3, 1],
+          ["h", 150, 1]
+        ]
       }
     ]
   }