Skip to content

Commit

Permalink
Reduce number of threads and use shared context
Browse files Browse the repository at this point in the history
Signed-off-by: Sharpner6 <1sc2l4qi@duck.com>
  • Loading branch information
sharpener6 committed Oct 16, 2024
1 parent 940e70e commit b0ec4a3
Show file tree
Hide file tree
Showing 6 changed files with 17 additions and 18 deletions.
6 changes: 1 addition & 5 deletions scaler/client/agent/client_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,8 +58,6 @@ def __init__(

self._future_manager = future_manager

self._external_context = zmq.asyncio.Context()

self._connector_internal = AsyncConnector(
context=zmq.asyncio.Context.shadow(self._context),
name="client_agent_internal",
Expand All @@ -70,7 +68,7 @@ def __init__(
identity=None,
)
self._connector_external = AsyncConnector(
context=self._external_context,
context=zmq.asyncio.Context.shadow(self._context),
name="client_agent_external",
socket_type=zmq.DEALER,
address=self._scheduler_address,
Expand Down Expand Up @@ -183,8 +181,6 @@ async def __get_loops(self):
self._connector_external.destroy()
self._connector_internal.destroy()

self._external_context.destroy()

if exception is None:
return

Expand Down
8 changes: 4 additions & 4 deletions scaler/client/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,9 +89,9 @@ def __initialize__(
self._heartbeat_interval_seconds = heartbeat_interval_seconds

self._stop_event = threading.Event()
self._internal_context = zmq.Context()
self._context = zmq.Context()
self._connector = SyncConnector(
context=self._internal_context,
context=self._context,
socket_type=zmq.PAIR,
address=self._client_agent_address,
identity=self._identity,
Expand All @@ -102,7 +102,7 @@ def __initialize__(
identity=self._identity,
client_agent_address=self._client_agent_address,
scheduler_address=ZMQConfig.from_string(address),
context=self._internal_context,
context=self._context,
future_manager=self._future_manager,
stop_event=self._stop_event,
timeout_seconds=self._timeout_seconds,
Expand Down Expand Up @@ -539,7 +539,7 @@ def __assert_client_not_stopped(self):

def __destroy(self):
self._agent.join()
self._internal_context.destroy(linger=1)
self._context.destroy(linger=1)

@staticmethod
def __get_parent_task_priority() -> Optional[int]:
Expand Down
4 changes: 2 additions & 2 deletions scaler/io/async_binder.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,14 @@


class AsyncBinder(Looper, Reporter):
def __init__(self, name: str, address: ZMQConfig, io_threads: int, identity: Optional[bytes] = None):
def __init__(self, context: zmq.asyncio.Context, name: str, address: ZMQConfig, identity: Optional[bytes] = None):
self._address = address

if identity is None:
identity = f"{os.getpid()}|{name}|{uuid.uuid4()}".encode()
self._identity = identity

self._context = zmq.asyncio.Context(io_threads=io_threads)
self._context = context
self._socket = self._context.socket(zmq.ROUTER)
self.__set_socket_options()
self._socket.bind(self._address.to_address())
Expand Down
5 changes: 3 additions & 2 deletions scaler/scheduler/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,9 +45,10 @@ def __init__(self, config: SchedulerConfig):
)

logging.info(f"{self.__class__.__name__}: monitor address is {self._address_monitor.to_address()}")
self._binder = AsyncBinder(name="scheduler", address=config.address, io_threads=config.io_threads)
self._context = zmq.asyncio.Context(io_threads=config.io_threads)
self._binder = AsyncBinder(context=self._context, name="scheduler", address=config.address)
self._binder_monitor = AsyncConnector(
context=zmq.asyncio.Context(),
context=self._context,
name="scheduler_monitor",
socket_type=zmq.PUB,
address=self._address_monitor,
Expand Down
6 changes: 3 additions & 3 deletions scaler/worker/agent/processor_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from typing import Dict, List, Optional, Tuple

import tblib.pickling_support
import zmq.asyncio

# from scaler.utility.logging.utility import setup_logger
from scaler.io.async_binder import AsyncBinder
Expand All @@ -32,8 +33,8 @@
class VanillaProcessorManager(Looper, ProcessorManager):
def __init__(
self,
context: zmq.asyncio.Context,
event_loop: str,
io_threads: int,
garbage_collect_interval_seconds: int,
trim_memory_threshold_bytes: int,
hard_processor_suspend: bool,
Expand All @@ -43,7 +44,6 @@ def __init__(
tblib.pickling_support.install()

self._event_loop = event_loop
self._io_threads = io_threads

self._garbage_collect_interval_seconds = garbage_collect_interval_seconds
self._trim_memory_threshold_bytes = trim_memory_threshold_bytes
Expand All @@ -67,7 +67,7 @@ def __init__(
self._task_active_lock: asyncio.Lock = asyncio.Lock()

self._binder_internal: AsyncBinder = AsyncBinder(
name="processor_manager", address=self._address, io_threads=self._io_threads, identity=None
context=context, name="processor_manager", address=self._address, identity=None
)
self._binder_internal.register(self.__on_receive_internal)

Expand Down
6 changes: 4 additions & 2 deletions scaler/worker/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ def __init__(
self._logging_paths = logging_paths
self._logging_level = logging_level

self._context: Optional[zmq.asyncio.Context] = None
self._connector_external: Optional[AsyncConnector] = None
self._task_manager: Optional[VanillaTaskManager] = None
self._heartbeat_manager: Optional[VanillaHeartbeatManager] = None
Expand All @@ -81,8 +82,9 @@ def __initialize(self):
setup_logger()
register_event_loop(self._event_loop)

self._context = zmq.asyncio.Context()
self._connector_external = AsyncConnector(
context=zmq.asyncio.Context(),
context=self._context,
name=self.name,
socket_type=zmq.DEALER,
address=self._address,
Expand All @@ -96,8 +98,8 @@ def __initialize(self):
self._task_manager = VanillaTaskManager(task_timeout_seconds=self._task_timeout_seconds)
self._timeout_manager = VanillaTimeoutManager(death_timeout_seconds=self._death_timeout_seconds)
self._processor_manager = VanillaProcessorManager(
context=self._context,
event_loop=self._event_loop,
io_threads=self._io_threads,
garbage_collect_interval_seconds=self._garbage_collect_interval_seconds,
trim_memory_threshold_bytes=self._trim_memory_threshold_bytes,
hard_processor_suspend=self._hard_processor_suspend,
Expand Down

0 comments on commit b0ec4a3

Please sign in to comment.