diff --git a/silverback/main.py b/silverback/main.py index 015ec1a5..b5061b96 100644 --- a/silverback/main.py +++ b/silverback/main.py @@ -1,6 +1,8 @@ import atexit +import inspect from collections import defaultdict from datetime import timedelta +from functools import wraps from typing import Any, Callable from ape.api.networks import LOCAL_NETWORK_NAME @@ -138,6 +140,7 @@ def __init__(self, settings: Settings | None = None): self.signer = settings.get_signer() self.new_block_timeout = settings.NEW_BLOCK_TIMEOUT + self.use_fork = settings.FORK_MODE signer_str = f"\n SIGNER={repr(self.signer)}" new_block_timeout_str = ( @@ -146,7 +149,7 @@ def __init__(self, settings: Settings | None = None): network_choice = f"{self.identifier.ecosystem}:{self.identifier.network}" logger.success( - f'Loaded Silverback Bot:\n NETWORK="{network_choice}"' + f'Loaded Silverback Bot:\n NETWORK="{network_choice}"\n FORK_MODE={self.use_fork}' f"{signer_str}{new_block_timeout_str}" ) @@ -284,6 +287,22 @@ def add_taskiq_task(handler: Callable) -> AsyncTaskiqDecoratedTask: self.tasks[task_type].append(TaskData(name=handler.__name__, labels=labels)) + if self.use_fork: + from ape import networks # NOTE: Defer import for load speed + + # Trigger worker-side handling using fork network by wrapping handler + is_awaitable = inspect.isawaitable(handler) + + @wraps(handler) + async def fork_handler(*args, **kwargs): + with networks.fork(): + if is_awaitable: + return await handler(*args, **kwargs) + else: + return handler(*args, **kwargs) + + handler = fork_handler + return self.broker.register_task( handler, task_name=handler.__name__, diff --git a/silverback/settings.py b/silverback/settings.py index 7b756484..acd2b6d4 100644 --- a/silverback/settings.py +++ b/silverback/settings.py @@ -28,6 +28,10 @@ class Settings(BaseSettings, ManagerAccessMixin): # A unique identifier for this silverback instance BOT_NAME: str = "bot" + # Execute every handler using an independent fork context + # NOTE: Requires fork-able provider installed and configured for network + FORK_MODE: bool = False + BROKER_CLASS: str = "taskiq:InMemoryBroker" BROKER_URI: str = "" # To be deprecated in 0.6 BROKER_KWARGS: dict[str, Any] = dict()