From f0b0c1f87cd52378ea49bfda6ec4d42c9f30af2c Mon Sep 17 00:00:00 2001
From: Yash Mayya <yash.mayya@gmail.com>
Date: Mon, 20 Jan 2025 16:48:03 +0530
Subject: [PATCH] Rework MSE query throttling to take into account estimated
 number of threads used by a query

---
 .../MultiStageBrokerRequestHandler.java       |   5 +-
 .../MultiStageQueryThrottler.java             |  82 +++++----
 .../MultiStageQueryThrottlerTest.java         | 171 ++++++++++--------
 .../concurrency/AdjustableSemaphore.java      |  28 ++-
 .../MultiStageEngineIntegrationTest.java      |  27 ++-
 .../planner/physical/DispatchableSubPlan.java |  35 ++++
 .../pinot/spi/utils/CommonConstants.java      |   6 +-
 7 files changed, 221 insertions(+), 133 deletions(-)

diff --git a/pinot-broker/src/main/java/org/apache/pinot/broker/requesthandler/MultiStageBrokerRequestHandler.java b/pinot-broker/src/main/java/org/apache/pinot/broker/requesthandler/MultiStageBrokerRequestHandler.java
index 2e75b6dd9018..e0661eb62c00 100644
--- a/pinot-broker/src/main/java/org/apache/pinot/broker/requesthandler/MultiStageBrokerRequestHandler.java
+++ b/pinot-broker/src/main/java/org/apache/pinot/broker/requesthandler/MultiStageBrokerRequestHandler.java
@@ -231,10 +231,11 @@ protected BrokerResponse handleRequest(long requestId, String query, SqlNodeAndO
     }
 
     Timer queryTimer = new Timer(queryTimeoutMs);
+    int estimatedNumQueryThreads = dispatchableSubPlan.getEstimatedNumQueryThreads();
     try {
       // It's fine to block in this thread because we use a separate thread pool from the main Jersey server to process
       // these requests.
-      if (!_queryThrottler.tryAcquire(queryTimeoutMs, TimeUnit.MILLISECONDS)) {
+      if (!_queryThrottler.tryAcquire(estimatedNumQueryThreads, queryTimeoutMs, TimeUnit.MILLISECONDS)) {
         LOGGER.warn("Timed out waiting to execute request {}: {}", requestId, query);
         requestContext.setErrorCode(QueryException.EXECUTION_TIMEOUT_ERROR_CODE);
         return new BrokerResponseNative(QueryException.EXECUTION_TIMEOUT_ERROR);
@@ -311,7 +312,7 @@ protected BrokerResponse handleRequest(long requestId, String query, SqlNodeAndO
 
       return brokerResponse;
     } finally {
-      _queryThrottler.release();
+      _queryThrottler.release(estimatedNumQueryThreads);
     }
   }
 
diff --git a/pinot-broker/src/main/java/org/apache/pinot/broker/requesthandler/MultiStageQueryThrottler.java b/pinot-broker/src/main/java/org/apache/pinot/broker/requesthandler/MultiStageQueryThrottler.java
index a6ca713b19f4..b9852425544c 100644
--- a/pinot-broker/src/main/java/org/apache/pinot/broker/requesthandler/MultiStageQueryThrottler.java
+++ b/pinot-broker/src/main/java/org/apache/pinot/broker/requesthandler/MultiStageQueryThrottler.java
@@ -37,8 +37,14 @@
 
 /**
  * This class helps limit the number of multi-stage queries being executed concurrently. Note that the cluster
- * configuration is a "per server" value and the broker currently simply assumes that a query will be across all
- * servers. Another assumption here is that queries are evenly distributed across brokers.
+ * configuration is a "per server" value and the broker currently computes the max server query threads as
+ * <em>CONFIG_OF_MULTI_STAGE_ENGINE_MAX_SERVER_QUERY_THREADS * numServers / numBrokers</em>. Note that the config value,
+ * number of servers, and number of brokers are all dynamically updated here.
+ * <p>
+ * Another assumption made here is that queries are evenly distributed across brokers.
+ * <p>
+ * This is designed to limit the number of multi-stage queries being concurrently executed across a cluster and is not
+ * intended to prevent individual large queries from being executed.
  */
 public class MultiStageQueryThrottler implements ClusterChangeHandler {
 
@@ -50,10 +56,10 @@ public class MultiStageQueryThrottler implements ClusterChangeHandler {
   private int _numBrokers;
   private int _numServers;
   /**
-   * If _maxConcurrentQueries is <= 0, it means that the cluster is not configured to limit the number of multi-stage
+   * If _maxServerQueryThreads is <= 0, it means that the cluster is not configured to limit the number of multi-stage
    * queries that can be executed concurrently. In this case, we should not block the query.
    */
-  private int _maxConcurrentQueries;
+  private int _maxServerQueryThreads;
   private AdjustableSemaphore _semaphore;
 
   @Override
@@ -63,11 +69,11 @@ public void init(HelixManager helixManager) {
     _helixConfigScope = new HelixConfigScopeBuilder(HelixConfigScope.ConfigScopeProperty.CLUSTER).forCluster(
         _helixManager.getClusterName()).build();
 
-    _maxConcurrentQueries = Integer.parseInt(
+    _maxServerQueryThreads = Integer.parseInt(
         _helixAdmin.getConfig(_helixConfigScope,
-                Collections.singletonList(CommonConstants.Helix.CONFIG_OF_MAX_CONCURRENT_MULTI_STAGE_QUERIES))
-            .getOrDefault(CommonConstants.Helix.CONFIG_OF_MAX_CONCURRENT_MULTI_STAGE_QUERIES,
-                CommonConstants.Helix.DEFAULT_MAX_CONCURRENT_MULTI_STAGE_QUERIES));
+                Collections.singletonList(CommonConstants.Helix.CONFIG_OF_MULTI_STAGE_ENGINE_MAX_SERVER_QUERY_THREADS))
+            .getOrDefault(CommonConstants.Helix.CONFIG_OF_MULTI_STAGE_ENGINE_MAX_SERVER_QUERY_THREADS,
+                CommonConstants.Helix.DEFAULT_MULTI_STAGE_ENGINE_MAX_SERVER_QUERY_THREADS));
 
     List<String> clusterInstances = _helixAdmin.getInstancesInCluster(_helixManager.getClusterName());
     _numBrokers = Math.max(1, (int) clusterInstances.stream()
@@ -77,36 +83,49 @@ public void init(HelixManager helixManager) {
         .filter(instance -> instance.startsWith(CommonConstants.Helix.PREFIX_OF_SERVER_INSTANCE))
         .count());
 
-    if (_maxConcurrentQueries > 0) {
-      _semaphore = new AdjustableSemaphore(Math.max(1, _maxConcurrentQueries * _numServers / _numBrokers), true);
+    if (_maxServerQueryThreads > 0) {
+      _semaphore = new AdjustableSemaphore(Math.max(1, _maxServerQueryThreads * _numServers / _numBrokers), true);
     }
   }
 
   /**
    * Returns true if the query can be executed (waiting until it can be executed if necessary), false otherwise.
    * <p>
-   * {@link #release()} should be called after the query is done executing. It is the responsibility of the caller to
-   * ensure that {@link #release()} is called exactly once for each call to this method.
+   * {@link #release(int)} should be called after the query is done executing. It is the responsibility of the caller to
+   * ensure that {@link #release(int)} is called exactly once for each call to this method.
    *
+   * @param numQueryThreads the estimated number of query server threads
    * @param timeout the maximum time to wait
    * @param unit the time unit of the timeout argument
+   *
    * @throws InterruptedException if the current thread is interrupted
+   * @throws RuntimeException if the query can never be dispatched due to the number of estimated query server threads
+   * being greater than the maximum number of server query threads calculated on the basis of
+   * <em>CONFIG_OF_MULTI_STAGE_ENGINE_MAX_SERVER_QUERY_THREADS * numServers / numBrokers</em>
    */
-  public boolean tryAcquire(long timeout, TimeUnit unit)
+  public boolean tryAcquire(int numQueryThreads, long timeout, TimeUnit unit)
       throws InterruptedException {
-    if (_maxConcurrentQueries <= 0) {
+    if (_maxServerQueryThreads <= 0) {
       return true;
     }
-    return _semaphore.tryAcquire(timeout, unit);
+
+    if (numQueryThreads > _semaphore.getTotalPermits()) {
+      throw new RuntimeException("Can't dispatch query because the estimated number of server threads for this query is "
+          + "too large for the configured value of '"
+          + CommonConstants.Helix.CONFIG_OF_MULTI_STAGE_ENGINE_MAX_SERVER_QUERY_THREADS + "'. Consider increasing the "
+          + "value of this configuration");
+    }
+
+    return _semaphore.tryAcquire(numQueryThreads, timeout, unit);
   }
 
   /**
    * Should be called after the query is done executing. It is the responsibility of the caller to ensure that this
-   * method is called exactly once for each call to {@link #tryAcquire(long, TimeUnit)}.
+   * method is called exactly once for each call to {@link #tryAcquire(int, long, TimeUnit)}.
    */
-  public void release() {
-    if (_maxConcurrentQueries > 0) {
-      _semaphore.release();
+  public void release(int numQueryThreads) {
+    if (_maxServerQueryThreads > 0) {
+      _semaphore.release(numQueryThreads);
     }
   }
 
@@ -128,23 +147,22 @@ public void processClusterChange(HelixConstants.ChangeType changeType) {
       if (numBrokers != _numBrokers || numServers != _numServers) {
         _numBrokers = numBrokers;
         _numServers = numServers;
-        if (_maxConcurrentQueries > 0) {
-          _semaphore.setPermits(Math.max(1, _maxConcurrentQueries * _numServers / _numBrokers));
+        if (_maxServerQueryThreads > 0) {
+          _semaphore.setPermits(Math.max(1, _maxServerQueryThreads * _numServers / _numBrokers));
         }
       }
     } else {
-      int maxConcurrentQueries = Integer.parseInt(
-          _helixAdmin.getConfig(_helixConfigScope,
-                  Collections.singletonList(CommonConstants.Helix.CONFIG_OF_MAX_CONCURRENT_MULTI_STAGE_QUERIES))
-              .getOrDefault(CommonConstants.Helix.CONFIG_OF_MAX_CONCURRENT_MULTI_STAGE_QUERIES,
-                  CommonConstants.Helix.DEFAULT_MAX_CONCURRENT_MULTI_STAGE_QUERIES));
+      int maxServerQueryThreads = Integer.parseInt(_helixAdmin.getConfig(_helixConfigScope,
+              Collections.singletonList(CommonConstants.Helix.CONFIG_OF_MULTI_STAGE_ENGINE_MAX_SERVER_QUERY_THREADS))
+          .getOrDefault(CommonConstants.Helix.CONFIG_OF_MULTI_STAGE_ENGINE_MAX_SERVER_QUERY_THREADS,
+              CommonConstants.Helix.DEFAULT_MULTI_STAGE_ENGINE_MAX_SERVER_QUERY_THREADS));
 
-      if (_maxConcurrentQueries == maxConcurrentQueries) {
+      if (_maxServerQueryThreads == maxServerQueryThreads) {
         return;
       }
 
-      if (_maxConcurrentQueries <= 0 && maxConcurrentQueries > 0
-          || _maxConcurrentQueries > 0 && maxConcurrentQueries <= 0) {
+      if (_maxServerQueryThreads <= 0 && maxServerQueryThreads > 0
+          || _maxServerQueryThreads > 0 && maxServerQueryThreads <= 0) {
         // This operation isn't safe to do while queries are running so we require a restart of the broker for this
         // change to take effect.
         LOGGER.warn("Enabling or disabling limitation of the maximum number of multi-stage queries running "
@@ -152,10 +170,10 @@ public void processClusterChange(HelixConstants.ChangeType changeType) {
         return;
       }
 
-      if (maxConcurrentQueries > 0) {
-        _semaphore.setPermits(Math.max(1, maxConcurrentQueries * _numServers / _numBrokers));
+      if (maxServerQueryThreads > 0) {
+        _semaphore.setPermits(Math.max(1, maxServerQueryThreads * _numServers / _numBrokers));
       }
-      _maxConcurrentQueries = maxConcurrentQueries;
+      _maxServerQueryThreads = maxServerQueryThreads;
     }
   }
 
diff --git a/pinot-broker/src/test/java/org/apache/pinot/broker/requesthandler/MultiStageQueryThrottlerTest.java b/pinot-broker/src/test/java/org/apache/pinot/broker/requesthandler/MultiStageQueryThrottlerTest.java
index fe2a5a124006..fcd5e3d7d0ff 100644
--- a/pinot-broker/src/test/java/org/apache/pinot/broker/requesthandler/MultiStageQueryThrottlerTest.java
+++ b/pinot-broker/src/test/java/org/apache/pinot/broker/requesthandler/MultiStageQueryThrottlerTest.java
@@ -52,8 +52,9 @@ public void setUp() {
     _mocks = MockitoAnnotations.openMocks(this);
     when(_helixManager.getClusterManagmentTool()).thenReturn(_helixAdmin);
     when(_helixManager.getClusterName()).thenReturn("testCluster");
-    when(_helixAdmin.getConfig(any(), any())).thenReturn(
-        Map.of(CommonConstants.Helix.CONFIG_OF_MAX_CONCURRENT_MULTI_STAGE_QUERIES, "4"));
+    when(_helixAdmin.getConfig(any(),
+        eq(Collections.singletonList(CommonConstants.Helix.CONFIG_OF_MULTI_STAGE_ENGINE_MAX_SERVER_QUERY_THREADS)))
+    ).thenReturn(Map.of(CommonConstants.Helix.CONFIG_OF_MULTI_STAGE_ENGINE_MAX_SERVER_QUERY_THREADS, "4"));
     when(_helixAdmin.getInstancesInCluster(eq("testCluster"))).thenReturn(
         List.of("Broker_0", "Broker_1", "Server_0", "Server_1"));
   }
@@ -70,9 +71,9 @@ public void testBasicAcquireRelease()
     _multiStageQueryThrottler = new MultiStageQueryThrottler();
     _multiStageQueryThrottler.init(_helixManager);
 
-    Assert.assertTrue(_multiStageQueryThrottler.tryAcquire(100, TimeUnit.MILLISECONDS));
+    Assert.assertTrue(_multiStageQueryThrottler.tryAcquire(1, 100, TimeUnit.MILLISECONDS));
     Assert.assertEquals(_multiStageQueryThrottler.availablePermits(), 3);
-    _multiStageQueryThrottler.release();
+    _multiStageQueryThrottler.release(1);
     Assert.assertEquals(_multiStageQueryThrottler.availablePermits(), 4);
   }
 
@@ -80,30 +81,31 @@ public void testBasicAcquireRelease()
   public void testAcquireTimeout()
       throws Exception {
     when(_helixAdmin.getConfig(any(),
-        eq(Collections.singletonList(CommonConstants.Helix.CONFIG_OF_MAX_CONCURRENT_MULTI_STAGE_QUERIES)))).thenReturn(
-        Map.of(CommonConstants.Helix.CONFIG_OF_MAX_CONCURRENT_MULTI_STAGE_QUERIES, "2"));
+        eq(Collections.singletonList(CommonConstants.Helix.CONFIG_OF_MULTI_STAGE_ENGINE_MAX_SERVER_QUERY_THREADS)))
+    ).thenReturn(Map.of(CommonConstants.Helix.CONFIG_OF_MULTI_STAGE_ENGINE_MAX_SERVER_QUERY_THREADS, "2"));
     _multiStageQueryThrottler = new MultiStageQueryThrottler();
     _multiStageQueryThrottler.init(_helixManager);
 
-    Assert.assertTrue(_multiStageQueryThrottler.tryAcquire(100, TimeUnit.MILLISECONDS));
+    Assert.assertTrue(_multiStageQueryThrottler.tryAcquire(1, 100, TimeUnit.MILLISECONDS));
     Assert.assertEquals(_multiStageQueryThrottler.availablePermits(), 1);
-    Assert.assertTrue(_multiStageQueryThrottler.tryAcquire(100, TimeUnit.MILLISECONDS));
+    Assert.assertTrue(_multiStageQueryThrottler.tryAcquire(1, 100, TimeUnit.MILLISECONDS));
     Assert.assertEquals(_multiStageQueryThrottler.availablePermits(), 0);
-    Assert.assertFalse(_multiStageQueryThrottler.tryAcquire(100, TimeUnit.MILLISECONDS));
+    Assert.assertFalse(_multiStageQueryThrottler.tryAcquire(1, 100, TimeUnit.MILLISECONDS));
   }
 
   @Test
   public void testDisabledThrottling()
       throws Exception {
-    when(_helixAdmin.getConfig(any(), any())).thenReturn(
-        Map.of(CommonConstants.Helix.CONFIG_OF_MAX_CONCURRENT_MULTI_STAGE_QUERIES, "-1"));
+    when(_helixAdmin.getConfig(any(),
+        eq(Collections.singletonList(CommonConstants.Helix.CONFIG_OF_MULTI_STAGE_ENGINE_MAX_SERVER_QUERY_THREADS)))
+    ).thenReturn(Map.of(CommonConstants.Helix.CONFIG_OF_MULTI_STAGE_ENGINE_MAX_SERVER_QUERY_THREADS, "-1"));
     _multiStageQueryThrottler = new MultiStageQueryThrottler();
     _multiStageQueryThrottler.init(_helixManager);
 
     // If maxConcurrentQueries is <= 0, the throttling mechanism should be "disabled" and any attempt to acquire should
     // succeed
     for (int i = 0; i < 100; i++) {
-      Assert.assertTrue(_multiStageQueryThrottler.tryAcquire(100, TimeUnit.MILLISECONDS));
+      Assert.assertTrue(_multiStageQueryThrottler.tryAcquire(10, 100, TimeUnit.MILLISECONDS));
     }
   }
 
@@ -113,10 +115,10 @@ public void testIncreaseNumBrokers()
     _multiStageQueryThrottler = new MultiStageQueryThrottler();
     _multiStageQueryThrottler.init(_helixManager);
 
-    for (int i = 0; i < 4; i++) {
-      Assert.assertTrue(_multiStageQueryThrottler.tryAcquire(100, TimeUnit.MILLISECONDS));
+    for (int i = 0; i < 2; i++) {
+      Assert.assertTrue(_multiStageQueryThrottler.tryAcquire(2, 100, TimeUnit.MILLISECONDS));
     }
-    Assert.assertFalse(_multiStageQueryThrottler.tryAcquire(100, TimeUnit.MILLISECONDS));
+    Assert.assertFalse(_multiStageQueryThrottler.tryAcquire(2, 100, TimeUnit.MILLISECONDS));
     Assert.assertEquals(_multiStageQueryThrottler.availablePermits(), 0);
 
     // Increase the number of brokers
@@ -126,13 +128,13 @@ public void testIncreaseNumBrokers()
 
     // Verify that the number of permits on this broker have been reduced to account for the new brokers
     Assert.assertEquals(_multiStageQueryThrottler.availablePermits(), -2);
-    Assert.assertFalse(_multiStageQueryThrottler.tryAcquire(100, TimeUnit.MILLISECONDS));
+    Assert.assertFalse(_multiStageQueryThrottler.tryAcquire(1, 100, TimeUnit.MILLISECONDS));
 
-    for (int i = 0; i < 4; i++) {
-      _multiStageQueryThrottler.release();
+    for (int i = 0; i < 2; i++) {
+      _multiStageQueryThrottler.release(2);
     }
     Assert.assertEquals(_multiStageQueryThrottler.availablePermits(), 2);
-    Assert.assertTrue(_multiStageQueryThrottler.tryAcquire(100, TimeUnit.MILLISECONDS));
+    Assert.assertTrue(_multiStageQueryThrottler.tryAcquire(1, 100, TimeUnit.MILLISECONDS));
   }
 
   @Test
@@ -141,10 +143,10 @@ public void testDecreaseNumBrokers()
     _multiStageQueryThrottler = new MultiStageQueryThrottler();
     _multiStageQueryThrottler.init(_helixManager);
 
-    for (int i = 0; i < 4; i++) {
-      Assert.assertTrue(_multiStageQueryThrottler.tryAcquire(100, TimeUnit.MILLISECONDS));
+    for (int i = 0; i < 2; i++) {
+      Assert.assertTrue(_multiStageQueryThrottler.tryAcquire(2, 100, TimeUnit.MILLISECONDS));
     }
-    Assert.assertFalse(_multiStageQueryThrottler.tryAcquire(100, TimeUnit.MILLISECONDS));
+    Assert.assertFalse(_multiStageQueryThrottler.tryAcquire(2, 100, TimeUnit.MILLISECONDS));
     Assert.assertEquals(_multiStageQueryThrottler.availablePermits(), 0);
 
     // Decrease the number of brokers
@@ -153,8 +155,8 @@ public void testDecreaseNumBrokers()
 
     // Ensure that the permits from the removed broker are added to this one.
     Assert.assertEquals(_multiStageQueryThrottler.availablePermits(), 4);
-    Assert.assertTrue(_multiStageQueryThrottler.tryAcquire(100, TimeUnit.MILLISECONDS));
-    Assert.assertEquals(_multiStageQueryThrottler.availablePermits(), 3);
+    Assert.assertTrue(_multiStageQueryThrottler.tryAcquire(3, 100, TimeUnit.MILLISECONDS));
+    Assert.assertEquals(_multiStageQueryThrottler.availablePermits(), 1);
   }
 
   @Test
@@ -163,10 +165,10 @@ public void testIncreaseNumServers()
     _multiStageQueryThrottler = new MultiStageQueryThrottler();
     _multiStageQueryThrottler.init(_helixManager);
 
-    for (int i = 0; i < 4; i++) {
-      Assert.assertTrue(_multiStageQueryThrottler.tryAcquire(100, TimeUnit.MILLISECONDS));
+    for (int i = 0; i < 2; i++) {
+      Assert.assertTrue(_multiStageQueryThrottler.tryAcquire(2, 100, TimeUnit.MILLISECONDS));
     }
-    Assert.assertFalse(_multiStageQueryThrottler.tryAcquire(100, TimeUnit.MILLISECONDS));
+    Assert.assertFalse(_multiStageQueryThrottler.tryAcquire(2, 100, TimeUnit.MILLISECONDS));
     Assert.assertEquals(_multiStageQueryThrottler.availablePermits(), 0);
 
     // Increase the number of servers
@@ -176,8 +178,8 @@ public void testIncreaseNumServers()
 
     // Ensure that the permits on this broker are increased to account for the new server
     Assert.assertEquals(_multiStageQueryThrottler.availablePermits(), 2);
-    Assert.assertTrue(_multiStageQueryThrottler.tryAcquire(100, TimeUnit.MILLISECONDS));
-    Assert.assertEquals(_multiStageQueryThrottler.availablePermits(), 1);
+    Assert.assertTrue(_multiStageQueryThrottler.tryAcquire(2, 100, TimeUnit.MILLISECONDS));
+    Assert.assertEquals(_multiStageQueryThrottler.availablePermits(), 0);
   }
 
   @Test
@@ -186,10 +188,10 @@ public void testDecreaseNumServers()
     _multiStageQueryThrottler = new MultiStageQueryThrottler();
     _multiStageQueryThrottler.init(_helixManager);
 
-    for (int i = 0; i < 4; i++) {
-      Assert.assertTrue(_multiStageQueryThrottler.tryAcquire(100, TimeUnit.MILLISECONDS));
+    for (int i = 0; i < 2; i++) {
+      Assert.assertTrue(_multiStageQueryThrottler.tryAcquire(2, 100, TimeUnit.MILLISECONDS));
     }
-    Assert.assertFalse(_multiStageQueryThrottler.tryAcquire(100, TimeUnit.MILLISECONDS));
+    Assert.assertFalse(_multiStageQueryThrottler.tryAcquire(2, 100, TimeUnit.MILLISECONDS));
     Assert.assertEquals(_multiStageQueryThrottler.availablePermits(), 0);
 
     // Decrease the number of servers
@@ -198,63 +200,61 @@ public void testDecreaseNumServers()
 
     // Verify that the number of permits on this broker have been reduced to account for the removed server
     Assert.assertEquals(_multiStageQueryThrottler.availablePermits(), -2);
-    Assert.assertFalse(_multiStageQueryThrottler.tryAcquire(100, TimeUnit.MILLISECONDS));
+    Assert.assertFalse(_multiStageQueryThrottler.tryAcquire(1, 100, TimeUnit.MILLISECONDS));
 
-    for (int i = 0; i < 4; i++) {
-      _multiStageQueryThrottler.release();
+    for (int i = 0; i < 2; i++) {
+      _multiStageQueryThrottler.release(2);
     }
     Assert.assertEquals(_multiStageQueryThrottler.availablePermits(), 2);
-    Assert.assertTrue(_multiStageQueryThrottler.tryAcquire(100, TimeUnit.MILLISECONDS));
+    Assert.assertTrue(_multiStageQueryThrottler.tryAcquire(2, 100, TimeUnit.MILLISECONDS));
   }
 
   @Test
-  public void testIncreaseMaxConcurrentQueries()
+  public void testIncreaseMaxServerQueryThreads()
       throws Exception {
     _multiStageQueryThrottler = new MultiStageQueryThrottler();
     _multiStageQueryThrottler.init(_helixManager);
 
-    for (int i = 0; i < 4; i++) {
-      Assert.assertTrue(_multiStageQueryThrottler.tryAcquire(100, TimeUnit.MILLISECONDS));
+    for (int i = 0; i < 2; i++) {
+      Assert.assertTrue(_multiStageQueryThrottler.tryAcquire(2, 100, TimeUnit.MILLISECONDS));
     }
-    Assert.assertFalse(_multiStageQueryThrottler.tryAcquire(100, TimeUnit.MILLISECONDS));
+    Assert.assertFalse(_multiStageQueryThrottler.tryAcquire(2, 100, TimeUnit.MILLISECONDS));
     Assert.assertEquals(_multiStageQueryThrottler.availablePermits(), 0);
 
     // Increase the value of cluster config maxConcurrentQueries
-    when(_helixAdmin.getConfig(any(),
-        eq(Collections.singletonList(CommonConstants.Helix.CONFIG_OF_MAX_CONCURRENT_MULTI_STAGE_QUERIES))))
-        .thenReturn(Map.of(CommonConstants.Helix.CONFIG_OF_MAX_CONCURRENT_MULTI_STAGE_QUERIES, "8"));
+    when(_helixAdmin.getConfig(any(), any()))
+        .thenReturn(Map.of(CommonConstants.Helix.CONFIG_OF_MULTI_STAGE_ENGINE_MAX_SERVER_QUERY_THREADS, "8"));
     _multiStageQueryThrottler.processClusterChange(HelixConstants.ChangeType.CLUSTER_CONFIG);
 
     Assert.assertEquals(_multiStageQueryThrottler.availablePermits(), 4);
-    Assert.assertTrue(_multiStageQueryThrottler.tryAcquire(100, TimeUnit.MILLISECONDS));
+    Assert.assertTrue(_multiStageQueryThrottler.tryAcquire(2, 100, TimeUnit.MILLISECONDS));
   }
 
   @Test
-  public void testDecreaseMaxConcurrentQueries()
+  public void testDecreaseMaxServerQueryThreads()
       throws Exception {
     _multiStageQueryThrottler = new MultiStageQueryThrottler();
     _multiStageQueryThrottler.init(_helixManager);
 
-    for (int i = 0; i < 4; i++) {
-      Assert.assertTrue(_multiStageQueryThrottler.tryAcquire(100, TimeUnit.MILLISECONDS));
+    for (int i = 0; i < 2; i++) {
+      Assert.assertTrue(_multiStageQueryThrottler.tryAcquire(2, 100, TimeUnit.MILLISECONDS));
     }
-    Assert.assertFalse(_multiStageQueryThrottler.tryAcquire(100, TimeUnit.MILLISECONDS));
+    Assert.assertFalse(_multiStageQueryThrottler.tryAcquire(2, 100, TimeUnit.MILLISECONDS));
     Assert.assertEquals(_multiStageQueryThrottler.availablePermits(), 0);
 
     // Decrease the value of cluster config maxConcurrentQueries
-    when(_helixAdmin.getConfig(any(),
-        eq(Collections.singletonList(CommonConstants.Helix.CONFIG_OF_MAX_CONCURRENT_MULTI_STAGE_QUERIES)))
-    ).thenReturn(Map.of(CommonConstants.Helix.CONFIG_OF_MAX_CONCURRENT_MULTI_STAGE_QUERIES, "3"));
+    when(_helixAdmin.getConfig(any(), any())).thenReturn(
+        Map.of(CommonConstants.Helix.CONFIG_OF_MULTI_STAGE_ENGINE_MAX_SERVER_QUERY_THREADS, "3"));
     _multiStageQueryThrottler.processClusterChange(HelixConstants.ChangeType.CLUSTER_CONFIG);
 
     Assert.assertEquals(_multiStageQueryThrottler.availablePermits(), -1);
-    Assert.assertFalse(_multiStageQueryThrottler.tryAcquire(100, TimeUnit.MILLISECONDS));
+    Assert.assertFalse(_multiStageQueryThrottler.tryAcquire(1, 100, TimeUnit.MILLISECONDS));
 
-    for (int i = 0; i < 4; i++) {
-      _multiStageQueryThrottler.release();
+    for (int i = 0; i < 2; i++) {
+      _multiStageQueryThrottler.release(2);
     }
     Assert.assertEquals(_multiStageQueryThrottler.availablePermits(), 3);
-    Assert.assertTrue(_multiStageQueryThrottler.tryAcquire(100, TimeUnit.MILLISECONDS));
+    Assert.assertTrue(_multiStageQueryThrottler.tryAcquire(2, 100, TimeUnit.MILLISECONDS));
   }
 
   @Test
@@ -266,63 +266,78 @@ public void testEnabledToDisabledTransitionDisallowed()
     Assert.assertEquals(_multiStageQueryThrottler.availablePermits(), 4);
 
     // Disable the throttling mechanism via cluster config change
-    when(_helixAdmin.getConfig(any(),
-        eq(Collections.singletonList(CommonConstants.Helix.CONFIG_OF_MAX_CONCURRENT_MULTI_STAGE_QUERIES)))
-    ).thenReturn(Map.of(CommonConstants.Helix.CONFIG_OF_MAX_CONCURRENT_MULTI_STAGE_QUERIES, "-1"));
+    when(_helixAdmin.getConfig(any(), any())).thenReturn(
+        Map.of(CommonConstants.Helix.CONFIG_OF_MULTI_STAGE_ENGINE_MAX_SERVER_QUERY_THREADS, "-1"));
     _multiStageQueryThrottler.processClusterChange(HelixConstants.ChangeType.CLUSTER_CONFIG);
 
     // Should not be allowed to disable the throttling mechanism if it is enabled during startup
     Assert.assertEquals(_multiStageQueryThrottler.availablePermits(), 4);
 
     for (int i = 0; i < 4; i++) {
-      Assert.assertTrue(_multiStageQueryThrottler.tryAcquire(100, TimeUnit.MILLISECONDS));
+      Assert.assertTrue(_multiStageQueryThrottler.tryAcquire(1, 100, TimeUnit.MILLISECONDS));
     }
     Assert.assertEquals(_multiStageQueryThrottler.availablePermits(), 0);
-    Assert.assertFalse(_multiStageQueryThrottler.tryAcquire(100, TimeUnit.MILLISECONDS));
+    Assert.assertFalse(_multiStageQueryThrottler.tryAcquire(1, 100, TimeUnit.MILLISECONDS));
   }
 
   @Test
   public void testDisabledToEnabledTransitionDisallowed()
       throws Exception {
     when(_helixAdmin.getConfig(any(),
-        eq(Collections.singletonList(CommonConstants.Helix.CONFIG_OF_MAX_CONCURRENT_MULTI_STAGE_QUERIES)))
-    ).thenReturn(Map.of(CommonConstants.Helix.CONFIG_OF_MAX_CONCURRENT_MULTI_STAGE_QUERIES, "-1"));
+        eq(Collections.singletonList(CommonConstants.Helix.CONFIG_OF_MULTI_STAGE_ENGINE_MAX_SERVER_QUERY_THREADS)))
+    ).thenReturn(Map.of(CommonConstants.Helix.CONFIG_OF_MULTI_STAGE_ENGINE_MAX_SERVER_QUERY_THREADS, "-1"));
     _multiStageQueryThrottler = new MultiStageQueryThrottler();
     _multiStageQueryThrottler.init(_helixManager);
 
-    // If maxConcurrentQueries is <= 0, the throttling mechanism should be "disabled" and any attempt to acquire should
+    // If maxServerQueryThreads is <= 0, the throttling mechanism should be "disabled" and any attempt to acquire should
     // succeed
     for (int i = 0; i < 100; i++) {
-      Assert.assertTrue(_multiStageQueryThrottler.tryAcquire(100, TimeUnit.MILLISECONDS));
+      Assert.assertTrue(_multiStageQueryThrottler.tryAcquire(10, 100, TimeUnit.MILLISECONDS));
     }
 
     // Enable the throttling mechanism via cluster config change
     when(_helixAdmin.getConfig(any(),
-        eq(Collections.singletonList(CommonConstants.Helix.CONFIG_OF_MAX_CONCURRENT_MULTI_STAGE_QUERIES)))
-    ).thenReturn(Map.of(CommonConstants.Helix.CONFIG_OF_MAX_CONCURRENT_MULTI_STAGE_QUERIES, "4"));
+        eq(Collections.singletonList(CommonConstants.Helix.CONFIG_OF_MULTI_STAGE_ENGINE_MAX_SERVER_QUERY_THREADS)))
+    ).thenReturn(Map.of(CommonConstants.Helix.CONFIG_OF_MULTI_STAGE_ENGINE_MAX_SERVER_QUERY_THREADS, "4"));
     _multiStageQueryThrottler.processClusterChange(HelixConstants.ChangeType.CLUSTER_CONFIG);
 
     // Should not be allowed to enable the throttling mechanism if it is disabled during startup
     for (int i = 0; i < 100; i++) {
-      Assert.assertTrue(_multiStageQueryThrottler.tryAcquire(100, TimeUnit.MILLISECONDS));
+      Assert.assertTrue(_multiStageQueryThrottler.tryAcquire(10, 100, TimeUnit.MILLISECONDS));
     }
   }
 
   @Test
-  public void testMaxConcurrentQueriesSmallerThanNumBrokers()
+  public void testLowMaxServerQueryThreads() {
+    _multiStageQueryThrottler = new MultiStageQueryThrottler();
+    _multiStageQueryThrottler.init(_helixManager);
+
+    Assert.assertEquals(_multiStageQueryThrottler.availablePermits(), 4);
+    // Thrown if the estimated number of query threads is greater than the number of available permits to this broker
+    Assert.assertThrows(RuntimeException.class,
+        () -> _multiStageQueryThrottler.tryAcquire(10, 100, TimeUnit.MILLISECONDS));
+  }
+
+  @Test
+  public void testAcquireReleaseWithDifferentQuerySizes()
       throws Exception {
-    when(_helixAdmin.getConfig(any(),
-        eq(Collections.singletonList(CommonConstants.Helix.CONFIG_OF_MAX_CONCURRENT_MULTI_STAGE_QUERIES)))
-    ).thenReturn(Map.of(CommonConstants.Helix.CONFIG_OF_MAX_CONCURRENT_MULTI_STAGE_QUERIES, "2"));
-    when(_helixAdmin.getInstancesInCluster(eq("testCluster"))).thenReturn(
-        List.of("Broker_0", "Broker_1", "Broker_2", "Broker_3", "Server_0", "Server_1"));
     _multiStageQueryThrottler = new MultiStageQueryThrottler();
     _multiStageQueryThrottler.init(_helixManager);
 
-    // The total permits should be capped at 1 even though maxConcurrentQueries * numServers / numBrokers is 0.
-    Assert.assertEquals(_multiStageQueryThrottler.availablePermits(), 1);
-    Assert.assertTrue(_multiStageQueryThrottler.tryAcquire(100, TimeUnit.MILLISECONDS));
-    Assert.assertEquals(_multiStageQueryThrottler.availablePermits(), 0);
-    Assert.assertFalse(_multiStageQueryThrottler.tryAcquire(100, TimeUnit.MILLISECONDS));
+    Assert.assertEquals(_multiStageQueryThrottler.availablePermits(), 4);
+
+    Assert.assertTrue(_multiStageQueryThrottler.tryAcquire(2, 100, TimeUnit.MILLISECONDS));
+    Assert.assertEquals(_multiStageQueryThrottler.availablePermits(), 2);
+
+    // A query with more than 2 threads shouldn't be permitted but a query with 2 threads should be permitted
+    Assert.assertFalse(_multiStageQueryThrottler.tryAcquire(3, 100, TimeUnit.MILLISECONDS));
+    Assert.assertTrue(_multiStageQueryThrottler.tryAcquire(2, 100, TimeUnit.MILLISECONDS));
+
+    // Release the permits
+    _multiStageQueryThrottler.release(2);
+    _multiStageQueryThrottler.release(2);
+
+    // The query with more than 2 threads should now be permitted
+    Assert.assertTrue(_multiStageQueryThrottler.tryAcquire(3, 100, TimeUnit.MILLISECONDS));
   }
 }
diff --git a/pinot-common/src/main/java/org/apache/pinot/common/concurrency/AdjustableSemaphore.java b/pinot-common/src/main/java/org/apache/pinot/common/concurrency/AdjustableSemaphore.java
index 2bbc25e42a0d..f0e405d7fc84 100644
--- a/pinot-common/src/main/java/org/apache/pinot/common/concurrency/AdjustableSemaphore.java
+++ b/pinot-common/src/main/java/org/apache/pinot/common/concurrency/AdjustableSemaphore.java
@@ -20,6 +20,7 @@
 
 import com.google.common.base.Preconditions;
 import java.util.concurrent.Semaphore;
+import java.util.concurrent.atomic.AtomicInteger;
 
 
 /**
@@ -27,25 +28,36 @@
  */
 public class AdjustableSemaphore extends Semaphore {
 
-  private int _totalPermits;
+  private final AtomicInteger _totalPermits;
 
   public AdjustableSemaphore(int permits) {
     super(permits);
-    _totalPermits = permits;
+    _totalPermits = new AtomicInteger(permits);
   }
 
   public AdjustableSemaphore(int permits, boolean fair) {
     super(permits, fair);
-    _totalPermits = permits;
+    _totalPermits = new AtomicInteger(permits);
   }
 
+  /**
+   * Sets the total number of permits to the given value without blocking.
+   */
   public void setPermits(int permits) {
     Preconditions.checkArgument(permits > 0, "Permits must be a positive integer");
-    if (permits < _totalPermits) {
-      reducePermits(_totalPermits - permits);
-    } else if (permits > _totalPermits) {
-      release(permits - _totalPermits);
+    if (permits < _totalPermits.get()) {
+      reducePermits(_totalPermits.get() - permits);
+    } else if (permits > _totalPermits.get()) {
+      release(permits - _totalPermits.get());
     }
-    _totalPermits = permits;
+    _totalPermits.set(permits);
+  }
+
+  /**
+   * Returns the total number of permits (as opposed to just the number of available permits returned by
+   * {@link #availablePermits()}).
+   */
+  public int getTotalPermits() {
+    return _totalPermits.get();
   }
 }
diff --git a/pinot-integration-tests/src/test/java/org/apache/pinot/integration/tests/MultiStageEngineIntegrationTest.java b/pinot-integration-tests/src/test/java/org/apache/pinot/integration/tests/MultiStageEngineIntegrationTest.java
index 74a477364e29..aa92a06dd569 100644
--- a/pinot-integration-tests/src/test/java/org/apache/pinot/integration/tests/MultiStageEngineIntegrationTest.java
+++ b/pinot-integration-tests/src/test/java/org/apache/pinot/integration/tests/MultiStageEngineIntegrationTest.java
@@ -87,13 +87,13 @@ public void setUp()
     startZk();
     startController();
 
-    // Set the max concurrent multi-stage queries to 5 for the cluster, so that we can test the query queueing logic
+    // Set the multi-stage max server query threads for the cluster, so that we can test the query queueing logic
     // in the MultiStageBrokerRequestHandler
     HelixConfigScope scope =
         new HelixConfigScopeBuilder(HelixConfigScope.ConfigScopeProperty.CLUSTER).forCluster(getHelixClusterName())
             .build();
-    _helixManager.getConfigAccessor().set(scope, CommonConstants.Helix.CONFIG_OF_MAX_CONCURRENT_MULTI_STAGE_QUERIES,
-        "5");
+    _helixManager.getConfigAccessor()
+        .set(scope, CommonConstants.Helix.CONFIG_OF_MULTI_STAGE_ENGINE_MAX_SERVER_QUERY_THREADS, "10");
 
     startBroker();
     startServer();
@@ -717,7 +717,8 @@ public void testMultiValueColumnGroupBy()
   }
 
   @Test
-  public void testVariadicFunction() throws Exception {
+  public void testVariadicFunction()
+      throws Exception {
     String sqlQuery = "SELECT ARRAY_TO_MV(VALUE_IN(RandomAirports, 'MFR', 'SUN', 'GTR')) as airport, count(*) "
         + "FROM mytable WHERE ARRAY_TO_MV(RandomAirports) IN ('MFR', 'SUN', 'GTR') GROUP BY airport";
     JsonNode jsonNode = postQuery(sqlQuery);
@@ -729,7 +730,8 @@ public void testVariadicFunction() throws Exception {
 
   @Test(dataProvider = "polymorphicScalarComparisonFunctionsDataProvider")
   public void testPolymorphicScalarComparisonFunctions(String type, String literal, String lesserLiteral,
-      Object expectedValue) throws Exception {
+      Object expectedValue)
+      throws Exception {
 
     // Queries written this way will trigger the PinotEvaluateLiteralRule which will call the scalar comparison function
     // on the literals. Simpler queries like SELECT ... WHERE 'test' = 'test' will not trigger the optimization rule
@@ -770,7 +772,8 @@ public void testPolymorphicScalarComparisonFunctions(String type, String literal
   }
 
   @Test
-  public void testPolymorphicScalarComparisonFunctionsDifferentType() throws Exception {
+  public void testPolymorphicScalarComparisonFunctionsDifferentType()
+      throws Exception {
     // Don't support comparison for literals with different types
     String sqlQueryPrefix = "WITH data as (SELECT 1 as \"foo\" FROM mytable) "
         + "SELECT * FROM data WHERE \"foo\" ";
@@ -816,8 +819,10 @@ Object[][] polymorphicScalarComparisonFunctionsDataProvider() {
     inputs.add(new Object[]{"FLOAT", "CAST(1.234 AS FLOAT)", "CAST(1.23 AS FLOAT)", "1.234"});
     inputs.add(new Object[]{"DOUBLE", "1.234", "1.23", "1.234"});
     inputs.add(new Object[]{"BOOLEAN", "CAST(true AS BOOLEAN)", "CAST(FALSE AS BOOLEAN)", "true"});
-    inputs.add(new Object[]{"TIMESTAMP", "CAST(1723593600000 AS TIMESTAMP)", "CAST (1623593600000 AS TIMESTAMP)",
-        new DateTime(1723593600000L, DateTimeZone.getDefault()).toString("yyyy-MM-dd HH:mm:ss.S")});
+    inputs.add(new Object[]{
+        "TIMESTAMP", "CAST(1723593600000 AS TIMESTAMP)", "CAST (1623593600000 AS TIMESTAMP)",
+        new DateTime(1723593600000L, DateTimeZone.getDefault()).toString("yyyy-MM-dd HH:mm:ss.S")
+    });
 
     return inputs.toArray(new Object[0][]);
   }
@@ -943,7 +948,8 @@ public void testSearch()
   }
 
   @Test
-  public void testLiteralFilterReduce() throws Exception {
+  public void testLiteralFilterReduce()
+      throws Exception {
     String sqlQuery = "SELECT * FROM (SELECT CASE WHEN AirTime > 0 THEN 'positive' ELSE 'negative' END AS AirTime "
         + "FROM mytable) WHERE AirTime IN ('positive', 'negative')";
     JsonNode jsonNode = postQuery(sqlQuery);
@@ -1096,7 +1102,8 @@ public void testNullIf()
   }
 
   @Test
-  public void testMVNumericCastInFilter() throws Exception {
+  public void testMVNumericCastInFilter()
+      throws Exception {
     String sqlQuery = "SELECT COUNT(*) FROM mytable WHERE ARRAY_TO_MV(CAST(DivAirportIDs AS BIGINT ARRAY)) > 0";
     JsonNode jsonNode = postQuery(sqlQuery);
     assertNoError(jsonNode);
diff --git a/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/physical/DispatchableSubPlan.java b/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/physical/DispatchableSubPlan.java
index c1cfb89ab7fc..5299b08ce7af 100644
--- a/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/physical/DispatchableSubPlan.java
+++ b/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/physical/DispatchableSubPlan.java
@@ -22,6 +22,7 @@
 import java.util.Map;
 import java.util.Set;
 import org.apache.calcite.runtime.PairList;
+import org.apache.pinot.core.util.QueryMultiThreadingUtils;
 
 
 /**
@@ -82,4 +83,38 @@ public Set<String> getTableNames() {
   public Map<String, Set<String>> getTableToUnavailableSegmentsMap() {
     return _tableToUnavailableSegmentsMap;
   }
+
+  /**
+   * Get the estimated total number of threads that will be spawned for this query (across all stages and servers).
+   */
+  public int getEstimatedNumQueryThreads() {
+    int estimatedNumQueryThreads = 0;
+    // Skip broker reduce root stage
+    for (DispatchablePlanFragment stage : _queryStageList.subList(1, _queryStageList.size())) {
+      // Non-leaf stage
+      if (stage.getWorkerIdToSegmentsMap().isEmpty()) {
+        estimatedNumQueryThreads += stage.getWorkerMetadataList().size();
+      } else {
+        // Leaf stage
+        for (Map<String, List<String>> segmentsMap : stage.getWorkerIdToSegmentsMap().values()) {
+          int numSegments = segmentsMap
+              .values()
+              .stream()
+              .mapToInt(List::size)
+              .sum();
+
+          // The leaf stage operator itself spawns a thread for each server query request
+          estimatedNumQueryThreads++;
+
+          // TODO: this isn't entirely accurate and can be improved. One issue is that the maxExecutionThreads can be
+          //       overridden in the query options and also in the server query executor configs.
+          //       Another issue is that not all leaf stage combine operators use the below method to calculate
+          //       the number of tasks / threads (the GroupByCombineOperator has some different logic for instance).
+          estimatedNumQueryThreads += QueryMultiThreadingUtils.getNumTasksForQuery(numSegments,
+              QueryMultiThreadingUtils.MAX_NUM_THREADS_PER_QUERY);
+        }
+      }
+    }
+    return estimatedNumQueryThreads;
+  }
 }
diff --git a/pinot-spi/src/main/java/org/apache/pinot/spi/utils/CommonConstants.java b/pinot-spi/src/main/java/org/apache/pinot/spi/utils/CommonConstants.java
index e3c3e0d48348..aac6b41fe829 100644
--- a/pinot-spi/src/main/java/org/apache/pinot/spi/utils/CommonConstants.java
+++ b/pinot-spi/src/main/java/org/apache/pinot/spi/utils/CommonConstants.java
@@ -242,9 +242,9 @@ public static class Instance {
     public static final boolean DEFAULT_MULTI_STAGE_ENGINE_TLS_ENABLED = false;
 
     // This is a "beta" config and can be changed or even removed in future releases.
-    public static final String CONFIG_OF_MAX_CONCURRENT_MULTI_STAGE_QUERIES =
-        "pinot.beta.multistage.engine.max.server.concurrent.queries";
-    public static final String DEFAULT_MAX_CONCURRENT_MULTI_STAGE_QUERIES = "-1";
+    public static final String CONFIG_OF_MULTI_STAGE_ENGINE_MAX_SERVER_QUERY_THREADS =
+        "pinot.beta.multistage.engine.max.server.query.threads";
+    public static final String DEFAULT_MULTI_STAGE_ENGINE_MAX_SERVER_QUERY_THREADS = "-1";
   }
 
   public static class Broker {