Skip to content

Commit

Permalink
Fix/add option decode for all (#4)
Browse files Browse the repository at this point in the history
* add decode_option for all classes

* bump ver

* ruff

* add tests for new option<subnetinfo> decoding

* fix readme

* bump python req in test extra
  • Loading branch information
camfairchild authored Oct 15, 2024
1 parent 7c1aa79 commit 47cb074
Show file tree
Hide file tree
Showing 13 changed files with 164 additions and 73 deletions.
22 changes: 11 additions & 11 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ name = "bt_decode"
crate-type = ["cdylib"]

[dependencies.pyo3]
version = "0.22.2"
version = "0.22.3"
features = ["extension-module"]

[dependencies.custom_derive]
Expand Down
28 changes: 14 additions & 14 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ hex_bytes_result = sub.query_runtime_api(
method="get_delegates",
params=[ ]
)
# Decode scale-encoded DelegateInfo
# Decode scale-encoded Vec<DelegateInfo>
delegates_info: List[DelegateInfo] = DelegateInfo.decode_vec(
bytes.fromhex(
hex_bytes_result
Expand All @@ -38,7 +38,7 @@ hex_bytes_result = sub.query_runtime_api(
method="get_delegated",
params=[list( validator_key.public_key )]
)
# Decode scale-encoded (DelegateInfo, take)
# Decode scale-encoded Vec<(DelegateInfo, take)>
delegated_info: List[Tuple[DelegateInfo, int]] = DelegateInfo.decode_delegated(
bytes.fromhex(
hex_bytes_result
Expand Down Expand Up @@ -82,8 +82,8 @@ hex_bytes_result = sub.query_runtime_api(
method="get_neurons",
params=[NETUID]
)
# Decode scale-encoded NeuronInfo
neurons: List[NeuronInfo] = NeuronInfo.decode(
# Decode scale-encoded Vec<NeuronInfo>
neurons: List[NeuronInfo] = NeuronInfo.decode_vec(
bytes.fromhex(
hex_bytes_result
))
Expand Down Expand Up @@ -126,8 +126,8 @@ hex_bytes_result = sub.query_runtime_api(
method="get_neurons_lite",
params=[NETUID]
)
# Decode scale-encoded NeuronInfoLite
neurons_lite: List[NeuronInfoLite] = NeuronInfoLite.decode(
# Decode scale-encoded Vec<NeuronInfoLite>
neurons_lite: List[NeuronInfoLite] = NeuronInfoLite.decode_vec(
bytes.fromhex(
hex_bytes_result
))
Expand Down Expand Up @@ -178,7 +178,7 @@ hex_bytes_result = sub.query_runtime_api(
method="get_stake_info_for_coldkeys",
params=[encoded_coldkeys]
)
# Decode scale-encoded (AccountId, StakeInfo)
# Decode scale-encoded Vec<(AccountId, StakeInfo)>
stake_info: List[Tuple[bytes, List["StakeInfo"]]] = StakeInfo.decode_vec_tuple_vec(
bytes.fromhex(
hex_bytes_result
Expand All @@ -199,8 +199,8 @@ hex_bytes_result = sub.query_runtime_api(
method="get_subnet_info",
params=[NETUID]
)
# Decode scale-encoded SubnetInfo
subnet_info: SubnetInfo = SubnetInfo.decode(
# Decode scale-encoded Option<SubnetInfo>
subnet_info: SubnetInfo = SubnetInfo.decode_option(
bytes.fromhex(
hex_bytes_result
))
Expand All @@ -219,15 +219,15 @@ hex_bytes_result = sub.query_runtime_api(
method="get_subnets_info",
params=[ ]
)
# Decode scale-encoded Optional[SubnetInfo]
subnets_info: List[Optional[SubnetInfo]] = SubnetInfo.decode_vec(
# Decode scale-encoded Vec<Option<SubnetInfo>>
subnets_info: List[Optional[SubnetInfo]] = SubnetInfo.decode_vec_option(
bytes.fromhex(
hex_bytes_result
))
```

### SubnetHyperparameters
#### get_subnet_info
#### get_subnet_hyperparams
```python
import bittensor
from bt_decode import SubnetHyperparameters
Expand All @@ -241,8 +241,8 @@ hex_bytes_result = sub.query_runtime_api(
method="get_subnet_hyperparams",
params=[NETUID]
)
# Decode scale-encoded SubnetHyperparameters
subnet_hyper_params: SubnetHyperparameters = SubnetHyperparameters.decode(
# Decode scale-encoded Option<SubnetHyperparameters>
subnet_hyper_params: Optional[SubnetHyperparameters] = SubnetHyperparameters.decode_option(
bytes.fromhex(
hex_bytes_result
))
Expand Down
69 changes: 68 additions & 1 deletion bt_decode.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,9 @@ class AxonInfo:
def decode(encoded: bytes) -> "AxonInfo":
pass
@staticmethod
def decode_option(encoded: bytes) -> Optional["AxonInfo"]:
pass
@staticmethod
def decode_vec(encoded: bytes) -> List["AxonInfo"]:
pass

Expand All @@ -40,9 +43,50 @@ class PrometheusInfo:
def decode(encoded: bytes) -> "PrometheusInfo":
pass
@staticmethod
def decode_option(encoded: bytes) -> Optional["PrometheusInfo"]:
pass
@staticmethod
def decode_vec(encoded: bytes) -> List["PrometheusInfo"]:
pass

class NeuronInfo:
hotkey: bytes
coldkey: bytes
uid: int
netuid: int
active: bool
axon_info: AxonInfo
prometheus_info: PrometheusInfo
stake: List[
Tuple[bytes, int]
] # map of coldkey to stake on this neuron/hotkey (includes delegations)
rank: int
emission: int
incentive: int
consensus: int
trust: int
validator_trust: int
dividends: int
last_update: int
validator_permit: bool
weights: List[
Tuple[int, int] # Vec of (uid, weight)
]
bonds: List[
Tuple[int, int] # Vec of (uid, bond)
]
pruning_score: int

@staticmethod
def decode(encoded: bytes) -> "NeuronInfo":
pass
@staticmethod
def decode_option(encoded: bytes) -> Optional["NeuronInfo"]:
pass
@staticmethod
def decode_vec(encoded: bytes) -> List["NeuronInfo"]:
pass

class NeuronInfoLite:
hotkey: bytes
coldkey: bytes
Expand Down Expand Up @@ -70,6 +114,9 @@ class NeuronInfoLite:
def decode(encoded: bytes) -> "NeuronInfoLite":
pass
@staticmethod
def decode_option(encoded: bytes) -> Optional["NeuronInfoLite"]:
pass
@staticmethod
def decode_vec(encoded: bytes) -> List["NeuronInfoLite"]:
pass

Expand All @@ -84,6 +131,9 @@ class SubnetIdentity:
def decode(encoded: bytes) -> "SubnetIdentity":
pass
@staticmethod
def decode_option(encoded: bytes) -> Optional["SubnetIdentity"]:
pass
@staticmethod
def decode_vec(encoded: bytes) -> List["SubnetIdentity"]:
pass

Expand Down Expand Up @@ -111,6 +161,9 @@ class SubnetInfo:
def decode(encoded: bytes) -> "SubnetInfo":
pass
@staticmethod
def decode_option(encoded: bytes) -> Optional["SubnetInfo"]:
pass
@staticmethod
def decode_vec(encoded: bytes) -> List["SubnetInfo"]:
pass
@staticmethod
Expand Down Expand Up @@ -142,6 +195,9 @@ class SubnetInfoV2:
def decode(encoded: bytes) -> "SubnetInfoV2":
pass
@staticmethod
def decode_option(encoded: bytes) -> Optional["SubnetInfoV2"]:
pass
@staticmethod
def decode_vec(encoded: bytes) -> List["SubnetInfoV2"]:
pass
@staticmethod
Expand Down Expand Up @@ -181,6 +237,9 @@ class SubnetHyperparameters:
def decode(encoded: bytes) -> "SubnetHyperparameters":
pass
@staticmethod
def decode_option(encoded: bytes) -> Optional["SubnetHyperparameters"]:
pass
@staticmethod
def decode_vec(encoded: bytes) -> List["SubnetHyperparameters"]:
pass

Expand All @@ -193,6 +252,9 @@ class StakeInfo:
def decode(encoded: bytes) -> "StakeInfo":
pass
@staticmethod
def decode_option(encoded: bytes) -> Optional["StakeInfo"]:
pass
@staticmethod
def decode_vec(encoded: bytes) -> List["StakeInfo"]:
pass
@staticmethod
Expand All @@ -206,13 +268,18 @@ class DelegateInfo:
owner_ss58: bytes
registrations: List[int] # Vec of netuid this delegate is registered on
validator_permits: List[int] # Vec of netuid this delegate has validator permit on
return_per_1000: int # Delegators current daily return per 1000 TAO staked minus take fee
return_per_1000: (
int # Delegators current daily return per 1000 TAO staked minus take fee
)
total_daily_return: int

@staticmethod
def decode(encoded: bytes) -> "DelegateInfo":
pass
@staticmethod
def decode_option(encoded: bytes) -> Optional["DelegateInfo"]:
pass
@staticmethod
def decode_vec(encoded: bytes) -> List["DelegateInfo"]:
pass
@staticmethod
Expand Down
2 changes: 1 addition & 1 deletion libs/custom-derive/Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[package]
name = "custom_derive"
version = "0.1.0"
version = "0.2.0"
edition = "2021"

[lib]
Expand Down
15 changes: 12 additions & 3 deletions libs/custom-derive/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -52,9 +52,8 @@ fn pydecode_impl(attr: TokenStream2, tokens: TokenStream2) -> Result<TokenStream
#[pyo3(name = "decode")]
#[staticmethod]
fn py_decode(encoded: &[u8]) -> Self {
let decoded = #struct_name::decode(&mut &encoded[..])
.expect(&format!("Failed to decode {}", #struct_name_str));
decoded
#struct_name::decode(&mut &encoded[..])
.expect(&format!("Failed to decode {}", #struct_name_str))
}
});

Expand All @@ -68,6 +67,16 @@ fn pydecode_impl(attr: TokenStream2, tokens: TokenStream2) -> Result<TokenStream
}
});

// Add the py_decode_option method
item_impl.items.push(parse_quote! {
#[pyo3(name = "decode_option")]
#[staticmethod]
fn py_decode_option(encoded: &[u8]) -> Option<Self> {
Option::<#struct_name>::decode(&mut &encoded[..])
.expect(&format!("Failed to decode Option<{}>", #struct_name_str))
}
});

Ok(quote!(#item_impl))
}

Expand Down
4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[project]
name = "bt-decode"
version = "0.2.0a"
version = "0.2.1a"
description = "A wrapper around the scale-codec crate for fast scale-decoding of Bittensor data structures."
readme = "README.md"
license = {file = "LICENSE"}
Expand Down Expand Up @@ -41,7 +41,7 @@ build-backend = "maturin"
[project.optional-dependencies]
dev = ["black==23.7.0","maturin", "ruff==0.4.7"]
test = [
"bittensor==7.3.1",
"bittensor==8.2.0",
"pytest==7.2.0",
"pytest-asyncio==0.23.7",
"pytest-mock==3.12.0",
Expand Down
3 changes: 1 addition & 2 deletions src/dyndecoder.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
use scale_info::{form::PortableForm, PortableRegistry, Type, TypeDef};
use scale_info::{
PortableType, TypeDefArray, TypeDefCompact, TypeDefPrimitive, TypeDefSequence,
TypeDefTuple,
PortableType, TypeDefArray, TypeDefCompact, TypeDefPrimitive, TypeDefSequence, TypeDefTuple,
};
use std::any::TypeId;
use std::collections::HashMap;
Expand Down
2 changes: 0 additions & 2 deletions tests/test_decode_by_type_string.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,7 +209,6 @@ class TestDecodeByPlainTypeString:

@classmethod
def setup_class(cls) -> None:

with open(TEST_TYPES_JSON, "r") as f:
types_json_str = f.read()

Expand All @@ -233,7 +232,6 @@ class TestDecodeByScaleInfoTypeString:

@classmethod
def setup_class(cls) -> None:

with open(TEST_TYPES_JSON, "r") as f:
types_json_str = f.read()

Expand Down
Loading

0 comments on commit 47cb074

Please sign in to comment.