diff --git a/alembic.ini b/alembic.ini index 90649e00..ca9c26af 100644 --- a/alembic.ini +++ b/alembic.ini @@ -3,6 +3,6 @@ [alembic] # path to migration scripts # Use forward slashes (/) also on windows to provide an os agnostic path -script_location = dbos/migrations +script_location = dbos/_migrations version_path_separator = os # Use os.pathsep. Default configuration used for new projects. diff --git a/dbos/_dbos.py b/dbos/_dbos.py index 18b4c174..248e53be 100644 --- a/dbos/_dbos.py +++ b/dbos/_dbos.py @@ -83,7 +83,7 @@ ) from ._dbos_config import ConfigFile, load_config, set_env_vars from ._error import DBOSException, DBOSNonExistentWorkflowError -from ._logger import add_otlp_to_all_loggers, dbos_logger, init_logger +from ._logger import add_otlp_to_all_loggers, dbos_logger from ._sys_db import SystemDatabase # Most DBOS functions are just any callable F, so decorators / wrappers work on F diff --git a/dbos/_migrations/versions/04ca4f231047_workflow_queues_executor_id.py b/dbos/_migrations/versions/04ca4f231047_workflow_queues_executor_id.py new file mode 100644 index 00000000..2e3838cf --- /dev/null +++ b/dbos/_migrations/versions/04ca4f231047_workflow_queues_executor_id.py @@ -0,0 +1,34 @@ +"""workflow_queues_executor_id + +Revision ID: 04ca4f231047 +Revises: d76646551a6c +Create Date: 2025-01-15 15:05:08.043190 + +""" + +from typing import Sequence, Union + +import sqlalchemy as sa +from alembic import op + +# revision identifiers, used by Alembic. +revision: str = "04ca4f231047" +down_revision: Union[str, None] = "d76646551a6c" +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + op.add_column( + "workflow_queue", + sa.Column( + "executor_id", + sa.Text(), + nullable=True, + ), + schema="dbos", + ) + + +def downgrade() -> None: + op.drop_column("workflow_queue", "executor_id", schema="dbos") diff --git a/dbos/_queue.py b/dbos/_queue.py index 8ad75b47..5d981a75 100644 --- a/dbos/_queue.py +++ b/dbos/_queue.py @@ -2,6 +2,9 @@ import traceback from typing import TYPE_CHECKING, Optional, TypedDict +from psycopg import errors +from sqlalchemy.exc import OperationalError + from ._core import P, R, execute_workflow_by_id, start_workflow if TYPE_CHECKING: @@ -33,9 +36,19 @@ def __init__( name: str, concurrency: Optional[int] = None, limiter: Optional[QueueRateLimit] = None, + worker_concurrency: Optional[int] = None, ) -> None: + if ( + worker_concurrency is not None + and concurrency is not None + and worker_concurrency > concurrency + ): + raise ValueError( + "worker_concurrency must be less than or equal to concurrency" + ) self.name = name self.concurrency = concurrency + self.worker_concurrency = worker_concurrency self.limiter = limiter from ._dbos import _get_or_create_dbos_registry @@ -60,6 +73,12 @@ def queue_thread(stop_event: threading.Event, dbos: "DBOS") -> None: wf_ids = dbos._sys_db.start_queued_workflows(queue, dbos._executor_id) for id in wf_ids: execute_workflow_by_id(dbos, id) + except OperationalError as e: + # Ignore serialization error + if not isinstance(e.orig, errors.SerializationFailure): + dbos.logger.warning( + f"Exception encountered in queue thread: {traceback.format_exc()}" + ) except Exception: dbos.logger.warning( f"Exception encountered in queue thread: {traceback.format_exc()}" diff --git a/dbos/_schemas/system_database.py b/dbos/_schemas/system_database.py index 23a0ec8c..7f399866 100644 --- a/dbos/_schemas/system_database.py +++ b/dbos/_schemas/system_database.py @@ -154,6 +154,7 @@ class SystemSchema: nullable=False, primary_key=True, ), + Column("executor_id", Text), Column("queue_name", Text, nullable=False), Column( "created_at_epoch_ms", diff --git a/dbos/_sys_db.py b/dbos/_sys_db.py index 4ba989ef..f82e7ea9 100644 --- a/dbos/_sys_db.py +++ b/dbos/_sys_db.py @@ -13,7 +13,6 @@ Optional, Sequence, Set, - Tuple, TypedDict, cast, ) @@ -23,6 +22,7 @@ import sqlalchemy.dialects.postgresql as pg from alembic import command from alembic.config import Config +from sqlalchemy import or_ from sqlalchemy.exc import DBAPIError from . import _serialization @@ -1140,27 +1140,38 @@ def start_queued_workflows(self, queue: "Queue", executor_id: str) -> List[str]: if num_recent_queries >= queue.limiter["limit"]: return [] - # Select not-yet-completed functions in the queue ordered by the - # time at which they were enqueued. - # If there is a concurrency limit N, select only the N most recent + # Dequeue functions eligible for this worker and ordered by the time at which they were enqueued. + # If there is a global or local concurrency limit N, select only the N oldest enqueued # functions, else select all of them. query = ( sa.select( SystemSchema.workflow_queue.c.workflow_uuid, SystemSchema.workflow_queue.c.started_at_epoch_ms, + SystemSchema.workflow_queue.c.executor_id, ) .where(SystemSchema.workflow_queue.c.queue_name == queue.name) .where(SystemSchema.workflow_queue.c.completed_at_epoch_ms == None) + .where( + # Only select functions that have not been started yet or have been started by this worker + or_( + SystemSchema.workflow_queue.c.executor_id == None, + SystemSchema.workflow_queue.c.executor_id == executor_id, + ) + ) .order_by(SystemSchema.workflow_queue.c.created_at_epoch_ms.asc()) ) - if queue.concurrency is not None: + # Set a dequeue limit if necessary + if queue.worker_concurrency is not None: + query = query.limit(queue.worker_concurrency) + elif queue.concurrency is not None: query = query.limit(queue.concurrency) - # From the functions retrieved, get the workflow IDs of the functions - # that have not yet been started so we can start them. rows = c.execute(query).fetchall() + + # Now, get the workflow IDs of functions that have not yet been started dequeued_ids: List[str] = [row[0] for row in rows if row[1] is None] ret_ids: list[str] = [] + dbos_logger.debug(f"[{queue.name}] dequeueing {len(dequeued_ids)} task(s)") for id in dequeued_ids: # If we have a limiter, stop starting functions when the number @@ -1183,11 +1194,11 @@ def start_queued_workflows(self, queue: "Queue", executor_id: str) -> List[str]: ) ) - # Then give it a start time + # Then give it a start time and assign the executor ID c.execute( SystemSchema.workflow_queue.update() .where(SystemSchema.workflow_queue.c.workflow_uuid == id) - .values(started_at_epoch_ms=start_time_ms) + .values(started_at_epoch_ms=start_time_ms, executor_id=executor_id) ) ret_ids.append(id) diff --git a/tests/test_queue.py b/tests/test_queue.py index 1453ca09..60443b93 100644 --- a/tests/test_queue.py +++ b/tests/test_queue.py @@ -3,10 +3,11 @@ import threading import time import uuid +from multiprocessing import Process import sqlalchemy as sa -from dbos import DBOS, Queue, SetWorkflowID +from dbos import DBOS, ConfigFile, Queue, SetWorkflowID from dbos._dbos import WorkflowHandle from dbos._schemas.system_database import SystemSchema from dbos._sys_db import WorkflowStatusString @@ -367,3 +368,103 @@ def test_queue_workflow_in_recovered_workflow(dbos: DBOS) -> None: assert wfh.get_status().status == "SUCCESS" assert queue_entries_are_cleaned_up(dbos) return + + +########################### +# TEST WORKER CONCURRENCY # +########################### + + +def test_one_at_a_time_with_worker_concurrency(dbos: DBOS) -> None: + wf_counter = 0 + flag = False + workflow_event = threading.Event() + main_thread_event = threading.Event() + + @DBOS.workflow() + def workflow_one() -> None: + nonlocal wf_counter + wf_counter += 1 + main_thread_event.set() # Signal main thread we got running + workflow_event.wait() # Wait to complete + + @DBOS.workflow() + def workflow_two() -> None: + nonlocal flag + flag = True + + queue = Queue("test_queue", worker_concurrency=1) + handle1 = queue.enqueue(workflow_one) + handle2 = queue.enqueue(workflow_two) + + # Wait until the first task is dequeued + main_thread_event.wait() + # Let pass a few dequeuing intervals + time.sleep(2) + # 2nd task should not have been dequeued + assert not flag + # Unlock the first task + workflow_event.set() + # Both tasks should have completed + assert handle1.get_result() == None + assert handle2.get_result() == None + assert flag + assert wf_counter == 1, f"wf_counter={wf_counter}" + assert queue_entries_are_cleaned_up(dbos) + + +# Declare a workflow globally (we need it to be registered across process under a known name) +@DBOS.workflow() +def worker_concurrency_test_workflow() -> None: + pass + + +def run_dbos_test_in_process(i: int) -> None: + dbos_config: ConfigFile = { + "name": "test-app", + "language": "python", + "database": { + "hostname": "localhost", + "port": 5432, + "username": "postgres", + "password": os.environ["PGPASSWORD"], + "app_db_name": "dbostestpy", + }, + "runtimeConfig": { + "start": ["python3 main.py"], + "admin_port": 8001 + i, + }, + "telemetry": {}, + "env": {}, + } + dbos = DBOS(config=dbos_config) + DBOS.launch() + + Queue("test_queue", worker_concurrency=1) + time.sleep( + 2 + ) # Give some time for the parent worker to enqueue and for this worker to dequeue + + queue_entries_are_cleaned_up(dbos) + + DBOS.destroy() + + +def test_worker_concurrency_with_n_dbos_instances(dbos: DBOS) -> None: + + # Start N proccesses to dequeue + processes = [] + for i in range(0, 10): + os.environ["DBOS__VMID"] = f"test-executor-{i}" + process = Process(target=run_dbos_test_in_process, args=(i,)) + process.start() + processes.append(process) + + # Enqueue N tasks but ensure this worker cannot dequeue + + queue = Queue("test_queue", limiter={"limit": 0, "period": 1}) + for i in range(0, 10): + queue.enqueue(worker_concurrency_test_workflow) + + for process in processes: + process.join()