From 47cb0743f61af2e000e7f587d3e859f2e26a36b6 Mon Sep 17 00:00:00 2001 From: Cameron Fairchild Date: Tue, 15 Oct 2024 14:32:11 -0400 Subject: [PATCH] Fix/add option decode for all (#4) * add decode_option for all classes * bump ver * ruff * add tests for new option decoding * fix readme * bump python req in test extra --- Cargo.lock | 22 ++++----- Cargo.toml | 2 +- README.md | 28 ++++++------ bt_decode.pyi | 69 ++++++++++++++++++++++++++++- libs/custom-derive/Cargo.toml | 2 +- libs/custom-derive/src/lib.rs | 15 +++++-- pyproject.toml | 4 +- src/dyndecoder.rs | 3 +- tests/test_decode_by_type_string.py | 2 - tests/test_decode_delegate_info.py | 24 +++++----- tests/test_decode_neurons.py | 24 +++++----- tests/test_decode_stake_info.py | 24 +++++----- tests/test_decode_subnet_info.py | 18 ++++++++ 13 files changed, 164 insertions(+), 73 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 9aca254..941d107 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -43,7 +43,7 @@ checksum = "baf1de4339761588bc0619e3cbc0120ee582ebb74b53b4efbf79117bd2da40fd" [[package]] name = "custom_derive" -version = "0.1.0" +version = "0.2.0" dependencies = [ "proc-macro2", "quote", @@ -218,9 +218,9 @@ dependencies = [ [[package]] name = "pyo3" -version = "0.22.2" +version = "0.22.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "831e8e819a138c36e212f3af3fd9eeffed6bf1510a805af35b0edee5ffa59433" +checksum = "00e89ce2565d6044ca31a3eb79a334c3a79a841120a98f64eea9f579564cb691" dependencies = [ "cfg-if", "indoc", @@ -236,9 +236,9 @@ dependencies = [ [[package]] name = "pyo3-build-config" -version = "0.22.2" +version = "0.22.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1e8730e591b14492a8945cdff32f089250b05f5accecf74aeddf9e8272ce1fa8" +checksum = "d8afbaf3abd7325e08f35ffb8deb5892046fcb2608b703db6a583a5ba4cea01e" dependencies = [ "once_cell", "target-lexicon", @@ -246,9 +246,9 @@ dependencies = [ [[package]] name = "pyo3-ffi" -version = "0.22.2" +version = "0.22.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5e97e919d2df92eb88ca80a037969f44e5e70356559654962cbb3316d00300c6" +checksum = "ec15a5ba277339d04763f4c23d85987a5b08cbb494860be141e6a10a8eb88022" dependencies = [ "libc", "pyo3-build-config", @@ -256,9 +256,9 @@ dependencies = [ [[package]] name = "pyo3-macros" -version = "0.22.2" +version = "0.22.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "eb57983022ad41f9e683a599f2fd13c3664d7063a3ac5714cae4b7bee7d3f206" +checksum = "15e0f01b5364bcfbb686a52fc4181d412b708a68ed20c330db9fc8d2c2bf5a43" dependencies = [ "proc-macro2", "pyo3-macros-backend", @@ -268,9 +268,9 @@ dependencies = [ [[package]] name = "pyo3-macros-backend" -version = "0.22.2" +version = "0.22.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ec480c0c51ddec81019531705acac51bcdbeae563557c982aa8263bb96880372" +checksum = "a09b550200e1e5ed9176976d0060cbc2ea82dc8515da07885e7b8153a85caacb" dependencies = [ "heck", "proc-macro2", diff --git a/Cargo.toml b/Cargo.toml index a29118e..f3a21dd 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -13,7 +13,7 @@ name = "bt_decode" crate-type = ["cdylib"] [dependencies.pyo3] -version = "0.22.2" +version = "0.22.3" features = ["extension-module"] [dependencies.custom_derive] diff --git a/README.md b/README.md index da7db29..06b8ba2 100644 --- a/README.md +++ b/README.md @@ -17,7 +17,7 @@ hex_bytes_result = sub.query_runtime_api( method="get_delegates", params=[ ] ) -# Decode scale-encoded DelegateInfo +# Decode scale-encoded Vec delegates_info: List[DelegateInfo] = DelegateInfo.decode_vec( bytes.fromhex( hex_bytes_result @@ -38,7 +38,7 @@ hex_bytes_result = sub.query_runtime_api( method="get_delegated", params=[list( validator_key.public_key )] ) -# Decode scale-encoded (DelegateInfo, take) +# Decode scale-encoded Vec<(DelegateInfo, take)> delegated_info: List[Tuple[DelegateInfo, int]] = DelegateInfo.decode_delegated( bytes.fromhex( hex_bytes_result @@ -82,8 +82,8 @@ hex_bytes_result = sub.query_runtime_api( method="get_neurons", params=[NETUID] ) -# Decode scale-encoded NeuronInfo -neurons: List[NeuronInfo] = NeuronInfo.decode( +# Decode scale-encoded Vec +neurons: List[NeuronInfo] = NeuronInfo.decode_vec( bytes.fromhex( hex_bytes_result )) @@ -126,8 +126,8 @@ hex_bytes_result = sub.query_runtime_api( method="get_neurons_lite", params=[NETUID] ) -# Decode scale-encoded NeuronInfoLite -neurons_lite: List[NeuronInfoLite] = NeuronInfoLite.decode( +# Decode scale-encoded Vec +neurons_lite: List[NeuronInfoLite] = NeuronInfoLite.decode_vec( bytes.fromhex( hex_bytes_result )) @@ -178,7 +178,7 @@ hex_bytes_result = sub.query_runtime_api( method="get_stake_info_for_coldkeys", params=[encoded_coldkeys] ) -# Decode scale-encoded (AccountId, StakeInfo) +# Decode scale-encoded Vec<(AccountId, StakeInfo)> stake_info: List[Tuple[bytes, List["StakeInfo"]]] = StakeInfo.decode_vec_tuple_vec( bytes.fromhex( hex_bytes_result @@ -199,8 +199,8 @@ hex_bytes_result = sub.query_runtime_api( method="get_subnet_info", params=[NETUID] ) -# Decode scale-encoded SubnetInfo -subnet_info: SubnetInfo = SubnetInfo.decode( +# Decode scale-encoded Option +subnet_info: SubnetInfo = SubnetInfo.decode_option( bytes.fromhex( hex_bytes_result )) @@ -219,15 +219,15 @@ hex_bytes_result = sub.query_runtime_api( method="get_subnets_info", params=[ ] ) -# Decode scale-encoded Optional[SubnetInfo] -subnets_info: List[Optional[SubnetInfo]] = SubnetInfo.decode_vec( +# Decode scale-encoded Vec> +subnets_info: List[Optional[SubnetInfo]] = SubnetInfo.decode_vec_option( bytes.fromhex( hex_bytes_result )) ``` ### SubnetHyperparameters -#### get_subnet_info +#### get_subnet_hyperparams ```python import bittensor from bt_decode import SubnetHyperparameters @@ -241,8 +241,8 @@ hex_bytes_result = sub.query_runtime_api( method="get_subnet_hyperparams", params=[NETUID] ) -# Decode scale-encoded SubnetHyperparameters -subnet_hyper_params: SubnetHyperparameters = SubnetHyperparameters.decode( +# Decode scale-encoded Option +subnet_hyper_params: Optional[SubnetHyperparameters] = SubnetHyperparameters.decode_option( bytes.fromhex( hex_bytes_result )) diff --git a/bt_decode.pyi b/bt_decode.pyi index 05353db..ab9d728 100644 --- a/bt_decode.pyi +++ b/bt_decode.pyi @@ -22,6 +22,9 @@ class AxonInfo: def decode(encoded: bytes) -> "AxonInfo": pass @staticmethod + def decode_option(encoded: bytes) -> Optional["AxonInfo"]: + pass + @staticmethod def decode_vec(encoded: bytes) -> List["AxonInfo"]: pass @@ -40,9 +43,50 @@ class PrometheusInfo: def decode(encoded: bytes) -> "PrometheusInfo": pass @staticmethod + def decode_option(encoded: bytes) -> Optional["PrometheusInfo"]: + pass + @staticmethod def decode_vec(encoded: bytes) -> List["PrometheusInfo"]: pass +class NeuronInfo: + hotkey: bytes + coldkey: bytes + uid: int + netuid: int + active: bool + axon_info: AxonInfo + prometheus_info: PrometheusInfo + stake: List[ + Tuple[bytes, int] + ] # map of coldkey to stake on this neuron/hotkey (includes delegations) + rank: int + emission: int + incentive: int + consensus: int + trust: int + validator_trust: int + dividends: int + last_update: int + validator_permit: bool + weights: List[ + Tuple[int, int] # Vec of (uid, weight) + ] + bonds: List[ + Tuple[int, int] # Vec of (uid, bond) + ] + pruning_score: int + + @staticmethod + def decode(encoded: bytes) -> "NeuronInfo": + pass + @staticmethod + def decode_option(encoded: bytes) -> Optional["NeuronInfo"]: + pass + @staticmethod + def decode_vec(encoded: bytes) -> List["NeuronInfo"]: + pass + class NeuronInfoLite: hotkey: bytes coldkey: bytes @@ -70,6 +114,9 @@ class NeuronInfoLite: def decode(encoded: bytes) -> "NeuronInfoLite": pass @staticmethod + def decode_option(encoded: bytes) -> Optional["NeuronInfoLite"]: + pass + @staticmethod def decode_vec(encoded: bytes) -> List["NeuronInfoLite"]: pass @@ -84,6 +131,9 @@ class SubnetIdentity: def decode(encoded: bytes) -> "SubnetIdentity": pass @staticmethod + def decode_option(encoded: bytes) -> Optional["SubnetIdentity"]: + pass + @staticmethod def decode_vec(encoded: bytes) -> List["SubnetIdentity"]: pass @@ -111,6 +161,9 @@ class SubnetInfo: def decode(encoded: bytes) -> "SubnetInfo": pass @staticmethod + def decode_option(encoded: bytes) -> Optional["SubnetInfo"]: + pass + @staticmethod def decode_vec(encoded: bytes) -> List["SubnetInfo"]: pass @staticmethod @@ -142,6 +195,9 @@ class SubnetInfoV2: def decode(encoded: bytes) -> "SubnetInfoV2": pass @staticmethod + def decode_option(encoded: bytes) -> Optional["SubnetInfoV2"]: + pass + @staticmethod def decode_vec(encoded: bytes) -> List["SubnetInfoV2"]: pass @staticmethod @@ -181,6 +237,9 @@ class SubnetHyperparameters: def decode(encoded: bytes) -> "SubnetHyperparameters": pass @staticmethod + def decode_option(encoded: bytes) -> Optional["SubnetHyperparameters"]: + pass + @staticmethod def decode_vec(encoded: bytes) -> List["SubnetHyperparameters"]: pass @@ -193,6 +252,9 @@ class StakeInfo: def decode(encoded: bytes) -> "StakeInfo": pass @staticmethod + def decode_option(encoded: bytes) -> Optional["StakeInfo"]: + pass + @staticmethod def decode_vec(encoded: bytes) -> List["StakeInfo"]: pass @staticmethod @@ -206,13 +268,18 @@ class DelegateInfo: owner_ss58: bytes registrations: List[int] # Vec of netuid this delegate is registered on validator_permits: List[int] # Vec of netuid this delegate has validator permit on - return_per_1000: int # Delegators current daily return per 1000 TAO staked minus take fee + return_per_1000: ( + int # Delegators current daily return per 1000 TAO staked minus take fee + ) total_daily_return: int @staticmethod def decode(encoded: bytes) -> "DelegateInfo": pass @staticmethod + def decode_option(encoded: bytes) -> Optional["DelegateInfo"]: + pass + @staticmethod def decode_vec(encoded: bytes) -> List["DelegateInfo"]: pass @staticmethod diff --git a/libs/custom-derive/Cargo.toml b/libs/custom-derive/Cargo.toml index 83e314f..2e2fd1c 100644 --- a/libs/custom-derive/Cargo.toml +++ b/libs/custom-derive/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "custom_derive" -version = "0.1.0" +version = "0.2.0" edition = "2021" [lib] diff --git a/libs/custom-derive/src/lib.rs b/libs/custom-derive/src/lib.rs index 4ed04d9..5da645a 100644 --- a/libs/custom-derive/src/lib.rs +++ b/libs/custom-derive/src/lib.rs @@ -52,9 +52,8 @@ fn pydecode_impl(attr: TokenStream2, tokens: TokenStream2) -> Result Self { - let decoded = #struct_name::decode(&mut &encoded[..]) - .expect(&format!("Failed to decode {}", #struct_name_str)); - decoded + #struct_name::decode(&mut &encoded[..]) + .expect(&format!("Failed to decode {}", #struct_name_str)) } }); @@ -68,6 +67,16 @@ fn pydecode_impl(attr: TokenStream2, tokens: TokenStream2) -> Result Option { + Option::<#struct_name>::decode(&mut &encoded[..]) + .expect(&format!("Failed to decode Option<{}>", #struct_name_str)) + } + }); + Ok(quote!(#item_impl)) } diff --git a/pyproject.toml b/pyproject.toml index 239a522..b8d7f48 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "bt-decode" -version = "0.2.0a" +version = "0.2.1a" description = "A wrapper around the scale-codec crate for fast scale-decoding of Bittensor data structures." readme = "README.md" license = {file = "LICENSE"} @@ -41,7 +41,7 @@ build-backend = "maturin" [project.optional-dependencies] dev = ["black==23.7.0","maturin", "ruff==0.4.7"] test = [ - "bittensor==7.3.1", + "bittensor==8.2.0", "pytest==7.2.0", "pytest-asyncio==0.23.7", "pytest-mock==3.12.0", diff --git a/src/dyndecoder.rs b/src/dyndecoder.rs index a8a22f6..b17e27b 100644 --- a/src/dyndecoder.rs +++ b/src/dyndecoder.rs @@ -1,7 +1,6 @@ use scale_info::{form::PortableForm, PortableRegistry, Type, TypeDef}; use scale_info::{ - PortableType, TypeDefArray, TypeDefCompact, TypeDefPrimitive, TypeDefSequence, - TypeDefTuple, + PortableType, TypeDefArray, TypeDefCompact, TypeDefPrimitive, TypeDefSequence, TypeDefTuple, }; use std::any::TypeId; use std::collections::HashMap; diff --git a/tests/test_decode_by_type_string.py b/tests/test_decode_by_type_string.py index e5563e3..2529c49 100644 --- a/tests/test_decode_by_type_string.py +++ b/tests/test_decode_by_type_string.py @@ -209,7 +209,6 @@ class TestDecodeByPlainTypeString: @classmethod def setup_class(cls) -> None: - with open(TEST_TYPES_JSON, "r") as f: types_json_str = f.read() @@ -233,7 +232,6 @@ class TestDecodeByScaleInfoTypeString: @classmethod def setup_class(cls) -> None: - with open(TEST_TYPES_JSON, "r") as f: types_json_str = f.read() diff --git a/tests/test_decode_delegate_info.py b/tests/test_decode_delegate_info.py index c40ed5d..526c830 100644 --- a/tests/test_decode_delegate_info.py +++ b/tests/test_decode_delegate_info.py @@ -46,10 +46,10 @@ def test_decode_delegated_no_errors(self): ) def test_decode_delegated_matches_python_impl(self): - delegate_info_list: List[ - Tuple[bt_decode.DelegateInfo, int] - ] = bt_decode.DelegateInfo.decode_delegated( - TEST_DELEGATE_INFO_HEX["delegated normal"]() + delegate_info_list: List[Tuple[bt_decode.DelegateInfo, int]] = ( + bt_decode.DelegateInfo.decode_delegated( + TEST_DELEGATE_INFO_HEX["delegated normal"]() + ) ) delegate_info_py_list = bittensor.DelegateInfo.delegated_list_from_vec_u8( @@ -96,14 +96,14 @@ def test_decode_vec_no_errors(self): _ = bt_decode.DelegateInfo.decode_vec(TEST_DELEGATE_INFO_HEX["vec normal"]()) def test_decode_vec_matches_python_impl(self): - delegates_info: List[ - bt_decode.DelegateInfo - ] = bt_decode.DelegateInfo.decode_vec(TEST_DELEGATE_INFO_HEX["vec normal"]()) - - delegates_info_py: List[ - bittensor.DelegateInfo - ] = bittensor.DelegateInfo.list_from_vec_u8( - list(TEST_DELEGATE_INFO_HEX["vec normal"]()) + delegates_info: List[bt_decode.DelegateInfo] = ( + bt_decode.DelegateInfo.decode_vec(TEST_DELEGATE_INFO_HEX["vec normal"]()) + ) + + delegates_info_py: List[bittensor.DelegateInfo] = ( + bittensor.DelegateInfo.list_from_vec_u8( + list(TEST_DELEGATE_INFO_HEX["vec normal"]()) + ) ) for delegate_info, delegate_info_py in zip(delegates_info, delegates_info_py): diff --git a/tests/test_decode_neurons.py b/tests/test_decode_neurons.py index 17b6236..50e88d5 100644 --- a/tests/test_decode_neurons.py +++ b/tests/test_decode_neurons.py @@ -106,16 +106,16 @@ def test_decode_vec_no_errors(self): ) def test_decode_vec_matches_python_impl(self): - neurons_info: List[ - bt_decode.NeuronInfoLite - ] = bt_decode.NeuronInfoLite.decode_vec( - TEST_NEURON_INFO_LITE_HEX["vec normal"]() + neurons_info: List[bt_decode.NeuronInfoLite] = ( + bt_decode.NeuronInfoLite.decode_vec( + TEST_NEURON_INFO_LITE_HEX["vec normal"]() + ) ) - neurons_info_py: List[ - bittensor.NeuronInfoLite - ] = bittensor.NeuronInfoLite.list_from_vec_u8( - list(TEST_NEURON_INFO_LITE_HEX["vec normal"]()) + neurons_info_py: List[bittensor.NeuronInfoLite] = ( + bittensor.NeuronInfoLite.list_from_vec_u8( + list(TEST_NEURON_INFO_LITE_HEX["vec normal"]()) + ) ) for neuron_info, neuron_info_py in zip(neurons_info, neurons_info_py): @@ -198,10 +198,10 @@ def test_decode_vec_matches_python_impl(self): TEST_NEURON_INFO_HEX["vec normal"]() ) - neurons_info_py: List[ - bittensor.NeuronInfo - ] = bittensor.NeuronInfo.list_from_vec_u8( - list(TEST_NEURON_INFO_HEX["vec normal"]()) + neurons_info_py: List[bittensor.NeuronInfo] = ( + bittensor.NeuronInfo.list_from_vec_u8( + list(TEST_NEURON_INFO_HEX["vec normal"]()) + ) ) for neuron_info, neuron_info_py in zip(neurons_info, neurons_info_py): diff --git a/tests/test_decode_stake_info.py b/tests/test_decode_stake_info.py index 67895fe..c5a6206 100644 --- a/tests/test_decode_stake_info.py +++ b/tests/test_decode_stake_info.py @@ -48,10 +48,10 @@ def test_decode_vec_matches_python_impl(self): bytes.fromhex(TEST_STAKE_INFO_HEX["vec normal"]) ) - stake_info_py_list: List[ - bittensor.StakeInfo - ] = bittensor.StakeInfo.list_from_vec_u8( - list(bytes.fromhex(TEST_STAKE_INFO_HEX["vec normal"])) + stake_info_py_list: List[bittensor.StakeInfo] = ( + bittensor.StakeInfo.list_from_vec_u8( + list(bytes.fromhex(TEST_STAKE_INFO_HEX["vec normal"])) + ) ) for stake_info, stake_info_py in zip(stake_info_list, stake_info_py_list): @@ -87,17 +87,17 @@ def test_decode_vec_matches_python_impl(self): self.assertGreater(attr_count, 0, "No attributes found") def test_decode_vec_vec_matches_python_impl(self): - stake_info_list: List[ - Tuple[bytes, List[bt_decode.StakeInfo]] - ] = bt_decode.StakeInfo.decode_vec_tuple_vec( - bytes.fromhex(TEST_STAKE_INFO_HEX["vec vec normal"]) + stake_info_list: List[Tuple[bytes, List[bt_decode.StakeInfo]]] = ( + bt_decode.StakeInfo.decode_vec_tuple_vec( + bytes.fromhex(TEST_STAKE_INFO_HEX["vec vec normal"]) + ) ) # Poor method name, should be dict_of_list_from_vec_u8 - stake_info_py_dict: Dict[ - str, List[bittensor.StakeInfo] - ] = bittensor.StakeInfo.list_of_tuple_from_vec_u8( - list(bytes.fromhex(TEST_STAKE_INFO_HEX["vec vec normal"])) + stake_info_py_dict: Dict[str, List[bittensor.StakeInfo]] = ( + bittensor.StakeInfo.list_of_tuple_from_vec_u8( + list(bytes.fromhex(TEST_STAKE_INFO_HEX["vec vec normal"])) + ) ) for stake_info_tuple, (coldkey, stake_info_py_list) in zip( diff --git a/tests/test_decode_subnet_info.py b/tests/test_decode_subnet_info.py index a222065..9024ca1 100644 --- a/tests/test_decode_subnet_info.py +++ b/tests/test_decode_subnet_info.py @@ -14,6 +14,8 @@ TEST_SUBNET_INFO_HEX = { "normal": "0828feff010013ffffffffffffffff214e010104feff0300c8010401040d03a1050000c28ff4070398b6d54370c07a546ab0bab5ca9847eb5890ada1bda127633e607097ad4517dd2ca0f010", + "option normal": "010828feff010013ffffffffffffffff214e010104feff0300c8010401040d03a1050000c28ff4070398b6d54370c07a546ab0bab5ca9847eb5890ada1bda127633e607097ad4517dd2ca0f010", + "option none": "00", "vec option normal": lambda: get_file_bytes("tests/subnets_info.hex"), } @@ -82,6 +84,22 @@ def test_decode_matches_python_impl(self): self.assertGreater(attr_count, 0, "No attributes found") + def test_decode_option_no_errors(self): + _ = bt_decode.SubnetInfo.decode_option( + bytes.fromhex(TEST_SUBNET_INFO_HEX["option normal"]) + ) + + def test_decode_option_handles_none_some(self): + should_be_some = bt_decode.SubnetInfo.decode_option( + bytes.fromhex(TEST_SUBNET_INFO_HEX["option normal"]) + ) + self.assertIsNotNone(should_be_some) + + should_be_none = bt_decode.SubnetInfo.decode_option( + bytes.fromhex(TEST_SUBNET_INFO_HEX["option none"]) + ) + self.assertIsNone(should_be_none) + def test_decode_vec_option_no_errors(self): _ = bt_decode.SubnetInfo.decode_vec_option( TEST_SUBNET_INFO_HEX["vec option normal"]()