From f3b6cf44c389248c96cce745d1fa76613923f24d Mon Sep 17 00:00:00 2001 From: sharpener6 <1sc2l4qi@duck.com> Date: Wed, 16 Oct 2024 10:31:04 -0400 Subject: [PATCH] Get rid of copy memory when receive and send message (#34) * Get rid of copy memory when receive and send Signed-off-by: Sharpner6 <1sc2l4qi@duck.com> * Reduce number of threads and use shared context Signed-off-by: Sharpner6 <1sc2l4qi@duck.com> * Add test name print out Signed-off-by: Sharpner6 <1sc2l4qi@duck.com> --------- Signed-off-by: Sharpner6 <1sc2l4qi@duck.com> --- README.md | 2 +- scaler/about.py | 2 +- scaler/client/agent/client_agent.py | 6 +----- scaler/client/client.py | 8 ++++---- scaler/io/async_binder.py | 15 ++++++++------- scaler/io/async_connector.py | 6 +++--- scaler/io/sync_connector.py | 6 +++--- scaler/io/sync_subscriber.py | 2 +- scaler/scheduler/scheduler.py | 5 +++-- scaler/worker/agent/processor_manager.py | 6 +++--- scaler/worker/worker.py | 6 ++++-- tests/test_async_indexed_queue.py | 6 ++++++ tests/test_async_priority_queue.py | 6 ++++++ tests/test_async_sorted_priority_queue.py | 6 ++++++ tests/test_balance.py | 7 ++++++- tests/test_client.py | 3 ++- tests/test_death_timeout.py | 18 ++++++++++-------- tests/test_future.py | 3 ++- tests/test_graph.py | 3 ++- tests/test_indexed_queue.py | 6 ++++++ tests/test_nested_task.py | 2 ++ tests/test_object_usage.py | 5 ++++- tests/test_profiling.py | 5 ++++- tests/test_protected.py | 8 ++++++-- tests/test_serializer.py | 3 ++- tests/test_ui.py | 4 ++-- tests/test_worker_object_tracker.py | 6 ++++++ tests/utility.py | 6 ++++++ 28 files changed, 110 insertions(+), 51 deletions(-) diff --git a/README.md b/README.md index ad96070..4dc1b8f 100644 --- a/README.md +++ b/README.md @@ -16,7 +16,7 @@ PyPI - Version - +

diff --git a/scaler/about.py b/scaler/about.py index cd12fb0..490fbd4 100644 --- a/scaler/about.py +++ b/scaler/about.py @@ -1 +1 @@ -__version__ = "1.8.8" +__version__ = "1.8.9" diff --git a/scaler/client/agent/client_agent.py b/scaler/client/agent/client_agent.py index 6bf6507..7fb3294 100644 --- a/scaler/client/agent/client_agent.py +++ b/scaler/client/agent/client_agent.py @@ -58,8 +58,6 @@ def __init__( self._future_manager = future_manager - self._external_context = zmq.asyncio.Context() - self._connector_internal = AsyncConnector( context=zmq.asyncio.Context.shadow(self._context), name="client_agent_internal", @@ -70,7 +68,7 @@ def __init__( identity=None, ) self._connector_external = AsyncConnector( - context=self._external_context, + context=zmq.asyncio.Context.shadow(self._context), name="client_agent_external", socket_type=zmq.DEALER, address=self._scheduler_address, @@ -183,8 +181,6 @@ async def __get_loops(self): self._connector_external.destroy() self._connector_internal.destroy() - self._external_context.destroy() - if exception is None: return diff --git a/scaler/client/client.py b/scaler/client/client.py index 100a558..8fc6071 100644 --- a/scaler/client/client.py +++ b/scaler/client/client.py @@ -89,9 +89,9 @@ def __initialize__( self._heartbeat_interval_seconds = heartbeat_interval_seconds self._stop_event = threading.Event() - self._internal_context = zmq.Context() + self._context = zmq.Context() self._connector = SyncConnector( - context=self._internal_context, + context=self._context, socket_type=zmq.PAIR, address=self._client_agent_address, identity=self._identity, @@ -102,7 +102,7 @@ def __initialize__( identity=self._identity, client_agent_address=self._client_agent_address, scheduler_address=ZMQConfig.from_string(address), - context=self._internal_context, + context=self._context, future_manager=self._future_manager, stop_event=self._stop_event, timeout_seconds=self._timeout_seconds, @@ -539,7 +539,7 @@ def __assert_client_not_stopped(self): def __destroy(self): self._agent.join() - self._internal_context.destroy(linger=1) + self._context.destroy(linger=1) @staticmethod def __get_parent_task_priority() -> Optional[int]: diff --git a/scaler/io/async_binder.py b/scaler/io/async_binder.py index 68d5e1a..80e1b93 100644 --- a/scaler/io/async_binder.py +++ b/scaler/io/async_binder.py @@ -5,6 +5,7 @@ from typing import Awaitable, Callable, List, Optional, Dict import zmq.asyncio +from zmq import Frame from scaler.io.utility import deserialize, serialize from scaler.protocol.python.mixins import Message @@ -14,14 +15,14 @@ class AsyncBinder(Looper, Reporter): - def __init__(self, name: str, address: ZMQConfig, io_threads: int, identity: Optional[bytes] = None): + def __init__(self, context: zmq.asyncio.Context, name: str, address: ZMQConfig, identity: Optional[bytes] = None): self._address = address if identity is None: identity = f"{os.getpid()}|{name}|{uuid.uuid4()}".encode() self._identity = identity - self._context = zmq.asyncio.Context(io_threads=io_threads) + self._context = context self._socket = self._context.socket(zmq.ROUTER) self.__set_socket_options() self._socket.bind(self._address.to_address()) @@ -38,18 +39,18 @@ def register(self, callback: Callable[[bytes, Message], Awaitable[None]]): self._callback = callback async def routine(self): - frames = await self._socket.recv_multipart() + frames: List[Frame] = await self._socket.recv_multipart(copy=False) if not self.__is_valid_message(frames): return source, payload = frames - message: Optional[Message] = deserialize(payload) + message: Optional[Message] = deserialize(payload.bytes) if message is None: - logging.error(f"received unknown message from {source!r}: {payload!r}") + logging.error(f"received unknown message from {source.bytes!r}: {payload!r}") return self.__count_received(message.__class__.__name__) - await self._callback(source, message) + await self._callback(source.bytes, message) async def send(self, to: bytes, message: Message): self.__count_sent(message.__class__.__name__) @@ -63,7 +64,7 @@ def __set_socket_options(self): self._socket.setsockopt(zmq.SNDHWM, 0) self._socket.setsockopt(zmq.RCVHWM, 0) - def __is_valid_message(self, frames: List[bytes]) -> bool: + def __is_valid_message(self, frames: List[Frame]) -> bool: if len(frames) < 2: logging.error(f"{self.__get_prefix()} received unexpected frames {frames}") return False diff --git a/scaler/io/async_connector.py b/scaler/io/async_connector.py index 86be509..a9a02c4 100644 --- a/scaler/io/async_connector.py +++ b/scaler/io/async_connector.py @@ -82,10 +82,10 @@ async def receive(self) -> Optional[Message]: if self._socket.closed: return None - payload = await self._socket.recv() - result: Optional[Message] = deserialize(payload) + payload = await self._socket.recv(copy=False) + result: Optional[Message] = deserialize(payload.bytes) if result is None: - logging.error(f"received unknown message: {payload!r}") + logging.error(f"received unknown message: {payload.bytes!r}") return None return result diff --git a/scaler/io/sync_connector.py b/scaler/io/sync_connector.py index a45b6a5..b3c9cb8 100644 --- a/scaler/io/sync_connector.py +++ b/scaler/io/sync_connector.py @@ -47,13 +47,13 @@ def identity(self) -> bytes: def send(self, message: Message): with self._lock: - self._socket.send(serialize(message)) + self._socket.send(serialize(message), copy=False) def receive(self) -> Optional[Message]: with self._lock: - payload = self._socket.recv() + payload = self._socket.recv(copy=False) - return self.__compose_message(payload) + return self.__compose_message(payload.bytes) def __compose_message(self, payload: bytes) -> Optional[Message]: result: Optional[Message] = deserialize(payload) diff --git a/scaler/io/sync_subscriber.py b/scaler/io/sync_subscriber.py index 5423f2c..17bc5b7 100644 --- a/scaler/io/sync_subscriber.py +++ b/scaler/io/sync_subscriber.py @@ -69,7 +69,7 @@ def __initialize(self): def __routine_polling(self): try: - self.__routine_receive(self._socket.recv()) + self.__routine_receive(self._socket.recv(copy=False).bytes) except zmq.Again: raise TimeoutError(f"Cannot connect to {self._address.to_address()} in {self._timeout_seconds} seconds") diff --git a/scaler/scheduler/scheduler.py b/scaler/scheduler/scheduler.py index e0fe274..2d378d1 100644 --- a/scaler/scheduler/scheduler.py +++ b/scaler/scheduler/scheduler.py @@ -45,9 +45,10 @@ def __init__(self, config: SchedulerConfig): ) logging.info(f"{self.__class__.__name__}: monitor address is {self._address_monitor.to_address()}") - self._binder = AsyncBinder(name="scheduler", address=config.address, io_threads=config.io_threads) + self._context = zmq.asyncio.Context(io_threads=config.io_threads) + self._binder = AsyncBinder(context=self._context, name="scheduler", address=config.address) self._binder_monitor = AsyncConnector( - context=zmq.asyncio.Context(), + context=self._context, name="scheduler_monitor", socket_type=zmq.PUB, address=self._address_monitor, diff --git a/scaler/worker/agent/processor_manager.py b/scaler/worker/agent/processor_manager.py index 5653df0..3d467aa 100644 --- a/scaler/worker/agent/processor_manager.py +++ b/scaler/worker/agent/processor_manager.py @@ -6,6 +6,7 @@ from typing import Dict, List, Optional, Tuple import tblib.pickling_support +import zmq.asyncio # from scaler.utility.logging.utility import setup_logger from scaler.io.async_binder import AsyncBinder @@ -32,8 +33,8 @@ class VanillaProcessorManager(Looper, ProcessorManager): def __init__( self, + context: zmq.asyncio.Context, event_loop: str, - io_threads: int, garbage_collect_interval_seconds: int, trim_memory_threshold_bytes: int, hard_processor_suspend: bool, @@ -43,7 +44,6 @@ def __init__( tblib.pickling_support.install() self._event_loop = event_loop - self._io_threads = io_threads self._garbage_collect_interval_seconds = garbage_collect_interval_seconds self._trim_memory_threshold_bytes = trim_memory_threshold_bytes @@ -67,7 +67,7 @@ def __init__( self._task_active_lock: asyncio.Lock = asyncio.Lock() self._binder_internal: AsyncBinder = AsyncBinder( - name="processor_manager", address=self._address, io_threads=self._io_threads, identity=None + context=context, name="processor_manager", address=self._address, identity=None ) self._binder_internal.register(self.__on_receive_internal) diff --git a/scaler/worker/worker.py b/scaler/worker/worker.py index e9ec4a7..c594536 100644 --- a/scaler/worker/worker.py +++ b/scaler/worker/worker.py @@ -63,6 +63,7 @@ def __init__( self._logging_paths = logging_paths self._logging_level = logging_level + self._context: Optional[zmq.asyncio.Context] = None self._connector_external: Optional[AsyncConnector] = None self._task_manager: Optional[VanillaTaskManager] = None self._heartbeat_manager: Optional[VanillaHeartbeatManager] = None @@ -81,8 +82,9 @@ def __initialize(self): setup_logger() register_event_loop(self._event_loop) + self._context = zmq.asyncio.Context() self._connector_external = AsyncConnector( - context=zmq.asyncio.Context(), + context=self._context, name=self.name, socket_type=zmq.DEALER, address=self._address, @@ -96,8 +98,8 @@ def __initialize(self): self._task_manager = VanillaTaskManager(task_timeout_seconds=self._task_timeout_seconds) self._timeout_manager = VanillaTimeoutManager(death_timeout_seconds=self._death_timeout_seconds) self._processor_manager = VanillaProcessorManager( + context=self._context, event_loop=self._event_loop, - io_threads=self._io_threads, garbage_collect_interval_seconds=self._garbage_collect_interval_seconds, trim_memory_threshold_bytes=self._trim_memory_threshold_bytes, hard_processor_suspend=self._hard_processor_suspend, diff --git a/tests/test_async_indexed_queue.py b/tests/test_async_indexed_queue.py index 63ab638..2f53b2d 100644 --- a/tests/test_async_indexed_queue.py +++ b/tests/test_async_indexed_queue.py @@ -1,10 +1,16 @@ import asyncio import unittest +from scaler.utility.logging.utility import setup_logger from scaler.utility.queues.async_indexed_queue import AsyncIndexedQueue +from tests.utility import logging_test_name class TestAsyncIndexedQueue(unittest.TestCase): + def setUp(self) -> None: + setup_logger() + logging_test_name(self) + def test_async_indexed_queue(self): async def async_test(): queue = AsyncIndexedQueue() diff --git a/tests/test_async_priority_queue.py b/tests/test_async_priority_queue.py index 42aa5cb..7431b2e 100644 --- a/tests/test_async_priority_queue.py +++ b/tests/test_async_priority_queue.py @@ -1,10 +1,16 @@ import asyncio import unittest +from scaler.utility.logging.utility import setup_logger from scaler.utility.queues.async_priority_queue import AsyncPriorityQueue +from tests.utility import logging_test_name class TestAsyncPriorityQueue(unittest.TestCase): + def setUp(self) -> None: + setup_logger() + logging_test_name(self) + def test_async_priority_queue(self): async def async_test(): queue = AsyncPriorityQueue() diff --git a/tests/test_async_sorted_priority_queue.py b/tests/test_async_sorted_priority_queue.py index 34c42ba..e1cb4d0 100644 --- a/tests/test_async_sorted_priority_queue.py +++ b/tests/test_async_sorted_priority_queue.py @@ -1,10 +1,16 @@ import asyncio import unittest +from scaler.utility.logging.utility import setup_logger from scaler.utility.queues.async_sorted_priority_queue import AsyncSortedPriorityQueue +from tests.utility import logging_test_name class TestSortedPriorityQueue(unittest.TestCase): + def setUp(self) -> None: + setup_logger() + logging_test_name(self) + def test_sorted_priority_queue(self): async def async_test(): queue = AsyncSortedPriorityQueue() diff --git a/tests/test_balance.py b/tests/test_balance.py index 339c8e0..ae7eff9 100644 --- a/tests/test_balance.py +++ b/tests/test_balance.py @@ -3,7 +3,8 @@ import unittest from scaler import Client, Cluster, SchedulerClusterCombo -from tests.utility import get_available_tcp_port +from scaler.utility.logging.utility import setup_logger +from tests.utility import get_available_tcp_port, logging_test_name def sleep_and_return_pid(sec: int): @@ -12,6 +13,10 @@ def sleep_and_return_pid(sec: int): class TestBalance(unittest.TestCase): + def setUp(self) -> None: + setup_logger() + logging_test_name(self) + def test_balance(self): """ Schedules a few long-lasting tasks to a single process cluster, then adds workers. We expect the remaining tasks diff --git a/tests/test_client.py b/tests/test_client.py index 678a744..fadd817 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -8,7 +8,7 @@ from scaler.utility.exceptions import ProcessorDiedError from scaler.utility.logging.scoped_logger import ScopedLogger from scaler.utility.logging.utility import setup_logger -from tests.utility import get_available_tcp_port +from tests.utility import get_available_tcp_port, logging_test_name def noop(sec: int): @@ -32,6 +32,7 @@ def raise_exception(foo: int): class TestClient(unittest.TestCase): def setUp(self) -> None: setup_logger() + logging_test_name(self) self.address = f"tcp://127.0.0.1:{get_available_tcp_port()}" self._workers = 3 self.cluster = SchedulerClusterCombo(address=self.address, n_workers=self._workers, event_loop="builtin") diff --git a/tests/test_death_timeout.py b/tests/test_death_timeout.py index 39cdcb3..58273d9 100644 --- a/tests/test_death_timeout.py +++ b/tests/test_death_timeout.py @@ -12,7 +12,8 @@ ) from scaler.utility.logging.utility import setup_logger from scaler.utility.zmq_config import ZMQConfig -from tests.utility import get_available_tcp_port +from tests.utility import get_available_tcp_port, logging_test_name + # This is a manual test because it can loop infinitely if it fails @@ -20,6 +21,7 @@ class TestDeathTimeout(unittest.TestCase): def setUp(self) -> None: setup_logger() + logging_test_name(self) def test_no_scheduler(self): logging.info("test with no scheduler") @@ -61,16 +63,16 @@ def test_shutdown(self): def test_no_timeout_if_suspended(self): """ - Client and scheduler shouldn't timeout a client if it is running inside a suspended processor. + Client and scheduler shouldn't time out a client if it is running inside a suspended processor. """ - CLIENT_TIMEOUT_SECONDS = 3 + client_timeout_seconds = 3 - def parent(client: Client): - return client.submit(child).result() + def parent(c: Client): + return c.submit(child).result() def child(): - time.sleep(CLIENT_TIMEOUT_SECONDS + 1) # prevents the parent task to execute. + time.sleep(client_timeout_seconds + 1) # prevents the parent task to execute. return "OK" address = f"tcp://127.0.0.1:{get_available_tcp_port()}" @@ -79,11 +81,11 @@ def child(): n_workers=1, per_worker_queue_size=2, event_loop="builtin", - client_timeout_seconds=CLIENT_TIMEOUT_SECONDS, + client_timeout_seconds=client_timeout_seconds, ) try: - with Client(address, timeout_seconds=CLIENT_TIMEOUT_SECONDS) as client: + with Client(address, timeout_seconds=client_timeout_seconds) as client: future = client.submit(parent, client) self.assertEqual(future.result(), "OK") finally: diff --git a/tests/test_future.py b/tests/test_future.py index a6f6666..011f115 100644 --- a/tests/test_future.py +++ b/tests/test_future.py @@ -6,7 +6,7 @@ from scaler import Client, SchedulerClusterCombo from scaler.utility.logging.utility import setup_logger -from tests.utility import get_available_tcp_port +from tests.utility import get_available_tcp_port, logging_test_name def noop_sleep(sec: int): @@ -16,6 +16,7 @@ def noop_sleep(sec: int): class TestFuture(unittest.TestCase): def setUp(self) -> None: setup_logger() + logging_test_name(self) self.address = f"tcp://127.0.0.1:{get_available_tcp_port()}" self._workers = 3 self.cluster = SchedulerClusterCombo(address=self.address, n_workers=self._workers, event_loop="builtin") diff --git a/tests/test_graph.py b/tests/test_graph.py index 3d22b41..8b50969 100644 --- a/tests/test_graph.py +++ b/tests/test_graph.py @@ -6,7 +6,7 @@ from scaler.utility.graph.optimization import cull_graph from scaler.utility.logging.scoped_logger import ScopedLogger from scaler.utility.logging.utility import setup_logger -from tests.utility import get_available_tcp_port +from tests.utility import get_available_tcp_port, logging_test_name def inc(i): @@ -24,6 +24,7 @@ def minus(a, b): class TestGraph(unittest.TestCase): def setUp(self) -> None: setup_logger() + logging_test_name(self) self.address = f"tcp://127.0.0.1:{get_available_tcp_port()}" self.cluster = SchedulerClusterCombo(address=self.address, n_workers=3, event_loop="builtin") diff --git a/tests/test_indexed_queue.py b/tests/test_indexed_queue.py index 146b228..0460fb5 100644 --- a/tests/test_indexed_queue.py +++ b/tests/test_indexed_queue.py @@ -1,9 +1,15 @@ import unittest +from scaler.utility.logging.utility import setup_logger from scaler.utility.queues.indexed_queue import IndexedQueue +from tests.utility import logging_test_name class TestIndexedQueue(unittest.TestCase): + def setUp(self) -> None: + setup_logger() + logging_test_name(self) + def test_indexed_queue(self): queue = IndexedQueue() queue.put(1) diff --git a/tests/test_nested_task.py b/tests/test_nested_task.py index c40fb7c..e675027 100644 --- a/tests/test_nested_task.py +++ b/tests/test_nested_task.py @@ -2,6 +2,7 @@ from scaler import Client, SchedulerClusterCombo from scaler.utility.logging.utility import setup_logger +from tests.utility import logging_test_name N_TASKS = 30 N_WORKERS = 3 @@ -11,6 +12,7 @@ class TestNestedTask(unittest.TestCase): def setUp(self) -> None: setup_logger() + logging_test_name(self) self.address = "tcp://127.0.0.1:23456" self.cluster = SchedulerClusterCombo(address=self.address, n_workers=N_WORKERS, event_loop="builtin") diff --git a/tests/test_object_usage.py b/tests/test_object_usage.py index ca05b5d..a062e9d 100644 --- a/tests/test_object_usage.py +++ b/tests/test_object_usage.py @@ -3,6 +3,7 @@ from scaler.scheduler.object_usage.object_tracker import ObjectTracker, ObjectUsage from scaler.utility.logging.utility import setup_logger +from tests.utility import logging_test_name @dataclasses.dataclass @@ -19,9 +20,11 @@ def sample_ready(obj: Sample): class TestObjectUsage(unittest.TestCase): - def test_object_usage(self): + def setUp(self) -> None: setup_logger() + logging_test_name(self) + def test_object_usage(self): object_usage: ObjectTracker[str, Sample] = ObjectTracker("sample", sample_ready) object_usage.add_object(Sample("a", "value1")) diff --git a/tests/test_profiling.py b/tests/test_profiling.py index f12adc8..62fe9a0 100644 --- a/tests/test_profiling.py +++ b/tests/test_profiling.py @@ -2,7 +2,8 @@ import unittest from scaler import Client, SchedulerClusterCombo -from tests.utility import get_available_tcp_port +from scaler.utility.logging.utility import setup_logger +from tests.utility import get_available_tcp_port, logging_test_name def dummy(n: int): @@ -20,6 +21,8 @@ def busy_dummy(n: int): class TestProfiling(unittest.TestCase): def setUp(self): + setup_logger() + logging_test_name(self) self.address = f"tcp://127.0.0.1:{get_available_tcp_port()}" self.cluster = SchedulerClusterCombo( address=self.address, n_workers=2, per_worker_queue_size=2, event_loop="builtin" diff --git a/tests/test_protected.py b/tests/test_protected.py index 941f86b..3f785ce 100644 --- a/tests/test_protected.py +++ b/tests/test_protected.py @@ -2,11 +2,15 @@ import unittest from scaler import Client, SchedulerClusterCombo - -from tests.utility import get_available_tcp_port +from scaler.utility.logging.utility import setup_logger +from tests.utility import get_available_tcp_port, logging_test_name class TestProtected(unittest.TestCase): + def setUp(self) -> None: + setup_logger() + logging_test_name(self) + def test_protected_true(self) -> None: address = f"tcp://127.0.0.1:{get_available_tcp_port()}" cluster = SchedulerClusterCombo( diff --git a/tests/test_serializer.py b/tests/test_serializer.py index 1b26e12..d9ce9e4 100644 --- a/tests/test_serializer.py +++ b/tests/test_serializer.py @@ -9,7 +9,7 @@ from scaler import Client, SchedulerClusterCombo, Serializer from scaler.utility.logging.scoped_logger import ScopedLogger from scaler.utility.logging.utility import setup_logger -from tests.utility import get_available_tcp_port +from tests.utility import get_available_tcp_port, logging_test_name def noop(sec: int): @@ -52,6 +52,7 @@ def trim_message_internal(message: Any) -> str: class TestSerializer(unittest.TestCase): def setUp(self) -> None: setup_logger() + logging_test_name(self) self.address = f"tcp://127.0.0.1:{get_available_tcp_port()}" self._workers = 3 self.cluster = SchedulerClusterCombo(address=self.address, n_workers=self._workers, event_loop="builtin") diff --git a/tests/test_ui.py b/tests/test_ui.py index c1456b9..233f98c 100644 --- a/tests/test_ui.py +++ b/tests/test_ui.py @@ -5,8 +5,7 @@ from scaler import Client, SchedulerClusterCombo from scaler.utility.logging.scoped_logger import ScopedLogger from scaler.utility.logging.utility import setup_logger - -from tests.utility import get_available_tcp_port +from tests.utility import get_available_tcp_port, logging_test_name def noop(sec: int): @@ -30,6 +29,7 @@ def noop_memory(length: int): class TestUI(unittest.TestCase): def setUp(self) -> None: setup_logger() + logging_test_name(self) self.address = f"tcp://127.0.0.1:{get_available_tcp_port()}" self._workers = 10 self.cluster = SchedulerClusterCombo(address=self.address, n_workers=self._workers, event_loop="builtin") diff --git a/tests/test_worker_object_tracker.py b/tests/test_worker_object_tracker.py index 72e6509..ef18eed 100644 --- a/tests/test_worker_object_tracker.py +++ b/tests/test_worker_object_tracker.py @@ -2,10 +2,16 @@ from scaler.protocol.python.common import ObjectContent from scaler.protocol.python.message import ObjectInstruction, ObjectRequest, ObjectResponse +from scaler.utility.logging.utility import setup_logger from scaler.worker.agent.object_tracker import VanillaObjectTracker +from tests.utility import logging_test_name class TestWorkerObjectTracker(unittest.TestCase): + def setUp(self) -> None: + setup_logger() + logging_test_name(self) + def test_object_tracker(self) -> None: tracker = VanillaObjectTracker() diff --git a/tests/utility.py b/tests/utility.py index 74d6d2f..37d0650 100644 --- a/tests/utility.py +++ b/tests/utility.py @@ -1,7 +1,13 @@ +import logging import socket +import unittest def get_available_tcp_port(hostname: str = "127.0.0.1") -> int: with socket.socket(socket.AddressFamily.AF_INET, socket.SocketKind.SOCK_STREAM) as sock: sock.bind((hostname, 0)) return sock.getsockname()[1] + + +def logging_test_name(obj: unittest.TestCase): + logging.info(f"{obj.__class__.__name__}:{obj._testMethodName} ==============================================")