Skip to content

Commit

Permalink
Improved return type detection
Browse files Browse the repository at this point in the history
* When both arguments are integer types, use the largest one for the return type
* Added some more tests
  • Loading branch information
normanj-bitquill committed Sep 10, 2024
1 parent 292f997 commit eaf996b
Show file tree
Hide file tree
Showing 3 changed files with 65 additions and 12 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -1201,23 +1201,23 @@ public class SqlStdOperatorTable extends ReflectiveSqlOperatorTable {
*/
public static final SqlFunction BITAND =
SqlBasicFunction.create("BITAND", SqlKind.BITAND,
ReturnTypes.FIRST_NOT_NULL_DEFAULT_INTEGER,
ReturnTypes.LARGEST_INT_OR_FIRST_NON_NULL_DEFAULT_INTEGER,
OperandTypes.INTEGER_INTEGER.or(OperandTypes.BINARY_BINARY));

/**
* <code>BITOR</code> scalar function.
*/
public static final SqlFunction BITOR =
SqlBasicFunction.create("BITOR", SqlKind.BITOR,
ReturnTypes.FIRST_NOT_NULL_DEFAULT_INTEGER,
ReturnTypes.LARGEST_INT_OR_FIRST_NON_NULL_DEFAULT_INTEGER,
OperandTypes.INTEGER_INTEGER.or(OperandTypes.BINARY_BINARY));

/**
* <code>BITXOR</code> scalar function.
*/
public static final SqlFunction BITXOR =
SqlBasicFunction.create("BITXOR", SqlKind.BITXOR,
ReturnTypes.FIRST_NOT_NULL_DEFAULT_INTEGER,
ReturnTypes.LARGEST_INT_OR_FIRST_NON_NULL_DEFAULT_INTEGER,
OperandTypes.INTEGER_INTEGER.or(OperandTypes.BINARY_BINARY));

/**
Expand Down
56 changes: 47 additions & 9 deletions core/src/main/java/org/apache/calcite/sql/type/ReturnTypes.java
Original file line number Diff line number Diff line change
Expand Up @@ -617,16 +617,54 @@ public static SqlCall stripSeparator(SqlCall call) {
* Returns the type of the first non NULL argument or INTEGER if all arguments
* are NULL.
*/
public static final SqlReturnTypeInference FIRST_NOT_NULL_DEFAULT_INTEGER = opBinding -> {
final RelDataTypeFactory typeFactory = opBinding.getTypeFactory();
for (RelDataType opType : opBinding.collectOperandTypes()) {
if (!opType.isNullable()) {
return typeFactory.createTypeWithNullability(opType, true);
}
public static final SqlReturnTypeInference LARGEST_INT_OR_FIRST_NON_NULL_DEFAULT_INTEGER =
opBinding -> {
final RelDataTypeFactory typeFactory = opBinding.getTypeFactory();
RelDataType integerReturnType = null;
boolean allArgsInteger = true;
for (RelDataType opType : opBinding.collectOperandTypes()) {
if (SqlTypeName.INT_TYPES.contains(opType.getSqlTypeName())) {
if (integerReturnType == null) {
integerReturnType = opType;
} else if (isIntegerTypeLarger(integerReturnType, opType)) {
integerReturnType = opType;
}
} else {
allArgsInteger = false;
break;
}
}
if (integerReturnType != null && allArgsInteger) {
return typeFactory.createTypeWithNullability(integerReturnType, true);
}
for (RelDataType opType : opBinding.collectOperandTypes()) {
if (!opType.isNullable()) {
return typeFactory.createTypeWithNullability(opType, true);
}
}
return typeFactory.createTypeWithNullability(
typeFactory.createSqlType(SqlTypeName.INTEGER), true);
};

/**
* Returns true if type2 is of a larger integer type than type1.
*
* @param type1 first type
* @param type2 second type
* @return true if type2 is of a larger integer type than type1
*/
private static boolean isIntegerTypeLarger(RelDataType type1, RelDataType type2) {
if (SqlTypeName.TINYINT == type1.getSqlTypeName()) {
return SqlTypeName.TINYINT != type2.getSqlTypeName();
} else if (SqlTypeName.SMALLINT == type1.getSqlTypeName()) {
return SqlTypeName.TINYINT != type2.getSqlTypeName()
&& SqlTypeName.SMALLINT != type2.getSqlTypeName();
} else if (SqlTypeName.INTEGER == type1.getSqlTypeName()) {
return SqlTypeName.BIGINT == type2.getSqlTypeName();
} else {
return false;
}
return typeFactory.createTypeWithNullability(
typeFactory.createSqlType(SqlTypeName.INTEGER), true);
};
}

/**
* Returns the same type as the multiset carries. The multiset type returned
Expand Down
15 changes: 15 additions & 0 deletions testkit/src/main/java/org/apache/calcite/test/SqlOperatorTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -15483,6 +15483,11 @@ private static void checkLogicalOrFunc(SqlOperatorFixture f) {
f.setFor(SqlStdOperatorTable.BITAND, VmName.EXPAND);
f.checkFails("bitand(^*^)", "Unknown identifier '\\*'", false);
f.checkScalar("bitand(2, 3)", "2", "INTEGER");
f.checkScalar("bitand(CAST(2 AS INTEGER), CAST(3 AS BIGINT))", "2", "BIGINT");
f.checkScalar("bitand(-5, 7)", "3", "INTEGER");
f.checkScalar("bitand(-5, -31)", "-31", "INTEGER");
f.checkScalar("bitand(CAST(-5 AS TINYINT), CAST(7 AS TINYINT))", "3", "TINYINT");
f.checkScalar("bitand(CAST(-5 AS TINYINT), CAST(-31 AS TINYINT))", "-31", "TINYINT");
f.checkType("bitand(CAST(2 AS TINYINT), CAST(6 AS TINYINT))", "TINYINT");
f.checkType("bitand(CAST(2 AS SMALLINT), CAST(6 AS SMALLINT))", "SMALLINT");
f.checkType("bitand(CAST(2 AS BIGINT), CAST(6 AS BIGINT))", "BIGINT");
Expand Down Expand Up @@ -15518,6 +15523,11 @@ private static void checkLogicalOrFunc(SqlOperatorFixture f) {
f.setFor(SqlStdOperatorTable.BITOR, VmName.EXPAND);
f.checkFails("bitor(^*^)", "Unknown identifier '\\*'", false);
f.checkScalar("bitor(2, 4)", "6", "INTEGER");
f.checkScalar("bitor(CAST(2 AS INTEGER), CAST(4 AS BIGINT))", "6", "BIGINT");
f.checkScalar("bitor(-5, 7)", "-1", "INTEGER");
f.checkScalar("bitor(-5, -31)", "-5", "INTEGER");
f.checkScalar("bitor(CAST(-5 AS TINYINT), CAST(7 AS TINYINT))", "-1", "TINYINT");
f.checkScalar("bitor(CAST(-5 AS TINYINT), CAST(-31 AS TINYINT))", "-5", "TINYINT");
f.checkType("bitor(CAST(2 AS TINYINT), CAST(6 AS TINYINT))", "TINYINT");
f.checkType("bitor(CAST(2 AS SMALLINT), CAST(6 AS SMALLINT))", "SMALLINT");
f.checkType("bitor(CAST(2 AS BIGINT), CAST(6 AS BIGINT))", "BIGINT");
Expand Down Expand Up @@ -15553,6 +15563,11 @@ private static void checkLogicalOrFunc(SqlOperatorFixture f) {
f.setFor(SqlStdOperatorTable.BITXOR, VmName.EXPAND);
f.checkFails("bitxor(^*^)", "Unknown identifier '\\*'", false);
f.checkScalar("bitxor(2, 3)", "1", "INTEGER");
f.checkScalar("bitxor(CAST(2 AS INTEGER), CAST(3 AS BIGINT))", "1", "BIGINT");
f.checkScalar("bitxor(-5, 7)", "-4", "INTEGER");
f.checkScalar("bitxor(-5, -31)", "26", "INTEGER");
f.checkScalar("bitxor(CAST(-5 AS TINYINT), CAST(7 AS TINYINT))", "-4", "TINYINT");
f.checkScalar("bitxor(CAST(-5 AS TINYINT), CAST(-31 AS TINYINT))", "26", "TINYINT");
f.checkType("bitxor(CAST(2 AS TINYINT), CAST(6 AS TINYINT))", "TINYINT");
f.checkType("bitxor(CAST(2 AS SMALLINT), CAST(6 AS SMALLINT))", "SMALLINT");
f.checkType("bitxor(CAST(2 AS BIGINT), CAST(6 AS BIGINT))", "BIGINT");
Expand Down

0 comments on commit eaf996b

Please sign in to comment.