Skip to content

Commit

Permalink
fix(strategy): properly match all regions to a node (#93)
Browse files Browse the repository at this point in the history
  • Loading branch information
ooliver1 authored Jun 4, 2023
1 parent b81a3a6 commit f25b84f
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 25 deletions.
25 changes: 0 additions & 25 deletions mafic/pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,8 +65,6 @@ class NodePool(Generic[ClientT]):
# Generics as expected, do not work in class variables.
# Don't fear, the public methods using this are typed well.
_nodes: ClassVar[dict[str, Node[Any]]] = {}
_node_regions: ClassVar[dict[VoiceRegion, set[Node[Any]]]] = {}
_node_shards: ClassVar[dict[int, set[Node[Any]]]] = {}
_client: ClientT | None = None

def __init__(
Expand Down Expand Up @@ -218,21 +216,6 @@ async def add_node(
.. versionadded:: 2.8
"""
# Add to dictionaries, creating a set or extending it if needed.
if node.regions:
for region in node.regions:
self._node_regions[region] = {
node,
*self._node_regions.get(region, set()),
}

if node.shard_ids:
for shard_id in node.shard_ids:
self._node_shards[shard_id] = {
node,
*self._node_shards.get(shard_id, set()),
}

_log.info("Created node, connecting it...", extra={"label": node.label})
await node.connect(player_cls=player_cls)

Expand All @@ -255,14 +238,6 @@ async def remove_node(
if isinstance(node, str):
node = self._nodes[node]

if node.regions:
for region in node.regions:
self._node_regions[region].remove(node)

if node.shard_ids:
for shard_id in node.shard_ids:
self._node_shards[shard_id].remove(node)

# Remove prematurely so it is not chosen.
del self._nodes[node.label]

Expand Down
12 changes: 12 additions & 0 deletions mafic/strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from typing import List, Optional

from .node import Node
from .region import VoiceRegion
from .type_variables import ClientT

StrategyCallable = Callable[
Expand Down Expand Up @@ -107,6 +108,17 @@ def location_strategy(
)
return nodes

try:
voice_region = VoiceRegion(voice_region)
except ValueError:
_log.error(
"Failed to match endpoint %s (match: %s) to a region, "
"defaulting to all nodes.",
endpoint,
voice_region,
)
return nodes

regional_nodes = list(
filter(
lambda node: node.regions is not None and voice_region in node.regions,
Expand Down

0 comments on commit f25b84f

Please sign in to comment.