Skip to content

Commit

Permalink
Basic communicator support (#18)
Browse files Browse the repository at this point in the history
Added communicator.
  • Loading branch information
qianl15 authored Jul 30, 2024
1 parent 2ad13f5 commit f12f13e
Show file tree
Hide file tree
Showing 7 changed files with 157 additions and 7 deletions.
4 changes: 2 additions & 2 deletions dbos_transact/application_database.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ def record_transaction_output(
pg.insert(ApplicationSchema.transaction_outputs).values(
workflow_uuid=output["workflow_uuid"],
function_id=output["function_id"],
output=output["output"] if output["output"] else None,
output=output["output"],
error=None,
txn_id=sa.text("(select pg_current_xact_id_if_assigned()::text)"),
txn_snapshot=output["txn_snapshot"],
Expand All @@ -106,7 +106,7 @@ def record_transaction_error(self, output: TransactionResultInternal) -> None:
workflow_uuid=output["workflow_uuid"],
function_id=output["function_id"],
output=None,
error=output["error"] if output["error"] else None,
error=output["error"],
txn_id=sa.text(
"(select pg_current_xact_id_if_assigned()::text)"
),
Expand Down
3 changes: 3 additions & 0 deletions dbos_transact/communicator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
class CommunicatorContext:
def __init__(self, function_id: int):
self.function_id = function_id
57 changes: 56 additions & 1 deletion dbos_transact/dbos.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,14 +10,20 @@
from typing import ParamSpec, TypeAlias

import dbos_transact.utils as utils
from dbos_transact.communicator import CommunicatorContext
from dbos_transact.error import DBOSRecoveryError, DBOSWorkflowConflictUUIDError
from dbos_transact.transaction import TransactionContext
from dbos_transact.workflows import WorkflowContext, WorkflowHandle

from .application_database import ApplicationDatabase, TransactionResultInternal
from .dbos_config import ConfigFile, load_config
from .logger import config_logger, dbos_logger
from .system_database import SystemDatabase, WorkflowInputs, WorkflowStatusInternal
from .system_database import (
OperationResultInternal,
SystemDatabase,
WorkflowInputs,
WorkflowStatusInternal,
)

P = ParamSpec("P")
R = TypeVar("R", covariant=True)
Expand All @@ -43,6 +49,15 @@ def __call__(self, ctx: TransactionContext, *args: Any, **kwargs: Any) -> Any: .
Transaction = TypeVar("Transaction", bound=TransactionProtocol)


class CommunicatorProtocol(Protocol):
__qualname__: str

def __call__(self, ctx: CommunicatorContext, *args: Any, **kwargs: Any) -> Any: ...


Communicator = TypeVar("Communicator", bound=CommunicatorProtocol)


class WorkflowInputContext(TypedDict):
workflow_uuid: str

Expand Down Expand Up @@ -228,6 +243,46 @@ def wrapper(_ctxt: TransactionContext, *args: Any, **kwargs: Any) -> Any:

return decorator

def communicator(self) -> Callable[[Communicator], Communicator]:
def decorator(func: Communicator) -> Communicator:
@wraps(func)
def wrapper(_ctxt: CommunicatorContext, *args: Any, **kwargs: Any) -> Any:
input_ctxt = cast(WorkflowContext, _ctxt)
input_ctxt.function_id += 1
comm_output: OperationResultInternal = {
"workflow_uuid": input_ctxt.workflow_uuid,
"function_id": input_ctxt.function_id,
"output": None,
"error": None,
}
comm_ctxt = CommunicatorContext(input_ctxt.function_id)
recorded_output = self.sys_db.check_operation_execution(
input_ctxt.workflow_uuid, comm_ctxt.function_id
)
if recorded_output:
if recorded_output["error"]:
deserialized_error = utils.deserialize(recorded_output["error"])
raise deserialized_error
elif recorded_output["output"]:
return utils.deserialize(recorded_output["output"])
else:
raise Exception("Output and error are both None")
output = None
try:
# TODO: support configurable retries
output = func(comm_ctxt, *args, **kwargs)
comm_output["output"] = utils.serialize(output)
except Exception as error:
comm_output["error"] = utils.serialize(error)
raise error
finally:
self.sys_db.record_operation_result(comm_output)
return output

return cast(Communicator, wrapper)

return decorator

def execute_workflow_uuid(self, workflow_uuid: str) -> None:
"""
This function is used to execute a workflow by a UUID for recovery.
Expand Down
53 changes: 53 additions & 0 deletions dbos_transact/system_database.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from alembic.config import Config

import dbos_transact.utils as utils
from dbos_transact.error import DBOSWorkflowConflictUUIDError

from .dbos_config import ConfigFile
from .schemas.system_database import SystemSchema
Expand Down Expand Up @@ -39,6 +40,18 @@ class WorkflowStatusInternal(TypedDict):
error: Optional[str] # Base64-encoded pickle


class RecordedResult(TypedDict):
output: Optional[str] # Base64-encoded pickle
error: Optional[str] # Base64-encoded pickle


class OperationResultInternal(TypedDict):
workflow_uuid: str
function_id: int
output: Optional[str] # Base64-encoded pickle
error: Optional[str] # Base64-encoded pickle


class SystemDatabase:

def __init__(self, config: ConfigFile):
Expand Down Expand Up @@ -170,3 +183,43 @@ def get_pending_workflows(self) -> list[str]:
)
).fetchall()
return [row[0] for row in rows]

def record_operation_result(self, result: OperationResultInternal) -> None:
error = result["error"]
output = result["output"]
assert error is None or output is None, "Only one of error or output can be set"
with self.engine.begin() as c:
try:
c.execute(
pg.insert(SystemSchema.operation_outputs).values(
workflow_uuid=result["workflow_uuid"],
function_id=result["function_id"],
output=output,
error=error,
)
)
except sa.exc.IntegrityError:
raise DBOSWorkflowConflictUUIDError(result["workflow_uuid"])
except Exception as e:
raise e

def check_operation_execution(
self, workflow_uuid: str, function_id: int
) -> Optional[RecordedResult]:
with self.engine.begin() as c:
rows = c.execute(
sa.select(
SystemSchema.operation_outputs.c.output,
SystemSchema.operation_outputs.c.error,
).where(
SystemSchema.operation_outputs.c.workflow_uuid == workflow_uuid,
SystemSchema.operation_outputs.c.function_id == function_id,
)
).all()
if len(rows) == 0:
return None
result: RecordedResult = {
"output": rows[0][0],
"error": rows[0][1],
}
return result
4 changes: 4 additions & 0 deletions dbos_transact/workflows.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from concurrent.futures import Future
from typing import Generic, TypeVar, cast

from dbos_transact.communicator import CommunicatorContext
from dbos_transact.transaction import TransactionContext

R = TypeVar("R")
Expand All @@ -15,6 +16,9 @@ def __init__(self, workflow_uuid: str):
def txn_ctx(self) -> TransactionContext:
return cast(TransactionContext, self)

def comm_ctx(self) -> CommunicatorContext:
return cast(CommunicatorContext, self)


class WorkflowHandle(Generic[R]):

Expand Down
12 changes: 10 additions & 2 deletions templates/hello/main.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,17 @@
import sqlalchemy as sa

from dbos_transact import DBOS, WorkflowContext
from dbos_transact.communicator import CommunicatorContext
from dbos_transact.transaction import TransactionContext

dbos = DBOS()


@dbos.workflow()
def example_workflow(ctx: WorkflowContext, var: str) -> str:
return example_transaction(ctx.txn_ctx(), var)
res1 = example_transaction(ctx.txn_ctx(), var)
res2 = example_communicator(ctx.comm_ctx(), var)
return res1 + res2


@dbos.transaction()
Expand All @@ -17,5 +20,10 @@ def example_transaction(ctx: TransactionContext, var: str) -> str:
return var + str(rows[0][0])


@dbos.communicator()
def example_communicator(ctx: CommunicatorContext, var: str) -> str:
return var + "2"


if __name__ == "__main__":
assert example_workflow(dbos.wf_ctx(), "mike") == "mike1"
assert example_workflow(dbos.wf_ctx(), "mike") == "mike1mike2"
31 changes: 29 additions & 2 deletions tests/test_dbos.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import pytest
import sqlalchemy as sa

from dbos_transact.communicator import CommunicatorContext
from dbos_transact.dbos import DBOS
from dbos_transact.transaction import TransactionContext
from dbos_transact.workflows import WorkflowContext
Expand All @@ -11,13 +12,15 @@
def test_simple_workflow(dbos: DBOS) -> None:
txn_counter: int = 0
wf_counter: int = 0
comm_counter: int = 0

@dbos.workflow()
def test_workflow(ctx: WorkflowContext, var: str, var2: str) -> str:
nonlocal wf_counter
wf_counter += 1
res = test_transaction(ctx.txn_ctx(), var2)
return res + var
res2 = test_communicator(ctx.comm_ctx(), var)
return res + res2

@dbos.transaction()
def test_transaction(ctx: TransactionContext, var2: str) -> str:
Expand All @@ -26,13 +29,20 @@ def test_transaction(ctx: TransactionContext, var2: str) -> str:
txn_counter += 1
return var2 + str(rows[0][0])

@dbos.communicator()
def test_communicator(ctx: CommunicatorContext, var: str) -> str:
nonlocal comm_counter
comm_counter += 1
return var

assert test_workflow(dbos.wf_ctx(), "bob", "bob") == "bob1bob"

# Test OAOO
wfuuid = str(uuid.uuid4())
assert test_workflow(dbos.wf_ctx(wfuuid), "alice", "alice") == "alice1alice"
assert test_workflow(dbos.wf_ctx(wfuuid), "alice", "alice") == "alice1alice"
assert txn_counter == 2 # Only increment once
assert comm_counter == 2 # Only increment once

# Test we can execute the workflow by uuid
dbos.execute_workflow_uuid(wfuuid)
Expand All @@ -42,21 +52,37 @@ def test_transaction(ctx: TransactionContext, var2: str) -> str:
def test_exception_workflow(dbos: DBOS) -> None:
txn_counter: int = 0
wf_counter: int = 0
comm_counter: int = 0

@dbos.transaction()
def exception_transaction(ctx: TransactionContext, var: str) -> str:
nonlocal txn_counter
txn_counter += 1
raise Exception(var)

@dbos.communicator()
def exception_communicator(ctx: CommunicatorContext, var: str) -> str:
nonlocal comm_counter
comm_counter += 1
raise Exception(var)

@dbos.workflow()
def exception_workflow(ctx: WorkflowContext) -> None:
nonlocal wf_counter
wf_counter += 1
err1 = None
err2 = None
try:
exception_transaction(ctx.txn_ctx(), "test error")
except Exception as e:
raise e
err1 = e

try:
exception_communicator(ctx.comm_ctx(), "test error")
except Exception as e:
err2 = e
assert err1 == err2 and err1 is not None
raise err1

with pytest.raises(Exception) as exc_info:
exception_workflow(dbos.wf_ctx())
Expand All @@ -73,6 +99,7 @@ def exception_workflow(ctx: WorkflowContext) -> None:
exception_workflow(dbos.wf_ctx(wfuuid))
assert "test error" in str(exc_info.value)
assert txn_counter == 2 # Only increment once
assert comm_counter == 2 # Only increment once

# Test we can execute the workflow by uuid, shouldn't throw errors
dbos.execute_workflow_uuid(wfuuid)
Expand Down

0 comments on commit f12f13e

Please sign in to comment.