Skip to content

Commit

Permalink
refactor!: use task type to differentiate tasks in middleware, not name
Browse files Browse the repository at this point in the history
  • Loading branch information
fubuloubu committed Apr 8, 2024
1 parent 144a087 commit 8ddb0d3
Showing 1 changed file with 30 additions and 21 deletions.
51 changes: 30 additions & 21 deletions silverback/middlewares.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,24 +7,22 @@
from taskiq import TaskiqMessage, TaskiqMiddleware, TaskiqResult

from silverback.persistence import HandlerResult
from silverback.types import SilverbackID, handler_id_block, handler_id_event
from silverback.types import SilverbackID, TaskType, handler_id_block, handler_id_event
from silverback.utils import hexbytes_dict


def resolve_task(message: TaskiqMessage) -> Tuple[str, Optional[int], Optional[int]]:
block_number = None
log_index = None
block_number = message.labels.get("number") or message.labels.get("block")
log_index = message.labels.get("log_index")
task_id = message.task_name

if task_id == "block":
block_number = message.args[0].number
task_id = handler_id_block(block_number)
elif "event" in task_id:
block_number = message.args[0].block_number
log_index = message.args[0].log_index
if log_index:
# TODO: Should standardize on event signature here instead of name in case of overloading
task_id = handler_id_event(message.args[0].contract_address, message.args[0].event_name)

elif block_number:
task_id = handler_id_block(block_number)

return task_id, block_number, log_index


Expand Down Expand Up @@ -66,34 +64,40 @@ def fix_dict(data: dict, recurse_count: int = 0) -> dict:
return message

def _create_label(self, message: TaskiqMessage) -> str:
if message.task_name == "block":
args = f"[block={message.args[0].hash.hex()}]"

elif "event" in message.task_name:
args = f"[txn={message.args[0].transaction_hash},log_index={message.args[0].log_index}]"
if labels_str := (
",".join(f"{k}={v}" for k, v in message.labels.items() if k != "task_name")
):
return f"{message.task_name}[{labels_str}]"

else:
args = ""

return f"{message.task_name}{args}"
return message.task_name

def pre_execute(self, message: TaskiqMessage) -> TaskiqMessage:
if message.task_name == "block":
message.labels["task_name"] = message.task_name
task_type = message.labels.pop("task_type", "<unknown>")

# NOTE: Don't compare `str` to `TaskType` using `is`
if task_type == TaskType.NEW_BLOCKS:
# NOTE: Necessary because we don't know the exact block class
message.args[0] = self.provider.network.ecosystem.decode_block(
hexbytes_dict(message.args[0])
)
message.labels["number"] = str(message.args[0].number)
message.labels["hash"] = message.args[0].hash.hex()

elif "event" in message.task_name:
# NOTE: Just in case the user doesn't specify type as `ContractLog`
message.args[0] = ContractLog.model_validate(message.args[0])
message.labels["block"] = str(message.args[0].block_number)
message.labels["txn_id"] = message.args[0].transaction_hash
message.labels["log_index"] = str(message.args[0].log_index)

logger.info(f"{self._create_label(message)} - Started")
logger.debug(f"{self._create_label(message)} - Started")
return message

def post_execute(self, message: TaskiqMessage, result: TaskiqResult):
percentage_time = 100 * (result.execution_time / self.block_time)
logger.info(
logger.success(
f"{self._create_label(message)} "
f"- {result.execution_time:.3f}s ({percentage_time:.1f}%)"
)
Expand All @@ -119,4 +123,9 @@ async def on_error(
result: TaskiqResult,
exception: BaseException,
):
logger.error(f"{message.task_name} - {type(exception).__name__}: {exception}")
percentage_time = 100 * (result.execution_time / self.block_time)
logger.error(
f"{self._create_label(message)} "
f"- {result.execution_time:.3f}s ({percentage_time:.1f}%)"
)
# NOTE: Unless stdout is ignored, error traceback appears in stdout

0 comments on commit 8ddb0d3

Please sign in to comment.