diff --git a/scaler/client/agent/client_agent.py b/scaler/client/agent/client_agent.py index 6bf6507..7fb3294 100644 --- a/scaler/client/agent/client_agent.py +++ b/scaler/client/agent/client_agent.py @@ -58,8 +58,6 @@ def __init__( self._future_manager = future_manager - self._external_context = zmq.asyncio.Context() - self._connector_internal = AsyncConnector( context=zmq.asyncio.Context.shadow(self._context), name="client_agent_internal", @@ -70,7 +68,7 @@ def __init__( identity=None, ) self._connector_external = AsyncConnector( - context=self._external_context, + context=zmq.asyncio.Context.shadow(self._context), name="client_agent_external", socket_type=zmq.DEALER, address=self._scheduler_address, @@ -183,8 +181,6 @@ async def __get_loops(self): self._connector_external.destroy() self._connector_internal.destroy() - self._external_context.destroy() - if exception is None: return diff --git a/scaler/client/client.py b/scaler/client/client.py index 100a558..8fc6071 100644 --- a/scaler/client/client.py +++ b/scaler/client/client.py @@ -89,9 +89,9 @@ def __initialize__( self._heartbeat_interval_seconds = heartbeat_interval_seconds self._stop_event = threading.Event() - self._internal_context = zmq.Context() + self._context = zmq.Context() self._connector = SyncConnector( - context=self._internal_context, + context=self._context, socket_type=zmq.PAIR, address=self._client_agent_address, identity=self._identity, @@ -102,7 +102,7 @@ def __initialize__( identity=self._identity, client_agent_address=self._client_agent_address, scheduler_address=ZMQConfig.from_string(address), - context=self._internal_context, + context=self._context, future_manager=self._future_manager, stop_event=self._stop_event, timeout_seconds=self._timeout_seconds, @@ -539,7 +539,7 @@ def __assert_client_not_stopped(self): def __destroy(self): self._agent.join() - self._internal_context.destroy(linger=1) + self._context.destroy(linger=1) @staticmethod def __get_parent_task_priority() -> Optional[int]: diff --git a/scaler/io/async_binder.py b/scaler/io/async_binder.py index 4132c5f..80e1b93 100644 --- a/scaler/io/async_binder.py +++ b/scaler/io/async_binder.py @@ -15,14 +15,14 @@ class AsyncBinder(Looper, Reporter): - def __init__(self, name: str, address: ZMQConfig, io_threads: int, identity: Optional[bytes] = None): + def __init__(self, context: zmq.asyncio.Context, name: str, address: ZMQConfig, identity: Optional[bytes] = None): self._address = address if identity is None: identity = f"{os.getpid()}|{name}|{uuid.uuid4()}".encode() self._identity = identity - self._context = zmq.asyncio.Context(io_threads=io_threads) + self._context = context self._socket = self._context.socket(zmq.ROUTER) self.__set_socket_options() self._socket.bind(self._address.to_address()) diff --git a/scaler/scheduler/scheduler.py b/scaler/scheduler/scheduler.py index e0fe274..2d378d1 100644 --- a/scaler/scheduler/scheduler.py +++ b/scaler/scheduler/scheduler.py @@ -45,9 +45,10 @@ def __init__(self, config: SchedulerConfig): ) logging.info(f"{self.__class__.__name__}: monitor address is {self._address_monitor.to_address()}") - self._binder = AsyncBinder(name="scheduler", address=config.address, io_threads=config.io_threads) + self._context = zmq.asyncio.Context(io_threads=config.io_threads) + self._binder = AsyncBinder(context=self._context, name="scheduler", address=config.address) self._binder_monitor = AsyncConnector( - context=zmq.asyncio.Context(), + context=self._context, name="scheduler_monitor", socket_type=zmq.PUB, address=self._address_monitor, diff --git a/scaler/worker/agent/processor_manager.py b/scaler/worker/agent/processor_manager.py index 5653df0..3d467aa 100644 --- a/scaler/worker/agent/processor_manager.py +++ b/scaler/worker/agent/processor_manager.py @@ -6,6 +6,7 @@ from typing import Dict, List, Optional, Tuple import tblib.pickling_support +import zmq.asyncio # from scaler.utility.logging.utility import setup_logger from scaler.io.async_binder import AsyncBinder @@ -32,8 +33,8 @@ class VanillaProcessorManager(Looper, ProcessorManager): def __init__( self, + context: zmq.asyncio.Context, event_loop: str, - io_threads: int, garbage_collect_interval_seconds: int, trim_memory_threshold_bytes: int, hard_processor_suspend: bool, @@ -43,7 +44,6 @@ def __init__( tblib.pickling_support.install() self._event_loop = event_loop - self._io_threads = io_threads self._garbage_collect_interval_seconds = garbage_collect_interval_seconds self._trim_memory_threshold_bytes = trim_memory_threshold_bytes @@ -67,7 +67,7 @@ def __init__( self._task_active_lock: asyncio.Lock = asyncio.Lock() self._binder_internal: AsyncBinder = AsyncBinder( - name="processor_manager", address=self._address, io_threads=self._io_threads, identity=None + context=context, name="processor_manager", address=self._address, identity=None ) self._binder_internal.register(self.__on_receive_internal) diff --git a/scaler/worker/worker.py b/scaler/worker/worker.py index e9ec4a7..c594536 100644 --- a/scaler/worker/worker.py +++ b/scaler/worker/worker.py @@ -63,6 +63,7 @@ def __init__( self._logging_paths = logging_paths self._logging_level = logging_level + self._context: Optional[zmq.asyncio.Context] = None self._connector_external: Optional[AsyncConnector] = None self._task_manager: Optional[VanillaTaskManager] = None self._heartbeat_manager: Optional[VanillaHeartbeatManager] = None @@ -81,8 +82,9 @@ def __initialize(self): setup_logger() register_event_loop(self._event_loop) + self._context = zmq.asyncio.Context() self._connector_external = AsyncConnector( - context=zmq.asyncio.Context(), + context=self._context, name=self.name, socket_type=zmq.DEALER, address=self._address, @@ -96,8 +98,8 @@ def __initialize(self): self._task_manager = VanillaTaskManager(task_timeout_seconds=self._task_timeout_seconds) self._timeout_manager = VanillaTimeoutManager(death_timeout_seconds=self._death_timeout_seconds) self._processor_manager = VanillaProcessorManager( + context=self._context, event_loop=self._event_loop, - io_threads=self._io_threads, garbage_collect_interval_seconds=self._garbage_collect_interval_seconds, trim_memory_threshold_bytes=self._trim_memory_threshold_bytes, hard_processor_suspend=self._hard_processor_suspend,