Skip to content

Commit

Permalink
Get rid of copy memory when receive and send message (#34)
Browse files Browse the repository at this point in the history
* Get rid of copy memory when receive and send

Signed-off-by: Sharpner6 <1sc2l4qi@duck.com>

* Reduce number of threads and use shared context

Signed-off-by: Sharpner6 <1sc2l4qi@duck.com>

* Add test name print out

Signed-off-by: Sharpner6 <1sc2l4qi@duck.com>

---------

Signed-off-by: Sharpner6 <1sc2l4qi@duck.com>
  • Loading branch information
sharpener6 authored Oct 16, 2024
1 parent 4cc802c commit f3b6cf4
Show file tree
Hide file tree
Showing 28 changed files with 110 additions and 51 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
<a href="https://pypi.org/project/scaler">
<img alt="PyPI - Version" src="https://img.shields.io/pypi/v/scaler?colorA=0f1632&colorB=255be3">
</a>
<img src="https://api.securityscorecards.dev/projects/github.com/citi/scaler/badge">
<img src="https://api.securityscorecards.dev/projects/github.com/Citi/scaler/badge">
</p>
</div>

Expand Down
2 changes: 1 addition & 1 deletion scaler/about.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = "1.8.8"
__version__ = "1.8.9"
6 changes: 1 addition & 5 deletions scaler/client/agent/client_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,8 +58,6 @@ def __init__(

self._future_manager = future_manager

self._external_context = zmq.asyncio.Context()

self._connector_internal = AsyncConnector(
context=zmq.asyncio.Context.shadow(self._context),
name="client_agent_internal",
Expand All @@ -70,7 +68,7 @@ def __init__(
identity=None,
)
self._connector_external = AsyncConnector(
context=self._external_context,
context=zmq.asyncio.Context.shadow(self._context),
name="client_agent_external",
socket_type=zmq.DEALER,
address=self._scheduler_address,
Expand Down Expand Up @@ -183,8 +181,6 @@ async def __get_loops(self):
self._connector_external.destroy()
self._connector_internal.destroy()

self._external_context.destroy()

if exception is None:
return

Expand Down
8 changes: 4 additions & 4 deletions scaler/client/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,9 +89,9 @@ def __initialize__(
self._heartbeat_interval_seconds = heartbeat_interval_seconds

self._stop_event = threading.Event()
self._internal_context = zmq.Context()
self._context = zmq.Context()
self._connector = SyncConnector(
context=self._internal_context,
context=self._context,
socket_type=zmq.PAIR,
address=self._client_agent_address,
identity=self._identity,
Expand All @@ -102,7 +102,7 @@ def __initialize__(
identity=self._identity,
client_agent_address=self._client_agent_address,
scheduler_address=ZMQConfig.from_string(address),
context=self._internal_context,
context=self._context,
future_manager=self._future_manager,
stop_event=self._stop_event,
timeout_seconds=self._timeout_seconds,
Expand Down Expand Up @@ -539,7 +539,7 @@ def __assert_client_not_stopped(self):

def __destroy(self):
self._agent.join()
self._internal_context.destroy(linger=1)
self._context.destroy(linger=1)

@staticmethod
def __get_parent_task_priority() -> Optional[int]:
Expand Down
15 changes: 8 additions & 7 deletions scaler/io/async_binder.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from typing import Awaitable, Callable, List, Optional, Dict

import zmq.asyncio
from zmq import Frame

from scaler.io.utility import deserialize, serialize
from scaler.protocol.python.mixins import Message
Expand All @@ -14,14 +15,14 @@


class AsyncBinder(Looper, Reporter):
def __init__(self, name: str, address: ZMQConfig, io_threads: int, identity: Optional[bytes] = None):
def __init__(self, context: zmq.asyncio.Context, name: str, address: ZMQConfig, identity: Optional[bytes] = None):
self._address = address

if identity is None:
identity = f"{os.getpid()}|{name}|{uuid.uuid4()}".encode()
self._identity = identity

self._context = zmq.asyncio.Context(io_threads=io_threads)
self._context = context
self._socket = self._context.socket(zmq.ROUTER)
self.__set_socket_options()
self._socket.bind(self._address.to_address())
Expand All @@ -38,18 +39,18 @@ def register(self, callback: Callable[[bytes, Message], Awaitable[None]]):
self._callback = callback

async def routine(self):
frames = await self._socket.recv_multipart()
frames: List[Frame] = await self._socket.recv_multipart(copy=False)
if not self.__is_valid_message(frames):
return

source, payload = frames
message: Optional[Message] = deserialize(payload)
message: Optional[Message] = deserialize(payload.bytes)
if message is None:
logging.error(f"received unknown message from {source!r}: {payload!r}")
logging.error(f"received unknown message from {source.bytes!r}: {payload!r}")
return

self.__count_received(message.__class__.__name__)
await self._callback(source, message)
await self._callback(source.bytes, message)

async def send(self, to: bytes, message: Message):
self.__count_sent(message.__class__.__name__)
Expand All @@ -63,7 +64,7 @@ def __set_socket_options(self):
self._socket.setsockopt(zmq.SNDHWM, 0)
self._socket.setsockopt(zmq.RCVHWM, 0)

def __is_valid_message(self, frames: List[bytes]) -> bool:
def __is_valid_message(self, frames: List[Frame]) -> bool:
if len(frames) < 2:
logging.error(f"{self.__get_prefix()} received unexpected frames {frames}")
return False
Expand Down
6 changes: 3 additions & 3 deletions scaler/io/async_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,10 +82,10 @@ async def receive(self) -> Optional[Message]:
if self._socket.closed:
return None

payload = await self._socket.recv()
result: Optional[Message] = deserialize(payload)
payload = await self._socket.recv(copy=False)
result: Optional[Message] = deserialize(payload.bytes)
if result is None:
logging.error(f"received unknown message: {payload!r}")
logging.error(f"received unknown message: {payload.bytes!r}")
return None

return result
Expand Down
6 changes: 3 additions & 3 deletions scaler/io/sync_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,13 +47,13 @@ def identity(self) -> bytes:

def send(self, message: Message):
with self._lock:
self._socket.send(serialize(message))
self._socket.send(serialize(message), copy=False)

def receive(self) -> Optional[Message]:
with self._lock:
payload = self._socket.recv()
payload = self._socket.recv(copy=False)

return self.__compose_message(payload)
return self.__compose_message(payload.bytes)

def __compose_message(self, payload: bytes) -> Optional[Message]:
result: Optional[Message] = deserialize(payload)
Expand Down
2 changes: 1 addition & 1 deletion scaler/io/sync_subscriber.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ def __initialize(self):

def __routine_polling(self):
try:
self.__routine_receive(self._socket.recv())
self.__routine_receive(self._socket.recv(copy=False).bytes)
except zmq.Again:
raise TimeoutError(f"Cannot connect to {self._address.to_address()} in {self._timeout_seconds} seconds")

Expand Down
5 changes: 3 additions & 2 deletions scaler/scheduler/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,9 +45,10 @@ def __init__(self, config: SchedulerConfig):
)

logging.info(f"{self.__class__.__name__}: monitor address is {self._address_monitor.to_address()}")
self._binder = AsyncBinder(name="scheduler", address=config.address, io_threads=config.io_threads)
self._context = zmq.asyncio.Context(io_threads=config.io_threads)
self._binder = AsyncBinder(context=self._context, name="scheduler", address=config.address)
self._binder_monitor = AsyncConnector(
context=zmq.asyncio.Context(),
context=self._context,
name="scheduler_monitor",
socket_type=zmq.PUB,
address=self._address_monitor,
Expand Down
6 changes: 3 additions & 3 deletions scaler/worker/agent/processor_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from typing import Dict, List, Optional, Tuple

import tblib.pickling_support
import zmq.asyncio

# from scaler.utility.logging.utility import setup_logger
from scaler.io.async_binder import AsyncBinder
Expand All @@ -32,8 +33,8 @@
class VanillaProcessorManager(Looper, ProcessorManager):
def __init__(
self,
context: zmq.asyncio.Context,
event_loop: str,
io_threads: int,
garbage_collect_interval_seconds: int,
trim_memory_threshold_bytes: int,
hard_processor_suspend: bool,
Expand All @@ -43,7 +44,6 @@ def __init__(
tblib.pickling_support.install()

self._event_loop = event_loop
self._io_threads = io_threads

self._garbage_collect_interval_seconds = garbage_collect_interval_seconds
self._trim_memory_threshold_bytes = trim_memory_threshold_bytes
Expand All @@ -67,7 +67,7 @@ def __init__(
self._task_active_lock: asyncio.Lock = asyncio.Lock()

self._binder_internal: AsyncBinder = AsyncBinder(
name="processor_manager", address=self._address, io_threads=self._io_threads, identity=None
context=context, name="processor_manager", address=self._address, identity=None
)
self._binder_internal.register(self.__on_receive_internal)

Expand Down
6 changes: 4 additions & 2 deletions scaler/worker/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ def __init__(
self._logging_paths = logging_paths
self._logging_level = logging_level

self._context: Optional[zmq.asyncio.Context] = None
self._connector_external: Optional[AsyncConnector] = None
self._task_manager: Optional[VanillaTaskManager] = None
self._heartbeat_manager: Optional[VanillaHeartbeatManager] = None
Expand All @@ -81,8 +82,9 @@ def __initialize(self):
setup_logger()
register_event_loop(self._event_loop)

self._context = zmq.asyncio.Context()
self._connector_external = AsyncConnector(
context=zmq.asyncio.Context(),
context=self._context,
name=self.name,
socket_type=zmq.DEALER,
address=self._address,
Expand All @@ -96,8 +98,8 @@ def __initialize(self):
self._task_manager = VanillaTaskManager(task_timeout_seconds=self._task_timeout_seconds)
self._timeout_manager = VanillaTimeoutManager(death_timeout_seconds=self._death_timeout_seconds)
self._processor_manager = VanillaProcessorManager(
context=self._context,
event_loop=self._event_loop,
io_threads=self._io_threads,
garbage_collect_interval_seconds=self._garbage_collect_interval_seconds,
trim_memory_threshold_bytes=self._trim_memory_threshold_bytes,
hard_processor_suspend=self._hard_processor_suspend,
Expand Down
6 changes: 6 additions & 0 deletions tests/test_async_indexed_queue.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,16 @@
import asyncio
import unittest

from scaler.utility.logging.utility import setup_logger
from scaler.utility.queues.async_indexed_queue import AsyncIndexedQueue
from tests.utility import logging_test_name


class TestAsyncIndexedQueue(unittest.TestCase):
def setUp(self) -> None:
setup_logger()
logging_test_name(self)

def test_async_indexed_queue(self):
async def async_test():
queue = AsyncIndexedQueue()
Expand Down
6 changes: 6 additions & 0 deletions tests/test_async_priority_queue.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,16 @@
import asyncio
import unittest

from scaler.utility.logging.utility import setup_logger
from scaler.utility.queues.async_priority_queue import AsyncPriorityQueue
from tests.utility import logging_test_name


class TestAsyncPriorityQueue(unittest.TestCase):
def setUp(self) -> None:
setup_logger()
logging_test_name(self)

def test_async_priority_queue(self):
async def async_test():
queue = AsyncPriorityQueue()
Expand Down
6 changes: 6 additions & 0 deletions tests/test_async_sorted_priority_queue.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,16 @@
import asyncio
import unittest

from scaler.utility.logging.utility import setup_logger
from scaler.utility.queues.async_sorted_priority_queue import AsyncSortedPriorityQueue
from tests.utility import logging_test_name


class TestSortedPriorityQueue(unittest.TestCase):
def setUp(self) -> None:
setup_logger()
logging_test_name(self)

def test_sorted_priority_queue(self):
async def async_test():
queue = AsyncSortedPriorityQueue()
Expand Down
7 changes: 6 additions & 1 deletion tests/test_balance.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@
import unittest

from scaler import Client, Cluster, SchedulerClusterCombo
from tests.utility import get_available_tcp_port
from scaler.utility.logging.utility import setup_logger
from tests.utility import get_available_tcp_port, logging_test_name


def sleep_and_return_pid(sec: int):
Expand All @@ -12,6 +13,10 @@ def sleep_and_return_pid(sec: int):


class TestBalance(unittest.TestCase):
def setUp(self) -> None:
setup_logger()
logging_test_name(self)

def test_balance(self):
"""
Schedules a few long-lasting tasks to a single process cluster, then adds workers. We expect the remaining tasks
Expand Down
3 changes: 2 additions & 1 deletion tests/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from scaler.utility.exceptions import ProcessorDiedError
from scaler.utility.logging.scoped_logger import ScopedLogger
from scaler.utility.logging.utility import setup_logger
from tests.utility import get_available_tcp_port
from tests.utility import get_available_tcp_port, logging_test_name


def noop(sec: int):
Expand All @@ -32,6 +32,7 @@ def raise_exception(foo: int):
class TestClient(unittest.TestCase):
def setUp(self) -> None:
setup_logger()
logging_test_name(self)
self.address = f"tcp://127.0.0.1:{get_available_tcp_port()}"
self._workers = 3
self.cluster = SchedulerClusterCombo(address=self.address, n_workers=self._workers, event_loop="builtin")
Expand Down
18 changes: 10 additions & 8 deletions tests/test_death_timeout.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,16 @@
)
from scaler.utility.logging.utility import setup_logger
from scaler.utility.zmq_config import ZMQConfig
from tests.utility import get_available_tcp_port
from tests.utility import get_available_tcp_port, logging_test_name


# This is a manual test because it can loop infinitely if it fails


class TestDeathTimeout(unittest.TestCase):
def setUp(self) -> None:
setup_logger()
logging_test_name(self)

def test_no_scheduler(self):
logging.info("test with no scheduler")
Expand Down Expand Up @@ -61,16 +63,16 @@ def test_shutdown(self):

def test_no_timeout_if_suspended(self):
"""
Client and scheduler shouldn't timeout a client if it is running inside a suspended processor.
Client and scheduler shouldn't time out a client if it is running inside a suspended processor.
"""

CLIENT_TIMEOUT_SECONDS = 3
client_timeout_seconds = 3

def parent(client: Client):
return client.submit(child).result()
def parent(c: Client):
return c.submit(child).result()

def child():
time.sleep(CLIENT_TIMEOUT_SECONDS + 1) # prevents the parent task to execute.
time.sleep(client_timeout_seconds + 1) # prevents the parent task to execute.
return "OK"

address = f"tcp://127.0.0.1:{get_available_tcp_port()}"
Expand All @@ -79,11 +81,11 @@ def child():
n_workers=1,
per_worker_queue_size=2,
event_loop="builtin",
client_timeout_seconds=CLIENT_TIMEOUT_SECONDS,
client_timeout_seconds=client_timeout_seconds,
)

try:
with Client(address, timeout_seconds=CLIENT_TIMEOUT_SECONDS) as client:
with Client(address, timeout_seconds=client_timeout_seconds) as client:
future = client.submit(parent, client)
self.assertEqual(future.result(), "OK")
finally:
Expand Down
Loading

0 comments on commit f3b6cf4

Please sign in to comment.