From 0ac1847ebc7ff7038b5946e468dfc0e80c325424 Mon Sep 17 00:00:00 2001 From: El De-dog-lo <3859395+fubuloubu@users.noreply.github.com> Date: Tue, 28 Jan 2025 13:14:05 -0500 Subject: [PATCH] refactor!: upgrade subs to use web3py (#191) * chore: add official support for Python 3.13 * refactor: use web3py v7 persistent provider for subscriptions * fix: typing updates --- .github/workflows/test.yaml | 2 +- setup.py | 7 +- silverback/exceptions.py | 2 +- silverback/runner.py | 160 +++++++++++++++++------------- silverback/subscriptions.py | 189 ------------------------------------ 5 files changed, 98 insertions(+), 262 deletions(-) delete mode 100644 silverback/subscriptions.py diff --git a/.github/workflows/test.yaml b/.github/workflows/test.yaml index b1fc4fb3..03a7d10b 100644 --- a/.github/workflows/test.yaml +++ b/.github/workflows/test.yaml @@ -63,7 +63,7 @@ jobs: strategy: matrix: os: [ubuntu-latest, macos-latest] # eventually add `windows-latest` - python-version: ["3.10", "3.11", "3.12"] + python-version: ["3.10", "3.11", "3.12", "3.13"] steps: - uses: actions/checkout@v3 diff --git a/setup.py b/setup.py index 7c6393e7..b36dbc37 100644 --- a/setup.py +++ b/setup.py @@ -61,9 +61,9 @@ url="https://github.com/ApeWorX/silverback", include_package_data=True, install_requires=[ - "apepay>=0.3.2,<1", + "apepay>=0.3.3,<1", "click", # Use same version as eth-ape - "eth-ape>=0.8.19,<1.0", + "eth-ape>=0.8.24,<1", "ethpm-types>=0.6.10", # lower pin only, `eth-ape` governs upper pin "eth-pydantic-types", # Use same version as eth-ape "packaging", # Use same version as eth-ape @@ -71,7 +71,7 @@ "taskiq[metrics]>=0.11.9,<0.12", "tomlkit>=0.12,<1", # For reading/writing global platform profile "fief-client[cli]>=0.19,<1", # for platform auth/cluster login - "websockets>=14.1,<15", # For subscriptions + "web3>=7.7,<8", # TODO: Remove when Ape v0.9 is released (Ape v0.8 allows web3 v6) ], entry_points={ "console_scripts": ["silverback=silverback._cli:cli"], @@ -95,5 +95,6 @@ "Programming Language :: Python :: 3.10", "Programming Language :: Python :: 3.11", "Programming Language :: Python :: 3.12", + "Programming Language :: Python :: 3.13", ], ) diff --git a/silverback/exceptions.py b/silverback/exceptions.py index ffc6c6d3..59c8f138 100644 --- a/silverback/exceptions.py +++ b/silverback/exceptions.py @@ -33,7 +33,7 @@ class SilverbackException(ApeException): # TODO: `ExceptionGroup` added in Python 3.11 class StartupFailure(SilverbackException): - def __init__(self, *exceptions: Exception | str): + def __init__(self, *exceptions: BaseException | str | None): if len(exceptions) == 1 and isinstance(exceptions[0], str): super().__init__(exceptions[0]) elif error_str := "\n".join(str(e) for e in exceptions): diff --git a/silverback/runner.py b/silverback/runner.py index 46987fa4..dec9fcd5 100644 --- a/silverback/runner.py +++ b/silverback/runner.py @@ -1,28 +1,31 @@ import asyncio from abc import ABC, abstractmethod +from typing import Callable from ape import chain from ape.logging import logger from ape.utils import ManagerAccessMixin from ape_ethereum.ecosystem import keccak +from eth_utils import to_hex from ethpm_types import EventABI from packaging.specifiers import SpecifierSet from packaging.version import Version from taskiq import AsyncTaskiqTask from taskiq.kicker import AsyncKicker +from web3 import AsyncWeb3, WebSocketProvider +from web3.utils.subscriptions import ( + LogsSubscription, + LogsSubscriptionContext, + NewHeadsSubscription, + NewHeadsSubscriptionContext, +) from .exceptions import Halt, NoTasksAvailableError, NoWebsocketAvailableError, StartupFailure from .main import SilverbackBot, SystemConfig, TaskData from .recorder import BaseRecorder, TaskResult from .state import Datastore, StateSnapshot -from .subscriptions import SubscriptionType, Web3SubscriptionsManager from .types import TaskType -from .utils import ( - async_wrap_iter, - hexbytes_dict, - run_taskiq_task_group_wait_results, - run_taskiq_task_wait_result, -) +from .utils import async_wrap_iter, run_taskiq_task_group_wait_results, run_taskiq_task_wait_result class BaseRunner(ABC): @@ -88,18 +91,18 @@ async def _checkpoint( await self.datastore.save(result.return_value) @abstractmethod - async def _block_task(self, task_data: TaskData): + async def _block_task(self, task_data: TaskData) -> asyncio.Task | None: """ Handle a block_handler task """ @abstractmethod - async def _event_task(self, task_data: TaskData): + async def _event_task(self, task_data: TaskData) -> asyncio.Task | None: """ Handle an event handler task for the given contract event """ - async def run(self): + async def run(self, *runtime_tasks: asyncio.Task | Callable[[], asyncio.Task]): """ Run the task broker client for the assembled ``SilverbackBot`` bot. @@ -148,12 +151,12 @@ async def run(self): "Silverback no longer supports runner-based snapshotting, " "please upgrade your bot SDK version to latest to use snapshots." ) - startup_state = StateSnapshot( + startup_state: StateSnapshot | None = StateSnapshot( last_block_seen=-1, last_block_processed=-1, ) # Use empty snapshot - elif not (startup_state := await self.datastore.init(bot_id=self.bot.identifier)): + elif not (startup_state := await self.datastore.init(self.bot.identifier)): logger.warning("No state snapshot detected, using empty snapshot") startup_state = StateSnapshot( # TODO: Migrate these to parameters (remove explicitly from state) @@ -178,7 +181,7 @@ async def run(self): # Initialize recorder (if available) if self.recorder: - await self.recorder.init(bot_id=self.bot.identifier) + await self.recorder.init(self.bot.identifier) # Execute Silverback startup task before we init the rest startup_taskdata_result = await run_taskiq_task_wait_result( @@ -210,6 +213,7 @@ async def run(self): # NOTE: No need to handle results otherwise # Create our long-running event listeners + listener_tasks = [] new_block_taskdata_results = await run_taskiq_task_wait_result( self._create_system_task_kicker(TaskType.SYSTEM_USER_TASKDATA), TaskType.NEW_BLOCK ) @@ -230,23 +234,22 @@ async def run(self): raise NoTasksAvailableError() # NOTE: Any propagated failure in here should be handled such that shutdown tasks also run - # TODO: `asyncio.TaskGroup` added in Python 3.11 - listener_tasks = ( - *( - asyncio.create_task(self._block_task(task_def)) - for task_def in new_block_taskdata_results.return_value - ), - *( - asyncio.create_task(self._event_task(task_def)) - for task_def in event_log_taskdata_results.return_value - ), - ) + for task_def in new_block_taskdata_results.return_value: + if (task := await self._block_task(task_def)) is not None: + listener_tasks.append(task) + + for task_def in event_log_taskdata_results.return_value: + if (task := await self._event_task(task_def)) is not None: + listener_tasks.append(task) + + listener_tasks.extend(t if isinstance(t, asyncio.Task) else t() for t in runtime_tasks) # NOTE: Safe to do this because no tasks were actually scheduled to run if len(listener_tasks) == 0: raise NoTasksAvailableError() # Run until one task bubbles up an exception that should stop execution + # TODO: `asyncio.TaskGroup` added in Python 3.11 tasks_with_errors, tasks_running = await asyncio.wait( listener_tasks, return_when=asyncio.FIRST_EXCEPTION ) @@ -310,19 +313,21 @@ def __init__(self, bot: SilverbackBot, *args, **kwargs): self.ws_uri = ws_uri - async def _block_task(self, task_data: TaskData): + async def _block_task(self, task_data: TaskData) -> None: new_block_task_kicker = self._create_task_kicker(task_data) - sub_id = await self.subscriptions.subscribe(SubscriptionType.BLOCKS) - logger.debug(f"Handling blocks via {sub_id}") - - async for raw_block in self.subscriptions.get_subscription_data(sub_id): - block = self.provider.network.ecosystem.decode_block(hexbytes_dict(raw_block)) + async def block_handler(ctx: NewHeadsSubscriptionContext): + block = self.provider.network.ecosystem.decode_block(dict(ctx.result)) await self._checkpoint(last_block_seen=block.number) - await self._handle_task(await new_block_task_kicker.kiq(raw_block)) + await self._handle_task(await new_block_task_kicker.kiq(block)) await self._checkpoint(last_block_processed=block.number) - async def _event_task(self, task_data: TaskData): + sub_id = await self._web3.subscription_manager.subscribe( + NewHeadsSubscription(label=task_data.name, handler=block_handler) + ) + logger.debug(f"Handling blocks via {sub_id}") + + async def _event_task(self, task_data: TaskData) -> None: if not (contract_address := task_data.labels.get("contract_address")): raise StartupFailure("Contract instance required.") @@ -333,26 +338,37 @@ async def _event_task(self, task_data: TaskData): event_log_task_kicker = self._create_task_kicker(task_data) - sub_id = await self.subscriptions.subscribe( - SubscriptionType.EVENTS, - address=contract_address, - topics=["0x" + keccak(text=event_abi.selector).hex()], - ) - logger.debug(f"Handling '{contract_address}:{event_abi.name}' logs via {sub_id}") - - async for raw_event in self.subscriptions.get_subscription_data(sub_id): + async def log_handler(ctx: LogsSubscriptionContext): event = next( # NOTE: `next` is okay since it only has one item - self.provider.network.ecosystem.decode_logs([raw_event], event_abi) + self.provider.network.ecosystem.decode_logs([ctx.result], event_abi) ) - + # TODO: Fix upstream w/ web3py + event.transaction_hash = "0x" + event.transaction_hash.hex() await self._checkpoint(last_block_seen=event.block_number) await self._handle_task(await event_log_task_kicker.kiq(event)) await self._checkpoint(last_block_processed=event.block_number) - async def run(self): - async with Web3SubscriptionsManager(self.ws_uri) as subscriptions: - self.subscriptions = subscriptions - await super().run() + sub_id = await self._web3.subscription_manager.subscribe( + LogsSubscription( + label=task_data.name, + address=contract_address, + topics=[to_hex(keccak(text=event_abi.selector))], + handler=log_handler, + ) + ) + logger.debug(f"Handling '{contract_address}:{event_abi.name}' logs via {sub_id}") + + async def run(self, *runtime_tasks: asyncio.Task | Callable[[], asyncio.Task]): + async with AsyncWeb3(WebSocketProvider(self.ws_uri)) as web3: + self._web3 = web3 + + def run_subscriptions() -> asyncio.Task: + return asyncio.create_task( + web3.subscription_manager.handle_subscriptions(run_forever=True) + ) + + await super().run(*runtime_tasks, run_subscriptions) + await web3.subscription_manager.unsubscribe_all() class PollingRunner(BaseRunner, ManagerAccessMixin): @@ -370,7 +386,7 @@ def __init__(self, bot: SilverbackBot, *args, **kwargs): "Do not use in production over long time periods unless you know what you're doing." ) - async def _block_task(self, task_data: TaskData): + async def _block_task(self, task_data: TaskData) -> asyncio.Task: new_block_task_kicker = self._create_task_kicker(task_data) if block_settings := self.bot.poll_settings.get("_blocks_"): @@ -381,17 +397,21 @@ async def _block_task(self, task_data: TaskData): new_block_timeout = ( new_block_timeout if new_block_timeout is not None else self.bot.new_block_timeout ) - async for block in async_wrap_iter( - chain.blocks.poll_blocks( - # NOTE: No start block because we should begin polling from head - new_block_timeout=new_block_timeout, - ) - ): - await self._checkpoint(last_block_seen=block.number) - await self._handle_task(await new_block_task_kicker.kiq(block)) - await self._checkpoint(last_block_processed=block.number) - async def _event_task(self, task_data: TaskData): + async def block_handler(): + async for block in async_wrap_iter( + chain.blocks.poll_blocks( + # NOTE: No start block because we should begin polling from head + new_block_timeout=new_block_timeout, + ) + ): + await self._checkpoint(last_block_seen=block.number) + await self._handle_task(await new_block_task_kicker.kiq(block)) + await self._checkpoint(last_block_processed=block.number) + + return asyncio.create_task(block_handler()) + + async def _event_task(self, task_data: TaskData) -> asyncio.Task: if not (contract_address := task_data.labels.get("contract_address")): raise StartupFailure("Contract instance required.") @@ -409,14 +429,18 @@ async def _event_task(self, task_data: TaskData): new_block_timeout = ( new_block_timeout if new_block_timeout is not None else self.bot.new_block_timeout ) - async for event in async_wrap_iter( - self.provider.poll_logs( - # NOTE: No start block because we should begin polling from head - address=contract_address, - new_block_timeout=new_block_timeout, - events=[event_abi], - ) - ): - await self._checkpoint(last_block_seen=event.block_number) - await self._handle_task(await event_log_task_kicker.kiq(event)) - await self._checkpoint(last_block_processed=event.block_number) + + async def log_handler(): + async for event in async_wrap_iter( + self.provider.poll_logs( + # NOTE: No start block because we should begin polling from head + address=contract_address, + new_block_timeout=new_block_timeout, + events=[event_abi], + ) + ): + await self._checkpoint(last_block_seen=event.block_number) + await self._handle_task(await event_log_task_kicker.kiq(event)) + await self._checkpoint(last_block_processed=event.block_number) + + return asyncio.create_task(log_handler()) diff --git a/silverback/subscriptions.py b/silverback/subscriptions.py deleted file mode 100644 index 23a1e64b..00000000 --- a/silverback/subscriptions.py +++ /dev/null @@ -1,189 +0,0 @@ -import asyncio -import json -from enum import Enum -from typing import AsyncGenerator, Optional - -from ape.logging import logger -from websockets import ConnectionClosedError -from websockets.asyncio import client as ws_client - - -class SubscriptionType(Enum): - BLOCKS = "newHeads" - EVENTS = "logs" - - -class Web3SubscriptionsManager: - websocket_reconnect_max_tries: int = 3 - rpc_response_timeout_count: int = 10 - subscription_polling_time: float = 0.1 # secs - - def __init__(self, ws_provider_uri: str): - # TODO: Temporary until a more permanent solution is added to ProviderAPI - if "infura" in ws_provider_uri and "ws/v3" not in ws_provider_uri: - ws_provider_uri = ws_provider_uri.replace("v3", "ws/v3") - - self._ws_provider_uri = ws_provider_uri - - # Stateful - self._connection: ws_client.ClientConnection | None = None - self._last_request: int = 0 - self._subscriptions: dict[str, asyncio.Queue] = {} - self._rpc_msg_buffer: list[dict] = [] - self._ws_lock = asyncio.Lock() - - def __repr__(self) -> str: - return f"<{self.__class__.__name__} uri={self._ws_provider_uri}>" - - async def __aenter__(self) -> "Web3SubscriptionsManager": - self.connection = await ws_client.connect(self._ws_provider_uri) - return self - - def __aiter__(self) -> "Web3SubscriptionsManager": - return self - - async def __anext__(self) -> str: - if not self.connection: - raise StopAsyncIteration - - return await self._receive() - - async def _receive(self, timeout: Optional[int] = None) -> str: - """Receive (and wait if no timeout) for the next message from the - socket. - """ - if not self.connection: - raise ConnectionError("Connection not opened") - - message = await asyncio.wait_for(self.connection.recv(), timeout) - # TODO: Handle retries when connection breaks - - response = json.loads(message) - if response.get("method") == "eth_subscription": - sub_params: dict = response.get("params", {}) - if not (sub_id := sub_params.get("subscription")) or not isinstance(sub_id, str): - logger.warning(f"Corrupted subscription data: {response}") - return response - - if sub_id not in self._subscriptions: - self._subscriptions[sub_id] = asyncio.Queue() - - await self._subscriptions[sub_id].put(sub_params.get("result", {})) - - else: - self._rpc_msg_buffer.append(response) - - return response - - def _create_request(self, method: str, params: list) -> dict: - self._last_request += 1 - return { - "jsonrpc": "2.0", - "id": self._last_request, - "method": method, - "params": params, - } - - async def _get_response(self, request_id: int) -> dict: - if buffer := self._rpc_msg_buffer: - for idx, data in enumerate(buffer): - if data.get("id") == request_id: - self._rpc_msg_buffer.pop(idx) - return data - - async with self._ws_lock: - tries = 0 - while tries < self.rpc_response_timeout_count: - if self._rpc_msg_buffer and self._rpc_msg_buffer[-1].get("id") == request_id: - return self._rpc_msg_buffer.pop() - - # NOTE: Python <3.10 does not support `anext` function - await self.__anext__() # Keep pulling until we get a response - - raise RuntimeError("Timeout waiting for response.") - - async def subscribe(self, type: SubscriptionType, **filter_params) -> str: - if not self.connection: - raise ValueError("Connection required.") - - if type is SubscriptionType.BLOCKS and filter_params: - raise ValueError("blocks subscription doesn't accept filter params.") - - request = self._create_request( - "eth_subscribe", - [type.value, filter_params] if type is SubscriptionType.EVENTS else [type.value], - ) - await self.connection.send(json.dumps(request)) - response = await self._get_response(request.get("id") or self._last_request) - - sub_id = response.get("result") - if not sub_id: - # NOTE: Re-dumping message to avoid type-checking concerns. - raise ValueError(f"Missing subscription ID in response: {json.dumps(response)}.") - - return sub_id - - async def get_subscription_data(self, sub_id: str) -> AsyncGenerator[dict, None]: - """Iterate items from the subscription queue. If nothing is in the - queue, await. - """ - while True: - if not (queue := self._subscriptions.get(sub_id)) or queue.empty(): - async with self._ws_lock: - # Keep pulling until a message comes to process - # NOTE: Python <3.10 does not support `anext` function - await self.__anext__() - else: - yield await queue.get() - - async def get_subscription_data_nowait( - self, sub_id: str, timeout: Optional[int] = 15 - ) -> AsyncGenerator[dict, None]: - """Iterate items from the subscription queue. If nothing is in the - queue, return. - """ - while True: - if not (queue := self._subscriptions.get(sub_id)) or queue.empty(): - async with self._ws_lock: - try: - await self._receive(timeout=timeout) - except TimeoutError: - logger.warning(f"Receive call timed out ({sub_id}).") - return - else: - try: - yield queue.get_nowait() - except asyncio.QueueEmpty: - return - - async def unsubscribe(self, sub_id: str) -> bool: - if sub_id not in self._subscriptions: - raise ValueError(f"Unknown sub_id '{sub_id}'") - - if not self.connection: - # Nothing to unsubscribe. - return True - - request = self._create_request("eth_unsubscribe", [sub_id]) - await self.connection.send(json.dumps(request)) - - response = await self._get_response(request.get("id") or self._last_request) - if success := response.get("result", False): - del self._subscriptions[sub_id] # NOTE: Save memory - - return success - - async def __aexit__(self, exc_type, exc, tb): - try: - # Try to gracefully unsubscribe to all events - await asyncio.gather(*(self.unsubscribe(sub_id) for sub_id in self._subscriptions)) - - except ConnectionClosedError: - pass # Websocket already closed (ctrl+C and patiently waiting) - - finally: - # Disconnect and release websocket - try: - await self.connection.close() - except RuntimeError: - pass # No running event loop to disconnect from (multiple ctrl+C presses)