diff --git a/dbos_transact/application_database.py b/dbos_transact/application_database.py index c66d3b6b..c88cbbb4 100644 --- a/dbos_transact/application_database.py +++ b/dbos_transact/application_database.py @@ -84,7 +84,7 @@ def record_transaction_output( pg.insert(ApplicationSchema.transaction_outputs).values( workflow_uuid=output["workflow_uuid"], function_id=output["function_id"], - output=output["output"] if output["output"] else None, + output=output["output"], error=None, txn_id=sa.text("(select pg_current_xact_id_if_assigned()::text)"), txn_snapshot=output["txn_snapshot"], @@ -106,7 +106,7 @@ def record_transaction_error(self, output: TransactionResultInternal) -> None: workflow_uuid=output["workflow_uuid"], function_id=output["function_id"], output=None, - error=output["error"] if output["error"] else None, + error=output["error"], txn_id=sa.text( "(select pg_current_xact_id_if_assigned()::text)" ), diff --git a/dbos_transact/communicator.py b/dbos_transact/communicator.py new file mode 100644 index 00000000..52e323cc --- /dev/null +++ b/dbos_transact/communicator.py @@ -0,0 +1,3 @@ +class CommunicatorContext: + def __init__(self, function_id: int): + self.function_id = function_id diff --git a/dbos_transact/dbos.py b/dbos_transact/dbos.py index 9882d6c0..1e993995 100644 --- a/dbos_transact/dbos.py +++ b/dbos_transact/dbos.py @@ -10,6 +10,7 @@ from typing import ParamSpec, TypeAlias import dbos_transact.utils as utils +from dbos_transact.communicator import CommunicatorContext from dbos_transact.error import DBOSRecoveryError, DBOSWorkflowConflictUUIDError from dbos_transact.transaction import TransactionContext from dbos_transact.workflows import WorkflowContext, WorkflowHandle @@ -17,7 +18,12 @@ from .application_database import ApplicationDatabase, TransactionResultInternal from .dbos_config import ConfigFile, load_config from .logger import config_logger, dbos_logger -from .system_database import SystemDatabase, WorkflowInputs, WorkflowStatusInternal +from .system_database import ( + OperationResultInternal, + SystemDatabase, + WorkflowInputs, + WorkflowStatusInternal, +) P = ParamSpec("P") R = TypeVar("R", covariant=True) @@ -43,6 +49,15 @@ def __call__(self, ctx: TransactionContext, *args: Any, **kwargs: Any) -> Any: . Transaction = TypeVar("Transaction", bound=TransactionProtocol) +class CommunicatorProtocol(Protocol): + __qualname__: str + + def __call__(self, ctx: CommunicatorContext, *args: Any, **kwargs: Any) -> Any: ... + + +Communicator = TypeVar("Communicator", bound=CommunicatorProtocol) + + class WorkflowInputContext(TypedDict): workflow_uuid: str @@ -228,6 +243,46 @@ def wrapper(_ctxt: TransactionContext, *args: Any, **kwargs: Any) -> Any: return decorator + def communicator(self) -> Callable[[Communicator], Communicator]: + def decorator(func: Communicator) -> Communicator: + @wraps(func) + def wrapper(_ctxt: CommunicatorContext, *args: Any, **kwargs: Any) -> Any: + input_ctxt = cast(WorkflowContext, _ctxt) + input_ctxt.function_id += 1 + comm_output: OperationResultInternal = { + "workflow_uuid": input_ctxt.workflow_uuid, + "function_id": input_ctxt.function_id, + "output": None, + "error": None, + } + comm_ctxt = CommunicatorContext(input_ctxt.function_id) + recorded_output = self.sys_db.check_operation_execution( + input_ctxt.workflow_uuid, comm_ctxt.function_id + ) + if recorded_output: + if recorded_output["error"]: + deserialized_error = utils.deserialize(recorded_output["error"]) + raise deserialized_error + elif recorded_output["output"]: + return utils.deserialize(recorded_output["output"]) + else: + raise Exception("Output and error are both None") + output = None + try: + # TODO: support configurable retries + output = func(comm_ctxt, *args, **kwargs) + comm_output["output"] = utils.serialize(output) + except Exception as error: + comm_output["error"] = utils.serialize(error) + raise error + finally: + self.sys_db.record_operation_result(comm_output) + return output + + return cast(Communicator, wrapper) + + return decorator + def execute_workflow_uuid(self, workflow_uuid: str) -> None: """ This function is used to execute a workflow by a UUID for recovery. diff --git a/dbos_transact/system_database.py b/dbos_transact/system_database.py index 10fdb68b..d824ccdb 100644 --- a/dbos_transact/system_database.py +++ b/dbos_transact/system_database.py @@ -8,6 +8,7 @@ from alembic.config import Config import dbos_transact.utils as utils +from dbos_transact.error import DBOSWorkflowConflictUUIDError from .dbos_config import ConfigFile from .schemas.system_database import SystemSchema @@ -39,6 +40,18 @@ class WorkflowStatusInternal(TypedDict): error: Optional[str] # Base64-encoded pickle +class RecordedResult(TypedDict): + output: Optional[str] # Base64-encoded pickle + error: Optional[str] # Base64-encoded pickle + + +class OperationResultInternal(TypedDict): + workflow_uuid: str + function_id: int + output: Optional[str] # Base64-encoded pickle + error: Optional[str] # Base64-encoded pickle + + class SystemDatabase: def __init__(self, config: ConfigFile): @@ -170,3 +183,43 @@ def get_pending_workflows(self) -> list[str]: ) ).fetchall() return [row[0] for row in rows] + + def record_operation_result(self, result: OperationResultInternal) -> None: + error = result["error"] + output = result["output"] + assert error is None or output is None, "Only one of error or output can be set" + with self.engine.begin() as c: + try: + c.execute( + pg.insert(SystemSchema.operation_outputs).values( + workflow_uuid=result["workflow_uuid"], + function_id=result["function_id"], + output=output, + error=error, + ) + ) + except sa.exc.IntegrityError: + raise DBOSWorkflowConflictUUIDError(result["workflow_uuid"]) + except Exception as e: + raise e + + def check_operation_execution( + self, workflow_uuid: str, function_id: int + ) -> Optional[RecordedResult]: + with self.engine.begin() as c: + rows = c.execute( + sa.select( + SystemSchema.operation_outputs.c.output, + SystemSchema.operation_outputs.c.error, + ).where( + SystemSchema.operation_outputs.c.workflow_uuid == workflow_uuid, + SystemSchema.operation_outputs.c.function_id == function_id, + ) + ).all() + if len(rows) == 0: + return None + result: RecordedResult = { + "output": rows[0][0], + "error": rows[0][1], + } + return result diff --git a/dbos_transact/workflows.py b/dbos_transact/workflows.py index b2490276..93d6f0be 100644 --- a/dbos_transact/workflows.py +++ b/dbos_transact/workflows.py @@ -1,6 +1,7 @@ from concurrent.futures import Future from typing import Generic, TypeVar, cast +from dbos_transact.communicator import CommunicatorContext from dbos_transact.transaction import TransactionContext R = TypeVar("R") @@ -15,6 +16,9 @@ def __init__(self, workflow_uuid: str): def txn_ctx(self) -> TransactionContext: return cast(TransactionContext, self) + def comm_ctx(self) -> CommunicatorContext: + return cast(CommunicatorContext, self) + class WorkflowHandle(Generic[R]): diff --git a/templates/hello/main.py b/templates/hello/main.py index 64c6ca53..0cb08485 100644 --- a/templates/hello/main.py +++ b/templates/hello/main.py @@ -1,6 +1,7 @@ import sqlalchemy as sa from dbos_transact import DBOS, WorkflowContext +from dbos_transact.communicator import CommunicatorContext from dbos_transact.transaction import TransactionContext dbos = DBOS() @@ -8,7 +9,9 @@ @dbos.workflow() def example_workflow(ctx: WorkflowContext, var: str) -> str: - return example_transaction(ctx.txn_ctx(), var) + res1 = example_transaction(ctx.txn_ctx(), var) + res2 = example_communicator(ctx.comm_ctx(), var) + return res1 + res2 @dbos.transaction() @@ -17,5 +20,10 @@ def example_transaction(ctx: TransactionContext, var: str) -> str: return var + str(rows[0][0]) +@dbos.communicator() +def example_communicator(ctx: CommunicatorContext, var: str) -> str: + return var + "2" + + if __name__ == "__main__": - assert example_workflow(dbos.wf_ctx(), "mike") == "mike1" + assert example_workflow(dbos.wf_ctx(), "mike") == "mike1mike2" diff --git a/tests/test_dbos.py b/tests/test_dbos.py index dbfaf559..ed0babb8 100644 --- a/tests/test_dbos.py +++ b/tests/test_dbos.py @@ -3,6 +3,7 @@ import pytest import sqlalchemy as sa +from dbos_transact.communicator import CommunicatorContext from dbos_transact.dbos import DBOS from dbos_transact.transaction import TransactionContext from dbos_transact.workflows import WorkflowContext @@ -11,13 +12,15 @@ def test_simple_workflow(dbos: DBOS) -> None: txn_counter: int = 0 wf_counter: int = 0 + comm_counter: int = 0 @dbos.workflow() def test_workflow(ctx: WorkflowContext, var: str, var2: str) -> str: nonlocal wf_counter wf_counter += 1 res = test_transaction(ctx.txn_ctx(), var2) - return res + var + res2 = test_communicator(ctx.comm_ctx(), var) + return res + res2 @dbos.transaction() def test_transaction(ctx: TransactionContext, var2: str) -> str: @@ -26,6 +29,12 @@ def test_transaction(ctx: TransactionContext, var2: str) -> str: txn_counter += 1 return var2 + str(rows[0][0]) + @dbos.communicator() + def test_communicator(ctx: CommunicatorContext, var: str) -> str: + nonlocal comm_counter + comm_counter += 1 + return var + assert test_workflow(dbos.wf_ctx(), "bob", "bob") == "bob1bob" # Test OAOO @@ -33,6 +42,7 @@ def test_transaction(ctx: TransactionContext, var2: str) -> str: assert test_workflow(dbos.wf_ctx(wfuuid), "alice", "alice") == "alice1alice" assert test_workflow(dbos.wf_ctx(wfuuid), "alice", "alice") == "alice1alice" assert txn_counter == 2 # Only increment once + assert comm_counter == 2 # Only increment once # Test we can execute the workflow by uuid dbos.execute_workflow_uuid(wfuuid) @@ -42,6 +52,7 @@ def test_transaction(ctx: TransactionContext, var2: str) -> str: def test_exception_workflow(dbos: DBOS) -> None: txn_counter: int = 0 wf_counter: int = 0 + comm_counter: int = 0 @dbos.transaction() def exception_transaction(ctx: TransactionContext, var: str) -> str: @@ -49,14 +60,29 @@ def exception_transaction(ctx: TransactionContext, var: str) -> str: txn_counter += 1 raise Exception(var) + @dbos.communicator() + def exception_communicator(ctx: CommunicatorContext, var: str) -> str: + nonlocal comm_counter + comm_counter += 1 + raise Exception(var) + @dbos.workflow() def exception_workflow(ctx: WorkflowContext) -> None: nonlocal wf_counter wf_counter += 1 + err1 = None + err2 = None try: exception_transaction(ctx.txn_ctx(), "test error") except Exception as e: - raise e + err1 = e + + try: + exception_communicator(ctx.comm_ctx(), "test error") + except Exception as e: + err2 = e + assert err1 == err2 and err1 is not None + raise err1 with pytest.raises(Exception) as exc_info: exception_workflow(dbos.wf_ctx()) @@ -73,6 +99,7 @@ def exception_workflow(ctx: WorkflowContext) -> None: exception_workflow(dbos.wf_ctx(wfuuid)) assert "test error" in str(exc_info.value) assert txn_counter == 2 # Only increment once + assert comm_counter == 2 # Only increment once # Test we can execute the workflow by uuid, shouldn't throw errors dbos.execute_workflow_uuid(wfuuid)