From 9432f85f9e578b9a896990045dc57a812cbceb3d Mon Sep 17 00:00:00 2001 From: maxdml Date: Tue, 21 Jan 2025 13:24:51 -0800 Subject: [PATCH] simplify --- dbos/_queue.py | 11 +++ dbos/_sys_db.py | 192 ++++++++++++++++++------------------------------ 2 files changed, 81 insertions(+), 122 deletions(-) diff --git a/dbos/_queue.py b/dbos/_queue.py index e5599243..e51601d4 100644 --- a/dbos/_queue.py +++ b/dbos/_queue.py @@ -2,6 +2,9 @@ import traceback from typing import TYPE_CHECKING, Optional, TypedDict +from psycopg import errors +from sqlalchemy.exc import OperationalError + from ._core import P, R, execute_workflow_by_id, start_workflow if TYPE_CHECKING: @@ -70,6 +73,14 @@ def queue_thread(stop_event: threading.Event, dbos: "DBOS") -> None: wf_ids = dbos._sys_db.start_queued_workflows(queue, dbos._executor_id) for id in wf_ids: execute_workflow_by_id(dbos, id) + except OperationalError as e: + # Ignore serialization error + if isinstance(e.orig, errors.SerializationFailure): + return [] + else: + dbos.logger.warning( + f"Exception encountered in queue thread: {traceback.format_exc()}" + ) except Exception: dbos.logger.warning( f"Exception encountered in queue thread: {traceback.format_exc()}" diff --git a/dbos/_sys_db.py b/dbos/_sys_db.py index c4125105..a1a9a6f0 100644 --- a/dbos/_sys_db.py +++ b/dbos/_sys_db.py @@ -13,7 +13,6 @@ Optional, Sequence, Set, - Tuple, TypedDict, cast, ) @@ -23,8 +22,8 @@ import sqlalchemy.dialects.postgresql as pg from alembic import command from alembic.config import Config -from psycopg import errors -from sqlalchemy.exc import DBAPIError, OperationalError +from sqlalchemy import or_ +from sqlalchemy.exc import DBAPIError from . import _serialization from ._dbos_config import ConfigFile @@ -1114,137 +1113,86 @@ def start_queued_workflows(self, queue: "Queue", executor_id: str) -> List[str]: c.execute(sa.text("SET TRANSACTION ISOLATION LEVEL REPEATABLE READ")) ret_ids: list[str] = [] - try: - # If there is a limiter, compute how many functions have started in its period. - if queue.limiter is not None: - query = ( - sa.select(sa.func.count()) - .select_from(SystemSchema.workflow_queue) - .where(SystemSchema.workflow_queue.c.queue_name == queue.name) - .where( - SystemSchema.workflow_queue.c.started_at_epoch_ms.isnot( - None - ) - ) - .where( - SystemSchema.workflow_queue.c.started_at_epoch_ms - > start_time_ms - limiter_period_ms - ) - ) - num_recent_queries = c.execute(query).fetchone()[0] # type: ignore - if num_recent_queries >= queue.limiter["limit"]: - return [] - - # Select not-yet-completed functions in the queue ordered by the - # time at which they were enqueued. - # If there is a concurrency limit N, select only the N oldest enqueued - # functions, else select all of them. + # If there is a limiter, compute how many functions have started in its period. + if queue.limiter is not None: query = ( - sa.select( - SystemSchema.workflow_queue.c.workflow_uuid, - SystemSchema.workflow_queue.c.started_at_epoch_ms, - SystemSchema.workflow_queue.c.executor_id, - ) + sa.select(sa.func.count()) + .select_from(SystemSchema.workflow_queue) .where(SystemSchema.workflow_queue.c.queue_name == queue.name) - .where(SystemSchema.workflow_queue.c.completed_at_epoch_ms == None) - .order_by(SystemSchema.workflow_queue.c.created_at_epoch_ms.asc()) + .where( + SystemSchema.workflow_queue.c.started_at_epoch_ms.isnot(None) + ) + .where( + SystemSchema.workflow_queue.c.started_at_epoch_ms + > start_time_ms - limiter_period_ms + ) ) - if queue.concurrency is not None: - query = query.limit(queue.concurrency) - - rows = c.execute(query).fetchall() - dbos_logger.debug(f"[{queue.name}] dequeued {len(rows)} task(s)") - if len(rows) == 0: + num_recent_queries = c.execute(query).fetchone()[0] # type: ignore + if num_recent_queries >= queue.limiter["limit"]: return [] - # First, get the IDs of functions that have already been started - # We will use these to calculate how many more functions this worker can start - number_of_tasks_already_started: int = len( - [row[0] for row in rows if row[1] is not None] - ) - dbos_logger.debug( - f"[{queue.name}] {number_of_tasks_already_started} task(s) already started" - ) - - # queue length >= queue.concurrency >= len(rows) >= number_of_tasks_already_started > 0 - number_of_eligible_tasks: int = ( - len(rows) - number_of_tasks_already_started - ) - dbos_logger.debug( - f"[{queue.name}] {number_of_eligible_tasks} task(s) eligible for dequeue" + # Dequeue functions eligible for this worker and ordered by the time at which they were enqueued. + # If there is a global or local concurrency limit N, select only the N oldest enqueued + # functions, else select all of them. + query = ( + sa.select( + SystemSchema.workflow_queue.c.workflow_uuid, + SystemSchema.workflow_queue.c.started_at_epoch_ms, + SystemSchema.workflow_queue.c.executor_id, ) - if number_of_eligible_tasks == 0: - return [] - - tasks_this_worker_is_already_working_on: int = len( - [row[0] for row in rows if len(row) == 3 and row[2] == executor_id] + .where(SystemSchema.workflow_queue.c.queue_name == queue.name) + .where(SystemSchema.workflow_queue.c.completed_at_epoch_ms == None) + .where( + # Only select functions that have not been started yet or have been started by this worker + or_( + SystemSchema.workflow_queue.c.executor_id == None, + SystemSchema.workflow_queue.c.executor_id == executor_id, + ) ) + .order_by(SystemSchema.workflow_queue.c.created_at_epoch_ms.asc()) + ) + # Set a dequeue limit if necessary + if queue.worker_concurrency is not None: + query = query.limit(queue.worker_concurrency) + elif queue.concurrency is not None: + query = query.limit(queue._concurrency) + + rows = c.execute(query).fetchall() + dbos_logger.debug(f"[{queue.name}] dequeued {len(rows)} task(s)") + + # Now, get the workflow IDs of functions that have not yet been started + dequeued_ids: List[str] = [row[0] for row in rows if row[1] is None] + dbos_logger.debug(f"[{queue.name}] dequeueing {len(dequeued_ids)} task(s)") + for id in dequeued_ids: + + # If we have a limiter, stop starting functions when the number + # of functions started this period exceeds the limit. + if queue.limiter is not None: + if len(ret_ids) + num_recent_queries >= queue.limiter["limit"]: + break - # This worker can dequeue up to whatever is smaller between the eligible tasks and its set concurrency - # Of course we must account for tasks this worker is already working on, dequeued during a previous pass of this function - max_tasks_this_worker_can_dequeue = int( - min( - number_of_eligible_tasks, - ( - queue.worker_concurrency - if queue.worker_concurrency is not None - else float("inf") - ), + # To start a function, first set its status to PENDING and update its executor ID + c.execute( + SystemSchema.workflow_status.update() + .where(SystemSchema.workflow_status.c.workflow_uuid == id) + .where( + SystemSchema.workflow_status.c.status + == WorkflowStatusString.ENQUEUED.value + ) + .values( + status=WorkflowStatusString.PENDING.value, + executor_id=executor_id, ) - - tasks_this_worker_is_already_working_on - ) - assert ( - max_tasks_this_worker_can_dequeue >= 0 - ) # TODO: remove this assert after sufficient tests are implemented - - # Now, get the workflow IDs of functions that have not yet been started - # Limit the list by the maximum concurrency for this worker - dequeued_ids: List[str] = [row[0] for row in rows if row[1] is None][ - :max_tasks_this_worker_can_dequeue - ] - dbos_logger.debug( - f"[{queue.name}] dequeueing {len(dequeued_ids)} task(s)" ) - for id in dequeued_ids: - - # If we have a limiter, stop starting functions when the number - # of functions started this period exceeds the limit. - if queue.limiter is not None: - if len(ret_ids) + num_recent_queries >= queue.limiter["limit"]: - break - # To start a function, first set its status to PENDING and update its executor ID - c.execute( - SystemSchema.workflow_status.update() - .where(SystemSchema.workflow_status.c.workflow_uuid == id) - .where( - SystemSchema.workflow_status.c.status - == WorkflowStatusString.ENQUEUED.value - ) - .values( - status=WorkflowStatusString.PENDING.value, - executor_id=executor_id, - ) - ) + # Then give it a start time and assign the executor ID + c.execute( + SystemSchema.workflow_queue.update() + .where(SystemSchema.workflow_queue.c.workflow_uuid == id) + .values(started_at_epoch_ms=start_time_ms, executor_id=executor_id) + ) + ret_ids.append(id) - # Then give it a start time and assign the executor ID - c.execute( - SystemSchema.workflow_queue.update() - .where(SystemSchema.workflow_queue.c.workflow_uuid == id) - .values( - started_at_epoch_ms=start_time_ms, executor_id=executor_id - ) - ) - ret_ids.append(id) - except OperationalError as e: - # Abandon the queue items in case of serialization error - if isinstance(e.orig, errors.SerializationFailure): - dbos_logger.warning("Serialization failure. Abandoning queue items") - return [] - else: - # Re-raise other OperationalError types if unrelated - raise - finally: # If we have a limiter, garbage-collect all completed functions started # before the period. If there's no limiter, there's no need--they were # deleted on completion.