From edd328c2eabdbe30a8ffdd1f8190e1194b9f5f5d Mon Sep 17 00:00:00 2001 From: namoray <152790419+namoray@users.noreply.github.com> Date: Sun, 22 Dec 2024 19:07:04 +0000 Subject: [PATCH] CR3 (#27) --- dev_utils/chain/examples.py | 17 +- fiber/chain/chain_utils.py | 16 +- fiber/chain/weights.py | 302 ++++++++++++++++++++++++------------ fiber/logging_utils.py | 2 +- pyproject.toml | 5 +- 5 files changed, 230 insertions(+), 112 deletions(-) diff --git a/dev_utils/chain/examples.py b/dev_utils/chain/examples.py index a2a25e0..3987cfa 100644 --- a/dev_utils/chain/examples.py +++ b/dev_utils/chain/examples.py @@ -20,23 +20,23 @@ async def metagraph_example(): logger.info(f"Found nodes: {nodes}") -async def set_weights_example(): +async def set_weights_example(netuid: int = 267): substrate = interface.get_substrate(subtensor_network="test") - nodes = get_nodes_for_netuid(substrate=substrate, netuid=176) + nodes = get_nodes_for_netuid(substrate=substrate, netuid=netuid) keypair = chain_utils.load_hotkey_keypair(wallet_name="default", hotkey_name="default") - validator_node_id = substrate.query("SubtensorModule", "Uids", [176, keypair.ss58_address]).value - version_key = substrate.query("SubtensorModule", "WeightsVersionKey", [176]).value + validator_node_id = substrate.query("SubtensorModule", "Uids", [netuid, keypair.ss58_address]).value + version_key = substrate.query("SubtensorModule", "WeightsVersionKey", [netuid]).value weights.set_node_weights( substrate=substrate, keypair=keypair, node_ids=[node.node_id for node in nodes], - node_weights=[node.incentive for node in nodes], - netuid=176, + node_weights=[node.incentive + 1 for node in nodes], + netuid=netuid, validator_node_id=validator_node_id, version_key=version_key, wait_for_inclusion=True, - wait_for_finalization=True, - ) + wait_for_finalization=True ) + # NOTE this is also a script in /scropts/post_ip_to_chain and you can use it on the cli with fiber-post-ip async def post_ip_to_chain_example(): @@ -68,5 +68,6 @@ async def main(): await set_weights_example() await post_ip_to_chain_example() + if __name__ == "__main__": asyncio.run(main()) diff --git a/fiber/chain/chain_utils.py b/fiber/chain/chain_utils.py index 6342cb3..94c7082 100644 --- a/fiber/chain/chain_utils.py +++ b/fiber/chain/chain_utils.py @@ -107,22 +107,28 @@ def sign_message(keypair: Keypair, message: str | None) -> str | None: return f"0x{keypair.sign(message).hex()}" - def query_substrate( - substrate: SubstrateInterface, module: str, method: str, params: list[Any], return_value: bool = True + substrate: SubstrateInterface, + module: str, + method: str, + params: list[Any], + return_value: bool = True, + block: int | None = None, ) -> tuple[SubstrateInterface, Any]: try: - query_result = substrate.query(module, method, params) + block_hash = substrate.get_block_hash(block) if block is not None else None + query_result = substrate.query(module, method, params, block_hash=block_hash) return_val = query_result.value if return_value else query_result return substrate, return_val except Exception as e: - logger.error(f"Query failed with error: {e}. Reconnecting and retrying.") + logger.debug(f"Substrate query failed with error: {e}. Reconnecting and retrying.") substrate = SubstrateInterface(url=substrate.url) - query_result = substrate.query(module, method, params) + block_hash = substrate.get_block_hash(block) if block is not None else None + query_result = substrate.query(module, method, params, block_hash=block_hash) return_val = query_result.value if return_value else query_result diff --git a/fiber/chain/weights.py b/fiber/chain/weights.py index c7bb9b9..0e3249e 100644 --- a/fiber/chain/weights.py +++ b/fiber/chain/weights.py @@ -1,76 +1,29 @@ -import time +import warnings from functools import wraps from typing import Any, Callable -from scalecodec import ScaleType +from bittensor_commit_reveal import get_encrypted_commit from substrateinterface import Keypair, SubstrateInterface from tenacity import retry, stop_after_attempt, wait_exponential from fiber import constants as fcst -from fiber.chain.chain_utils import format_error_message +from fiber.chain.chain_utils import format_error_message, query_substrate from fiber.chain.interface import get_substrate from fiber.logging_utils import get_logger logger = get_logger(__name__) -@retry( - stop=stop_after_attempt(3), - wait=wait_exponential(multiplier=1, min=1, max=4), - reraise=True, -) -def _query_subtensor( - substrate: SubstrateInterface, - name: str, - block: int | None = None, - params: int | None = None, -) -> ScaleType: - try: - return substrate.query( - module="SubtensorModule", - storage_function=name, - params=params, # type: ignore - block_hash=(None if block is None else substrate.get_block_hash(block)), # type: ignore - ) - except Exception: - # Should prevent SSL errors - substrate = get_substrate(subtensor_address=substrate.url) - raise - - -def _get_hyperparameter( - substrate: SubstrateInterface, - param_name: str, - netuid: int, - block: int | None = None, -) -> list[int] | int | None: - subnet_exists = getattr( - _query_subtensor(substrate, "NetworksAdded", block, [netuid]), # type: ignore - "value", - False, - ) - if not subnet_exists: - return None - return getattr( - _query_subtensor(substrate, param_name, block, [netuid]), # type: ignore - "value", - None, - ) - - -def _blocks_since_last_update(substrate: SubstrateInterface, netuid: int, node_id: int) -> int | None: - current_block = substrate.get_block_number(None) # type: ignore - last_updated = _get_hyperparameter(substrate, "LastUpdate", netuid) - assert not isinstance(last_updated, int), "LastUpdate should be a list of ints" - if last_updated is None: - return None - return current_block - int(last_updated[node_id]) - +def _log_and_reraise(func: Callable[..., Any]) -> Callable[..., Any]: + @wraps(func) + def wrapper(*args, **kwargs): + try: + return func(*args, **kwargs) + except Exception as e: + logger.exception(f"Exception in {func.__name__}: {str(e)}") + raise -def _min_interval_to_set_weights(substrate: SubstrateInterface, netuid: int) -> int: - weights_set_rate_limit = _get_hyperparameter(substrate, "WeightsSetRateLimit", netuid) - assert isinstance(weights_set_rate_limit, int), "WeightsSetRateLimit should be an int" - return weights_set_rate_limit + return wrapper def _normalize_and_quantize_weights(node_ids: list[int], node_weights: list[float]) -> tuple[list[int], list[int]]: @@ -90,24 +43,33 @@ def _normalize_and_quantize_weights(node_ids: list[int], node_weights: list[floa return node_ids_formatted, node_weights_formatted -def _log_and_reraise(func: Callable[..., Any]) -> Callable[..., Any]: - @wraps(func) - def wrapper(*args, **kwargs): - try: - return func(*args, **kwargs) - except Exception as e: - logger.exception(f"Exception in {func.__name__}: {str(e)}") - raise +def blocks_since_last_update(substrate: SubstrateInterface, netuid: int, node_id: int) -> int: + substrate, current_block = query_substrate(substrate, "System", "Number", [], return_value=True) + substrate, last_updated_value = query_substrate(substrate, "SubtensorModule", "LastUpdate", [netuid], return_value=False) + updated: int = current_block - last_updated_value[node_id].value + return updated - return wrapper +def min_interval_to_set_weights(substrate: SubstrateInterface, netuid: int) -> int: + substrate, weights_set_rate_limit = query_substrate( + substrate, "SubtensorModule", "WeightsSetRateLimit", [netuid], return_value=True + ) + assert isinstance(weights_set_rate_limit, int), "WeightsSetRateLimit should be an int" + return weights_set_rate_limit -def _can_set_weights(substrate: SubstrateInterface, netuid: int, validator_node_id: int) -> bool: - blocks_since_update = _blocks_since_last_update(substrate, netuid, validator_node_id) - min_interval = _min_interval_to_set_weights(substrate, netuid) + +def can_set_weights(substrate: SubstrateInterface, netuid: int, validator_node_id: int) -> bool: + blocks_since_update = blocks_since_last_update(substrate, netuid, validator_node_id) + min_interval = min_interval_to_set_weights(substrate, netuid) if min_interval is None: return True - return blocks_since_update is not None and blocks_since_update > min_interval + + can_set_weights = blocks_since_update is not None and blocks_since_update >= min_interval + if not can_set_weights: + logger.error( + f"It is too soon to set weights! {blocks_since_update} blocks since last update, {min_interval} blocks required." + ) + return can_set_weights def _send_weights_to_chain( @@ -126,7 +88,7 @@ def _send_weights_to_chain( reraise=True, ) @_log_and_reraise - def _set_weights(): + def _send_weights(): with substrate as si: rpc_call = si.compose_call( call_module="SubtensorModule", @@ -155,48 +117,132 @@ def _set_weights(): return False, format_error_message(response.error_message) - return _set_weights() + return _send_weights() -def set_node_weights( +def _send_commit_reveal_weights_to_chain( + substrate: SubstrateInterface, + keypair: Keypair, + commit: bytes, + reveal_round: int, + netuid: int, + wait_for_inclusion: bool = False, + wait_for_finalization: bool = False, +) -> tuple[bool, str | None]: + @retry( + stop=stop_after_attempt(3), + wait=wait_exponential(multiplier=1.5, min=2, max=5), + reraise=True, + ) + @_log_and_reraise + def _send_commit_reveal_weights(): + call = substrate.compose_call( + call_module="SubtensorModule", + call_function="commit_crv3_weights", + call_params={ + "netuid": netuid, + "commit": commit, + "reveal_round": reveal_round, + }, + ) + extrinsic = substrate.create_signed_extrinsic( + call=call, + keypair=keypair, + ) + + response = substrate.submit_extrinsic( + extrinsic=extrinsic, + wait_for_inclusion=wait_for_inclusion, + wait_for_finalization=wait_for_finalization, + ) + + if not wait_for_finalization and not wait_for_inclusion: + return True, "Not waiting for finalization or inclusion." + + response.process_events() + if response.is_success: + return True, None + else: + return False, format_error_message(response.error_message) + + return _send_commit_reveal_weights() + + +def _set_weights_without_commit_reveal( substrate: SubstrateInterface, keypair: Keypair, node_ids: list[int], node_weights: list[float], netuid: int, - validator_node_id: int, version_key: int = 0, wait_for_inclusion: bool = False, wait_for_finalization: bool = False, - max_attempts: int = 1, ) -> bool: - node_ids_formatted, node_weights_formatted = _normalize_and_quantize_weights(node_ids, node_weights) + logger.info(f"Setting weights for subnet {netuid} with version key {version_key} - no commit reveal...") + success, error_message = _send_weights_to_chain( + substrate, + keypair, + node_ids, + node_weights, + netuid, + version_key, + wait_for_inclusion, + wait_for_finalization, + ) - # Fetch a new substrate object to reset the connection - substrate = get_substrate(subtensor_address=substrate.url) + if not wait_for_finalization and not wait_for_inclusion: + logger.info("Not waiting for finalization or inclusion to set weights. Returning immediately.") + return success - weights_can_be_set = False - for attempt in range(1, max_attempts + 1): - if not _can_set_weights(substrate, netuid, validator_node_id): - logger.info(logger.info(f"Skipping attempt {attempt}/{max_attempts}. Too soon to set weights. Will wait 30 secs...")) - time.sleep(30) - continue + if success: + if wait_for_finalization: + logger.info("✅ Successfully set weights and finalized") + elif wait_for_inclusion: + logger.info("✅ Successfully set weights and included") else: - weights_can_be_set = True - break + logger.info("✅ Successfully set weights") + else: + logger.error(f"❌ Failed to set weights: {error_message}") - if not weights_can_be_set: - logger.error("No attempt to set weightsmade. Perhaps it is too soon to set weights!") - return False + substrate.close() + return success - logger.info("Attempting to set weights...") - success, error_message = _send_weights_to_chain( + +def _set_weights_with_commit_reveal( + substrate: SubstrateInterface, + keypair: Keypair, + node_ids: list[int], + node_weights: list[float], + netuid: int, + version_key: int = 0, + wait_for_inclusion: bool = False, + wait_for_finalization: bool = False, +) -> bool: + substrate, current_block = query_substrate(substrate, "System", "Number", [], return_value=True) + + substrate, tempo = query_substrate(substrate, "SubtensorModule", "Tempo", [netuid], return_value=True) + substrate, subnet_reveal_period_epochs = query_substrate( + substrate, "SubtensorModule", "RevealPeriodEpochs", [netuid], return_value=True + ) + + # Encrypt `commit_hash` with t-lock and `get reveal_round` + commit_for_reveal, reveal_round = get_encrypted_commit( + uids=node_ids, + weights=node_weights, + version_key=version_key, + tempo=tempo, + current_block=current_block, + netuid=netuid, + subnet_reveal_period_epochs=subnet_reveal_period_epochs, + ) + + logger.info(f"Committing weights hash {commit_for_reveal.hex()} for subnet {netuid} with " f"reveal round {reveal_round}...") + success, error_message = _send_commit_reveal_weights_to_chain( substrate, keypair, - node_ids_formatted, - node_weights_formatted, + commit_for_reveal, + reveal_round, netuid, - version_key, wait_for_inclusion, wait_for_finalization, ) @@ -217,3 +263,67 @@ def set_node_weights( substrate.close() return success + + +def set_node_weights( + substrate: SubstrateInterface, + keypair: Keypair, + node_ids: list[int], + node_weights: list[float], + netuid: int, + validator_node_id: int, + version_key: int = 0, + wait_for_inclusion: bool = False, + wait_for_finalization: bool = False, + max_attempts: int | None = None, # NOTE: DEPRECATED +) -> bool: + if max_attempts is not None: + warnings.warn( + "Parameter 'max_attempts' is deprecated and will be removed in version 2.2.0", DeprecationWarning, stacklevel=2 + ) + node_ids_formatted, node_weights_formatted = _normalize_and_quantize_weights(node_ids, node_weights) + + # Fetch a new substrate object to reset the connection + substrate = get_substrate(subtensor_address=substrate.url) + + if not can_set_weights(substrate, netuid, validator_node_id): + return False + + # NOTE: Sadly this can't be an argument of the function, the hyperparam must be set on chain + # For it to function properly + substrate, commit_reveal_enabled = query_substrate( + substrate, + "SubtensorModule", + "CommitRevealWeightsEnabled", + [netuid], + return_value=True, + ) + + logger.info(f"Commit reveal enabled hyperparameter is set to {commit_reveal_enabled}") + + if commit_reveal_enabled is False: + return _set_weights_without_commit_reveal( + substrate, + keypair, + node_ids_formatted, + node_weights_formatted, + netuid, + version_key, + wait_for_inclusion, + wait_for_finalization, + ) + + elif commit_reveal_enabled is True: + return _set_weights_with_commit_reveal( + substrate, + keypair, + node_ids_formatted, + node_weights_formatted, + netuid, + version_key, + wait_for_inclusion, + wait_for_finalization, + ) + + else: + raise ValueError(f"Commit reveal enabled hyperparameter is set to {commit_reveal_enabled}, which is not a valid value") diff --git a/fiber/logging_utils.py b/fiber/logging_utils.py index 876e1b9..e98203e 100644 --- a/fiber/logging_utils.py +++ b/fiber/logging_utils.py @@ -53,5 +53,5 @@ def get_logger(name: str): console_handler.setFormatter(colored_formatter) logger.addHandler(console_handler) - logger.info(f"Logging mode is {logging.getLevelName(logger.getEffectiveLevel())}") + logger.debug(f"Logging mode is {logging.getLevelName(logger.getEffectiveLevel())}") return logger diff --git a/pyproject.toml b/pyproject.toml index d121d1f..8ee4078 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "hatchling.build" [project] name = "fiber" -version = "2.0.0" +version = "2.1.0" description = "The ultra lightweight network for miner-validator communication" readme = "README.md" requires-python = ">=3.10" @@ -19,7 +19,8 @@ dependencies = [ "colorama>0.3.0,<=0.4.6", "python-dotenv==1.0.1", "pydantic>2,<=2.9.2", - "netaddr==1.3.0" + "netaddr==1.3.0", + "bittensor-commit-reveal==0.1.0" ] [project.optional-dependencies]