Skip to content

Commit

Permalink
fix tests
Browse files Browse the repository at this point in the history
  • Loading branch information
roman-opentensor committed Jan 9, 2025
1 parent ba1a132 commit 2ce909f
Show file tree
Hide file tree
Showing 3 changed files with 61 additions and 52 deletions.
20 changes: 12 additions & 8 deletions bittensor/core/async_subtensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -979,9 +979,9 @@ async def get_balance(
"""
block_hash = await self.determine_block_hash(block, block_hash, reuse_block)
balance = await self.substrate.query(
"System",
"Account",
[address],
module="System",
storage_function="Account",
params=[address],
block_hash=block_hash,
reuse_block_hash=reuse_block,
)
Expand Down Expand Up @@ -1786,11 +1786,13 @@ async def get_total_stake_for_coldkey(
Returns:
Balance of the stake held on the coldkey.
"""
block_hash = await self.determine_block_hash(block, block_hash, reuse_block)
block_hash = await self.determine_block_hash(
block=block, block_hash=block_hash, reuse_block=reuse_block
)
result = await self.substrate.query(
"SubtensorModule",
"TotalColdkeyStake",
[ss58_address],
module="SubtensorModule",
storage_function="TotalColdkeyStake",
params=[ss58_address],
block_hash=block_hash,
reuse_block_hash=reuse_block,
)
Expand Down Expand Up @@ -1857,7 +1859,9 @@ async def get_total_stake_for_hotkey(
Returns:
Balance of the stake held on the hotkey.
"""
block_hash = await self.determine_block_hash(block, block_hash, reuse_block)
block_hash = await self.determine_block_hash(
block=block, block_hash=block_hash, reuse_block=reuse_block
)
result = await self.substrate.query(
module="SubtensorModule",
storage_function="TotalHotkeyStake",
Expand Down
91 changes: 47 additions & 44 deletions tests/unit_tests/test_async_subtensor.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
import pytest
from bittensor_wallet import Wallet

from bittensor import AsyncSubtensor
from bittensor.core import async_subtensor
from bittensor.core.chain_data import proposal_vote_data
from bittensor.core.subtensor import AsyncSubtensor


@pytest.fixture(autouse=True)
Expand Down Expand Up @@ -579,21 +579,14 @@ async def test_get_balance(subtensor, mocker):
fake_block_hash = None
reuse_block = True

expected_balance = async_subtensor.Balance(1000)

mocked_determine_block_hash = mocker.AsyncMock()
mocker.patch.object(
async_subtensor.AsyncSubtensor,
"determine_block_hash",
mocked_determine_block_hash,
)

mocked_get_balances = mocker.AsyncMock(
return_value={fake_address: expected_balance}
)
mocker.patch.object(
async_subtensor.AsyncSubtensor, "get_balances", mocked_get_balances
)
mocked_balance = mocker.patch.object(async_subtensor, "Balance")

# Call
result = await subtensor.get_balance(
Expand All @@ -603,12 +596,17 @@ async def test_get_balance(subtensor, mocker):
mocked_determine_block_hash.assert_awaited_once_with(
fake_block, fake_block_hash, reuse_block
)
mocked_get_balances.assert_awaited_once_with(
*[fake_address],
subtensor.substrate.query.assert_awaited_once_with(
module="System",
storage_function="Account",
params=[fake_address],
block_hash=mocked_determine_block_hash.return_value,
reuse_block=reuse_block,
reuse_block_hash=reuse_block,
)
mocked_balance.assert_called_once_with(
subtensor.substrate.query.return_value.__getitem__.return_value.__getitem__.return_value
)
assert result == expected_balance
assert result == mocked_balance.return_value


@pytest.mark.parametrize("balance", [100, 100.1])
Expand Down Expand Up @@ -689,63 +687,68 @@ async def test_get_transfer_with_exception(subtensor, mocker):
async def test_get_total_stake_for_coldkey(subtensor, mocker):
"""Tests get_total_stake_for_coldkey method."""
# Preps
fake_addresses = ("a1", "a2")
fake_addresses = "a1"
fake_block_hash = None

mocked_substrate_create_storage_key = mocker.AsyncMock()
subtensor.substrate.create_storage_key = mocked_substrate_create_storage_key

mocked_batch_0_call = mocker.Mock(
params=[
0,
]
)
mocked_batch_1_call = 0
mocked_substrate_query_multi = mocker.AsyncMock(
return_value=[
(mocked_batch_0_call, mocked_batch_1_call),
]
mocked_determine_block_hash = mocker.AsyncMock()
mocker.patch.object(
async_subtensor.AsyncSubtensor,
"determine_block_hash",
mocked_determine_block_hash,
)

subtensor.substrate.query_multi = mocked_substrate_query_multi
mocked_balance_from_rao = mocker.patch.object(async_subtensor.Balance, "from_rao")

# Call
result = await subtensor.get_total_stake_for_coldkey(
*fake_addresses, block_hash=fake_block_hash
fake_addresses, block_hash=fake_block_hash
)

assert mocked_substrate_create_storage_key.call_count == len(fake_addresses)
mocked_substrate_query_multi.assert_called_once()
assert result == {0: async_subtensor.Balance(mocked_batch_1_call)}
mocked_determine_block_hash.assert_awaited_once_with(
block=None, block_hash=None, reuse_block=False
)
subtensor.substrate.query.assert_awaited_once_with(
module="SubtensorModule",
storage_function="TotalColdkeyStake",
params=[fake_addresses],
block_hash=mocked_determine_block_hash.return_value,
reuse_block_hash=False,
)
assert result == mocked_balance_from_rao.return_value


@pytest.mark.asyncio
async def test_get_total_stake_for_hotkey(subtensor, mocker):
"""Tests get_total_stake_for_hotkey method."""
# Preps
fake_addresses = ("a1", "a2")
fake_addresses = "a1"
fake_block_hash = None
reuse_block = True

mocked_substrate_query_multiple = mocker.AsyncMock(return_value={0: 1})
mocked_determine_block_hash = mocker.AsyncMock()
mocker.patch.object(
async_subtensor.AsyncSubtensor,
"determine_block_hash",
mocked_determine_block_hash,
)

subtensor.substrate.query_multiple = mocked_substrate_query_multiple
mocked_balance_from_rao = mocker.patch.object(async_subtensor.Balance, "from_rao")

# Call
result = await subtensor.get_total_stake_for_hotkey(
*fake_addresses, block_hash=fake_block_hash, reuse_block=reuse_block
fake_addresses, block_hash=fake_block_hash
)

# Assertions
mocked_substrate_query_multiple.assert_called_once_with(
params=list(fake_addresses),
mocked_determine_block_hash.assert_awaited_once_with(
block=None, block_hash=None, reuse_block=False
)
subtensor.substrate.query.assert_awaited_once_with(
module="SubtensorModule",
storage_function="TotalHotkeyStake",
block_hash=fake_block_hash,
reuse_block_hash=reuse_block,
params=[fake_addresses],
block_hash=mocked_determine_block_hash.return_value,
reuse_block_hash=False,
)
mocked_substrate_query_multiple.assert_called_once()
assert result == {0: async_subtensor.Balance(1)}
assert result == mocked_balance_from_rao.return_value


@pytest.mark.parametrize(
Expand Down
2 changes: 2 additions & 0 deletions tests/unit_tests/test_subtensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@ def test_methods_comparable(mocker):
"encode_params",
"get_hyperparameter",
"sign_and_send_extrinsic",
"get_total_stake_for_coldkeys",
"get_total_stake_for_hotkeys",
]
subtensor_methods = [
m
Expand Down

0 comments on commit 2ce909f

Please sign in to comment.