Skip to content

Commit

Permalink
simplify
Browse files Browse the repository at this point in the history
  • Loading branch information
maxdml committed Jan 21, 2025
1 parent 059143d commit 9432f85
Show file tree
Hide file tree
Showing 2 changed files with 81 additions and 122 deletions.
11 changes: 11 additions & 0 deletions dbos/_queue.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,9 @@
import traceback
from typing import TYPE_CHECKING, Optional, TypedDict

from psycopg import errors
from sqlalchemy.exc import OperationalError

from ._core import P, R, execute_workflow_by_id, start_workflow

if TYPE_CHECKING:
Expand Down Expand Up @@ -70,6 +73,14 @@ def queue_thread(stop_event: threading.Event, dbos: "DBOS") -> None:
wf_ids = dbos._sys_db.start_queued_workflows(queue, dbos._executor_id)
for id in wf_ids:
execute_workflow_by_id(dbos, id)
except OperationalError as e:
# Ignore serialization error
if isinstance(e.orig, errors.SerializationFailure):
return []
else:
dbos.logger.warning(
f"Exception encountered in queue thread: {traceback.format_exc()}"
)
except Exception:
dbos.logger.warning(
f"Exception encountered in queue thread: {traceback.format_exc()}"
Expand Down
192 changes: 70 additions & 122 deletions dbos/_sys_db.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
Optional,
Sequence,
Set,
Tuple,
TypedDict,
cast,
)
Expand All @@ -23,8 +22,8 @@
import sqlalchemy.dialects.postgresql as pg
from alembic import command
from alembic.config import Config
from psycopg import errors
from sqlalchemy.exc import DBAPIError, OperationalError
from sqlalchemy import or_
from sqlalchemy.exc import DBAPIError

from . import _serialization
from ._dbos_config import ConfigFile
Expand Down Expand Up @@ -1114,137 +1113,86 @@ def start_queued_workflows(self, queue: "Queue", executor_id: str) -> List[str]:
c.execute(sa.text("SET TRANSACTION ISOLATION LEVEL REPEATABLE READ"))

ret_ids: list[str] = []
try:
# If there is a limiter, compute how many functions have started in its period.
if queue.limiter is not None:
query = (
sa.select(sa.func.count())
.select_from(SystemSchema.workflow_queue)
.where(SystemSchema.workflow_queue.c.queue_name == queue.name)
.where(
SystemSchema.workflow_queue.c.started_at_epoch_ms.isnot(
None
)
)
.where(
SystemSchema.workflow_queue.c.started_at_epoch_ms
> start_time_ms - limiter_period_ms
)
)
num_recent_queries = c.execute(query).fetchone()[0] # type: ignore
if num_recent_queries >= queue.limiter["limit"]:
return []

# Select not-yet-completed functions in the queue ordered by the
# time at which they were enqueued.
# If there is a concurrency limit N, select only the N oldest enqueued
# functions, else select all of them.
# If there is a limiter, compute how many functions have started in its period.
if queue.limiter is not None:
query = (
sa.select(
SystemSchema.workflow_queue.c.workflow_uuid,
SystemSchema.workflow_queue.c.started_at_epoch_ms,
SystemSchema.workflow_queue.c.executor_id,
)
sa.select(sa.func.count())
.select_from(SystemSchema.workflow_queue)
.where(SystemSchema.workflow_queue.c.queue_name == queue.name)
.where(SystemSchema.workflow_queue.c.completed_at_epoch_ms == None)
.order_by(SystemSchema.workflow_queue.c.created_at_epoch_ms.asc())
.where(
SystemSchema.workflow_queue.c.started_at_epoch_ms.isnot(None)
)
.where(
SystemSchema.workflow_queue.c.started_at_epoch_ms
> start_time_ms - limiter_period_ms
)
)
if queue.concurrency is not None:
query = query.limit(queue.concurrency)

rows = c.execute(query).fetchall()
dbos_logger.debug(f"[{queue.name}] dequeued {len(rows)} task(s)")
if len(rows) == 0:
num_recent_queries = c.execute(query).fetchone()[0] # type: ignore
if num_recent_queries >= queue.limiter["limit"]:
return []

# First, get the IDs of functions that have already been started
# We will use these to calculate how many more functions this worker can start
number_of_tasks_already_started: int = len(
[row[0] for row in rows if row[1] is not None]
)
dbos_logger.debug(
f"[{queue.name}] {number_of_tasks_already_started} task(s) already started"
)

# queue length >= queue.concurrency >= len(rows) >= number_of_tasks_already_started > 0
number_of_eligible_tasks: int = (
len(rows) - number_of_tasks_already_started
)
dbos_logger.debug(
f"[{queue.name}] {number_of_eligible_tasks} task(s) eligible for dequeue"
# Dequeue functions eligible for this worker and ordered by the time at which they were enqueued.
# If there is a global or local concurrency limit N, select only the N oldest enqueued
# functions, else select all of them.
query = (
sa.select(
SystemSchema.workflow_queue.c.workflow_uuid,
SystemSchema.workflow_queue.c.started_at_epoch_ms,
SystemSchema.workflow_queue.c.executor_id,
)
if number_of_eligible_tasks == 0:
return []

tasks_this_worker_is_already_working_on: int = len(
[row[0] for row in rows if len(row) == 3 and row[2] == executor_id]
.where(SystemSchema.workflow_queue.c.queue_name == queue.name)
.where(SystemSchema.workflow_queue.c.completed_at_epoch_ms == None)
.where(
# Only select functions that have not been started yet or have been started by this worker
or_(
SystemSchema.workflow_queue.c.executor_id == None,
SystemSchema.workflow_queue.c.executor_id == executor_id,
)
)
.order_by(SystemSchema.workflow_queue.c.created_at_epoch_ms.asc())
)
# Set a dequeue limit if necessary
if queue.worker_concurrency is not None:
query = query.limit(queue.worker_concurrency)
elif queue.concurrency is not None:
query = query.limit(queue._concurrency)

rows = c.execute(query).fetchall()
dbos_logger.debug(f"[{queue.name}] dequeued {len(rows)} task(s)")

# Now, get the workflow IDs of functions that have not yet been started
dequeued_ids: List[str] = [row[0] for row in rows if row[1] is None]
dbos_logger.debug(f"[{queue.name}] dequeueing {len(dequeued_ids)} task(s)")
for id in dequeued_ids:

# If we have a limiter, stop starting functions when the number
# of functions started this period exceeds the limit.
if queue.limiter is not None:
if len(ret_ids) + num_recent_queries >= queue.limiter["limit"]:
break

# This worker can dequeue up to whatever is smaller between the eligible tasks and its set concurrency
# Of course we must account for tasks this worker is already working on, dequeued during a previous pass of this function
max_tasks_this_worker_can_dequeue = int(
min(
number_of_eligible_tasks,
(
queue.worker_concurrency
if queue.worker_concurrency is not None
else float("inf")
),
# To start a function, first set its status to PENDING and update its executor ID
c.execute(
SystemSchema.workflow_status.update()
.where(SystemSchema.workflow_status.c.workflow_uuid == id)
.where(
SystemSchema.workflow_status.c.status
== WorkflowStatusString.ENQUEUED.value
)
.values(
status=WorkflowStatusString.PENDING.value,
executor_id=executor_id,
)
- tasks_this_worker_is_already_working_on
)
assert (
max_tasks_this_worker_can_dequeue >= 0
) # TODO: remove this assert after sufficient tests are implemented

# Now, get the workflow IDs of functions that have not yet been started
# Limit the list by the maximum concurrency for this worker
dequeued_ids: List[str] = [row[0] for row in rows if row[1] is None][
:max_tasks_this_worker_can_dequeue
]
dbos_logger.debug(
f"[{queue.name}] dequeueing {len(dequeued_ids)} task(s)"
)
for id in dequeued_ids:

# If we have a limiter, stop starting functions when the number
# of functions started this period exceeds the limit.
if queue.limiter is not None:
if len(ret_ids) + num_recent_queries >= queue.limiter["limit"]:
break

# To start a function, first set its status to PENDING and update its executor ID
c.execute(
SystemSchema.workflow_status.update()
.where(SystemSchema.workflow_status.c.workflow_uuid == id)
.where(
SystemSchema.workflow_status.c.status
== WorkflowStatusString.ENQUEUED.value
)
.values(
status=WorkflowStatusString.PENDING.value,
executor_id=executor_id,
)
)
# Then give it a start time and assign the executor ID
c.execute(
SystemSchema.workflow_queue.update()
.where(SystemSchema.workflow_queue.c.workflow_uuid == id)
.values(started_at_epoch_ms=start_time_ms, executor_id=executor_id)
)
ret_ids.append(id)

# Then give it a start time and assign the executor ID
c.execute(
SystemSchema.workflow_queue.update()
.where(SystemSchema.workflow_queue.c.workflow_uuid == id)
.values(
started_at_epoch_ms=start_time_ms, executor_id=executor_id
)
)
ret_ids.append(id)
except OperationalError as e:
# Abandon the queue items in case of serialization error
if isinstance(e.orig, errors.SerializationFailure):
dbos_logger.warning("Serialization failure. Abandoning queue items")
return []
else:
# Re-raise other OperationalError types if unrelated
raise
finally:
# If we have a limiter, garbage-collect all completed functions started
# before the period. If there's no limiter, there's no need--they were
# deleted on completion.
Expand Down

0 comments on commit 9432f85

Please sign in to comment.