From 5f06a9b37322e549fad938d31a6602b3e4c43fae Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Mon, 30 Dec 2024 09:28:56 -0500 Subject: [PATCH 1/5] Bump com.puppycrawl.tools:checkstyle from 10.21.0 to 10.21.1 (#14728) Bumps [com.puppycrawl.tools:checkstyle](https://github.com/checkstyle/checkstyle) from 10.21.0 to 10.21.1. - [Release notes](https://github.com/checkstyle/checkstyle/releases) - [Commits](https://github.com/checkstyle/checkstyle/compare/checkstyle-10.21.0...checkstyle-10.21.1) --- updated-dependencies: - dependency-name: com.puppycrawl.tools:checkstyle dependency-type: direct:production update-type: version-update:semver-patch ... Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- pom.xml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pom.xml b/pom.xml index 804d4c5ebb03..c7f8e6560d1f 100644 --- a/pom.xml +++ b/pom.xml @@ -2449,7 +2449,7 @@ com.puppycrawl.tools checkstyle - 10.21.0 + 10.21.1 From 020b14593397fc88140a6808be69fddb19b23c85 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Mon, 30 Dec 2024 09:29:29 -0500 Subject: [PATCH 2/5] Bump software.amazon.awssdk:bom from 2.29.41 to 2.29.43 (#14729) Bumps software.amazon.awssdk:bom from 2.29.41 to 2.29.43. --- updated-dependencies: - dependency-name: software.amazon.awssdk:bom dependency-type: direct:production update-type: version-update:semver-patch ... Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- pom.xml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pom.xml b/pom.xml index c7f8e6560d1f..39091574ec86 100644 --- a/pom.xml +++ b/pom.xml @@ -175,7 +175,7 @@ 0.15.0 0.4.7 4.2.2 - 2.29.41 + 2.29.43 1.2.30 1.18.0 2.13.0 From 351e4641cc21a053e54ab67cfce65f57e562ce6d Mon Sep 17 00:00:00 2001 From: "Xiaotian (Jackie) Jiang" <17555551+Jackie-Jiang@users.noreply.github.com> Date: Mon, 30 Dec 2024 18:29:11 -0800 Subject: [PATCH 3/5] Fix a bug for partition enabled instance assignment with minimize movement (#14726) --- ...InstanceReplicaGroupPartitionSelector.java | 3 +- .../instance/InstanceAssignmentTest.java | 279 ++++++++++++++---- .../SegmentsValidationAndRetentionConfig.java | 17 ++ .../spi/utils/builder/TableConfigBuilder.java | 1 + 4 files changed, 243 insertions(+), 57 deletions(-) diff --git a/pinot-controller/src/main/java/org/apache/pinot/controller/helix/core/assignment/instance/InstanceReplicaGroupPartitionSelector.java b/pinot-controller/src/main/java/org/apache/pinot/controller/helix/core/assignment/instance/InstanceReplicaGroupPartitionSelector.java index 8da6dbe2f62e..b8c19ede69eb 100644 --- a/pinot-controller/src/main/java/org/apache/pinot/controller/helix/core/assignment/instance/InstanceReplicaGroupPartitionSelector.java +++ b/pinot-controller/src/main/java/org/apache/pinot/controller/helix/core/assignment/instance/InstanceReplicaGroupPartitionSelector.java @@ -411,7 +411,8 @@ private void replicaGroupBasedMinimumMovement(Map> for (int replicaGroupId = 0; replicaGroupId < numReplicaGroups; replicaGroupId++) { List instancesInReplicaGroup = replicaGroupIdToInstancesMap.get(replicaGroupId); if (replicaGroupId < existingNumReplicaGroups) { - int maxNumPartitionsPerInstance = (numInstancesPerReplicaGroup + numPartitions - 1) / numPartitions; + int maxNumPartitionsPerInstance = + (numPartitions + numInstancesPerReplicaGroup - 1) / numInstancesPerReplicaGroup; Map instanceToNumPartitionsMap = Maps.newHashMapWithExpectedSize(numInstancesPerReplicaGroup); for (String instance : instancesInReplicaGroup) { diff --git a/pinot-controller/src/test/java/org/apache/pinot/controller/helix/core/assignment/instance/InstanceAssignmentTest.java b/pinot-controller/src/test/java/org/apache/pinot/controller/helix/core/assignment/instance/InstanceAssignmentTest.java index 113d4e164965..39aef7f35ad8 100644 --- a/pinot-controller/src/test/java/org/apache/pinot/controller/helix/core/assignment/instance/InstanceAssignmentTest.java +++ b/pinot-controller/src/test/java/org/apache/pinot/controller/helix/core/assignment/instance/InstanceAssignmentTest.java @@ -26,6 +26,7 @@ import java.util.HashSet; import java.util.LinkedList; import java.util.List; +import java.util.Map; import java.util.Random; import java.util.Set; import org.apache.helix.model.InstanceConfig; @@ -115,15 +116,15 @@ public void testDefaultOfflineReplicaGroup() { // Instance of index 7 is not assigned because of the hash-based rotation // Math.abs("myTable_OFFLINE".hashCode()) % 10 = 8 // [i8, i9, i0, i1, i2, i3, i4, i5, i6, i7] - // r0, r1, r2, r0, r1, r2, r0, r1, r2 + // r0 r1 r2 r0 r1 r2 r0 r1 r2 // r0: [i8, i1, i4] - // p0, p0, p1 + // p0 p0 p1 // p1 // r1: [i9, i2, i5] - // p0, p0, p1 + // p0 p0 p1 // p1 // r2: [i0, i3, i6] - // p0, p0, p1 + // p0 p0 p1 // p1 assertEquals(instancePartitions.getInstances(0, 0), Arrays.asList(SERVER_INSTANCE_ID_PREFIX + 8, SERVER_INSTANCE_ID_PREFIX + 1)); @@ -137,31 +138,52 @@ public void testDefaultOfflineReplicaGroup() { Arrays.asList(SERVER_INSTANCE_ID_PREFIX + 0, SERVER_INSTANCE_ID_PREFIX + 3)); assertEquals(instancePartitions.getInstances(1, 2), Arrays.asList(SERVER_INSTANCE_ID_PREFIX + 6, SERVER_INSTANCE_ID_PREFIX + 0)); + } - // ===== Test against the cases when the existing instancePartitions isn't null, - // and minimizeDataMovement is set to true. ===== - // Put the existing instancePartitions as the parameter to the InstanceAssignmentDriver. - // The returned instance partition should be the same as the last computed one. - tableConfig.getValidationConfig().setMinimizeDataMovement(true); + @Test + public void testMinimizeDataMovement() { + int numReplicas = 3; + int numPartitions = 2; + int numInstancesPerPartition = 2; + String partitionColumn = "partition"; + InstanceAssignmentConfig instanceAssignmentConfig = new InstanceAssignmentConfig( + new InstanceTagPoolConfig(TagNameUtils.getOfflineTagForTenant(TENANT_NAME), false, 0, null), null, + new InstanceReplicaGroupPartitionConfig(true, 0, numReplicas, 0, numPartitions, numInstancesPerPartition, true, + partitionColumn), null, true); + TableConfig tableConfig = new TableConfigBuilder(TableType.OFFLINE).setTableName(RAW_TABLE_NAME) + .setNumReplicas(numReplicas) + .setInstanceAssignmentConfigMap(Map.of("OFFLINE", instanceAssignmentConfig)) + .build(); + + int numInstances = 10; + List instanceConfigs = new ArrayList<>(numInstances); + for (int i = 0; i < numInstances; i++) { + InstanceConfig instanceConfig = new InstanceConfig(SERVER_INSTANCE_ID_PREFIX + i); + instanceConfig.addTag(OFFLINE_TAG); + instanceConfigs.add(instanceConfig); + } + // Start without existing InstancePartitions: // Instances should be assigned to 3 replica-groups with a round-robin fashion, each with 3 instances, then these 3 // instances should be assigned to 2 partitions, each with 2 instances - instancePartitions = driver.assignInstances(InstancePartitionsType.OFFLINE, instanceConfigs, instancePartitions); + InstanceAssignmentDriver driver = new InstanceAssignmentDriver(tableConfig); + InstancePartitions instancePartitions = + driver.assignInstances(InstancePartitionsType.OFFLINE, instanceConfigs, null); assertEquals(instancePartitions.getNumReplicaGroups(), numReplicas); assertEquals(instancePartitions.getNumPartitions(), numPartitions); // Instance of index 7 is not assigned because of the hash-based rotation // Math.abs("myTable_OFFLINE".hashCode()) % 10 = 8 // [i8, i9, i0, i1, i2, i3, i4, i5, i6, i7] - // r0, r1, r2, r0, r1, r2, r0, r1, r2 + // r0 r1 r2 r0 r1 r2 r0 r1 r2 // r0: [i8, i1, i4] - // p0, p0, p1 + // p0 p0 p1 // p1 // r1: [i9, i2, i5] - // p0, p0, p1 + // p0 p0 p1 // p1 // r2: [i0, i3, i6] - // p0, p0, p1 + // p0 p0 p1 // p1 assertEquals(instancePartitions.getInstances(0, 0), Arrays.asList(SERVER_INSTANCE_ID_PREFIX + 8, SERVER_INSTANCE_ID_PREFIX + 1)); @@ -196,15 +218,15 @@ public void testDefaultOfflineReplicaGroup() { // Instance of index 7 is not assigned because of the hash-based rotation // Math.abs("myTable_OFFLINE".hashCode()) % 10 = 8 // [i8, i9, i0, i1, i10, i3, i4, i5, i11, i7] - // r0, r1, r2, r0, r1, r2, r0, r1, r2 + // r0 r1 r2 r0 r1 r2 r0 r1 r2 // r0: [i8, i1, i4] - // p0, p0, p1 + // p0 p0 p1 // p1 // r1: [i9, i5, i10] - // p0, p1, p0 + // p0 p1 p0 // p1 // r2: [i0, i3, i11] - // p0, p0, p1 + // p0 p0 p1 // p1 assertEquals(instancePartitions.getInstances(0, 0), Arrays.asList(SERVER_INSTANCE_ID_PREFIX + 8, SERVER_INSTANCE_ID_PREFIX + 1)); @@ -226,24 +248,28 @@ public void testDefaultOfflineReplicaGroup() { instanceConfigs.add(instanceConfig); } numInstancesPerPartition = 3; - tableConfig.getValidationConfig() - .setReplicaGroupStrategyConfig(new ReplicaGroupStrategyConfig(partitionColumnName, numInstancesPerPartition)); + instanceAssignmentConfig = new InstanceAssignmentConfig( + new InstanceTagPoolConfig(TagNameUtils.getOfflineTagForTenant(TENANT_NAME), false, 0, null), null, + new InstanceReplicaGroupPartitionConfig(true, 0, numReplicas, 0, numPartitions, numInstancesPerPartition, true, + partitionColumn), null, true); + tableConfig.setInstanceAssignmentConfigMap(Map.of("OFFLINE", instanceAssignmentConfig)); instancePartitions = driver.assignInstances(InstancePartitionsType.OFFLINE, instanceConfigs, instancePartitions); assertEquals(instancePartitions.getNumReplicaGroups(), numReplicas); assertEquals(instancePartitions.getNumPartitions(), numPartitions); // Math.abs("myTable_OFFLINE".hashCode()) % 12 = 2 - // [i10, i11, i12, i13, i3, i4, i5, i11, i7, i8, i9, i0, i1] + // [i10, i11, i12, i13, i3, i4, i5, i7, i8, i9, i0, i1] + // r1 r2 r0 r1 r2 r0 r1 r2 r0 r1 r2 r0 // r0: [i8, i1, i4, i12] - // p0, p0, p1, p0 - // p1, p1 + // p0 p0 p1 p0 + // p1 p1 // r1: [i9, i5, i10, i13] - // p0, p1, p0, p0 - // p1, p1 + // p0 p1 p0 p0 + // p1 p1 // r2: [i0, i3, i11, i7] - // p0, p0, p1, p0 - // p1, p1 + // p0 p0 p1 p0 + // p1 p1 assertEquals(instancePartitions.getInstances(0, 0), Arrays.asList(SERVER_INSTANCE_ID_PREFIX + 8, SERVER_INSTANCE_ID_PREFIX + 1, SERVER_INSTANCE_ID_PREFIX + 12)); assertEquals(instancePartitions.getInstances(1, 0), @@ -251,86 +277,227 @@ public void testDefaultOfflineReplicaGroup() { assertEquals(instancePartitions.getInstances(0, 1), Arrays.asList(SERVER_INSTANCE_ID_PREFIX + 9, SERVER_INSTANCE_ID_PREFIX + 10, SERVER_INSTANCE_ID_PREFIX + 13)); assertEquals(instancePartitions.getInstances(1, 1), - Arrays.asList(SERVER_INSTANCE_ID_PREFIX + 5, SERVER_INSTANCE_ID_PREFIX + 9, SERVER_INSTANCE_ID_PREFIX + 10)); + Arrays.asList(SERVER_INSTANCE_ID_PREFIX + 5, SERVER_INSTANCE_ID_PREFIX + 10, SERVER_INSTANCE_ID_PREFIX + 9)); assertEquals(instancePartitions.getInstances(0, 2), Arrays.asList(SERVER_INSTANCE_ID_PREFIX + 0, SERVER_INSTANCE_ID_PREFIX + 3, SERVER_INSTANCE_ID_PREFIX + 7)); assertEquals(instancePartitions.getInstances(1, 2), - Arrays.asList(SERVER_INSTANCE_ID_PREFIX + 11, SERVER_INSTANCE_ID_PREFIX + 0, SERVER_INSTANCE_ID_PREFIX + 3)); + Arrays.asList(SERVER_INSTANCE_ID_PREFIX + 11, SERVER_INSTANCE_ID_PREFIX + 3, SERVER_INSTANCE_ID_PREFIX + 0)); // Reduce the number of instances per partition from 3 to 2. numInstancesPerPartition = 2; - tableConfig.getValidationConfig() - .setReplicaGroupStrategyConfig(new ReplicaGroupStrategyConfig(partitionColumnName, numInstancesPerPartition)); + instanceAssignmentConfig = new InstanceAssignmentConfig( + new InstanceTagPoolConfig(TagNameUtils.getOfflineTagForTenant(TENANT_NAME), false, 0, null), null, + new InstanceReplicaGroupPartitionConfig(true, 0, numReplicas, 0, numPartitions, numInstancesPerPartition, true, + partitionColumn), null, true); + tableConfig.setInstanceAssignmentConfigMap(Map.of("OFFLINE", instanceAssignmentConfig)); instancePartitions = driver.assignInstances(InstancePartitionsType.OFFLINE, instanceConfigs, instancePartitions); assertEquals(instancePartitions.getNumReplicaGroups(), numReplicas); assertEquals(instancePartitions.getNumPartitions(), numPartitions); - // The instance assignment should be the same as the one without the newly added instances. + // r0: [i8, i1, i4, i12] + // p0 p0 p1 p1 + // r1: [i9, i5, i10, i13] + // p0 p1 p0 p1 + // r2: [i0, i3, i11, i7] + // p0 p0 p1 p1 assertEquals(instancePartitions.getInstances(0, 0), Arrays.asList(SERVER_INSTANCE_ID_PREFIX + 8, SERVER_INSTANCE_ID_PREFIX + 1)); assertEquals(instancePartitions.getInstances(1, 0), - Arrays.asList(SERVER_INSTANCE_ID_PREFIX + 4, SERVER_INSTANCE_ID_PREFIX + 8)); + Arrays.asList(SERVER_INSTANCE_ID_PREFIX + 4, SERVER_INSTANCE_ID_PREFIX + 12)); assertEquals(instancePartitions.getInstances(0, 1), Arrays.asList(SERVER_INSTANCE_ID_PREFIX + 9, SERVER_INSTANCE_ID_PREFIX + 10)); assertEquals(instancePartitions.getInstances(1, 1), - Arrays.asList(SERVER_INSTANCE_ID_PREFIX + 5, SERVER_INSTANCE_ID_PREFIX + 9)); + Arrays.asList(SERVER_INSTANCE_ID_PREFIX + 5, SERVER_INSTANCE_ID_PREFIX + 13)); assertEquals(instancePartitions.getInstances(0, 2), Arrays.asList(SERVER_INSTANCE_ID_PREFIX + 0, SERVER_INSTANCE_ID_PREFIX + 3)); assertEquals(instancePartitions.getInstances(1, 2), - Arrays.asList(SERVER_INSTANCE_ID_PREFIX + 11, SERVER_INSTANCE_ID_PREFIX + 0)); + Arrays.asList(SERVER_INSTANCE_ID_PREFIX + 11, SERVER_INSTANCE_ID_PREFIX + 7)); // Add one more replica group (from 3 to 4). numReplicas = 4; tableConfig.getValidationConfig().setReplication(Integer.toString(numReplicas)); + instanceAssignmentConfig = new InstanceAssignmentConfig( + new InstanceTagPoolConfig(TagNameUtils.getOfflineTagForTenant(TENANT_NAME), false, 0, null), null, + new InstanceReplicaGroupPartitionConfig(true, 0, numReplicas, 0, numPartitions, numInstancesPerPartition, true, + partitionColumn), null, true); + tableConfig.setInstanceAssignmentConfigMap(Map.of("OFFLINE", instanceAssignmentConfig)); instancePartitions = driver.assignInstances(InstancePartitionsType.OFFLINE, instanceConfigs, instancePartitions); assertEquals(instancePartitions.getNumReplicaGroups(), numReplicas); assertEquals(instancePartitions.getNumPartitions(), numPartitions); // Math.abs("myTable_OFFLINE".hashCode()) % 12 = 2 - // [i10, i11, i12, i13, i3, i4, i5, i11, i7, i8, i9, i0, i1] - // The existing replica groups remain unchanged. - // For the new replica group r3, the candidate instances become [i12, i13, i7]. - // r3: [i12, i13, i7] - // p0, p0, p1 - // p1 + // [i10, i11, i12, i13, i3, i4, i5, i7, i8, i9, i0, i1] + // r1 r2 r0 r1 r2 r0 r1 r2 r0 r3 r3 r3 + // r0: [i8, i4, i12] + // p0 p1 p1 + // p0 + // r1: [i5, i10, i13] + // p1 p0 p1 + // p0 + // r2: [i3, i11, i7] + // p0 p1 p1 + // p0 + // r3: [i9, i0, i1] + // p0 p0 p1 + // p1 assertEquals(instancePartitions.getInstances(0, 0), - Arrays.asList(SERVER_INSTANCE_ID_PREFIX + 8, SERVER_INSTANCE_ID_PREFIX + 1)); + Arrays.asList(SERVER_INSTANCE_ID_PREFIX + 8, SERVER_INSTANCE_ID_PREFIX + 12)); assertEquals(instancePartitions.getInstances(1, 0), - Arrays.asList(SERVER_INSTANCE_ID_PREFIX + 4, SERVER_INSTANCE_ID_PREFIX + 8)); + Arrays.asList(SERVER_INSTANCE_ID_PREFIX + 4, SERVER_INSTANCE_ID_PREFIX + 12)); assertEquals(instancePartitions.getInstances(0, 1), - Arrays.asList(SERVER_INSTANCE_ID_PREFIX + 9, SERVER_INSTANCE_ID_PREFIX + 10)); + Arrays.asList(SERVER_INSTANCE_ID_PREFIX + 13, SERVER_INSTANCE_ID_PREFIX + 10)); assertEquals(instancePartitions.getInstances(1, 1), - Arrays.asList(SERVER_INSTANCE_ID_PREFIX + 5, SERVER_INSTANCE_ID_PREFIX + 9)); + Arrays.asList(SERVER_INSTANCE_ID_PREFIX + 5, SERVER_INSTANCE_ID_PREFIX + 13)); assertEquals(instancePartitions.getInstances(0, 2), - Arrays.asList(SERVER_INSTANCE_ID_PREFIX + 0, SERVER_INSTANCE_ID_PREFIX + 3)); + Arrays.asList(SERVER_INSTANCE_ID_PREFIX + 11, SERVER_INSTANCE_ID_PREFIX + 3)); assertEquals(instancePartitions.getInstances(1, 2), - Arrays.asList(SERVER_INSTANCE_ID_PREFIX + 11, SERVER_INSTANCE_ID_PREFIX + 0)); + Arrays.asList(SERVER_INSTANCE_ID_PREFIX + 11, SERVER_INSTANCE_ID_PREFIX + 7)); assertEquals(instancePartitions.getInstances(0, 3), - Arrays.asList(SERVER_INSTANCE_ID_PREFIX + 12, SERVER_INSTANCE_ID_PREFIX + 13)); + Arrays.asList(SERVER_INSTANCE_ID_PREFIX + 9, SERVER_INSTANCE_ID_PREFIX + 0)); assertEquals(instancePartitions.getInstances(1, 3), - Arrays.asList(SERVER_INSTANCE_ID_PREFIX + 7, SERVER_INSTANCE_ID_PREFIX + 12)); + Arrays.asList(SERVER_INSTANCE_ID_PREFIX + 1, SERVER_INSTANCE_ID_PREFIX + 9)); // Remove one replica group (from 4 to 3). numReplicas = 3; tableConfig.getValidationConfig().setReplication(Integer.toString(numReplicas)); + tableConfig.getValidationConfig().setReplication(Integer.toString(numReplicas)); + instanceAssignmentConfig = new InstanceAssignmentConfig( + new InstanceTagPoolConfig(TagNameUtils.getOfflineTagForTenant(TENANT_NAME), false, 0, null), null, + new InstanceReplicaGroupPartitionConfig(true, 0, numReplicas, 0, numPartitions, numInstancesPerPartition, true, + partitionColumn), null, true); + tableConfig.setInstanceAssignmentConfigMap(Map.of("OFFLINE", instanceAssignmentConfig)); instancePartitions = driver.assignInstances(InstancePartitionsType.OFFLINE, instanceConfigs, instancePartitions); assertEquals(instancePartitions.getNumReplicaGroups(), numReplicas); assertEquals(instancePartitions.getNumPartitions(), numPartitions); - // The output should be the same as the one before adding one replica group. + // Math.abs("myTable_OFFLINE".hashCode()) % 12 = 2 + // [i10, i11, i12, i13, i3, i4, i5, i7, i8, i9, i0, i1] + // r1 r2 r0 r1 r2 r0 r1 r2 r0 r0 r1 r2 + // r0: [i8, i4, i12, i9] + // p0 p1 p0 p1 + // r1: [i5, i10, i13, i0] + // p1 p0 p0 p1 + // r2: [i3, i11, i7, i1] + // p0 p0 p1 p1 assertEquals(instancePartitions.getInstances(0, 0), - Arrays.asList(SERVER_INSTANCE_ID_PREFIX + 8, SERVER_INSTANCE_ID_PREFIX + 1)); + Arrays.asList(SERVER_INSTANCE_ID_PREFIX + 8, SERVER_INSTANCE_ID_PREFIX + 12)); assertEquals(instancePartitions.getInstances(1, 0), - Arrays.asList(SERVER_INSTANCE_ID_PREFIX + 4, SERVER_INSTANCE_ID_PREFIX + 8)); + Arrays.asList(SERVER_INSTANCE_ID_PREFIX + 4, SERVER_INSTANCE_ID_PREFIX + 9)); assertEquals(instancePartitions.getInstances(0, 1), - Arrays.asList(SERVER_INSTANCE_ID_PREFIX + 9, SERVER_INSTANCE_ID_PREFIX + 10)); + Arrays.asList(SERVER_INSTANCE_ID_PREFIX + 13, SERVER_INSTANCE_ID_PREFIX + 10)); assertEquals(instancePartitions.getInstances(1, 1), - Arrays.asList(SERVER_INSTANCE_ID_PREFIX + 5, SERVER_INSTANCE_ID_PREFIX + 9)); + Arrays.asList(SERVER_INSTANCE_ID_PREFIX + 5, SERVER_INSTANCE_ID_PREFIX + 0)); assertEquals(instancePartitions.getInstances(0, 2), - Arrays.asList(SERVER_INSTANCE_ID_PREFIX + 0, SERVER_INSTANCE_ID_PREFIX + 3)); + Arrays.asList(SERVER_INSTANCE_ID_PREFIX + 11, SERVER_INSTANCE_ID_PREFIX + 3)); assertEquals(instancePartitions.getInstances(1, 2), - Arrays.asList(SERVER_INSTANCE_ID_PREFIX + 11, SERVER_INSTANCE_ID_PREFIX + 0)); + Arrays.asList(SERVER_INSTANCE_ID_PREFIX + 1, SERVER_INSTANCE_ID_PREFIX + 7)); + } + + @Test + public void testMinimizeDataMovementPoolBasedSingleInstancePartitions() { + int numReplicas = 2; + int numPartitions = 10; + int numInstancesPerPartition = 1; + String partitionColumn = "partition"; + InstanceAssignmentConfig instanceAssignmentConfig = new InstanceAssignmentConfig( + new InstanceTagPoolConfig(TagNameUtils.getOfflineTagForTenant(TENANT_NAME), true, 0, null), null, + new InstanceReplicaGroupPartitionConfig(true, 0, numReplicas, 0, numPartitions, numInstancesPerPartition, true, + partitionColumn), null, true); + TableConfig tableConfig = new TableConfigBuilder(TableType.OFFLINE).setTableName(RAW_TABLE_NAME) + .setNumReplicas(numReplicas) + .setInstanceAssignmentConfigMap(Map.of("OFFLINE", instanceAssignmentConfig)) + .build(); + + int numPools = 2; + int numInstances = 6; + List instanceConfigs = new ArrayList<>(numInstances); + for (int i = 0; i < numInstances; i++) { + InstanceConfig instanceConfig = new InstanceConfig(SERVER_INSTANCE_ID_PREFIX + i); + instanceConfig.addTag(OFFLINE_TAG); + instanceConfig.getRecord() + .setMapField(InstanceUtils.POOL_KEY, Map.of(OFFLINE_TAG, Integer.toString(i % numPools))); + instanceConfigs.add(instanceConfig); + } + + // Start without existing InstancePartitions: + // Instances from each pool should be assigned to 1 replica-group, each with 3 instances, then these 3 instances + // should be assigned to 10 partitions, each with 1 instance + InstanceAssignmentDriver driver = new InstanceAssignmentDriver(tableConfig); + InstancePartitions instancePartitions = + driver.assignInstances(InstancePartitionsType.OFFLINE, instanceConfigs, null); + assertEquals(instancePartitions.getNumReplicaGroups(), numReplicas); + assertEquals(instancePartitions.getNumPartitions(), numPartitions); + + // Math.abs("myTable_OFFLINE".hashCode()) % 2 = 0 + // Math.abs("myTable_OFFLINE".hashCode()) % 3 = 2 + // [i4, i0, i2] + // [i5, i1, i3] + // p0 p1 p2 + // p3 p4 p5 + // p6 p7 p8 + // p9 + assertEquals(instancePartitions.getInstances(0, 0), List.of(SERVER_INSTANCE_ID_PREFIX + 4)); + assertEquals(instancePartitions.getInstances(0, 1), List.of(SERVER_INSTANCE_ID_PREFIX + 5)); + assertEquals(instancePartitions.getInstances(1, 0), List.of(SERVER_INSTANCE_ID_PREFIX + 0)); + assertEquals(instancePartitions.getInstances(1, 1), List.of(SERVER_INSTANCE_ID_PREFIX + 1)); + assertEquals(instancePartitions.getInstances(2, 0), List.of(SERVER_INSTANCE_ID_PREFIX + 2)); + assertEquals(instancePartitions.getInstances(2, 1), List.of(SERVER_INSTANCE_ID_PREFIX + 3)); + assertEquals(instancePartitions.getInstances(3, 0), List.of(SERVER_INSTANCE_ID_PREFIX + 4)); + assertEquals(instancePartitions.getInstances(3, 1), List.of(SERVER_INSTANCE_ID_PREFIX + 5)); + assertEquals(instancePartitions.getInstances(4, 0), List.of(SERVER_INSTANCE_ID_PREFIX + 0)); + assertEquals(instancePartitions.getInstances(4, 1), List.of(SERVER_INSTANCE_ID_PREFIX + 1)); + assertEquals(instancePartitions.getInstances(5, 0), List.of(SERVER_INSTANCE_ID_PREFIX + 2)); + assertEquals(instancePartitions.getInstances(5, 1), List.of(SERVER_INSTANCE_ID_PREFIX + 3)); + assertEquals(instancePartitions.getInstances(6, 0), List.of(SERVER_INSTANCE_ID_PREFIX + 4)); + assertEquals(instancePartitions.getInstances(6, 1), List.of(SERVER_INSTANCE_ID_PREFIX + 5)); + assertEquals(instancePartitions.getInstances(7, 0), List.of(SERVER_INSTANCE_ID_PREFIX + 0)); + assertEquals(instancePartitions.getInstances(7, 1), List.of(SERVER_INSTANCE_ID_PREFIX + 1)); + assertEquals(instancePartitions.getInstances(8, 0), List.of(SERVER_INSTANCE_ID_PREFIX + 2)); + assertEquals(instancePartitions.getInstances(8, 1), List.of(SERVER_INSTANCE_ID_PREFIX + 3)); + assertEquals(instancePartitions.getInstances(9, 0), List.of(SERVER_INSTANCE_ID_PREFIX + 4)); + assertEquals(instancePartitions.getInstances(9, 1), List.of(SERVER_INSTANCE_ID_PREFIX + 5)); + + // Add 2 new instances + // Each existing instance should keep 3 partitions unmoved, and only 1 partition should be moved to the new instance + for (int i = numInstances; i < numInstances + 2; i++) { + InstanceConfig instanceConfig = new InstanceConfig(SERVER_INSTANCE_ID_PREFIX + i); + instanceConfig.addTag(OFFLINE_TAG); + instanceConfig.getRecord() + .setMapField(InstanceUtils.POOL_KEY, Map.of(OFFLINE_TAG, Integer.toString(i % numPools))); + instanceConfigs.add(instanceConfig); + } + instancePartitions = driver.assignInstances(InstancePartitionsType.OFFLINE, instanceConfigs, instancePartitions); + assertEquals(instancePartitions.getNumReplicaGroups(), numReplicas); + assertEquals(instancePartitions.getNumPartitions(), numPartitions); + + // Math.abs("myTable_OFFLINE".hashCode()) % 2 = 0 + // Math.abs("myTable_OFFLINE".hashCode()) % 4 = 2 + // [i4, i6, i0, i2] + // [i5, i7, i1, i3] + // p0 p9 p1 p2 + // p3 p4 p5 + // p6 p7 p8 + assertEquals(instancePartitions.getInstances(0, 0), List.of(SERVER_INSTANCE_ID_PREFIX + 4)); + assertEquals(instancePartitions.getInstances(0, 1), List.of(SERVER_INSTANCE_ID_PREFIX + 5)); + assertEquals(instancePartitions.getInstances(1, 0), List.of(SERVER_INSTANCE_ID_PREFIX + 0)); + assertEquals(instancePartitions.getInstances(1, 1), List.of(SERVER_INSTANCE_ID_PREFIX + 1)); + assertEquals(instancePartitions.getInstances(2, 0), List.of(SERVER_INSTANCE_ID_PREFIX + 2)); + assertEquals(instancePartitions.getInstances(2, 1), List.of(SERVER_INSTANCE_ID_PREFIX + 3)); + assertEquals(instancePartitions.getInstances(3, 0), List.of(SERVER_INSTANCE_ID_PREFIX + 4)); + assertEquals(instancePartitions.getInstances(3, 1), List.of(SERVER_INSTANCE_ID_PREFIX + 5)); + assertEquals(instancePartitions.getInstances(4, 0), List.of(SERVER_INSTANCE_ID_PREFIX + 0)); + assertEquals(instancePartitions.getInstances(4, 1), List.of(SERVER_INSTANCE_ID_PREFIX + 1)); + assertEquals(instancePartitions.getInstances(5, 0), List.of(SERVER_INSTANCE_ID_PREFIX + 2)); + assertEquals(instancePartitions.getInstances(5, 1), List.of(SERVER_INSTANCE_ID_PREFIX + 3)); + assertEquals(instancePartitions.getInstances(6, 0), List.of(SERVER_INSTANCE_ID_PREFIX + 4)); + assertEquals(instancePartitions.getInstances(6, 1), List.of(SERVER_INSTANCE_ID_PREFIX + 5)); + assertEquals(instancePartitions.getInstances(7, 0), List.of(SERVER_INSTANCE_ID_PREFIX + 0)); + assertEquals(instancePartitions.getInstances(7, 1), List.of(SERVER_INSTANCE_ID_PREFIX + 1)); + assertEquals(instancePartitions.getInstances(8, 0), List.of(SERVER_INSTANCE_ID_PREFIX + 2)); + assertEquals(instancePartitions.getInstances(8, 1), List.of(SERVER_INSTANCE_ID_PREFIX + 3)); + assertEquals(instancePartitions.getInstances(9, 0), List.of(SERVER_INSTANCE_ID_PREFIX + 6)); + assertEquals(instancePartitions.getInstances(9, 1), List.of(SERVER_INSTANCE_ID_PREFIX + 7)); } public void testMirrorServerSetBasedRandom() throws FileNotFoundException { diff --git a/pinot-spi/src/main/java/org/apache/pinot/spi/config/table/SegmentsValidationAndRetentionConfig.java b/pinot-spi/src/main/java/org/apache/pinot/spi/config/table/SegmentsValidationAndRetentionConfig.java index 0b8a403041ab..592a6c1960f8 100644 --- a/pinot-spi/src/main/java/org/apache/pinot/spi/config/table/SegmentsValidationAndRetentionConfig.java +++ b/pinot-spi/src/main/java/org/apache/pinot/spi/config/table/SegmentsValidationAndRetentionConfig.java @@ -21,6 +21,7 @@ import com.fasterxml.jackson.annotation.JsonIgnore; import java.util.concurrent.TimeUnit; import org.apache.pinot.spi.config.BaseJsonConfig; +import org.apache.pinot.spi.config.table.assignment.InstanceAssignmentConfig; import org.apache.pinot.spi.config.table.ingestion.IngestionConfig; import org.apache.pinot.spi.utils.TimeUtils; @@ -43,20 +44,26 @@ public class SegmentsValidationAndRetentionConfig extends BaseJsonConfig { private TimeUnit _timeType; @Deprecated // Use SegmentAssignmentConfig instead private String _segmentAssignmentStrategy; + @Deprecated // Use SegmentAssignmentConfig instead private ReplicaGroupStrategyConfig _replicaGroupStrategyConfig; private CompletionConfig _completionConfig; private String _crypterClassName; + @Deprecated private boolean _minimizeDataMovement; // Possible values can be http or https. If this field is set, a Pinot server can download segments from peer servers // using the specified download scheme. Both realtime tables and offline tables can set this field. // For more usage of this field, please refer to this design doc: https://tinyurl.com/f63ru4sb private String _peerSegmentDownloadScheme; + /** + * @deprecated Use {@link InstanceAssignmentConfig} instead + */ @Deprecated public String getSegmentAssignmentStrategy() { return _segmentAssignmentStrategy; } + @Deprecated public void setSegmentAssignmentStrategy(String segmentAssignmentStrategy) { _segmentAssignmentStrategy = segmentAssignmentStrategy; } @@ -174,10 +181,15 @@ public void setSchemaName(String schemaName) { _schemaName = schemaName; } + /** + * @deprecated Use {@link InstanceAssignmentConfig} instead. + */ + @Deprecated public ReplicaGroupStrategyConfig getReplicaGroupStrategyConfig() { return _replicaGroupStrategyConfig; } + @Deprecated public void setReplicaGroupStrategyConfig(ReplicaGroupStrategyConfig replicaGroupStrategyConfig) { _replicaGroupStrategyConfig = replicaGroupStrategyConfig; } @@ -226,10 +238,15 @@ public void setCrypterClassName(String crypterClassName) { _crypterClassName = crypterClassName; } + /** + * @deprecated Use {@link InstanceAssignmentConfig} instead + */ + @Deprecated public boolean isMinimizeDataMovement() { return _minimizeDataMovement; } + @Deprecated public void setMinimizeDataMovement(boolean minimizeDataMovement) { _minimizeDataMovement = minimizeDataMovement; } diff --git a/pinot-spi/src/main/java/org/apache/pinot/spi/utils/builder/TableConfigBuilder.java b/pinot-spi/src/main/java/org/apache/pinot/spi/utils/builder/TableConfigBuilder.java index 5e9d915cfc46..3b51a6052f4e 100644 --- a/pinot-spi/src/main/java/org/apache/pinot/spi/utils/builder/TableConfigBuilder.java +++ b/pinot-spi/src/main/java/org/apache/pinot/spi/utils/builder/TableConfigBuilder.java @@ -78,6 +78,7 @@ public class TableConfigBuilder { @Deprecated private String _segmentAssignmentStrategy; private String _peerSegmentDownloadScheme; + @Deprecated private ReplicaGroupStrategyConfig _replicaGroupStrategyConfig; private CompletionConfig _completionConfig; private String _crypterClassName; From f98b250bf5b15cac6a6b48b878263fb999517b73 Mon Sep 17 00:00:00 2001 From: "Xiaotian (Jackie) Jiang" <17555551+Jackie-Jiang@users.noreply.github.com> Date: Mon, 30 Dec 2024 23:02:30 -0800 Subject: [PATCH 4/5] [Multi-stage] Support is_enable_group_trim agg option (#14664) --- pinot-common/src/main/proto/plan.proto | 2 + .../calcite/rel/hint/PinotHintOptions.java | 3 +- .../rel/logical/PinotLogicalAggregate.java | 55 +++-- .../PinotAggregateExchangeNodeInsertRule.java | 216 +++++++++++++----- .../calcite/rel/rules/PinotQueryRuleSets.java | 4 +- .../parser/CalciteRexExpressionParser.java | 4 +- .../query/planner/explain/PlanNodeMerger.java | 6 + .../logical/EquivalentStagesFinder.java | 4 +- .../logical/RelToPlanNodeConverter.java | 2 +- .../query/planner/plannode/AggregateNode.java | 30 ++- .../planner/serde/PlanNodeDeserializer.java | 3 +- .../planner/serde/PlanNodeSerializer.java | 2 + .../test/resources/queries/GroupByPlans.json | 49 ++++ .../plan/server/ServerPlanRequestVisitor.java | 39 ++-- .../operator/AggregateOperatorTest.java | 2 +- .../operator/MultiStageAccountingTest.java | 2 +- .../test/resources/queries/QueryHints.json | 8 + 17 files changed, 325 insertions(+), 106 deletions(-) diff --git a/pinot-common/src/main/proto/plan.proto b/pinot-common/src/main/proto/plan.proto index 49d357307648..e3b2bbf65482 100644 --- a/pinot-common/src/main/proto/plan.proto +++ b/pinot-common/src/main/proto/plan.proto @@ -69,6 +69,8 @@ message AggregateNode { repeated int32 groupKeys = 3; AggType aggType = 4; bool leafReturnFinalResult = 5; + repeated Collation collations = 6; + int32 limit = 7; } message FilterNode { diff --git a/pinot-query-planner/src/main/java/org/apache/pinot/calcite/rel/hint/PinotHintOptions.java b/pinot-query-planner/src/main/java/org/apache/pinot/calcite/rel/hint/PinotHintOptions.java index 558b2f898539..3c676edd18e5 100644 --- a/pinot-query-planner/src/main/java/org/apache/pinot/calcite/rel/hint/PinotHintOptions.java +++ b/pinot-query-planner/src/main/java/org/apache/pinot/calcite/rel/hint/PinotHintOptions.java @@ -42,7 +42,8 @@ private PinotHintOptions() { public static class AggregateOptions { public static final String IS_PARTITIONED_BY_GROUP_BY_KEYS = "is_partitioned_by_group_by_keys"; public static final String IS_LEAF_RETURN_FINAL_RESULT = "is_leaf_return_final_result"; - public static final String SKIP_LEAF_STAGE_GROUP_BY_AGGREGATION = "is_skip_leaf_stage_group_by"; + public static final String IS_SKIP_LEAF_STAGE_GROUP_BY = "is_skip_leaf_stage_group_by"; + public static final String IS_ENABLE_GROUP_TRIM = "is_enable_group_trim"; public static final String NUM_GROUPS_LIMIT = "num_groups_limit"; public static final String MAX_INITIAL_RESULT_HOLDER_CAPACITY = "max_initial_result_holder_capacity"; diff --git a/pinot-query-planner/src/main/java/org/apache/pinot/calcite/rel/logical/PinotLogicalAggregate.java b/pinot-query-planner/src/main/java/org/apache/pinot/calcite/rel/logical/PinotLogicalAggregate.java index 241c44703e6b..f9edb412c883 100644 --- a/pinot-query-planner/src/main/java/org/apache/pinot/calcite/rel/logical/PinotLogicalAggregate.java +++ b/pinot-query-planner/src/main/java/org/apache/pinot/calcite/rel/logical/PinotLogicalAggregate.java @@ -22,6 +22,7 @@ import javax.annotation.Nullable; import org.apache.calcite.plan.RelOptCluster; import org.apache.calcite.plan.RelTraitSet; +import org.apache.calcite.rel.RelFieldCollation; import org.apache.calcite.rel.RelNode; import org.apache.calcite.rel.RelWriter; import org.apache.calcite.rel.core.Aggregate; @@ -35,39 +36,36 @@ public class PinotLogicalAggregate extends Aggregate { private final AggType _aggType; private final boolean _leafReturnFinalResult; + // The following fields are set when group trim is enabled, and are extracted from the Sort on top of this Aggregate. + private final List _collations; + private final int _limit; + public PinotLogicalAggregate(RelOptCluster cluster, RelTraitSet traitSet, List hints, RelNode input, ImmutableBitSet groupSet, @Nullable List groupSets, List aggCalls, - AggType aggType, boolean leafReturnFinalResult) { + AggType aggType, boolean leafReturnFinalResult, @Nullable List collations, int limit) { super(cluster, traitSet, hints, input, groupSet, groupSets, aggCalls); _aggType = aggType; _leafReturnFinalResult = leafReturnFinalResult; + _collations = collations; + _limit = limit; } - public PinotLogicalAggregate(RelOptCluster cluster, RelTraitSet traitSet, List hints, RelNode input, - ImmutableBitSet groupSet, @Nullable List groupSets, List aggCalls, - AggType aggType) { - this(cluster, traitSet, hints, input, groupSet, groupSets, aggCalls, aggType, false); - } - - public PinotLogicalAggregate(Aggregate aggRel, List aggCalls, AggType aggType, - boolean leafReturnFinalResult) { - this(aggRel.getCluster(), aggRel.getTraitSet(), aggRel.getHints(), aggRel.getInput(), aggRel.getGroupSet(), - aggRel.getGroupSets(), aggCalls, aggType, leafReturnFinalResult); + public PinotLogicalAggregate(Aggregate aggRel, RelNode input, ImmutableBitSet groupSet, + @Nullable List groupSets, List aggCalls, AggType aggType, + boolean leafReturnFinalResult, @Nullable List collations, int limit) { + this(aggRel.getCluster(), aggRel.getTraitSet(), aggRel.getHints(), input, groupSet, groupSets, aggCalls, aggType, + leafReturnFinalResult, collations, limit); } - public PinotLogicalAggregate(Aggregate aggRel, List aggCalls, AggType aggType) { - this(aggRel, aggCalls, aggType, false); - } - - public PinotLogicalAggregate(Aggregate aggRel, RelNode input, List aggCalls, AggType aggType) { - this(aggRel.getCluster(), aggRel.getTraitSet(), aggRel.getHints(), input, aggRel.getGroupSet(), - aggRel.getGroupSets(), aggCalls, aggType); + public PinotLogicalAggregate(Aggregate aggRel, RelNode input, List aggCalls, AggType aggType, + boolean leafReturnFinalResult, @Nullable List collations, int limit) { + this(aggRel, input, aggRel.getGroupSet(), aggRel.getGroupSets(), aggCalls, aggType, + leafReturnFinalResult, collations, limit); } public PinotLogicalAggregate(Aggregate aggRel, RelNode input, ImmutableBitSet groupSet, List aggCalls, - AggType aggType, boolean leafReturnFinalResult) { - this(aggRel.getCluster(), aggRel.getTraitSet(), aggRel.getHints(), input, groupSet, null, aggCalls, aggType, - leafReturnFinalResult); + AggType aggType, boolean leafReturnFinalResult, @Nullable List collations, int limit) { + this(aggRel, input, groupSet, null, aggCalls, aggType, leafReturnFinalResult, collations, limit); } public AggType getAggType() { @@ -78,11 +76,20 @@ public boolean isLeafReturnFinalResult() { return _leafReturnFinalResult; } + @Nullable + public List getCollations() { + return _collations; + } + + public int getLimit() { + return _limit; + } + @Override public PinotLogicalAggregate copy(RelTraitSet traitSet, RelNode input, ImmutableBitSet groupSet, @Nullable List groupSets, List aggCalls) { return new PinotLogicalAggregate(getCluster(), traitSet, hints, input, groupSet, groupSets, aggCalls, _aggType, - _leafReturnFinalResult); + _leafReturnFinalResult, _collations, _limit); } @Override @@ -90,12 +97,14 @@ public RelWriter explainTerms(RelWriter pw) { RelWriter relWriter = super.explainTerms(pw); relWriter.item("aggType", _aggType); relWriter.itemIf("leafReturnFinalResult", true, _leafReturnFinalResult); + relWriter.itemIf("collations", _collations, _collations != null); + relWriter.itemIf("limit", _limit, _limit > 0); return relWriter; } @Override public RelNode withHints(List hintList) { return new PinotLogicalAggregate(getCluster(), traitSet, hintList, input, groupSet, groupSets, aggCalls, _aggType, - _leafReturnFinalResult); + _leafReturnFinalResult, _collations, _limit); } } diff --git a/pinot-query-planner/src/main/java/org/apache/pinot/calcite/rel/rules/PinotAggregateExchangeNodeInsertRule.java b/pinot-query-planner/src/main/java/org/apache/pinot/calcite/rel/rules/PinotAggregateExchangeNodeInsertRule.java index df11fdb49a2e..84b2a274aa27 100644 --- a/pinot-query-planner/src/main/java/org/apache/pinot/calcite/rel/rules/PinotAggregateExchangeNodeInsertRule.java +++ b/pinot-query-planner/src/main/java/org/apache/pinot/calcite/rel/rules/PinotAggregateExchangeNodeInsertRule.java @@ -28,10 +28,12 @@ import org.apache.calcite.rel.RelCollation; import org.apache.calcite.rel.RelDistribution; import org.apache.calcite.rel.RelDistributions; +import org.apache.calcite.rel.RelFieldCollation; import org.apache.calcite.rel.RelNode; import org.apache.calcite.rel.core.Aggregate; import org.apache.calcite.rel.core.AggregateCall; import org.apache.calcite.rel.core.Project; +import org.apache.calcite.rel.core.Sort; import org.apache.calcite.rel.core.Union; import org.apache.calcite.rel.logical.LogicalAggregate; import org.apache.calcite.rel.rules.AggregateExtractProjectRule; @@ -82,49 +84,161 @@ * - COUNT(*) with a GROUP_BY_KEY transforms into: COUNT(*)__LEAF --> COUNT(*)__FINAL, where * - COUNT(*)__LEAF produces TUPLE[ SUM(1), GROUP_BY_KEY ] * - COUNT(*)__FINAL produces TUPLE[ SUM(COUNT(*)__LEAF), GROUP_BY_KEY ] + * + * There are 3 sub-rules: + * 1. {@link SortProjectAggregate}: + * Matches the case when there's a Sort on top of Project on top of Aggregate, and enable group trim hint is present. + * E.g. + * SELECT /*+ aggOptions(is_enable_group_trim='true') * / + * COUNT(*) AS cnt, col1 FROM myTable GROUP BY col1 ORDER BY cnt DESC LIMIT 10 + * It will extract the collations and limit from the Sort node, and set them into the Aggregate node. It works only + * when the sort key is a direct reference to the input, i.e. no transform on the input columns. + * 2. {@link SortAggregate}: + * Matches the case when there's a Sort on top of Aggregate, and enable group trim hint is present. + * E.g. + * SELECT /*+ aggOptions(is_enable_group_trim='true') * / + * col1, COUNT(*) AS cnt FROM myTable GROUP BY col1 ORDER BY cnt DESC LIMIT 10 + * It will extract the collations and limit from the Sort node, and set them into the Aggregate node. + * 3. {@link WithoutSort}: + * Matches Aggregate node if there is no match of {@link SortProjectAggregate} or {@link SortAggregate}. + * + * TODO: + * 1. Always enable group trim when the result is guaranteed to be accurate + * 2. Add intermediate stage group trim + * 3. Allow tuning group trim parameters with query hint */ -public class PinotAggregateExchangeNodeInsertRule extends RelOptRule { - public static final PinotAggregateExchangeNodeInsertRule INSTANCE = - new PinotAggregateExchangeNodeInsertRule(PinotRuleUtils.PINOT_REL_FACTORY); - - public PinotAggregateExchangeNodeInsertRule(RelBuilderFactory factory) { - // NOTE: Explicitly match for LogicalAggregate because after applying the rule, LogicalAggregate is replaced with - // PinotLogicalAggregate, and the rule won't be applied again. - super(operand(LogicalAggregate.class, any()), factory, null); +public class PinotAggregateExchangeNodeInsertRule { + + public static class SortProjectAggregate extends RelOptRule { + public static final SortProjectAggregate INSTANCE = new SortProjectAggregate(PinotRuleUtils.PINOT_REL_FACTORY); + + private SortProjectAggregate(RelBuilderFactory factory) { + // NOTE: Explicitly match for LogicalAggregate because after applying the rule, LogicalAggregate is replaced with + // PinotLogicalAggregate, and the rule won't be applied again. + super(operand(Sort.class, operand(Project.class, operand(LogicalAggregate.class, any()))), factory, null); + } + + @Override + public void onMatch(RelOptRuleCall call) { + LogicalAggregate aggRel = call.rel(2); + if (aggRel.getGroupSet().isEmpty()) { + return; + } + Map hintOptions = + PinotHintStrategyTable.getHintOptions(aggRel.getHints(), PinotHintOptions.AGGREGATE_HINT_OPTIONS); + if (hintOptions == null || !Boolean.parseBoolean( + hintOptions.get(PinotHintOptions.AggregateOptions.IS_ENABLE_GROUP_TRIM))) { + return; + } + + Sort sortRel = call.rel(0); + Project projectRel = call.rel(1); + List projects = projectRel.getProjects(); + List collations = sortRel.getCollation().getFieldCollations(); + List newCollations = new ArrayList<>(collations.size()); + for (RelFieldCollation fieldCollation : collations) { + RexNode project = projects.get(fieldCollation.getFieldIndex()); + if (project instanceof RexInputRef) { + newCollations.add(fieldCollation.withFieldIndex(((RexInputRef) project).getIndex())); + } else { + // Cannot enable group trim when the sort key is not a direct reference to the input. + return; + } + } + int limit = 0; + if (sortRel.fetch != null) { + limit = RexLiteral.intValue(sortRel.fetch); + } + if (limit <= 0) { + // Cannot enable group trim when there is no limit. + return; + } + + PinotLogicalAggregate newAggRel = createPlan(call, aggRel, true, hintOptions, newCollations, limit); + RelNode newProjectRel = projectRel.copy(projectRel.getTraitSet(), List.of(newAggRel)); + call.transformTo(sortRel.copy(sortRel.getTraitSet(), List.of(newProjectRel))); + } } - /** - * Split the AGG into 3 plan fragments, all with the same AGG type (in some cases the final agg name may be different) - * Pinot internal plan fragment optimization can use the info of the input data type to infer whether it should - * generate the "final-stage AGG operator" or "intermediate-stage AGG operator" or "leaf-stage AGG operator" - * - * @param call the {@link RelOptRuleCall} on match. - * @see org.apache.pinot.core.query.aggregation.function.AggregationFunction - */ - @Override - public void onMatch(RelOptRuleCall call) { - Aggregate aggRel = call.rel(0); - boolean hasGroupBy = !aggRel.getGroupSet().isEmpty(); - RelCollation collation = extractWithInGroupCollation(aggRel); - Map hintOptions = - PinotHintStrategyTable.getHintOptions(aggRel.getHints(), PinotHintOptions.AGGREGATE_HINT_OPTIONS); - // Collation is not supported in leaf stage aggregation. - if (collation != null || (hasGroupBy && hintOptions != null && Boolean.parseBoolean( - hintOptions.get(PinotHintOptions.AggregateOptions.SKIP_LEAF_STAGE_GROUP_BY_AGGREGATION)))) { - call.transformTo(createPlanWithExchangeDirectAggregation(call, collation)); - } else if (hasGroupBy && hintOptions != null && Boolean.parseBoolean( + public static class SortAggregate extends RelOptRule { + public static final SortAggregate INSTANCE = new SortAggregate(PinotRuleUtils.PINOT_REL_FACTORY); + + private SortAggregate(RelBuilderFactory factory) { + // NOTE: Explicitly match for LogicalAggregate because after applying the rule, LogicalAggregate is replaced with + // PinotLogicalAggregate, and the rule won't be applied again. + super(operand(Sort.class, operand(LogicalAggregate.class, any())), factory, null); + } + + @Override + public void onMatch(RelOptRuleCall call) { + LogicalAggregate aggRel = call.rel(1); + if (aggRel.getGroupSet().isEmpty()) { + return; + } + Map hintOptions = + PinotHintStrategyTable.getHintOptions(aggRel.getHints(), PinotHintOptions.AGGREGATE_HINT_OPTIONS); + if (hintOptions == null || !Boolean.parseBoolean( + hintOptions.get(PinotHintOptions.AggregateOptions.IS_ENABLE_GROUP_TRIM))) { + return; + } + + Sort sortRel = call.rel(0); + List collations = sortRel.getCollation().getFieldCollations(); + int limit = 0; + if (sortRel.fetch != null) { + limit = RexLiteral.intValue(sortRel.fetch); + } + if (limit <= 0) { + // Cannot enable group trim when there is no limit. + return; + } + + PinotLogicalAggregate newAggRel = createPlan(call, aggRel, true, hintOptions, collations, limit); + call.transformTo(sortRel.copy(sortRel.getTraitSet(), List.of(newAggRel))); + } + } + + public static class WithoutSort extends RelOptRule { + public static final WithoutSort INSTANCE = new WithoutSort(PinotRuleUtils.PINOT_REL_FACTORY); + + private WithoutSort(RelBuilderFactory factory) { + // NOTE: Explicitly match for LogicalAggregate because after applying the rule, LogicalAggregate is replaced with + // PinotLogicalAggregate, and the rule won't be applied again. + super(operand(LogicalAggregate.class, any()), factory, null); + } + + @Override + public void onMatch(RelOptRuleCall call) { + Aggregate aggRel = call.rel(0); + Map hintOptions = + PinotHintStrategyTable.getHintOptions(aggRel.getHints(), PinotHintOptions.AGGREGATE_HINT_OPTIONS); + call.transformTo( + createPlan(call, aggRel, !aggRel.getGroupSet().isEmpty(), hintOptions != null ? hintOptions : Map.of(), null, + 0)); + } + } + + private static PinotLogicalAggregate createPlan(RelOptRuleCall call, Aggregate aggRel, boolean hasGroupBy, + Map hintOptions, @Nullable List collations, int limit) { + // WITHIN GROUP collation is not supported in leaf stage aggregation. + RelCollation withinGroupCollation = extractWithinGroupCollation(aggRel); + if (withinGroupCollation != null || (hasGroupBy && Boolean.parseBoolean( + hintOptions.get(PinotHintOptions.AggregateOptions.IS_SKIP_LEAF_STAGE_GROUP_BY)))) { + return createPlanWithExchangeDirectAggregation(call, aggRel, withinGroupCollation, collations, limit); + } else if (hasGroupBy && Boolean.parseBoolean( hintOptions.get(PinotHintOptions.AggregateOptions.IS_PARTITIONED_BY_GROUP_BY_KEYS))) { - call.transformTo(new PinotLogicalAggregate(aggRel, buildAggCalls(aggRel, AggType.DIRECT, false), AggType.DIRECT)); + return new PinotLogicalAggregate(aggRel, aggRel.getInput(), buildAggCalls(aggRel, AggType.DIRECT, false), + AggType.DIRECT, false, collations, limit); } else { - boolean leafReturnFinalResult = hintOptions != null && Boolean.parseBoolean( - hintOptions.get(PinotHintOptions.AggregateOptions.IS_LEAF_RETURN_FINAL_RESULT)); - call.transformTo(createPlanWithLeafExchangeFinalAggregate(call, leafReturnFinalResult)); + boolean leafReturnFinalResult = + Boolean.parseBoolean(hintOptions.get(PinotHintOptions.AggregateOptions.IS_LEAF_RETURN_FINAL_RESULT)); + return createPlanWithLeafExchangeFinalAggregate(aggRel, leafReturnFinalResult, collations, limit); } } // TODO: Currently it only handles one WITHIN GROUP collation across all AggregateCalls. @Nullable - private static RelCollation extractWithInGroupCollation(Aggregate aggRel) { + private static RelCollation extractWithinGroupCollation(Aggregate aggRel) { for (AggregateCall aggCall : aggRel.getAggCallList()) { RelCollation collation = aggCall.getCollation(); if (!collation.getFieldCollations().isEmpty()) { @@ -138,55 +252,54 @@ private static RelCollation extractWithInGroupCollation(Aggregate aggRel) { * Use this group by optimization to skip leaf stage aggregation when aggregating at leaf level is not desired. Many * situation could be wasted effort to do group-by on leaf, eg: when cardinality of group by column is very high. */ - private static PinotLogicalAggregate createPlanWithExchangeDirectAggregation(RelOptRuleCall call, - @Nullable RelCollation collation) { - Aggregate aggRel = call.rel(0); + private static PinotLogicalAggregate createPlanWithExchangeDirectAggregation(RelOptRuleCall call, Aggregate aggRel, + @Nullable RelCollation withinGroupCollation, @Nullable List collations, int limit) { RelNode input = aggRel.getInput(); // Create Project when there's none below the aggregate. if (!(PinotRuleUtils.unboxRel(input) instanceof Project)) { - aggRel = (Aggregate) generateProjectUnderAggregate(call); + aggRel = (Aggregate) generateProjectUnderAggregate(call, aggRel); input = aggRel.getInput(); } ImmutableBitSet groupSet = aggRel.getGroupSet(); RelDistribution distribution = RelDistributions.hash(groupSet.asList()); RelNode exchange; - if (collation != null) { + if (withinGroupCollation != null) { // Insert a LogicalSort node between exchange and aggregate whe collation exists. - exchange = PinotLogicalSortExchange.create(input, distribution, collation, false, true); + exchange = PinotLogicalSortExchange.create(input, distribution, withinGroupCollation, false, true); } else { exchange = PinotLogicalExchange.create(input, distribution); } - return new PinotLogicalAggregate(aggRel, exchange, buildAggCalls(aggRel, AggType.DIRECT, false), AggType.DIRECT); + return new PinotLogicalAggregate(aggRel, exchange, buildAggCalls(aggRel, AggType.DIRECT, false), AggType.DIRECT, + false, collations, limit); } /** * Aggregate node will be split into LEAF + EXCHANGE + FINAL. * TODO: Add optional INTERMEDIATE stage to reduce hotspot. */ - private static PinotLogicalAggregate createPlanWithLeafExchangeFinalAggregate(RelOptRuleCall call, - boolean leafReturnFinalResult) { - Aggregate aggRel = call.rel(0); + private static PinotLogicalAggregate createPlanWithLeafExchangeFinalAggregate(Aggregate aggRel, + boolean leafReturnFinalResult, @Nullable List collations, int limit) { // Create a LEAF aggregate. PinotLogicalAggregate leafAggRel = - new PinotLogicalAggregate(aggRel, buildAggCalls(aggRel, AggType.LEAF, leafReturnFinalResult), AggType.LEAF, - leafReturnFinalResult); + new PinotLogicalAggregate(aggRel, aggRel.getInput(), buildAggCalls(aggRel, AggType.LEAF, leafReturnFinalResult), + AggType.LEAF, leafReturnFinalResult, collations, limit); // Create an EXCHANGE node over the LEAF aggregate. PinotLogicalExchange exchange = PinotLogicalExchange.create(leafAggRel, RelDistributions.hash(ImmutableIntList.range(0, aggRel.getGroupCount()))); // Create a FINAL aggregate over the EXCHANGE. - return convertAggFromIntermediateInput(call, exchange, AggType.FINAL, leafReturnFinalResult); + return convertAggFromIntermediateInput(aggRel, exchange, AggType.FINAL, leafReturnFinalResult, collations, limit); } /** * The following is copied from {@link AggregateExtractProjectRule#onMatch(RelOptRuleCall)} with modification to take * aggregate input as input. */ - private static RelNode generateProjectUnderAggregate(RelOptRuleCall call) { - final Aggregate aggregate = call.rel(0); + private static RelNode generateProjectUnderAggregate(RelOptRuleCall call, Aggregate aggregate) { // --------------- MODIFIED --------------- final RelNode input = aggregate.getInput(); + // final Aggregate aggregate = call.rel(0); // final RelNode input = call.rel(1); // ------------- END MODIFIED ------------- @@ -230,9 +343,8 @@ private static RelNode generateProjectUnderAggregate(RelOptRuleCall call) { return relBuilder.build(); } - private static PinotLogicalAggregate convertAggFromIntermediateInput(RelOptRuleCall call, - PinotLogicalExchange exchange, AggType aggType, boolean leafReturnFinalResult) { - Aggregate aggRel = call.rel(0); + private static PinotLogicalAggregate convertAggFromIntermediateInput(Aggregate aggRel, PinotLogicalExchange exchange, + AggType aggType, boolean leafReturnFinalResult, @Nullable List collations, int limit) { RelNode input = aggRel.getInput(); List projects = findImmediateProjects(input); @@ -269,7 +381,7 @@ private static PinotLogicalAggregate convertAggFromIntermediateInput(RelOptRuleC } return new PinotLogicalAggregate(aggRel, exchange, ImmutableBitSet.range(groupCount), aggCalls, aggType, - leafReturnFinalResult); + leafReturnFinalResult, collations, limit); } private static List buildAggCalls(Aggregate aggRel, AggType aggType, boolean leafReturnFinalResult) { diff --git a/pinot-query-planner/src/main/java/org/apache/pinot/calcite/rel/rules/PinotQueryRuleSets.java b/pinot-query-planner/src/main/java/org/apache/pinot/calcite/rel/rules/PinotQueryRuleSets.java index e831e7460a52..80e524e11f0e 100644 --- a/pinot-query-planner/src/main/java/org/apache/pinot/calcite/rel/rules/PinotQueryRuleSets.java +++ b/pinot-query-planner/src/main/java/org/apache/pinot/calcite/rel/rules/PinotQueryRuleSets.java @@ -136,7 +136,9 @@ private PinotQueryRuleSets() { PinotSingleValueAggregateRemoveRule.INSTANCE, PinotJoinExchangeNodeInsertRule.INSTANCE, - PinotAggregateExchangeNodeInsertRule.INSTANCE, + PinotAggregateExchangeNodeInsertRule.SortProjectAggregate.INSTANCE, + PinotAggregateExchangeNodeInsertRule.SortAggregate.INSTANCE, + PinotAggregateExchangeNodeInsertRule.WithoutSort.INSTANCE, PinotWindowExchangeNodeInsertRule.INSTANCE, PinotSetOpExchangeNodeInsertRule.INSTANCE, diff --git a/pinot-query-planner/src/main/java/org/apache/pinot/query/parser/CalciteRexExpressionParser.java b/pinot-query-planner/src/main/java/org/apache/pinot/query/parser/CalciteRexExpressionParser.java index a20b2479d4f0..fdd19a9aef23 100644 --- a/pinot-query-planner/src/main/java/org/apache/pinot/query/parser/CalciteRexExpressionParser.java +++ b/pinot-query-planner/src/main/java/org/apache/pinot/query/parser/CalciteRexExpressionParser.java @@ -29,7 +29,6 @@ import org.apache.pinot.common.utils.DataSchema.ColumnDataType; import org.apache.pinot.common.utils.request.RequestUtils; import org.apache.pinot.query.planner.logical.RexExpression; -import org.apache.pinot.query.planner.plannode.SortNode; import org.apache.pinot.spi.utils.BooleanUtils; import org.apache.pinot.spi.utils.ByteArray; import org.apache.pinot.sql.parsers.ParserUtils; @@ -96,8 +95,7 @@ public static List convertAggregateList(List groupByList return expressions; } - public static List convertOrderByList(SortNode node, PinotQuery pinotQuery) { - List collations = node.getCollations(); + public static List convertOrderByList(List collations, PinotQuery pinotQuery) { List orderByExpressions = new ArrayList<>(collations.size()); for (RelFieldCollation collation : collations) { orderByExpressions.add(convertOrderBy(collation, pinotQuery)); diff --git a/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/explain/PlanNodeMerger.java b/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/explain/PlanNodeMerger.java index 611d4417259b..6ae02da45fc9 100644 --- a/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/explain/PlanNodeMerger.java +++ b/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/explain/PlanNodeMerger.java @@ -147,6 +147,12 @@ public PlanNode visitAggregate(AggregateNode node, PlanNode context) { if (node.isLeafReturnFinalResult() != otherNode.isLeafReturnFinalResult()) { return null; } + if (!node.getCollations().equals(otherNode.getCollations())) { + return null; + } + if (node.getLimit() != otherNode.getLimit()) { + return null; + } List children = mergeChildren(node, context); if (children == null) { return null; diff --git a/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/logical/EquivalentStagesFinder.java b/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/logical/EquivalentStagesFinder.java index 55813264ffb0..61cf5d5be626 100644 --- a/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/logical/EquivalentStagesFinder.java +++ b/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/logical/EquivalentStagesFinder.java @@ -195,7 +195,9 @@ public Boolean visitAggregate(AggregateNode node1, PlanNode node2) { && Objects.equals(node1.getFilterArgs(), that.getFilterArgs()) && Objects.equals(node1.getGroupKeys(), that.getGroupKeys()) && node1.getAggType() == that.getAggType() - && node1.isLeafReturnFinalResult() == that.isLeafReturnFinalResult(); + && node1.isLeafReturnFinalResult() == that.isLeafReturnFinalResult() + && Objects.equals(node1.getCollations(), that.getCollations()) + && node1.getLimit() == that.getLimit(); } @Override diff --git a/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/logical/RelToPlanNodeConverter.java b/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/logical/RelToPlanNodeConverter.java index 38170116126a..3f5ab2261e0c 100644 --- a/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/logical/RelToPlanNodeConverter.java +++ b/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/logical/RelToPlanNodeConverter.java @@ -264,7 +264,7 @@ private AggregateNode convertLogicalAggregate(PinotLogicalAggregate node) { } return new AggregateNode(DEFAULT_STAGE_ID, toDataSchema(node.getRowType()), NodeHint.fromRelHints(node.getHints()), convertInputs(node.getInputs()), functionCalls, filterArgs, node.getGroupSet().asList(), node.getAggType(), - node.isLeafReturnFinalResult()); + node.isLeafReturnFinalResult(), node.getCollations(), node.getLimit()); } private ProjectNode convertLogicalProject(LogicalProject node) { diff --git a/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/plannode/AggregateNode.java b/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/plannode/AggregateNode.java index be4a6d9fb87d..5e6fda1e1b6e 100644 --- a/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/plannode/AggregateNode.java +++ b/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/plannode/AggregateNode.java @@ -20,6 +20,8 @@ import java.util.List; import java.util.Objects; +import javax.annotation.Nullable; +import org.apache.calcite.rel.RelFieldCollation; import org.apache.pinot.common.utils.DataSchema; import org.apache.pinot.query.planner.logical.RexExpression; @@ -31,15 +33,22 @@ public class AggregateNode extends BasePlanNode { private final AggType _aggType; private final boolean _leafReturnFinalResult; + // The following fields are set when group trim is enabled, and are extracted from the Sort on top of this Aggregate. + // The group trim behavior at leaf stage is shared with single-stage engine. + private final List _collations; + private final int _limit; + public AggregateNode(int stageId, DataSchema dataSchema, NodeHint nodeHint, List inputs, List aggCalls, List filterArgs, List groupKeys, AggType aggType, - boolean leafReturnFinalResult) { + boolean leafReturnFinalResult, @Nullable List collations, int limit) { super(stageId, dataSchema, nodeHint, inputs); _aggCalls = aggCalls; _filterArgs = filterArgs; _groupKeys = groupKeys; _aggType = aggType; _leafReturnFinalResult = leafReturnFinalResult; + _collations = collations != null ? collations : List.of(); + _limit = limit; } public List getAggCalls() { @@ -62,6 +71,14 @@ public boolean isLeafReturnFinalResult() { return _leafReturnFinalResult; } + public List getCollations() { + return _collations; + } + + public int getLimit() { + return _limit; + } + @Override public String explain() { return "AGGREGATE_" + _aggType; @@ -75,7 +92,7 @@ public T visit(PlanNodeVisitor visitor, C context) { @Override public PlanNode withInputs(List inputs) { return new AggregateNode(_stageId, _dataSchema, _nodeHint, inputs, _aggCalls, _filterArgs, _groupKeys, _aggType, - _leafReturnFinalResult); + _leafReturnFinalResult, _collations, _limit); } @Override @@ -90,14 +107,15 @@ public boolean equals(Object o) { return false; } AggregateNode that = (AggregateNode) o; - return Objects.equals(_aggCalls, that._aggCalls) && Objects.equals(_filterArgs, that._filterArgs) && Objects.equals( - _groupKeys, that._groupKeys) && _aggType == that._aggType - && _leafReturnFinalResult == that._leafReturnFinalResult; + return _leafReturnFinalResult == that._leafReturnFinalResult && _limit == that._limit && Objects.equals(_aggCalls, + that._aggCalls) && Objects.equals(_filterArgs, that._filterArgs) && Objects.equals(_groupKeys, that._groupKeys) + && _aggType == that._aggType && Objects.equals(_collations, that._collations); } @Override public int hashCode() { - return Objects.hash(super.hashCode(), _aggCalls, _filterArgs, _groupKeys, _aggType, _leafReturnFinalResult); + return Objects.hash(super.hashCode(), _aggCalls, _filterArgs, _groupKeys, _aggType, _leafReturnFinalResult, + _collations, _limit); } /** diff --git a/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/serde/PlanNodeDeserializer.java b/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/serde/PlanNodeDeserializer.java index abd474ebce3e..0f6851418925 100644 --- a/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/serde/PlanNodeDeserializer.java +++ b/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/serde/PlanNodeDeserializer.java @@ -87,7 +87,8 @@ private static AggregateNode deserializeAggregateNode(Plan.PlanNode protoNode) { return new AggregateNode(protoNode.getStageId(), extractDataSchema(protoNode), extractNodeHint(protoNode), extractInputs(protoNode), convertFunctionCalls(protoAggregateNode.getAggCallsList()), protoAggregateNode.getFilterArgsList(), protoAggregateNode.getGroupKeysList(), - convertAggType(protoAggregateNode.getAggType()), protoAggregateNode.getLeafReturnFinalResult()); + convertAggType(protoAggregateNode.getAggType()), protoAggregateNode.getLeafReturnFinalResult(), + convertCollations(protoAggregateNode.getCollationsList()), protoAggregateNode.getLimit()); } private static FilterNode deserializeFilterNode(Plan.PlanNode protoNode) { diff --git a/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/serde/PlanNodeSerializer.java b/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/serde/PlanNodeSerializer.java index 65ccb13b2cae..e7862173e749 100644 --- a/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/serde/PlanNodeSerializer.java +++ b/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/serde/PlanNodeSerializer.java @@ -98,6 +98,8 @@ public Void visitAggregate(AggregateNode node, Plan.PlanNode.Builder builder) { .addAllGroupKeys(node.getGroupKeys()) .setAggType(convertAggType(node.getAggType())) .setLeafReturnFinalResult(node.isLeafReturnFinalResult()) + .addAllCollations(convertCollations(node.getCollations())) + .setLimit(node.getLimit()) .build(); builder.setAggregateNode(aggregateNode); return null; diff --git a/pinot-query-planner/src/test/resources/queries/GroupByPlans.json b/pinot-query-planner/src/test/resources/queries/GroupByPlans.json index 63a69f5e8ecb..8e513066d904 100644 --- a/pinot-query-planner/src/test/resources/queries/GroupByPlans.json +++ b/pinot-query-planner/src/test/resources/queries/GroupByPlans.json @@ -249,6 +249,55 @@ "\n LogicalTableScan(table=[[default, a]])", "\n" ] + }, + { + "description": "SQL hint based group by optimization with partitioned aggregated values and group trim enabled", + "sql": "EXPLAIN PLAN FOR SELECT /*+ aggOptions(is_leaf_return_final_result='true', is_enable_group_trim='true') */ col1, COUNT(DISTINCT col2) AS cnt FROM a WHERE col3 >= 0 GROUP BY col1 ORDER BY cnt DESC LIMIT 10", + "output": [ + "Execution Plan", + "\nLogicalSort(sort0=[$1], dir0=[DESC], offset=[0], fetch=[10])", + "\n PinotLogicalSortExchange(distribution=[hash], collation=[[1 DESC]], isSortOnSender=[false], isSortOnReceiver=[true])", + "\n LogicalSort(sort0=[$1], dir0=[DESC], fetch=[10])", + "\n PinotLogicalAggregate(group=[{0}], agg#0=[DISTINCTCOUNT($1)], aggType=[FINAL], leafReturnFinalResult=[true], collations=[[1 DESC]], limit=[10])", + "\n PinotLogicalExchange(distribution=[hash[0]])", + "\n PinotLogicalAggregate(group=[{0}], agg#0=[DISTINCTCOUNT($1)], aggType=[LEAF], leafReturnFinalResult=[true], collations=[[1 DESC]], limit=[10])", + "\n LogicalFilter(condition=[>=($2, 0)])", + "\n LogicalTableScan(table=[[default, a]])", + "\n" + ] + }, + { + "description": "SQL hint based group by optimization with group trim enabled without returning group key", + "sql": "EXPLAIN PLAN FOR SELECT /*+ aggOptions(is_enable_group_trim='true') */ COUNT(DISTINCT col2) AS cnt FROM a WHERE a.col3 >= 0 GROUP BY col1 ORDER BY cnt DESC LIMIT 10", + "output": [ + "Execution Plan", + "\nLogicalSort(sort0=[$0], dir0=[DESC], offset=[0], fetch=[10])", + "\n PinotLogicalSortExchange(distribution=[hash], collation=[[0 DESC]], isSortOnSender=[false], isSortOnReceiver=[true])", + "\n LogicalSort(sort0=[$0], dir0=[DESC], fetch=[10])", + "\n LogicalProject(cnt=[$1])", + "\n PinotLogicalAggregate(group=[{0}], agg#0=[DISTINCTCOUNT($1)], aggType=[FINAL], collations=[[1 DESC]], limit=[10])", + "\n PinotLogicalExchange(distribution=[hash[0]])", + "\n PinotLogicalAggregate(group=[{0}], agg#0=[DISTINCTCOUNT($1)], aggType=[LEAF], collations=[[1 DESC]], limit=[10])", + "\n LogicalFilter(condition=[>=($2, 0)])", + "\n LogicalTableScan(table=[[default, a]])", + "\n" + ] + }, + { + "description": "SQL hint based distinct optimization with group trim enabled", + "sql": "EXPLAIN PLAN FOR SELECT /*+ aggOptions(is_enable_group_trim='true') */ DISTINCT col1, col2 FROM a WHERE col3 >= 0 LIMIT 10", + "output": [ + "Execution Plan", + "\nLogicalSort(offset=[0], fetch=[10])", + "\n PinotLogicalSortExchange(distribution=[hash], collation=[[]], isSortOnSender=[false], isSortOnReceiver=[false])", + "\n LogicalSort(fetch=[10])", + "\n PinotLogicalAggregate(group=[{0, 1}], aggType=[FINAL], collations=[[]], limit=[10])", + "\n PinotLogicalExchange(distribution=[hash[0, 1]])", + "\n PinotLogicalAggregate(group=[{0, 1}], aggType=[LEAF], collations=[[]], limit=[10])", + "\n LogicalFilter(condition=[>=($2, 0)])", + "\n LogicalTableScan(table=[[default, a]])", + "\n" + ] } ] } diff --git a/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/plan/server/ServerPlanRequestVisitor.java b/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/plan/server/ServerPlanRequestVisitor.java index bd58b7f64f04..1ac11809aa26 100644 --- a/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/plan/server/ServerPlanRequestVisitor.java +++ b/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/plan/server/ServerPlanRequestVisitor.java @@ -22,6 +22,7 @@ import java.util.ArrayList; import java.util.Collections; import java.util.List; +import org.apache.calcite.rel.RelFieldCollation; import org.apache.pinot.calcite.rel.logical.PinotRelExchangeType; import org.apache.pinot.common.datablock.DataBlock; import org.apache.pinot.common.request.DataSource; @@ -71,22 +72,29 @@ static void walkPlanNode(PlanNode node, ServerPlanRequestContext context) { public Void visitAggregate(AggregateNode node, ServerPlanRequestContext context) { if (visit(node.getInputs().get(0), context)) { PinotQuery pinotQuery = context.getPinotQuery(); - if (pinotQuery.getGroupByList() == null) { - List groupByList = CalciteRexExpressionParser.convertInputRefs(node.getGroupKeys(), pinotQuery); + List groupByList = CalciteRexExpressionParser.convertInputRefs(node.getGroupKeys(), pinotQuery); + if (!groupByList.isEmpty()) { pinotQuery.setGroupByList(groupByList); - pinotQuery.setSelectList( - CalciteRexExpressionParser.convertAggregateList(groupByList, node.getAggCalls(), node.getFilterArgs(), - pinotQuery)); - if (node.getAggType() == AggregateNode.AggType.DIRECT) { - pinotQuery.putToQueryOptions(CommonConstants.Broker.Request.QueryOptionKey.SERVER_RETURN_FINAL_RESULT, - "true"); - } else if (node.isLeafReturnFinalResult()) { - pinotQuery.putToQueryOptions( - CommonConstants.Broker.Request.QueryOptionKey.SERVER_RETURN_FINAL_RESULT_KEY_UNPARTITIONED, "true"); + } + pinotQuery.setSelectList( + CalciteRexExpressionParser.convertAggregateList(groupByList, node.getAggCalls(), node.getFilterArgs(), + pinotQuery)); + if (node.getAggType() == AggregateNode.AggType.DIRECT) { + pinotQuery.putToQueryOptions(CommonConstants.Broker.Request.QueryOptionKey.SERVER_RETURN_FINAL_RESULT, "true"); + } else if (node.isLeafReturnFinalResult()) { + pinotQuery.putToQueryOptions( + CommonConstants.Broker.Request.QueryOptionKey.SERVER_RETURN_FINAL_RESULT_KEY_UNPARTITIONED, "true"); + } + int limit = node.getLimit(); + if (limit > 0) { + List collations = node.getCollations(); + if (!collations.isEmpty()) { + pinotQuery.setOrderByList(CalciteRexExpressionParser.convertOrderByList(collations, pinotQuery)); } - // there cannot be any more modification of PinotQuery post agg, thus this is the last one possible. - context.setLeafStageBoundaryNode(node); + pinotQuery.setLimit(limit); } + // There cannot be any more modification of PinotQuery post agg, thus this is the last one possible. + context.setLeafStageBoundaryNode(node); } return null; } @@ -193,8 +201,9 @@ public Void visitSort(SortNode node, ServerPlanRequestContext context) { if (visit(node.getInputs().get(0), context)) { PinotQuery pinotQuery = context.getPinotQuery(); if (pinotQuery.getOrderByList() == null) { - if (!node.getCollations().isEmpty()) { - pinotQuery.setOrderByList(CalciteRexExpressionParser.convertOrderByList(node, pinotQuery)); + List collations = node.getCollations(); + if (!collations.isEmpty()) { + pinotQuery.setOrderByList(CalciteRexExpressionParser.convertOrderByList(collations, pinotQuery)); } if (node.getFetch() >= 0) { pinotQuery.setLimit(node.getFetch()); diff --git a/pinot-query-runtime/src/test/java/org/apache/pinot/query/runtime/operator/AggregateOperatorTest.java b/pinot-query-runtime/src/test/java/org/apache/pinot/query/runtime/operator/AggregateOperatorTest.java index f7f56e0ccb6e..b2e73f226a3a 100644 --- a/pinot-query-runtime/src/test/java/org/apache/pinot/query/runtime/operator/AggregateOperatorTest.java +++ b/pinot-query-runtime/src/test/java/org/apache/pinot/query/runtime/operator/AggregateOperatorTest.java @@ -273,7 +273,7 @@ private AggregateOperator getOperator(DataSchema resultSchema, List filterArgs, List groupKeys, PlanNode.NodeHint nodeHint) { return new AggregateOperator(OperatorTestUtil.getTracingContext(), _input, new AggregateNode(-1, resultSchema, nodeHint, List.of(), aggCalls, filterArgs, groupKeys, AggType.DIRECT, - false)); + false, null, 0)); } private AggregateOperator getOperator(DataSchema resultSchema, List aggCalls, diff --git a/pinot-query-runtime/src/test/java/org/apache/pinot/query/runtime/operator/MultiStageAccountingTest.java b/pinot-query-runtime/src/test/java/org/apache/pinot/query/runtime/operator/MultiStageAccountingTest.java index fc7ebba0b4cb..05ccf5762191 100644 --- a/pinot-query-runtime/src/test/java/org/apache/pinot/query/runtime/operator/MultiStageAccountingTest.java +++ b/pinot-query-runtime/src/test/java/org/apache/pinot/query/runtime/operator/MultiStageAccountingTest.java @@ -152,7 +152,7 @@ private static MultiStageOperator getAggregateOperator() { new DataSchema(new String[]{"group", "sum"}, new DataSchema.ColumnDataType[]{INT, DOUBLE}); return new AggregateOperator(OperatorTestUtil.getTracingContext(), input, new AggregateNode(-1, resultSchema, PlanNode.NodeHint.EMPTY, List.of(), aggCalls, filterArgs, groupKeys, - AggregateNode.AggType.DIRECT, false)); + AggregateNode.AggType.DIRECT, false, null, 0)); } private static MultiStageOperator getHashJoinOperator() { diff --git a/pinot-query-runtime/src/test/resources/queries/QueryHints.json b/pinot-query-runtime/src/test/resources/queries/QueryHints.json index e7c2ca375700..e8d30ed40905 100644 --- a/pinot-query-runtime/src/test/resources/queries/QueryHints.json +++ b/pinot-query-runtime/src/test/resources/queries/QueryHints.json @@ -321,6 +321,14 @@ "description": "aggregate with skip intermediate stage hint (via hint option is_partitioned_by_group_by_keys)", "sql": "SELECT /*+ aggOptions(is_partitioned_by_group_by_keys='true') */ {tbl1}.num, COUNT(*), SUM({tbl1}.val), SUM({tbl1}.num), COUNT(DISTINCT {tbl1}.val) FROM {tbl1} WHERE {tbl1}.val >= 0 AND {tbl1}.name != 'a' GROUP BY {tbl1}.num" }, + { + "description": "aggregate with skip intermediate stage and enable group trim hint", + "sql": "SELECT /*+ aggOptions(is_partitioned_by_group_by_keys='true', is_enable_group_trim='true') */ num, COUNT(*), SUM(val), SUM(num), COUNT(DISTINCT val) FROM {tbl1} WHERE val >= 0 AND name != 'a' GROUP BY num ORDER BY COUNT(*) DESC, num LIMIT 1" + }, + { + "description": "distinct with enable group trim hint", + "sql": "SELECT /*+ aggOptions(is_enable_group_trim='true') */ DISTINCT num, val FROM {tbl1} WHERE val >= 0 AND name != 'a' ORDER BY val DESC, num LIMIT 1" + }, { "description": "join with pre-partitioned left and right tables", "sql": "SELECT {tbl1}.num, {tbl1}.val, {tbl2}.data FROM {tbl1} /*+ tableOptions(partition_function='hashcode', partition_key='num', partition_size='4') */ JOIN {tbl2} /*+ tableOptions(partition_function='hashcode', partition_key='num', partition_size='4') */ ON {tbl1}.num = {tbl2}.num WHERE {tbl2}.data > 0" From 9b9606898872105341b5a218c81a61f8d60d76e4 Mon Sep 17 00:00:00 2001 From: "Xiaotian (Jackie) Jiang" <17555551+Jackie-Jiang@users.noreply.github.com> Date: Mon, 30 Dec 2024 23:02:46 -0800 Subject: [PATCH 5/5] Optimize MergeEqInFilterOptimizer by reducing the hash computation of Expression (#14732) --- .../filter/MergeEqInFilterOptimizer.java | 124 +++++++++++------- 1 file changed, 74 insertions(+), 50 deletions(-) diff --git a/pinot-core/src/main/java/org/apache/pinot/core/query/optimizer/filter/MergeEqInFilterOptimizer.java b/pinot-core/src/main/java/org/apache/pinot/core/query/optimizer/filter/MergeEqInFilterOptimizer.java index 6836f8022617..5104587322ec 100644 --- a/pinot-core/src/main/java/org/apache/pinot/core/query/optimizer/filter/MergeEqInFilterOptimizer.java +++ b/pinot-core/src/main/java/org/apache/pinot/core/query/optimizer/filter/MergeEqInFilterOptimizer.java @@ -18,16 +18,17 @@ */ package org.apache.pinot.core.query.optimizer.filter; +import com.google.common.collect.Maps; import java.util.ArrayList; +import java.util.Collection; import java.util.HashMap; -import java.util.HashSet; import java.util.List; import java.util.Map; -import java.util.Set; import javax.annotation.Nullable; import org.apache.pinot.common.request.Expression; import org.apache.pinot.common.request.ExpressionType; import org.apache.pinot.common.request.Function; +import org.apache.pinot.common.request.context.RequestContextUtils; import org.apache.pinot.common.utils.request.RequestUtils; import org.apache.pinot.spi.data.Schema; import org.apache.pinot.sql.FilterKind; @@ -61,9 +62,10 @@ private Expression optimize(Expression filterExpression) { String operator = function.getOperator(); if (operator.equals(FilterKind.OR.name())) { List children = function.getOperands(); - Map> valuesMap = new HashMap<>(); - List newChildren = new ArrayList<>(); - boolean recreateFilter = false; + // Key is the lhs of the EQ/IN predicate, value is the map from string representation of the value to the value + Map> valuesMap = new HashMap<>(); + List newChildren = new ArrayList<>(children.size()); + boolean[] recreateFilter = new boolean[1]; // Iterate over all the child filters to merge EQ and IN predicates for (Expression child : children) { @@ -80,52 +82,62 @@ private Expression optimize(Expression filterExpression) { List operands = childFunction.getOperands(); Expression lhs = operands.get(0); Expression value = operands.get(1); - Set values = valuesMap.get(lhs); - if (values == null) { - values = new HashSet<>(); - values.add(value); - valuesMap.put(lhs, values); - } else { - values.add(value); - // Recreate filter when multiple predicates can be merged - recreateFilter = true; - } + // Use string value to de-duplicate the values to prevent the overhead of Expression.hashCode(). This is + // consistent with how server handles predicates. + String stringValue = RequestContextUtils.getStringValue(value); + valuesMap.compute(lhs, (k, v) -> { + if (v == null) { + Map values = new HashMap<>(); + values.put(stringValue, value); + return values; + } else { + v.put(stringValue, value); + // Recreate filter when multiple predicates can be merged + recreateFilter[0] = true; + return v; + } + }); } else if (childOperator.equals(FilterKind.IN.name())) { List operands = childFunction.getOperands(); Expression lhs = operands.get(0); - Set inPredicateValuesSet = new HashSet<>(); - int numOperands = operands.size(); - for (int i = 1; i < numOperands; i++) { - inPredicateValuesSet.add(operands.get(i)); - } - int numUniqueValues = inPredicateValuesSet.size(); - if (numUniqueValues == 1 || numUniqueValues != numOperands - 1) { - // Recreate filter when the IN predicate contains only 1 value (can be rewritten to EQ predicate), - // or values can be de-duplicated - recreateFilter = true; - } - Set values = valuesMap.get(lhs); - if (values == null) { - valuesMap.put(lhs, inPredicateValuesSet); - } else { - values.addAll(inPredicateValuesSet); - // Recreate filter when multiple predicates can be merged - recreateFilter = true; - } + valuesMap.compute(lhs, (k, v) -> { + if (v == null) { + Map values = getInValues(operands); + int numUniqueValues = values.size(); + if (numUniqueValues == 1 || numUniqueValues != operands.size() - 1) { + // Recreate filter when the IN predicate contains only 1 value (can be rewritten to EQ predicate), or + // values can be de-duplicated + recreateFilter[0] = true; + } + return values; + } else { + int numOperands = operands.size(); + for (int i = 1; i < numOperands; i++) { + Expression value = operands.get(i); + // Use string value to de-duplicate the values to prevent the overhead of Expression.hashCode(). This + // is consistent with how server handles predicates. + String stringValue = RequestContextUtils.getStringValue(value); + v.put(stringValue, value); + } + // Recreate filter when multiple predicates can be merged + recreateFilter[0] = true; + return v; + } + }); } else { newChildren.add(child); } } } - if (recreateFilter) { + if (recreateFilter[0]) { if (newChildren.isEmpty() && valuesMap.size() == 1) { // Single range without other filters - Map.Entry> entry = valuesMap.entrySet().iterator().next(); - return getFilterExpression(entry.getKey(), entry.getValue()); + Map.Entry> entry = valuesMap.entrySet().iterator().next(); + return getFilterExpression(entry.getKey(), entry.getValue().values()); } else { - for (Map.Entry> entry : valuesMap.entrySet()) { - newChildren.add(getFilterExpression(entry.getKey(), entry.getValue())); + for (Map.Entry> entry : valuesMap.entrySet()) { + newChildren.add(getFilterExpression(entry.getKey(), entry.getValue().values())); } function.setOperands(newChildren); return filterExpression; @@ -138,17 +150,12 @@ private Expression optimize(Expression filterExpression) { return filterExpression; } else if (operator.equals(FilterKind.IN.name())) { List operands = function.getOperands(); - Expression lhs = operands.get(0); - Set values = new HashSet<>(); - int numOperands = operands.size(); - for (int i = 1; i < numOperands; i++) { - values.add(operands.get(i)); - } + Map values = getInValues(operands); int numUniqueValues = values.size(); - if (numUniqueValues == 1 || numUniqueValues != numOperands - 1) { - // Recreate filter when the IN predicate contains only 1 value (can be rewritten to EQ predicate), or values - // can be de-duplicated - return getFilterExpression(lhs, values); + if (numUniqueValues == 1 || numUniqueValues != operands.size() - 1) { + // Recreate filter when the IN predicate contains only 1 value (can be rewritten to EQ predicate), or values can + // be de-duplicated + return getFilterExpression(operands.get(0), values.values()); } else { return filterExpression; } @@ -157,10 +164,27 @@ private Expression optimize(Expression filterExpression) { } } + /** + * Helper method to get the values from the IN predicate. Returns a map from string representation of the value to the + * value. + */ + private Map getInValues(List operands) { + int numOperands = operands.size(); + Map values = Maps.newHashMapWithExpectedSize(numOperands - 1); + for (int i = 1; i < numOperands; i++) { + Expression value = operands.get(i); + // Use string value to de-duplicate the values to prevent the overhead of Expression.hashCode(). This is + // consistent with how server handles predicates. + String stringValue = RequestContextUtils.getStringValue(value); + values.put(stringValue, value); + } + return values; + } + /** * Helper method to construct a EQ or IN predicate filter Expression from the given lhs and values. */ - private static Expression getFilterExpression(Expression lhs, Set values) { + private static Expression getFilterExpression(Expression lhs, Collection values) { int numValues = values.size(); if (numValues == 1) { return RequestUtils.getFunctionExpression(FilterKind.EQUALS.name(), lhs, values.iterator().next());