Skip to content

Commit

Permalink
Queues max worker concurrency (#177)
Browse files Browse the repository at this point in the history
This PR allows users to control the maximum numbers of tasks, in a
queue, a single DBOS Transact instance can execute concurrently. This
knob is exposed by the new `worker_concurrency` Queue initialization
parameter.

This is implemented by modifying the queue DB query to only retrieve
uncompleted tasks eligible for this worker and limiting the query to
either `worker_concurrency` or `concurrency`, with this priority, if
they are set.
  • Loading branch information
maxdml authored Jan 22, 2025
1 parent 4a4a80a commit 04bc3a1
Show file tree
Hide file tree
Showing 7 changed files with 178 additions and 12 deletions.
2 changes: 1 addition & 1 deletion alembic.ini
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,6 @@
[alembic]
# path to migration scripts
# Use forward slashes (/) also on windows to provide an os agnostic path
script_location = dbos/migrations
script_location = dbos/_migrations

version_path_separator = os # Use os.pathsep. Default configuration used for new projects.
2 changes: 1 addition & 1 deletion dbos/_dbos.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@
)
from ._dbos_config import ConfigFile, load_config, set_env_vars
from ._error import DBOSException, DBOSNonExistentWorkflowError
from ._logger import add_otlp_to_all_loggers, dbos_logger, init_logger
from ._logger import add_otlp_to_all_loggers, dbos_logger
from ._sys_db import SystemDatabase

# Most DBOS functions are just any callable F, so decorators / wrappers work on F
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
"""workflow_queues_executor_id
Revision ID: 04ca4f231047
Revises: d76646551a6c
Create Date: 2025-01-15 15:05:08.043190
"""

from typing import Sequence, Union

import sqlalchemy as sa
from alembic import op

# revision identifiers, used by Alembic.
revision: str = "04ca4f231047"
down_revision: Union[str, None] = "d76646551a6c"
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None


def upgrade() -> None:
op.add_column(
"workflow_queue",
sa.Column(
"executor_id",
sa.Text(),
nullable=True,
),
schema="dbos",
)


def downgrade() -> None:
op.drop_column("workflow_queue", "executor_id", schema="dbos")
19 changes: 19 additions & 0 deletions dbos/_queue.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,9 @@
import traceback
from typing import TYPE_CHECKING, Optional, TypedDict

from psycopg import errors
from sqlalchemy.exc import OperationalError

from ._core import P, R, execute_workflow_by_id, start_workflow

if TYPE_CHECKING:
Expand Down Expand Up @@ -33,9 +36,19 @@ def __init__(
name: str,
concurrency: Optional[int] = None,
limiter: Optional[QueueRateLimit] = None,
worker_concurrency: Optional[int] = None,
) -> None:
if (
worker_concurrency is not None
and concurrency is not None
and worker_concurrency > concurrency
):
raise ValueError(
"worker_concurrency must be less than or equal to concurrency"
)
self.name = name
self.concurrency = concurrency
self.worker_concurrency = worker_concurrency
self.limiter = limiter
from ._dbos import _get_or_create_dbos_registry

Expand All @@ -60,6 +73,12 @@ def queue_thread(stop_event: threading.Event, dbos: "DBOS") -> None:
wf_ids = dbos._sys_db.start_queued_workflows(queue, dbos._executor_id)
for id in wf_ids:
execute_workflow_by_id(dbos, id)
except OperationalError as e:
# Ignore serialization error
if not isinstance(e.orig, errors.SerializationFailure):
dbos.logger.warning(
f"Exception encountered in queue thread: {traceback.format_exc()}"
)
except Exception:
dbos.logger.warning(
f"Exception encountered in queue thread: {traceback.format_exc()}"
Expand Down
1 change: 1 addition & 0 deletions dbos/_schemas/system_database.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,7 @@ class SystemSchema:
nullable=False,
primary_key=True,
),
Column("executor_id", Text),
Column("queue_name", Text, nullable=False),
Column(
"created_at_epoch_ms",
Expand Down
29 changes: 20 additions & 9 deletions dbos/_sys_db.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
Optional,
Sequence,
Set,
Tuple,
TypedDict,
cast,
)
Expand All @@ -23,6 +22,7 @@
import sqlalchemy.dialects.postgresql as pg
from alembic import command
from alembic.config import Config
from sqlalchemy import or_
from sqlalchemy.exc import DBAPIError

from . import _serialization
Expand Down Expand Up @@ -1140,27 +1140,38 @@ def start_queued_workflows(self, queue: "Queue", executor_id: str) -> List[str]:
if num_recent_queries >= queue.limiter["limit"]:
return []

# Select not-yet-completed functions in the queue ordered by the
# time at which they were enqueued.
# If there is a concurrency limit N, select only the N most recent
# Dequeue functions eligible for this worker and ordered by the time at which they were enqueued.
# If there is a global or local concurrency limit N, select only the N oldest enqueued
# functions, else select all of them.
query = (
sa.select(
SystemSchema.workflow_queue.c.workflow_uuid,
SystemSchema.workflow_queue.c.started_at_epoch_ms,
SystemSchema.workflow_queue.c.executor_id,
)
.where(SystemSchema.workflow_queue.c.queue_name == queue.name)
.where(SystemSchema.workflow_queue.c.completed_at_epoch_ms == None)
.where(
# Only select functions that have not been started yet or have been started by this worker
or_(
SystemSchema.workflow_queue.c.executor_id == None,
SystemSchema.workflow_queue.c.executor_id == executor_id,
)
)
.order_by(SystemSchema.workflow_queue.c.created_at_epoch_ms.asc())
)
if queue.concurrency is not None:
# Set a dequeue limit if necessary
if queue.worker_concurrency is not None:
query = query.limit(queue.worker_concurrency)
elif queue.concurrency is not None:
query = query.limit(queue.concurrency)

# From the functions retrieved, get the workflow IDs of the functions
# that have not yet been started so we can start them.
rows = c.execute(query).fetchall()

# Now, get the workflow IDs of functions that have not yet been started
dequeued_ids: List[str] = [row[0] for row in rows if row[1] is None]
ret_ids: list[str] = []
dbos_logger.debug(f"[{queue.name}] dequeueing {len(dequeued_ids)} task(s)")
for id in dequeued_ids:

# If we have a limiter, stop starting functions when the number
Expand All @@ -1183,11 +1194,11 @@ def start_queued_workflows(self, queue: "Queue", executor_id: str) -> List[str]:
)
)

# Then give it a start time
# Then give it a start time and assign the executor ID
c.execute(
SystemSchema.workflow_queue.update()
.where(SystemSchema.workflow_queue.c.workflow_uuid == id)
.values(started_at_epoch_ms=start_time_ms)
.values(started_at_epoch_ms=start_time_ms, executor_id=executor_id)
)
ret_ids.append(id)

Expand Down
103 changes: 102 additions & 1 deletion tests/test_queue.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,11 @@
import threading
import time
import uuid
from multiprocessing import Process

import sqlalchemy as sa

from dbos import DBOS, Queue, SetWorkflowID
from dbos import DBOS, ConfigFile, Queue, SetWorkflowID
from dbos._dbos import WorkflowHandle
from dbos._schemas.system_database import SystemSchema
from dbos._sys_db import WorkflowStatusString
Expand Down Expand Up @@ -367,3 +368,103 @@ def test_queue_workflow_in_recovered_workflow(dbos: DBOS) -> None:
assert wfh.get_status().status == "SUCCESS"
assert queue_entries_are_cleaned_up(dbos)
return


###########################
# TEST WORKER CONCURRENCY #
###########################


def test_one_at_a_time_with_worker_concurrency(dbos: DBOS) -> None:
wf_counter = 0
flag = False
workflow_event = threading.Event()
main_thread_event = threading.Event()

@DBOS.workflow()
def workflow_one() -> None:
nonlocal wf_counter
wf_counter += 1
main_thread_event.set() # Signal main thread we got running
workflow_event.wait() # Wait to complete

@DBOS.workflow()
def workflow_two() -> None:
nonlocal flag
flag = True

queue = Queue("test_queue", worker_concurrency=1)
handle1 = queue.enqueue(workflow_one)
handle2 = queue.enqueue(workflow_two)

# Wait until the first task is dequeued
main_thread_event.wait()
# Let pass a few dequeuing intervals
time.sleep(2)
# 2nd task should not have been dequeued
assert not flag
# Unlock the first task
workflow_event.set()
# Both tasks should have completed
assert handle1.get_result() == None
assert handle2.get_result() == None
assert flag
assert wf_counter == 1, f"wf_counter={wf_counter}"
assert queue_entries_are_cleaned_up(dbos)


# Declare a workflow globally (we need it to be registered across process under a known name)
@DBOS.workflow()
def worker_concurrency_test_workflow() -> None:
pass


def run_dbos_test_in_process(i: int) -> None:
dbos_config: ConfigFile = {
"name": "test-app",
"language": "python",
"database": {
"hostname": "localhost",
"port": 5432,
"username": "postgres",
"password": os.environ["PGPASSWORD"],
"app_db_name": "dbostestpy",
},
"runtimeConfig": {
"start": ["python3 main.py"],
"admin_port": 8001 + i,
},
"telemetry": {},
"env": {},
}
dbos = DBOS(config=dbos_config)
DBOS.launch()

Queue("test_queue", worker_concurrency=1)
time.sleep(
2
) # Give some time for the parent worker to enqueue and for this worker to dequeue

queue_entries_are_cleaned_up(dbos)

DBOS.destroy()


def test_worker_concurrency_with_n_dbos_instances(dbos: DBOS) -> None:

# Start N proccesses to dequeue
processes = []
for i in range(0, 10):
os.environ["DBOS__VMID"] = f"test-executor-{i}"
process = Process(target=run_dbos_test_in_process, args=(i,))
process.start()
processes.append(process)

# Enqueue N tasks but ensure this worker cannot dequeue

queue = Queue("test_queue", limiter={"limit": 0, "period": 1})
for i in range(0, 10):
queue.enqueue(worker_concurrency_test_workflow)

for process in processes:
process.join()

0 comments on commit 04bc3a1

Please sign in to comment.