Skip to content

Commit

Permalink
[CALCITE-6555] RelBuilder.aggregateRex wrongly thinks aggregate funct…
Browse files Browse the repository at this point in the history
…ions of "GROUP BY ()" queries are NOT NULL

In RelBuilder, the aggregateRex method (added in CALCITE-5802)
wrongly thinks that aggregate functions in a `GROUP BY ()`
query are NOT NULL. Consider the query

  SELECT SUM(empno) AS s, COUNT(empno) AS c
  FROM emp
  GROUP BY ()

`SUM(empno)` should be nullable, even though `empno` has type
`SMALLINT NOT NULL`, because `GROUP BY ()` will return one row
even if `emp` has no rows, and therefore `SUM` will be
evaluated over the empty set. A RelBuilder test that attempts
to build an equivalent query gets the following error stack:

  java.lang.AssertionError: type mismatch:
  ref:
  SMALLINT NOT NULL
  input:
  SMALLINT

We add a test case for measure queries, because measures are
the only code path that uses `aggregateRex` at present.
  • Loading branch information
julianhyde committed Aug 31, 2024
1 parent 30304bb commit 8771e3f
Show file tree
Hide file tree
Showing 3 changed files with 119 additions and 30 deletions.
105 changes: 75 additions & 30 deletions core/src/main/java/org/apache/calcite/tools/RelBuilder.java
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,7 @@
import java.util.Set;
import java.util.SortedSet;
import java.util.TreeSet;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.function.BiFunction;
import java.util.function.Consumer;
import java.util.function.Function;
Expand Down Expand Up @@ -2632,14 +2633,22 @@ public RelBuilder aggregateRex(GroupKey groupKey, boolean projectKey,
Iterable<? extends RexNode> nodes) {
final GroupKeyImpl groupKeyImpl = (GroupKeyImpl) groupKey;
final AggBuilder aggBuilder = new AggBuilder(groupKeyImpl.nodes);
for (RexNode node : nodes) {
aggBuilder.add(node);

// First pass. Call convert on each expression to ensure that aggCalls
// gets populated.
aggBuilder.registerExpressions(nodes);

// Create the Aggregate on the stack.
aggregate(groupKey, aggBuilder.aggCalls);

// Second pass. Call convert on each expression so that it references the
// actual aggCalls in the Aggregate that was just pushed onto the stack.
final List<RexNode> projects = new ArrayList<>();
if (projectKey) {
projects.addAll(fields(Util.range(groupKey.groupKeyCount())));
}
return aggregate(groupKey, aggBuilder.aggCalls)
.project(
Iterables.concat(
fields(Util.range(projectKey ? groupKey.groupKeyCount() : 0)),
aggBuilder.postProjects));
aggBuilder.convertExpressions(projects::add, nodes);
return project(projects);
}

/** Finishes the implementation of {@link #aggregate} by creating an
Expand Down Expand Up @@ -5040,46 +5049,82 @@ default boolean removeRedundantDistinct() {
/** Working state for {@link #aggregateRex}. */
private class AggBuilder {
final ImmutableList<RexNode> groupKeys;
final List<RexNode> postProjects = new ArrayList<>();
final List<AggCall> aggCalls = new ArrayList<>();

private AggBuilder(ImmutableList<RexNode> groupKeys) {
this.groupKeys = groupKeys;
}

/** Adds a node that may or may not contain an aggregate function. */
void add(RexNode node) {
postProjects.add(convert(node));
}

/** Adds a node that we know to contain an aggregate function, and returns
* an expression whose input row type is the output row type of the
* aggregate layer ({@link #groupKeys} and {@link #aggCalls}). */
private RexNode convert(RexNode node) {
final RexBuilder rexBuilder = cluster.getRexBuilder();
if (node instanceof RexCall) {
final RexCall call = (RexCall) node;
if (call.getOperator().isAggregator()) {
final AggCall aggCall =
aggregateCall((SqlAggFunction) call.op, call.operands);
final int i = groupKeys.size() + aggCalls.size();
aggCalls.add(aggCall);
return rexBuilder.makeInputRef(call.getType(), i);
private RexNode convert(RegisterAgg registrar, RexNode node,
@Nullable String name) {
switch (node.getKind()) {
case AS:
final ImmutableList<RexNode> asOperands = ((RexCall) node).operands;
final String name2;
if (name != null) {
name2 = name;
} else {
final List<RexNode> operands = new ArrayList<>();
call.operands.forEach(operand ->
operands.add(convert(operand)));
return call.clone(call.type, operands);
final RexLiteral literal = (RexLiteral) asOperands.get(1);
name2 = requireNonNull(literal.getValueAs(String.class));
}
} else if (node instanceof RexInputRef) {
final RexNode node2 = convert(registrar, asOperands.get(0), name2);
return alias(node2, name2);

case INPUT_REF:
final int j = groupKeys.indexOf(node);
if (j < 0) {
throw new IllegalArgumentException("not a group key: " + node);
}
return rexBuilder.makeInputRef(node.getType(), j);
} else {
return field(j);

default:
if (node instanceof RexCall) {
final RexCall call = (RexCall) node;
if (call.getOperator().isAggregator()) {
// return a reference to the i'th agg call
return registrar.registerAgg((SqlAggFunction) call.op,
call.operands, call.type, name);
} else {
return call.clone(call.type,
Util.transform(call.operands, operand ->
convert(registrar, operand, null)));
}
}
return node;
}
}

void registerExpressions(Iterable<? extends RexNode> nodes) {
for (RexNode node : nodes) {
convert(this::registerAgg, node, null);
}
}

RexInputRef registerAgg(SqlAggFunction op, List<RexNode> operands,
RelDataType type, @Nullable String name) {
final int i = groupKeys.size() + aggCalls.size();
aggCalls.add(aggregateCall(op, operands).as(name));
return getRexBuilder().makeInputRef(type, i);
}

void convertExpressions(Consumer<RexNode> projects,
Iterable<? extends RexNode> nodes) {
final AtomicInteger j = new AtomicInteger(groupKeys.size());
for (RexNode node : nodes) {
projects.accept(
convert((op, operands, type, name) -> field(j.getAndIncrement()),
node, null));
}
}
}

/** Callback to handle creation of an aggregate call in
* {@link AggBuilder#convert}. */
private interface RegisterAgg {
RexInputRef registerAgg(SqlAggFunction op, List<RexNode> operands,
RelDataType type, @Nullable String name);
}
}
23 changes: 23 additions & 0 deletions core/src/test/java/org/apache/calcite/test/RelBuilderTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -3398,6 +3398,29 @@ private static RelBuilder assertSize(RelBuilder b,
assertThat(r3.getRowType().getFullTypeString(), is(expectedRowType));
}

/** Tests {@link RelBuilder#aggregateRex} with an aggregate call that needs to
* become nullable because of "GROUP BY ()". */
@Test void testAggregateRex4() {
// SELECT SUM(sal) AS s, COUNT(sal) AS c
// FROM emp
// GROUP BY ()
Function<RelBuilder, RelNode> f = b ->
b.scan("EMP")
.aggregateRex(b.groupKey(),
b.alias(b.call(SqlStdOperatorTable.SUM, b.field("EMPNO")), "s"),
b.alias(b.call(SqlStdOperatorTable.COUNT, b.field("SAL")), "c"))
.build();
final String expected =
"LogicalAggregate(group=[{}], s=[SUM($0)], c=[COUNT($5)])\n"
+ " LogicalTableScan(table=[[scott, EMP]])\n";
// s is nullable because "GROUP BY ()" may have a group that contains 0 rows
final String expectedRowType =
"RecordType(SMALLINT s, BIGINT NOT NULL c) NOT NULL";
final RelNode r = f.apply(createBuilder());
assertThat(r, hasTree(expected));
assertThat(r.getRowType().getFullTypeString(), is(expectedRowType));
}

/** Tests that a projection retains field names after a join. */
@Test void testProjectJoin() {
final RelBuilder builder = RelBuilder.create(config().build());
Expand Down
21 changes: 21 additions & 0 deletions core/src/test/resources/sql/measure.iq
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,27 @@ group by job;

!ok

# Measure on primary key gives type error (casting away NOT NULL); cause was
# [CALCITE-6555] RelBuilder.aggregateRex thinks aggregate functions of
# "GROUP BY ()" queries are NOT NULL
with empm as (
select *, min(empno) as measure avg_sal
from emp
)
select deptno, avg_sal as a
from empm
group by deptno;
+--------+------+
| DEPTNO | A |
+--------+------+
| 10 | 7782 |
| 20 | 7369 |
| 30 | 7499 |
+--------+------+
(3 rows)

!ok

# Equivalent using AGGREGATE
select job, aggregate(avg_sal) as a
from empm
Expand Down

0 comments on commit 8771e3f

Please sign in to comment.