Skip to content

Commit

Permalink
[CALCITE-6617] TypeCoercion is not applied correctly to comparisons
Browse files Browse the repository at this point in the history
Additional tests in cast.iq (Julian Hyde)

Close #3998

Signed-off-by: Mihai Budiu <mbudiu@feldera.com>
  • Loading branch information
mihaibudiu authored and julianhyde committed Oct 10, 2024
1 parent 78e873d commit 243c3ad
Show file tree
Hide file tree
Showing 16 changed files with 362 additions and 141 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -374,7 +374,7 @@ static void initializeArrowState(@TempDir Path sharedTempDir)
String sql = "select * from arrowdata\n"
+ " where \"floatField\"=15.0";
String plan = "PLAN=ArrowToEnumerableConverter\n"
+ " ArrowFilter(condition=[=(CAST($2):DOUBLE, 15.0E0)])\n"
+ " ArrowFilter(condition=[=($2, 15.0E0)])\n"
+ " ArrowTableScan(table=[[ARROW, ARROWDATA]], fields=[[0, 1, 2, 3]])\n\n";
String result = "intField=15; stringField=15; floatField=15.0; longField=15\n";

Expand Down Expand Up @@ -666,7 +666,7 @@ static void initializeArrowState(@TempDir Path sharedTempDir)
@Test void testFilteredAgg() {
String sql = "select SUM(SAL) FILTER (WHERE COMM > 400) as SALESSUM from EMP";
String plan = "PLAN=EnumerableAggregate(group=[{}], SALESSUM=[SUM($0) FILTER $1])\n"
+ " EnumerableCalc(expr#0..7=[{inputs}], expr#8=[400], expr#9=[>($t6, $t8)], "
+ " EnumerableCalc(expr#0..7=[{inputs}], expr#8=[400:DECIMAL(19, 0)], expr#9=[>($t6, $t8)], "
+ "expr#10=[IS TRUE($t9)], SAL=[$t5], $f1=[$t10])\n"
+ " ArrowToEnumerableConverter\n"
+ " ArrowTableScan(table=[[ARROW, EMP]], fields=[[0, 1, 2, 3, 4, 5, 6, 7]])\n\n";
Expand All @@ -684,7 +684,7 @@ static void initializeArrowState(@TempDir Path sharedTempDir)
String sql = "select SUM(SAL) FILTER (WHERE COMM > 400) as SALESSUM from EMP group by EMPNO";
String plan = "PLAN=EnumerableCalc(expr#0..1=[{inputs}], SALESSUM=[$t1])\n"
+ " EnumerableAggregate(group=[{0}], SALESSUM=[SUM($1) FILTER $2])\n"
+ " EnumerableCalc(expr#0..7=[{inputs}], expr#8=[400], expr#9=[>($t6, $t8)], "
+ " EnumerableCalc(expr#0..7=[{inputs}], expr#8=[400:DECIMAL(19, 0)], expr#9=[>($t6, $t8)], "
+ "expr#10=[IS TRUE($t9)], EMPNO=[$t0], SAL=[$t5], $f2=[$t10])\n"
+ " ArrowToEnumerableConverter\n"
+ " ArrowTableScan(table=[[ARROW, EMP]], fields=[[0, 1, 2, 3, 4, 5, 6, 7]])\n\n";
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -492,8 +492,8 @@ private RelDataType getTightestCommonTypeOrThrow(
}

/**
* Determines common type for a comparison operator when one operand is String type and the
* other is not. For date + timestamp operands, use timestamp as common type,
* Determines common type for a comparison operator.
* For date and timestamp operands, use timestamp as common type,
* i.e. Timestamp(2017-01-01 00:00 ...) &gt; Date(2018) evaluates to be false.
*/
@Override public @Nullable RelDataType commonTypeForBinaryComparison(
Expand All @@ -509,9 +509,11 @@ private RelDataType getTightestCommonTypeOrThrow(
return null;
}

// DATETIME + CHARACTER -> DATETIME
// REVIEW Danny 2019-09-23: There is some legacy redundant code in SqlToRelConverter
// that coerce Datetime and CHARACTER comparison.
if (SqlTypeUtil.sameNamedType(type1, type2)) {
return factory.leastRestrictive(ImmutableList.of(type1, type2));
}

// DATETIME < CHARACTER -> DATETIME
if (SqlTypeUtil.isCharacter(type1) && SqlTypeUtil.isDatetime(type2)) {
return factory.createTypeWithNullability(type2, type1.isNullable());
}
Expand All @@ -520,7 +522,7 @@ private RelDataType getTightestCommonTypeOrThrow(
return factory.createTypeWithNullability(type1, type2.isNullable());
}

// DATE + TIMESTAMP -> TIMESTAMP
// DATE < TIMESTAMP -> TIMESTAMP
if (SqlTypeUtil.isDate(type1) && SqlTypeUtil.isTimestamp(type2)) {
return factory.createTypeWithNullability(type2, type1.isNullable());
}
Expand Down Expand Up @@ -556,6 +558,19 @@ private RelDataType getTightestCommonTypeOrThrow(
return null;
}

if (SqlTypeUtil.isString(type1) && SqlTypeUtil.isString(type2)) {
// Return the string with the larger precision
if (type1.getPrecision() == RelDataType.PRECISION_NOT_SPECIFIED) {
return factory.createTypeWithNullability(type1, type2.isNullable());
} else if (type2.getPrecision() == RelDataType.PRECISION_NOT_SPECIFIED) {
return factory.createTypeWithNullability(type2, type1.isNullable());
} else if (type1.getPrecision() > type2.getPrecision()) {
return factory.createTypeWithNullability(type1, type2.isNullable());
} else {
return factory.createTypeWithNullability(type2, type1.isNullable());
}
}

// 1 > '1' will be coerced to 1 > 1.
if (SqlTypeUtil.isAtomic(type1) && SqlTypeUtil.isCharacter(type2)) {
if (SqlTypeUtil.isTimestamp(type1)) {
Expand All @@ -581,6 +596,43 @@ private RelDataType getTightestCommonTypeOrThrow(
}
}

if (SqlTypeUtil.isApproximateNumeric(type1) && SqlTypeUtil.isApproximateNumeric(type2)) {
if (type1.getPrecision() > type2.getPrecision()) {
return factory.createTypeWithNullability(type1, type2.isNullable());
} else {
return factory.createTypeWithNullability(type2, type1.isNullable());
}
}

if (SqlTypeUtil.isApproximateNumeric(type1) && SqlTypeUtil.isExactNumeric(type2)) {
return factory.createTypeWithNullability(type1, type2.isNullable());
}

if (SqlTypeUtil.isApproximateNumeric(type2) && SqlTypeUtil.isExactNumeric(type1)) {
return factory.createTypeWithNullability(type2, type1.isNullable());
}

if (SqlTypeUtil.isExactNumeric(type1) && SqlTypeUtil.isExactNumeric(type2)) {
if (SqlTypeUtil.isDecimal(type1)) {
// Use max precision
RelDataType result =
factory.createSqlType(type1.getSqlTypeName(),
Math.max(type1.getPrecision(), type2.getPrecision()), type1.getScale());
return factory.createTypeWithNullability(result, type1.isNullable() || type2.isNullable());
} else if (SqlTypeUtil.isDecimal(type2)) {
// Use max precision
RelDataType result =
factory.createSqlType(type2.getSqlTypeName(),
Math.max(type1.getPrecision(), type2.getPrecision()), type2.getScale());
return factory.createTypeWithNullability(result, type1.isNullable() || type2.isNullable());
}
if (type1.getPrecision() > type2.getPrecision()) {
return factory.createTypeWithNullability(type1, type2.isNullable());
} else {
return factory.createTypeWithNullability(type2, type1.isNullable());
}
}

return null;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -93,8 +93,7 @@ public interface TypeCoercion {
@Nullable RelDataType type1, @Nullable RelDataType type2);

/**
* Determines common type for a comparison operator whose operands are STRING
* type and the other (non STRING) type.
* Determines common type for a comparison operator.
*/
@Nullable RelDataType commonTypeForBinaryComparison(
@Nullable RelDataType type1, @Nullable RelDataType type2);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,6 @@
import java.math.BigDecimal;
import java.util.AbstractList;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.stream.Collectors;

Expand Down Expand Up @@ -291,18 +290,9 @@ protected boolean binaryArithmeticWithStrings(
return null;
}

RelDataType commonType;
if (SqlTypeUtil.sameNamedType(type1, type2)) {
commonType = factory.leastRestrictive(Arrays.asList(type1, type2));
} else {
commonType = commonTypeForBinaryComparison(type1, type2);
}
RelDataType commonType = commonTypeForBinaryComparison(type1, type2);
for (int i = 2; i < dataTypes.size() && commonType != null; i++) {
if (SqlTypeUtil.sameNamedType(commonType, dataTypes.get(i))) {
commonType = factory.leastRestrictive(Arrays.asList(commonType, dataTypes.get(i)));
} else {
commonType = commonTypeForBinaryComparison(commonType, dataTypes.get(i));
}
commonType = commonTypeForBinaryComparison(commonType, dataTypes.get(i));
}
return commonType;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -319,33 +319,33 @@ private static String toSql(RelNode root, SqlDialect dialect,
+ "where \"product_id\" > 0\n"
+ "group by \"product_id\"";
final String expectedDefault = "SELECT"
+ " SUM(\"shelf_width\") FILTER (WHERE \"net_weight\" > 0 IS TRUE),"
+ " SUM(\"shelf_width\") FILTER (WHERE \"net_weight\" > 0E0 IS TRUE),"
+ " SUM(\"shelf_width\")\n"
+ "FROM \"foodmart\".\"product\"\n"
+ "WHERE \"product_id\" > 0\n"
+ "GROUP BY \"product_id\"";
final String expectedBigQuery = "SELECT"
+ " SUM(CASE WHEN net_weight > 0 IS TRUE"
+ " SUM(CASE WHEN net_weight > 0E0 IS TRUE"
+ " THEN shelf_width ELSE NULL END), "
+ "SUM(shelf_width)\n"
+ "FROM foodmart.product\n"
+ "WHERE product_id > 0\n"
+ "GROUP BY product_id";
final String expectedFirebolt = "SELECT"
+ " SUM(CASE WHEN \"net_weight\" > 0 IS TRUE"
+ " SUM(CASE WHEN \"net_weight\" > 0E0 IS TRUE"
+ " THEN \"shelf_width\" ELSE NULL END), "
+ "SUM(\"shelf_width\")\n"
+ "FROM \"foodmart\".\"product\"\n"
+ "WHERE \"product_id\" > 0\n"
+ "GROUP BY \"product_id\"";
final String expectedMysql = "SELECT"
+ " SUM(CASE WHEN `net_weight` > 0 IS TRUE"
+ " SUM(CASE WHEN `net_weight` > 0E0 IS TRUE"
+ " THEN `shelf_width` ELSE NULL END), SUM(`shelf_width`)\n"
+ "FROM `foodmart`.`product`\n"
+ "WHERE `product_id` > 0\n"
+ "GROUP BY `product_id`";
final String expectedStarRocks = "SELECT"
+ " SUM(CASE WHEN `net_weight` > 0 IS TRUE"
+ " SUM(CASE WHEN `net_weight` > 0E0 IS TRUE"
+ " THEN `shelf_width` ELSE NULL END), SUM(`shelf_width`)\n"
+ "FROM `foodmart`.`product`\n"
+ "WHERE `product_id` > 0\n"
Expand Down Expand Up @@ -539,7 +539,7 @@ private static String toSql(RelNode root, SqlDialect dialect,
final String expected = "SELECT *\n"
+ "FROM \"foodmart\".\"product\"\n"
+ "WHERE (\"product_id\" = 10 OR \"product_id\" <= 5) "
+ "AND (80 >= \"shelf_width\" OR \"shelf_width\" > 30)";
+ "AND (CAST(80 AS DOUBLE) >= \"shelf_width\" OR \"shelf_width\" > CAST(30 AS DOUBLE))";
sql(query).ok(expected);
}

Expand Down Expand Up @@ -2097,34 +2097,34 @@ private void checkHavingAliasSameAsColumn(boolean upperAlias) {
+ " sum(\"gross_weight\") as \"" + alias + "\"\n"
+ "from \"product\"\n"
+ "group by \"product_id\"\n"
+ "having sum(\"product\".\"gross_weight\") < 200";
+ "having sum(\"product\".\"gross_weight\") < 2.000E2";
// PostgreSQL has isHavingAlias=false, case-sensitive=true
final String expectedPostgresql = "SELECT \"product_id\" + 1,"
+ " SUM(\"gross_weight\") AS \"" + alias + "\"\n"
+ "FROM \"foodmart\".\"product\"\n"
+ "GROUP BY \"product_id\"\n"
+ "HAVING SUM(\"gross_weight\") < 200";
+ "HAVING SUM(\"gross_weight\") < 2.000E2";
// MySQL has isHavingAlias=true, case-sensitive=true
final String expectedMysql = "SELECT `product_id` + 1, `" + alias + "`\n"
+ "FROM (SELECT `product_id`, SUM(`gross_weight`) AS `" + alias + "`\n"
+ "FROM `foodmart`.`product`\n"
+ "GROUP BY `product_id`\n"
+ "HAVING `" + alias + "` < 200) AS `t1`";
+ "HAVING `" + alias + "` < 2.000E2) AS `t1`";
// BigQuery has isHavingAlias=true, case-sensitive=false
final String expectedBigQuery = upperAlias
? "SELECT product_id + 1, GROSS_WEIGHT\n"
+ "FROM (SELECT product_id, SUM(gross_weight) AS GROSS_WEIGHT\n"
+ "FROM foodmart.product\n"
+ "GROUP BY product_id\n"
+ "HAVING GROSS_WEIGHT < 200) AS t1"
+ "HAVING GROSS_WEIGHT < 2.000E2) AS t1"
// Before [CALCITE-3896] was fixed, we got
// "HAVING SUM(gross_weight) < 200) AS t1"
// which on BigQuery gives you an error about aggregating aggregates
: "SELECT product_id + 1, gross_weight\n"
+ "FROM (SELECT product_id, SUM(gross_weight) AS gross_weight\n"
+ "FROM foodmart.product\n"
+ "GROUP BY product_id\n"
+ "HAVING gross_weight < 200) AS t1";
+ "HAVING gross_weight < 2.000E2) AS t1";
sql(query)
.withBigQuery().ok(expectedBigQuery)
.withPostgresql().ok(expectedPostgresql)
Expand All @@ -2144,11 +2144,11 @@ private void checkHavingAliasSameAsColumn(boolean upperAlias) {
final String expected = "SELECT \"product_id\"\n"
+ "FROM (SELECT \"product_id\", AVG(\"gross_weight\") AS \"AGW\"\n"
+ "FROM \"foodmart\".\"product\"\n"
+ "WHERE \"net_weight\" < 100\n"
+ "WHERE \"net_weight\" < CAST(100 AS DOUBLE)\n"
+ "GROUP BY \"product_id\"\n"
+ "HAVING AVG(\"gross_weight\") > 50) AS \"t2\"\n"
+ "HAVING AVG(\"gross_weight\") > CAST(50 AS DOUBLE)) AS \"t2\"\n"
+ "GROUP BY \"product_id\"\n"
+ "HAVING AVG(\"AGW\") > 60";
+ "HAVING AVG(\"AGW\") > 6.00E1";
sql(query).ok(expected);
}

Expand Down Expand Up @@ -5366,21 +5366,21 @@ private void checkLiteral2(String expression, String expected) {
+ "UNION ALL\n"
+ "SELECT NULL) END AS `$f0`\n"
+ "FROM `foodmart`.`product`) AS `t0` ON TRUE\n"
+ "WHERE `product`.`net_weight` > `t0`.`$f0`";
+ "WHERE `product`.`net_weight` > CAST(`t0`.`$f0` AS DOUBLE)";
final String expectedPostgresql = "SELECT \"product\".\"product_class_id\" AS \"C\"\n"
+ "FROM \"foodmart\".\"product\"\n"
+ "LEFT JOIN (SELECT CASE COUNT(*) WHEN 0 THEN NULL WHEN 1 THEN MIN(\"product_class_id\") ELSE (SELECT CAST(NULL AS INTEGER)\n"
+ "UNION ALL\n"
+ "SELECT CAST(NULL AS INTEGER)) END AS \"$f0\"\n"
+ "FROM \"foodmart\".\"product\") AS \"t0\" ON TRUE\n"
+ "WHERE \"product\".\"net_weight\" > \"t0\".\"$f0\"";
+ "WHERE \"product\".\"net_weight\" > CAST(\"t0\".\"$f0\" AS DOUBLE PRECISION)";
final String expectedHsqldb = "SELECT product.product_class_id AS C\n"
+ "FROM foodmart.product\n"
+ "LEFT JOIN (SELECT CASE COUNT(*) WHEN 0 THEN NULL WHEN 1 THEN MIN(product_class_id) ELSE ((VALUES 0E0)\n"
+ "UNION ALL\n"
+ "(VALUES 0E0)) END AS $f0\n"
+ "FROM foodmart.product) AS t0 ON TRUE\n"
+ "WHERE product.net_weight > t0.$f0";
+ "WHERE product.net_weight > CAST(t0.$f0 AS DOUBLE)";
sql(query)
.withConfig(c -> c.withExpand(true))
.withMysql().ok(expectedMysql)
Expand Down Expand Up @@ -6905,7 +6905,7 @@ private void checkLiteral2(String expression, String expected) {
+ "within group (order by \"net_weight\" desc) filter (where \"net_weight\" > 0)"
+ "from \"product\" group by \"product_class_id\"";
final String expected = "SELECT \"product_class_id\", COLLECT(\"net_weight\") "
+ "FILTER (WHERE \"net_weight\" > 0 IS TRUE) "
+ "FILTER (WHERE \"net_weight\" > 0E0 IS TRUE) "
+ "WITHIN GROUP (ORDER BY \"net_weight\" DESC)\n"
+ "FROM \"foodmart\".\"product\"\n"
+ "GROUP BY \"product_class_id\"";
Expand Down Expand Up @@ -8241,7 +8241,7 @@ private void checkLiteral2(String expression, String expected) {
final String expected = "SELECT *\n"
+ "FROM TABLE(DEDUP(CURSOR ((SELECT \"product_id\", \"product_name\"\n"
+ "FROM \"foodmart\".\"product\"\n"
+ "WHERE \"net_weight\" > 100 AND \"product_name\" = 'Hello World')), "
+ "WHERE \"net_weight\" > CAST(100 AS DOUBLE) AND \"product_name\" = 'Hello World')), "
+ "CURSOR ((SELECT \"employee_id\", \"full_name\"\n"
+ "FROM \"foodmart\".\"employee\"\n"
+ "GROUP BY \"employee_id\", \"full_name\")), 'NAME'))";
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3112,6 +3112,15 @@ private void checkPushJoinThroughUnionOnRightDoesNotMatchSemiOrAntiJoin(JoinRelT
.check();
}

/** Test case for
* <a href="https://issues.apache.org/jira/browse/CALCITE-6617">[CALCITE-6617]
* TypeCoercion is not applied correctly to comparisons</a>. */
@Test void testRand() {
final String sql = "SELECT * FROM (SELECT 1, ROUND(RAND()) AS A)\n"
+ "WHERE A BETWEEN 1 AND 10 OR A IN (1, 2, 3, 4, 5, 6, 7, 8, 9, 10)";
sql(sql).withRule(CoreRules.PROJECT_REDUCE_EXPRESSIONS).check();
}

/** Test case for
* <a href="https://issues.apache.org/jira/browse/CALCITE-6481">[CALCITE-6481]
* Optimize 'VALUES...UNION...VALUES' to a single 'VALUES' the IN-list contains CAST
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ public class TCatalogReader extends MockCatalogReader {
t1.addColumn("t1_smallint", f.smallintType);
t1.addColumn("t1_int", f.intType);
t1.addColumn("t1_bigint", f.bigintType);
t1.addColumn("t1_float", f.floatType);
t1.addColumn("t1_real", f.realType);
t1.addColumn("t1_double", f.doubleType);
t1.addColumn("t1_decimal", f.decimalType);
t1.addColumn("t1_timestamp", f.timestampType);
Expand All @@ -64,7 +64,7 @@ public class TCatalogReader extends MockCatalogReader {
t2.addColumn("t2_smallint", f.smallintType);
t2.addColumn("t2_int", f.intType);
t2.addColumn("t2_bigint", f.bigintType);
t2.addColumn("t2_float", f.floatType);
t2.addColumn("t2_real", f.realType);
t2.addColumn("t2_double", f.doubleType);
t2.addColumn("t2_decimal", f.decimalType);
t2.addColumn("t2_timestamp", f.timestampType);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -161,14 +161,14 @@ public static void checkActualAndReferenceFiles() {
// char decimal float double
// char decimal smallint double
final String sql = "select t1_int, t1_decimal, t1_smallint, t1_double from t1 "
+ "union select t2_varchar20, t2_decimal, t2_float, t2_bigint from t2 "
+ "union select t1_varchar20, t1_decimal, t1_float, t1_double from t1 "
+ "union select t2_varchar20, t2_decimal, t2_real, t2_bigint from t2 "
+ "union select t1_varchar20, t1_decimal, t1_real, t1_double from t1 "
+ "union select t2_varchar20, t2_decimal, t2_smallint, t2_double from t2";
sql(sql).ok();
}

@Test void testInsertQuerySourceCoercion() {
final String sql = "insert into t1 select t2_smallint, t2_int, t2_bigint, t2_float,\n"
final String sql = "insert into t1 select t2_smallint, t2_int, t2_bigint, t2_real,\n"
+ "t2_double, t2_decimal, t2_int, t2_date, t2_timestamp, t2_varchar20, t2_int from t2";
sql(sql).ok();
}
Expand Down
Loading

0 comments on commit 243c3ad

Please sign in to comment.