Skip to content

Commit

Permalink
OAOO For Transactions (#15)
Browse files Browse the repository at this point in the history
It works for both normal outputs and exceptions.
  • Loading branch information
qianl15 authored Jul 26, 2024
1 parent f581a38 commit 4b6acb5
Show file tree
Hide file tree
Showing 5 changed files with 179 additions and 53 deletions.
78 changes: 60 additions & 18 deletions dbos_transact/application_database.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
from typing import Optional, TypedDict
from typing import Optional, TypedDict, cast

import sqlalchemy as sa
import sqlalchemy.dialects.postgresql as pg
import sqlalchemy.exc as sa_exc
from sqlalchemy.orm import Session, sessionmaker

from dbos_transact.error import DBOSWorkflowConflictUUIDError
from dbos_transact.schemas.application_database import ApplicationSchema

from .dbos_config import ConfigFile
Expand All @@ -19,6 +21,11 @@ class TransactionResultInternal(TypedDict):
executor_id: Optional[str]


class RecordedResult(TypedDict):
output: Optional[str] # Base64-encoded pickle
error: Optional[str] # Base64-encoded pickle


class ApplicationDatabase:

def __init__(self, config: ConfigFile):
Expand Down Expand Up @@ -72,30 +79,65 @@ def destroy(self) -> None:
def record_transaction_output(
session: Session, output: TransactionResultInternal
) -> None:
session.execute(
pg.insert(ApplicationSchema.transaction_outputs).values(
workflow_uuid=output["workflow_uuid"],
function_id=output["function_id"],
output=output["output"] if output["output"] else None,
error=None,
txn_id=sa.text("(select pg_current_xact_id_if_assigned()::text)"),
txn_snapshot=output["txn_snapshot"],
executor_id=output["executor_id"] if output["executor_id"] else None,
)
)

def record_transaction_error(self, output: TransactionResultInternal) -> None:
with self.engine.begin() as conn:
conn.execute(
try:
session.execute(
pg.insert(ApplicationSchema.transaction_outputs).values(
workflow_uuid=output["workflow_uuid"],
function_id=output["function_id"],
output=None,
error=output["error"] if output["error"] else None,
output=output["output"] if output["output"] else None,
error=None,
txn_id=sa.text("(select pg_current_xact_id_if_assigned()::text)"),
txn_snapshot=output["txn_snapshot"],
executor_id=(
output["executor_id"] if output["executor_id"] else None
),
)
)
except sa_exc.IntegrityError:
raise DBOSWorkflowConflictUUIDError(output["workflow_uuid"])
except Exception as e:
raise e

def record_transaction_error(self, output: TransactionResultInternal) -> None:
try:
with self.engine.begin() as conn:
conn.execute(
pg.insert(ApplicationSchema.transaction_outputs).values(
workflow_uuid=output["workflow_uuid"],
function_id=output["function_id"],
output=None,
error=output["error"] if output["error"] else None,
txn_id=sa.text(
"(select pg_current_xact_id_if_assigned()::text)"
),
txn_snapshot=output["txn_snapshot"],
executor_id=(
output["executor_id"] if output["executor_id"] else None
),
)
)
except sa_exc.IntegrityError:
raise DBOSWorkflowConflictUUIDError(output["workflow_uuid"])
except Exception as e:
raise e

@staticmethod
def check_transaction_execution(
session: Session, workflow_uuid: str, function_id: int
) -> Optional[RecordedResult]:
rows = session.execute(
sa.select(
ApplicationSchema.transaction_outputs.c.output,
ApplicationSchema.transaction_outputs.c.error,
).where(
ApplicationSchema.transaction_outputs.c.workflow_uuid == workflow_uuid,
ApplicationSchema.transaction_outputs.c.function_id == function_id,
)
).all()
if len(rows) == 0:
return None
result: RecordedResult = {
"output": rows[0][0],
"error": rows[0][1],
}
return result
97 changes: 64 additions & 33 deletions dbos_transact/dbos.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,16 @@
import uuid
from functools import wraps
from typing import Any, Callable, Optional, Protocol, TypeVar, cast
from typing import Any, Callable, Optional, Protocol, TypedDict, TypeVar, cast

import dbos_transact.utils as utils
from dbos_transact.error import DBOSWorkflowConflictUUIDError
from dbos_transact.transaction import TransactionContext
from dbos_transact.workflows import WorkflowContext

from .application_database import ApplicationDatabase, TransactionResultInternal
from .dbos_config import ConfigFile, load_config
from .logger import config_logger, dbos_logger
from .system_database import (
SystemDatabase,
WorkflowInputs,
WorkflowStatusInternal,
WorkflowStatusString,
)
from .system_database import SystemDatabase, WorkflowInputs, WorkflowStatusInternal


class WorkflowProtocol(Protocol):
Expand All @@ -35,6 +31,10 @@ def __call__(self, ctx: TransactionContext, *args: Any, **kwargs: Any) -> Any: .
Transaction = TypeVar("Transaction", bound=TransactionProtocol)


class WorkflowInputContext(TypedDict):
workflow_uuid: str


class DBOS:
def __init__(self, config: Optional[ConfigFile] = None) -> None:
if config is None:
Expand All @@ -52,12 +52,12 @@ def destroy(self) -> None:
def workflow(self) -> Callable[[Workflow], Workflow]:
def decorator(func: Workflow) -> Workflow:
@wraps(func)
def wrapper(_: WorkflowContext, *args: Any, **kwargs: Any) -> Any:
workflow_uuid = str(uuid.uuid4())

def wrapper(_ctxt: WorkflowContext, *args: Any, **kwargs: Any) -> Any:
input_ctxt = cast(WorkflowInputContext, _ctxt)
workflow_uuid = input_ctxt["workflow_uuid"]
status: WorkflowStatusInternal = {
"workflow_uuid": workflow_uuid,
"status": WorkflowStatusString.PENDING.value,
"status": "PENDING",
"name": func.__qualname__,
"output": None,
"error": None,
Expand All @@ -76,13 +76,16 @@ def wrapper(_: WorkflowContext, *args: Any, **kwargs: Any) -> Any:

try:
output = func(ctx, *args, **kwargs)
except DBOSWorkflowConflictUUIDError as wferror:
# TODO: handle this properly by waiting/returning the output
raise wferror
except Exception as error:
status["status"] = WorkflowStatusString.ERROR.value
status["status"] = "ERROR"
status["error"] = utils.serialize(error)
self.sys_db.update_workflow_status(status)
raise error

status["status"] = WorkflowStatusString.SUCCESS.value
status["status"] = "SUCCESS"
status["output"] = utils.serialize(output)
self.sys_db.update_workflow_status(status)
return output
Expand All @@ -91,41 +94,69 @@ def wrapper(_: WorkflowContext, *args: Any, **kwargs: Any) -> Any:

return decorator

def wf_ctx(self) -> WorkflowContext:
return cast(WorkflowContext, None)
def wf_ctx(self, workflow_uuid: Optional[str] = None) -> WorkflowContext:
workflow_uuid = workflow_uuid if workflow_uuid else str(uuid.uuid4())
input_ctxt: WorkflowInputContext = {"workflow_uuid": workflow_uuid}
return cast(WorkflowContext, input_ctxt)

def transaction(self) -> Callable[[Transaction], Transaction]:
def decorator(func: Transaction) -> Transaction:
@wraps(func)
def wrapper(wf_ctxt: TransactionContext, *args: Any, **kwargs: Any) -> Any:
ctxt = cast(WorkflowContext, wf_ctxt)
ctxt.function_id += 1
def wrapper(_ctxt: TransactionContext, *args: Any, **kwargs: Any) -> Any:
input_ctxt = cast(WorkflowContext, _ctxt)
input_ctxt.function_id += 1
with self.app_db.sessionmaker() as session:
txn_output: TransactionResultInternal = {
"workflow_uuid": input_ctxt.workflow_uuid,
"function_id": input_ctxt.function_id,
"output": None,
"error": None,
"txn_snapshot": "", # TODO: add actual snapshot
"executor_id": None,
"txn_id": None,
}
has_recorded_error = False
try:
# TODO: support multiple isolation levels
# TODO: handle serialization errors properly
with session.begin():
# This must be the first statement in the transaction!
session.connection(
execution_options={"isolation_level": "SERIALIZABLE"}
execution_options={"isolation_level": "REPEATABLE READ"}
)
txn_ctxt = TransactionContext(session, ctxt.function_id)
# TODO: Check transaction output for OAOO
txn_output: TransactionResultInternal = {
"workflow_uuid": ctxt.workflow_uuid,
"function_id": ctxt.function_id,
"output": None,
"error": None,
"txn_snapshot": "",
"executor_id": None,
"txn_id": None,
}
txn_ctxt = TransactionContext(
session, input_ctxt.function_id
)
# Check recorded output for OAOO
recorded_output = (
ApplicationDatabase.check_transaction_execution(
session,
input_ctxt.workflow_uuid,
txn_ctxt.function_id,
)
)
if recorded_output:
if recorded_output["error"]:
deserialized_error = utils.deserialize(
recorded_output["error"]
)
has_recorded_error = True
raise deserialized_error
elif recorded_output["output"]:
return utils.deserialize(recorded_output["output"])
else:
raise Exception("Output and error are both None")
output = func(txn_ctxt, *args, **kwargs)
txn_output["output"] = utils.serialize(output)
ApplicationDatabase.record_transaction_output(
txn_ctxt.session, txn_output
)

except Exception as error:
# TODO: handle serialization errors properly
txn_output["error"] = utils.serialize(error)
self.app_db.record_transaction_error(txn_output)
# Don't record the error if it was already recorded
if not has_recorded_error:
txn_output["error"] = utils.serialize(error)
self.app_db.record_transaction_error(txn_output)
raise error
return output

Expand Down
23 changes: 23 additions & 0 deletions dbos_transact/error.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
from typing import Optional


class DBOSException(Exception):
def __init__(self, message: str, dbos_error_code: Optional[int] = None):
self.message = message
self.dbos_error_code = dbos_error_code
super().__init__(self.message)

def __str__(self) -> str:
if self.dbos_error_code:
return f"DBOS Error {self.dbos_error_code}: {self.message}"
return f"DBOS Error: {self.message}"


ConflictingUUIDError = 1


class DBOSWorkflowConflictUUIDError(DBOSException):
def __init__(self, workflow_uuid: str):
super().__init__(
f"Conflicting UUID {workflow_uuid}", dbos_error_code=ConflictingUUIDError
)
9 changes: 7 additions & 2 deletions dbos_transact/system_database.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import os
from enum import Enum
from typing import Any, Optional, TypedDict
from typing import Any, Literal, Optional, TypedDict

import sqlalchemy as sa
import sqlalchemy.dialects.postgresql as pg
Expand All @@ -21,14 +21,19 @@ class WorkflowStatusString(Enum):
CANCELLED = "CANCELLED"


WorkflowStatuses = Literal[
"PENDING", "SUCCESS", "ERROR", "RETRIES_EXCEEDED", "CANCELLED"
]


class WorkflowInputs(TypedDict):
args: Any
kwargs: Any


class WorkflowStatusInternal(TypedDict):
workflow_uuid: str
status: str
status: WorkflowStatuses
name: str
output: Optional[str] # Base64-encoded pickle
error: Optional[str] # Base64-encoded pickle
Expand Down
25 changes: 25 additions & 0 deletions tests/test_dbos.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import uuid

import pytest
import sqlalchemy as sa

Expand All @@ -7,6 +9,7 @@


def test_simple_workflow(dbos: DBOS) -> None:
txn_counter: int = 0

@dbos.workflow()
def test_workflow(ctx: WorkflowContext, var: str, var2: str) -> str:
Expand All @@ -16,15 +19,26 @@ def test_workflow(ctx: WorkflowContext, var: str, var2: str) -> str:
@dbos.transaction()
def test_transaction(ctx: TransactionContext, var2: str) -> str:
rows = ctx.session.execute(sa.text("SELECT 1")).fetchall()
nonlocal txn_counter
txn_counter += 1
return var2 + str(rows[0][0])

assert test_workflow(dbos.wf_ctx(), "bob", "bob") == "bob1bob"

# Test OAOO
wfuuid = str(uuid.uuid4())
assert test_workflow(dbos.wf_ctx(wfuuid), "alice", "alice") == "alice1alice"
assert test_workflow(dbos.wf_ctx(wfuuid), "alice", "alice") == "alice1alice"
assert txn_counter == 2 # Only increment once


def test_exception_workflow(dbos: DBOS) -> None:
txn_counter: int = 0

@dbos.transaction()
def exception_transaction(ctx: TransactionContext, var: str) -> str:
nonlocal txn_counter
txn_counter += 1
raise Exception(var)

@dbos.workflow()
Expand All @@ -38,3 +52,14 @@ def exception_workflow(ctx: WorkflowContext) -> None:
exception_workflow(dbos.wf_ctx())

assert "test error" in str(exc_info.value)

# Test OAOO
wfuuid = str(uuid.uuid4())
with pytest.raises(Exception) as exc_info:
exception_workflow(dbos.wf_ctx(wfuuid))
assert "test error" in str(exc_info.value)

with pytest.raises(Exception) as exc_info:
exception_workflow(dbos.wf_ctx(wfuuid))
assert "test error" in str(exc_info.value)
assert txn_counter == 2 # Only increment once

0 comments on commit 4b6acb5

Please sign in to comment.