Skip to content

Commit

Permalink
fix: setting weights (#278)
Browse files Browse the repository at this point in the history
  • Loading branch information
grantdfoster authored Oct 16, 2024
1 parent d401d56 commit 23a811a
Showing 1 changed file with 8 additions and 7 deletions.
15 changes: 8 additions & 7 deletions masa/utils/weights.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
U16_MAX = 65535


# Uses in `bittensor.utils.weight_utils.process_weights_for_netuid`
@legacy_torch_api_compat
def normalize_max_weight(
x: Union[NDArray[np.float32], "torch.FloatTensor"], limit: float = 0.1
Expand Down Expand Up @@ -154,7 +155,7 @@ def process_weights_for_netuid(
if use_torch()
else np.ones((metagraph.n), dtype=np.int64) / metagraph.n
)
logging.debug("final_weights", final_weights)
# logging.debug("final_weights", final_weights)
final_weights_count = (
torch.tensor(list(range(len(final_weights))))
if use_torch()
Expand All @@ -177,7 +178,7 @@ def process_weights_for_netuid(
else np.ones((metagraph.n), dtype=np.int64) * 1e-5
) # creating minimum even non-zero weights
weights[non_zero_weight_idx] += non_zero_weights
logging.debug("final_weights", *weights)
# logging.debug("final_weights", *weights)
normalized_weights = normalize_max_weight(x=weights, limit=max_weight_limit)
nw_arange = (
torch.tensor(list(range(len(normalized_weights))))
Expand All @@ -186,7 +187,7 @@ def process_weights_for_netuid(
)
return nw_arange, normalized_weights

logging.debug("non_zero_weights", *non_zero_weights)
# logging.debug("non_zero_weights", *non_zero_weights)

# Compute the exclude quantile and find the weights in the lowest quantile
max_exclude = max(0, len(non_zero_weights) - min_allowed_weights) / len(
Expand All @@ -205,13 +206,13 @@ def process_weights_for_netuid(
# Exclude all weights below the allowed quantile.
non_zero_weight_uids = non_zero_weight_uids[lowest_quantile <= non_zero_weights]
non_zero_weights = non_zero_weights[lowest_quantile <= non_zero_weights]
logging.debug("non_zero_weight_uids", *non_zero_weight_uids)
logging.debug("non_zero_weights", *non_zero_weights)
# logging.debug("non_zero_weight_uids", *non_zero_weight_uids)
# logging.debug("non_zero_weights", *non_zero_weights)

# Normalize weights and return.
normalized_weights = normalize_max_weight(
x=non_zero_weights, limit=max_weight_limit
)
logging.debug("final_weights", non_zero_weights)
# logging.debug("final_weights", normalized_weights)

return non_zero_weight_uids, non_zero_weights
return non_zero_weight_uids, normalized_weights

0 comments on commit 23a811a

Please sign in to comment.