diff --git a/pinot-core/src/main/java/org/apache/pinot/core/query/optimizer/filter/NumericalFilterOptimizer.java b/pinot-core/src/main/java/org/apache/pinot/core/query/optimizer/filter/NumericalFilterOptimizer.java index c434897379cb..a220f053d870 100644 --- a/pinot-core/src/main/java/org/apache/pinot/core/query/optimizer/filter/NumericalFilterOptimizer.java +++ b/pinot-core/src/main/java/org/apache/pinot/core/query/optimizer/filter/NumericalFilterOptimizer.java @@ -24,7 +24,6 @@ import org.apache.pinot.common.request.Expression; import org.apache.pinot.common.request.ExpressionType; import org.apache.pinot.common.request.Function; -import org.apache.pinot.common.request.Literal; import org.apache.pinot.spi.data.FieldSpec; import org.apache.pinot.spi.data.FieldSpec.DataType; import org.apache.pinot.spi.data.Schema; @@ -84,16 +83,25 @@ boolean canBeOptimized(Expression filterExpression, @Nullable Schema schema) { Expression optimizeChild(Expression filterExpression, @Nullable Schema schema) { Function function = filterExpression.getFunctionCall(); FilterKind kind = FilterKind.valueOf(function.getOperator()); + + if (!kind.isRange() && kind != FilterKind.EQUALS && kind != FilterKind.NOT_EQUALS) { + return filterExpression; + } + + List operands = function.getOperands(); + // Verify that LHS is a numeric column and RHS is a literal before rewriting. + Expression lhs = operands.get(0); + Expression rhs = operands.get(1); + + DataType dataType = getDataType(lhs, schema); + if (dataType == null || !dataType.isNumeric() || !rhs.isSetLiteral()) { + // No rewrite here + return filterExpression; + } + switch (kind) { case BETWEEN: { - // Verify that value is a numeric column before rewriting. - List operands = function.getOperands(); - Expression value = operands.get(0); - DataType dataType = getDataType(value, schema); - if (dataType != null && dataType.isNumeric()) { - return rewriteBetweenExpression(filterExpression, dataType); - } - break; + return rewriteBetweenExpression(filterExpression, dataType); } case EQUALS: case NOT_EQUALS: @@ -101,22 +109,11 @@ Expression optimizeChild(Expression filterExpression, @Nullable Schema schema) { case GREATER_THAN_OR_EQUAL: case LESS_THAN: case LESS_THAN_OR_EQUAL: { - List operands = function.getOperands(); - // Verify that LHS is a numeric column and RHS is a numeric literal before rewriting. - Expression lhs = operands.get(0); - Expression rhs = operands.get(1); - - if (isNumericLiteral(rhs)) { - DataType dataType = getDataType(lhs, schema); - if (dataType != null && dataType.isNumeric()) { - if (kind.isRange()) { - return rewriteRangeExpression(filterExpression, kind, dataType, rhs); - } else { - return rewriteEqualsExpression(filterExpression, kind, dataType, rhs); - } - } + if (kind.isRange()) { + return rewriteRangeExpression(filterExpression, kind, dataType, rhs); + } else { + return rewriteEqualsExpression(filterExpression, kind, dataType, rhs); } - break; } default: break; @@ -373,6 +370,8 @@ private static Expression rewriteBetweenExpression(Expression between, DataType Expression lower = operands.get(1); Expression upper = operands.get(2); + // The BETWEEN filter predicate currently only supports literals as lower and upper bounds, but we're still checking + // here just in case. if (lower.isSetLiteral()) { switch (lower.getLiteral().getSetField()) { case LONG_VALUE: { @@ -524,9 +523,9 @@ private static void rewriteRangeOperator(Expression range, FilterKind kind, int if (comparison > 0) { // Literal value is greater than the converted value, so rewrite: // "column > literal" to "column > converted" - // "column >= literal" to "column >= converted" + // "column >= literal" to "column > converted" // "column < literal" to "column <= converted" - // "column <= literal" to "column < converted" + // "column <= literal" to "column <= converted" if (kind == FilterKind.GREATER_THAN || kind == FilterKind.GREATER_THAN_OR_EQUAL) { range.getFunctionCall().setOperator(FilterKind.GREATER_THAN.name()); } else if (kind == FilterKind.LESS_THAN || kind == FilterKind.LESS_THAN_OR_EQUAL) { @@ -581,21 +580,4 @@ private static DataType getDataType(Expression expression, Schema schema) { } return null; } - - /** @return true if expression is a numeric literal; otherwise, false. */ - private static boolean isNumericLiteral(Expression expression) { - if (expression.getType() == ExpressionType.LITERAL) { - Literal._Fields type = expression.getLiteral().getSetField(); - switch (type) { - case INT_VALUE: - case LONG_VALUE: - case FLOAT_VALUE: - case DOUBLE_VALUE: - return true; - default: - break; - } - } - return false; - } } diff --git a/pinot-core/src/test/java/org/apache/pinot/core/query/optimizer/filter/NumericalFilterOptimizerTest.java b/pinot-core/src/test/java/org/apache/pinot/core/query/optimizer/filter/NumericalFilterOptimizerTest.java index b8f14f041e44..363ef8887db2 100644 --- a/pinot-core/src/test/java/org/apache/pinot/core/query/optimizer/filter/NumericalFilterOptimizerTest.java +++ b/pinot-core/src/test/java/org/apache/pinot/core/query/optimizer/filter/NumericalFilterOptimizerTest.java @@ -316,6 +316,17 @@ public void testBetweenRewrites() { // Test INT column with DOUBLE upper bound lesser than Integer.MIN_VALUE. Assert.assertEquals(rewrite("SELECT * FROM testTable WHERE intColumn BETWEEN -4000000000.0 AND -3000000000.0"), "Expression(type:LITERAL, literal:)"); + // Test INT column with LONG lower bound lesser than Integer.MIN_VALUE. + Assert.assertEquals(rewrite("SELECT * FROM testTable WHERE intColumn BETWEEN -4000000000 AND 0"), + "Expression(type:FUNCTION, functionCall:Function(operator:BETWEEN, operands:[Expression(type:IDENTIFIER, " + + "identifier:Identifier(name:intColumn)), Expression(type:LITERAL, literal:), Expression(type:LITERAL, literal:)]))"); + // Test INT column with LONG upper bound greater than Integer.MAX_VALUE. + Assert.assertEquals(rewrite("SELECT * FROM testTable WHERE intColumn BETWEEN 0 AND 4000000000"), + "Expression(type:FUNCTION, functionCall:Function(operator:BETWEEN, operands:[Expression(type:IDENTIFIER, " + + "identifier:Identifier(name:intColumn)), Expression(type:LITERAL, literal:), Expression(type:LITERAL, literal:)]))"); + // Test LONG column with DOUBLE lower bound greater than Long.MAX_VALUE. Assert.assertEquals( rewrite("SELECT * FROM testTable WHERE longColumn BETWEEN 9323372036854775808.0 AND 9323372036854775809.0"), @@ -324,6 +335,16 @@ public void testBetweenRewrites() { Assert.assertEquals( rewrite("SELECT * FROM testTable WHERE longColumn BETWEEN -9323372036854775809.0 AND -9323372036854775808.0"), "Expression(type:LITERAL, literal:)"); + // Test LONG column with DOUBLE lower bound lesser than Long.MIN_VALUE. + Assert.assertEquals(rewrite("SELECT * FROM testTable WHERE longColumn BETWEEN -9323372036854775809.0 AND 0"), + "Expression(type:FUNCTION, functionCall:Function(operator:BETWEEN, operands:[Expression(type:IDENTIFIER, " + + "identifier:Identifier(name:longColumn)), Expression(type:LITERAL, literal:), Expression(type:LITERAL, literal:)]))"); + // Test LONG column with DOUBLE upper bound greater than Long.MAX_VALUE. + Assert.assertEquals(rewrite("SELECT * FROM testTable WHERE longColumn BETWEEN 0 AND 9323372036854775808.0"), + "Expression(type:FUNCTION, functionCall:Function(operator:BETWEEN, operands:[Expression(type:IDENTIFIER, " + + "identifier:Identifier(name:longColumn)), Expression(type:LITERAL, literal:), " + + "Expression(type:LITERAL, literal:)]))"); // Test DOUBLE literal rewrite for INT and LONG columns. Assert.assertEquals(rewrite("SELECT * FROM testTable WHERE intColumn BETWEEN 2.5 AND 7.5"), diff --git a/pinot-integration-test-base/src/test/java/org/apache/pinot/integration/tests/QueryGenerator.java b/pinot-integration-test-base/src/test/java/org/apache/pinot/integration/tests/QueryGenerator.java index baf02481ada7..37cbdc8ad9cb 100644 --- a/pinot-integration-test-base/src/test/java/org/apache/pinot/integration/tests/QueryGenerator.java +++ b/pinot-integration-test-base/src/test/java/org/apache/pinot/integration/tests/QueryGenerator.java @@ -1005,16 +1005,6 @@ public QueryFragment generatePredicate(String columnName, boolean useMultistageE List columnValues = _columnToValueList.get(columnName); String leftValue = pickRandom(columnValues); String rightValue = pickRandom(columnValues); - - if (_singleValueNumericalColumnNames.contains(columnName)) { - // For numerical columns, make sure leftValue < rightValue. - if (Double.parseDouble(leftValue) > Double.parseDouble(rightValue)) { - String temp = leftValue; - leftValue = rightValue; - rightValue = temp; - } - } - return new StringQueryFragment( String.format("%s BETWEEN %s AND %s", columnName, leftValue, rightValue), String.format("`%s` BETWEEN %s AND %s", columnName, leftValue, rightValue));