diff --git a/fiber/chain/weights.py b/fiber/chain/weights.py index a567a20..daa61e8 100644 --- a/fiber/chain/weights.py +++ b/fiber/chain/weights.py @@ -3,7 +3,6 @@ from typing import Any, Callable from scalecodec import ScaleType -from scalecodec.types import GenericExtrinsic from substrateinterface import Keypair, SubstrateInterface from tenacity import retry, stop_after_attempt, wait_exponential @@ -115,35 +114,6 @@ def wrapper(*args, **kwargs): return wrapper -@retry( - stop=stop_after_attempt(3), - wait=wait_exponential(multiplier=1.5, min=2, max=5), - reraise=True, -) -@_log_and_reraise -def _send_extrinsic( - substrate: SubstrateInterface, - extrinsic_to_send: GenericExtrinsic, - wait_for_inclusion: bool = False, - wait_for_finalization: bool = False, -) -> tuple[bool, str | None]: - ## Context manager here so if we need to reconnect, the retry loop will catch it - with substrate as si: - response = si.submit_extrinsic( - extrinsic_to_send, - wait_for_inclusion=wait_for_inclusion, - wait_for_finalization=wait_for_finalization, - ) - if not wait_for_finalization and not wait_for_inclusion: - return True, "Not waiting for finalization or inclusion." - response.process_events() - - if response.is_success: - return True, "Successfully set weights." - - return False, _format_error_message(response.error_message) - - def _can_set_weights(substrate: SubstrateInterface, netuid: int, validator_node_id: int) -> bool: blocks_since_update = _blocks_since_last_update(substrate, netuid, validator_node_id) min_interval = _min_interval_to_set_weights(substrate, netuid) @@ -152,6 +122,54 @@ def _can_set_weights(substrate: SubstrateInterface, netuid: int, validator_node_ return blocks_since_update is not None and blocks_since_update > min_interval +def _send_weights_to_chain( + substrate: SubstrateInterface, + keypair: Keypair, + node_ids: list[int], + node_weights: list[float], + netuid: int, + version_key: int = 0, + wait_for_inclusion: bool = False, + wait_for_finalization: bool = False, +) -> tuple[bool, str | None]: + @retry( + stop=stop_after_attempt(3), + wait=wait_exponential(multiplier=1.5, min=2, max=5), + reraise=True, + ) + @_log_and_reraise + def _set_weights(): + with substrate as si: + rpc_call = si.compose_call( + call_module="SubtensorModule", + call_function="set_weights", + call_params={ + "dests": node_ids, + "weights": node_weights, + "netuid": netuid, + "version_key": version_key, + }, + ) + extrinsic_to_send = si.create_signed_extrinsic(call=rpc_call, keypair=keypair, era={"period": 5}) + + response = si.submit_extrinsic( + extrinsic_to_send, + wait_for_inclusion=wait_for_inclusion, + wait_for_finalization=wait_for_finalization, + ) + + if not wait_for_finalization and not wait_for_inclusion: + return True, "Not waiting for finalization or inclusion." + response.process_events() + + if response.is_success: + return True, "Successfully set weights." + + return False, _format_error_message(response.error_message) + + return _set_weights() + + def set_node_weights( substrate: SubstrateInterface, keypair: Keypair, @@ -169,18 +187,6 @@ def set_node_weights( # Fetch a new substrate object to reset the connection substrate = get_substrate(subtensor_address=substrate.url) - rpc_call = substrate.compose_call( - call_module="SubtensorModule", - call_function="set_weights", - call_params={ - "dests": node_ids_formatted, - "weights": node_weights_formatted, - "netuid": netuid, - "version_key": version_key, - }, - ) - extrinsic_to_send = substrate.create_signed_extrinsic(call=rpc_call, keypair=keypair, era={"period": 5}) - weights_can_be_set = False for attempt in range(1, max_attempts + 1): if not _can_set_weights(substrate, netuid, validator_node_id): @@ -196,12 +202,15 @@ def set_node_weights( return False logger.info("Attempting to set weights...") - - success, error_message = _send_extrinsic( - substrate=substrate, - extrinsic_to_send=extrinsic_to_send, - wait_for_inclusion=wait_for_inclusion, - wait_for_finalization=wait_for_finalization, + success, error_message = _send_weights_to_chain( + substrate, + keypair, + node_ids_formatted, + node_weights_formatted, + netuid, + version_key, + wait_for_inclusion, + wait_for_finalization, ) if not wait_for_finalization and not wait_for_inclusion: