diff --git a/scaler/about.py b/scaler/about.py index d3d41ce..d48d2bf 100644 --- a/scaler/about.py +++ b/scaler/about.py @@ -1 +1 @@ -__version__ = "1.8.10" +__version__ = "1.8.11" diff --git a/scaler/client/agent/future_manager.py b/scaler/client/agent/future_manager.py index 154fcd5..d9733ec 100644 --- a/scaler/client/agent/future_manager.py +++ b/scaler/client/agent/future_manager.py @@ -1,11 +1,12 @@ import logging import threading -from concurrent.futures import InvalidStateError, Future +from concurrent.futures import Future, InvalidStateError from typing import Dict, Tuple from scaler.client.agent.mixins import FutureManager from scaler.client.future import ScalerFuture from scaler.client.serializer.mixins import Serializer +from scaler.io.utility import concat_list_of_bytes from scaler.protocol.python.common import TaskStatus from scaler.protocol.python.message import ObjectResponse, TaskResult from scaler.utility.exceptions import DisconnectedError, NoWorkerError, TaskNotFoundError, WorkerDiedError @@ -106,9 +107,9 @@ def on_object_response(self, response: ObjectResponse): try: if status == TaskStatus.Success: - future.set_result(self._serializer.deserialize(object_bytes)) + future.set_result(self._serializer.deserialize(concat_list_of_bytes(object_bytes))) elif status == TaskStatus.Failed: - future.set_exception(deserialize_failure(object_bytes)) + future.set_exception(deserialize_failure(concat_list_of_bytes(object_bytes))) except InvalidStateError: continue # future got canceled diff --git a/scaler/client/client.py b/scaler/client/client.py index 8fc6071..e5c9899 100644 --- a/scaler/client/client.py +++ b/scaler/client/client.py @@ -91,10 +91,7 @@ def __initialize__( self._stop_event = threading.Event() self._context = zmq.Context() self._connector = SyncConnector( - context=self._context, - socket_type=zmq.PAIR, - address=self._client_agent_address, - identity=self._identity, + context=self._context, socket_type=zmq.PAIR, address=self._client_agent_address, identity=self._identity ) self._future_manager = ClientFutureManager(self._serializer) @@ -309,7 +306,7 @@ def send_object(self, obj: Any, name: Optional[str] = None) -> ObjectReference: self.__assert_client_not_stopped() cache = self._object_buffer.buffer_send_object(obj, name) - return ObjectReference(cache.object_name, cache.object_id, len(cache.object_bytes)) + return ObjectReference(cache.object_name, cache.object_id, sum(map(len, cache.object_bytes))) def disconnect(self): """ diff --git a/scaler/client/object_buffer.py b/scaler/client/object_buffer.py index bbb1050..885c8cc 100644 --- a/scaler/client/object_buffer.py +++ b/scaler/client/object_buffer.py @@ -6,6 +6,7 @@ from scaler.client.serializer.mixins import Serializer from scaler.io.sync_connector import SyncConnector +from scaler.io.utility import chunk_to_list_of_bytes from scaler.protocol.python.common import ObjectContent from scaler.protocol.python.message import ObjectInstruction from scaler.utility.object_utility import generate_object_id, generate_serializer_object_id @@ -15,7 +16,7 @@ class ObjectCache: object_id: bytes object_name: bytes - object_bytes: bytes + object_bytes: List[bytes] class ObjectBuffer: @@ -83,13 +84,15 @@ def commit_delete_objects(self): def __construct_serializer(self) -> ObjectCache: serializer_bytes = cloudpickle.dumps(self._serializer, protocol=pickle.HIGHEST_PROTOCOL) object_id = generate_serializer_object_id(self._identity) - return ObjectCache(object_id, b"serializer", serializer_bytes) + return ObjectCache(object_id, b"serializer", chunk_to_list_of_bytes(serializer_bytes)) def __construct_function(self, fn: Callable) -> ObjectCache: function_bytes = self._serializer.serialize(fn) object_id = generate_object_id(self._identity, function_bytes) function_cache = ObjectCache( - object_id, getattr(fn, "__name__", f"").encode(), function_bytes + object_id, + getattr(fn, "__name__", f"").encode(), + chunk_to_list_of_bytes(function_bytes), ) return function_cache @@ -97,4 +100,4 @@ def __construct_object(self, obj: Any, name: Optional[str] = None) -> ObjectCach object_payload = self._serializer.serialize(obj) object_id = generate_object_id(self._identity, object_payload) name_bytes = name.encode() if name else f"".encode() - return ObjectCache(object_id, name_bytes, object_payload) + return ObjectCache(object_id, name_bytes, chunk_to_list_of_bytes(object_payload)) diff --git a/scaler/cluster/combo.py b/scaler/cluster/combo.py index d7af8ec..858722b 100644 --- a/scaler/cluster/combo.py +++ b/scaler/cluster/combo.py @@ -7,6 +7,7 @@ from scaler.io.config import ( DEFAULT_CLIENT_TIMEOUT_SECONDS, DEFAULT_GARBAGE_COLLECT_INTERVAL_SECONDS, + DEFAULT_HARD_PROCESSOR_SUSPEND, DEFAULT_HEARTBEAT_INTERVAL_SECONDS, DEFAULT_IO_THREADS, DEFAULT_LOAD_BALANCE_SECONDS, @@ -18,7 +19,6 @@ DEFAULT_TRIM_MEMORY_THRESHOLD_BYTES, DEFAULT_WORKER_DEATH_TIMEOUT, DEFAULT_WORKER_TIMEOUT_SECONDS, - DEFAULT_HARD_PROCESSOR_SUSPEND, ) from scaler.utility.zmq_config import ZMQConfig diff --git a/scaler/cluster/scheduler.py b/scaler/cluster/scheduler.py index 347721a..b770160 100644 --- a/scaler/cluster/scheduler.py +++ b/scaler/cluster/scheduler.py @@ -1,7 +1,7 @@ import asyncio import multiprocessing from asyncio import AbstractEventLoop, Task -from typing import Optional, Tuple, Any +from typing import Any, Optional, Tuple from scaler.scheduler.config import SchedulerConfig from scaler.scheduler.scheduler import Scheduler, scheduler_main diff --git a/scaler/entry_points/cluster.py b/scaler/entry_points/cluster.py index 581465a..8814003 100644 --- a/scaler/entry_points/cluster.py +++ b/scaler/entry_points/cluster.py @@ -4,13 +4,13 @@ from scaler.cluster.cluster import Cluster from scaler.io.config import ( DEFAULT_GARBAGE_COLLECT_INTERVAL_SECONDS, + DEFAULT_HARD_PROCESSOR_SUSPEND, DEFAULT_HEARTBEAT_INTERVAL_SECONDS, DEFAULT_IO_THREADS, DEFAULT_NUMBER_OF_WORKER, DEFAULT_TASK_TIMEOUT_SECONDS, DEFAULT_TRIM_MEMORY_THRESHOLD_BYTES, DEFAULT_WORKER_DEATH_TIMEOUT, - DEFAULT_HARD_PROCESSOR_SUSPEND, ) from scaler.utility.event_loop import EventLoopType, register_event_loop from scaler.utility.zmq_config import ZMQConfig diff --git a/scaler/entry_points/top.py b/scaler/entry_points/top.py index c8c598e..f47970a 100644 --- a/scaler/entry_points/top.py +++ b/scaler/entry_points/top.py @@ -1,7 +1,7 @@ import argparse import curses import functools -from typing import List, Literal, Dict, Union +from typing import Dict, List, Literal, Union from scaler.io.sync_subscriber import SyncSubscriber from scaler.protocol.python.message import StateScheduler diff --git a/scaler/io/async_binder.py b/scaler/io/async_binder.py index 80e1b93..bb2e192 100644 --- a/scaler/io/async_binder.py +++ b/scaler/io/async_binder.py @@ -2,7 +2,7 @@ import os import uuid from collections import defaultdict -from typing import Awaitable, Callable, List, Optional, Dict +from typing import Awaitable, Callable, Dict, List, Optional import zmq.asyncio from zmq import Frame diff --git a/scaler/io/config.py b/scaler/io/config.py index 3934444..7e22715 100644 --- a/scaler/io/config.py +++ b/scaler/io/config.py @@ -12,8 +12,11 @@ # number of seconds for profiling PROFILING_INTERVAL_SECONDS = 1 +# cap'n proto only allow Data/Text/Blob size to be as big as 500MB +CAPNP_DATA_SIZE_LIMIT = 2**29 - 1 + # message size limitation, max can be 2**64 -MESSAGE_SIZE_LIMIT = 2**64 - 1 +CAPNP_MESSAGE_SIZE_LIMIT = 2**64 - 1 # ========================== # SCHEDULER SPECIFIC OPTIONS diff --git a/scaler/io/utility.py b/scaler/io/utility.py index 6465671..861841f 100644 --- a/scaler/io/utility.py +++ b/scaler/io/utility.py @@ -1,14 +1,14 @@ import logging -from typing import Optional +from typing import List, Optional -from scaler.io.config import MESSAGE_SIZE_LIMIT +from scaler.io.config import CAPNP_DATA_SIZE_LIMIT, CAPNP_MESSAGE_SIZE_LIMIT from scaler.protocol.capnp._python import _message # noqa from scaler.protocol.python.message import PROTOCOL from scaler.protocol.python.mixins import Message def deserialize(data: bytes) -> Optional[Message]: - with _message.Message.from_bytes(data, traversal_limit_in_words=MESSAGE_SIZE_LIMIT) as payload: + with _message.Message.from_bytes(data, traversal_limit_in_words=CAPNP_MESSAGE_SIZE_LIMIT) as payload: if not hasattr(payload, payload.which()): logging.error(f"unknown message type: {payload.which()}") return None @@ -20,3 +20,15 @@ def deserialize(data: bytes) -> Optional[Message]: def serialize(message: Message) -> bytes: payload = _message.Message(**{PROTOCOL.inverse[type(message)]: message.get_message()}) return payload.to_bytes() + + +def chunk_to_list_of_bytes(data: bytes) -> List[bytes]: + # TODO: change to list of memoryview when capnp can support memoryview + return [data[i : i + CAPNP_DATA_SIZE_LIMIT] for i in range(0, len(data), CAPNP_DATA_SIZE_LIMIT)] + + +def concat_list_of_bytes(data: List[bytes]) -> bytes: + one_object_bytes = bytearray() + for chunk in data: + one_object_bytes.extend(chunk) + return one_object_bytes diff --git a/scaler/protocol/capnp/_python.py b/scaler/protocol/capnp/_python.py index 93d725f..6d77cae 100644 --- a/scaler/protocol/capnp/_python.py +++ b/scaler/protocol/capnp/_python.py @@ -1,4 +1,5 @@ import capnp # noqa + import scaler.protocol.capnp.common_capnp as _common # noqa import scaler.protocol.capnp.message_capnp as _message # noqa import scaler.protocol.capnp.status_capnp as _status # noqa diff --git a/scaler/protocol/capnp/common.capnp b/scaler/protocol/capnp/common.capnp index 3f08c37..628f626 100644 --- a/scaler/protocol/capnp/common.capnp +++ b/scaler/protocol/capnp/common.capnp @@ -18,5 +18,5 @@ enum TaskStatus { struct ObjectContent { objectIds @0 :List(Data); objectNames @1 :List(Data); - objectBytes @2 :List(Data); + objectBytes @2 :List(List(Data)); } diff --git a/scaler/protocol/python/common.py b/scaler/protocol/python/common.py index 1e94e9c..e254ddf 100644 --- a/scaler/protocol/python/common.py +++ b/scaler/protocol/python/common.py @@ -1,6 +1,6 @@ import dataclasses import enum -from typing import Tuple +from typing import List, Tuple from scaler.protocol.capnp._python import _common # noqa from scaler.protocol.python.mixins import Message @@ -26,7 +26,7 @@ class TaskStatus(enum.Enum): @dataclasses.dataclass class ObjectContent(Message): def __init__(self, msg): - self._msg = msg + super().__init__(msg) @property def object_ids(self) -> Tuple[bytes, ...]: @@ -37,14 +37,14 @@ def object_names(self) -> Tuple[bytes, ...]: return tuple(self._msg.objectNames) @property - def object_bytes(self) -> Tuple[bytes, ...]: + def object_bytes(self) -> Tuple[List[bytes], ...]: return tuple(self._msg.objectBytes) @staticmethod def new_msg( object_ids: Tuple[bytes, ...], object_names: Tuple[bytes, ...] = tuple(), - object_bytes: Tuple[bytes, ...] = tuple(), + object_bytes: Tuple[List[bytes], ...] = tuple(), ) -> "ObjectContent": return ObjectContent( _common.ObjectContent( diff --git a/scaler/protocol/python/message.py b/scaler/protocol/python/message.py index 43d55b9..e38fb80 100644 --- a/scaler/protocol/python/message.py +++ b/scaler/protocol/python/message.py @@ -1,19 +1,19 @@ import dataclasses import enum import os -from typing import List, Set, Tuple, Optional, Type +from typing import List, Optional, Set, Tuple, Type import bidict from scaler.protocol.capnp._python import _message # noqa -from scaler.protocol.python.common import TaskStatus, ObjectContent +from scaler.protocol.python.common import ObjectContent, TaskStatus from scaler.protocol.python.mixins import Message from scaler.protocol.python.status import ( BinderStatus, ClientManagerStatus, ObjectManagerStatus, - Resource, ProcessorStatus, + Resource, TaskManagerStatus, WorkerManagerStatus, ) diff --git a/scaler/scheduler/mixins.py b/scaler/scheduler/mixins.py index bb6af88..5169e63 100644 --- a/scaler/scheduler/mixins.py +++ b/scaler/scheduler/mixins.py @@ -1,5 +1,5 @@ import abc -from typing import Optional, Set +from typing import List, Optional, Set from scaler.protocol.python.message import ( ClientDisconnect, @@ -7,12 +7,12 @@ DisconnectRequest, GraphTask, GraphTaskCancel, + ObjectInstruction, ObjectRequest, Task, TaskCancel, TaskResult, WorkerHeartbeat, - ObjectInstruction, ) from scaler.utility.mixins import Reporter @@ -27,7 +27,7 @@ async def on_object_request(self, source: bytes, request: ObjectRequest): raise NotImplementedError() @abc.abstractmethod - def on_add_object(self, object_user: bytes, object_id: bytes, object_name: bytes, object_bytes: bytes): + def on_add_object(self, object_user: bytes, object_id: bytes, object_name: bytes, object_bytes: List[bytes]): raise NotImplementedError() @abc.abstractmethod @@ -47,7 +47,7 @@ def get_object_name(self, object_id: bytes) -> bytes: raise NotImplementedError() @abc.abstractmethod - def get_object_content(self, object_id: bytes) -> bytes: + def get_object_content(self, object_id: bytes) -> List[bytes]: raise NotImplementedError() diff --git a/scaler/scheduler/object_manager.py b/scaler/scheduler/object_manager.py index 4953e22..0907e28 100644 --- a/scaler/scheduler/object_manager.py +++ b/scaler/scheduler/object_manager.py @@ -1,7 +1,7 @@ import dataclasses import logging from asyncio import Queue -from typing import Optional, Set +from typing import List, Optional, Set from scaler.io.async_binder import AsyncBinder from scaler.io.async_connector import AsyncConnector @@ -19,7 +19,7 @@ class _ObjectCreation(ObjectUsage): object_id: bytes object_creator: bytes object_name: bytes - object_bytes: bytes + object_bytes: List[bytes] def get_object_key(self) -> bytes: return self.object_id @@ -71,7 +71,7 @@ async def on_object_request(self, source: bytes, request: ObjectRequest): logging.error(f"received unknown object request type {request=} from {source=!r}") - def on_add_object(self, object_user: bytes, object_id: bytes, object_name: bytes, object_bytes: bytes): + def on_add_object(self, object_user: bytes, object_id: bytes, object_name: bytes, object_bytes: List[bytes]): creation = _ObjectCreation(object_id, object_user, object_name, object_bytes) logging.debug( f"add object cache " @@ -102,15 +102,16 @@ def get_object_name(self, object_id: bytes) -> bytes: return self._object_storage.get_object(object_id).object_name - def get_object_content(self, object_id: bytes) -> bytes: + def get_object_content(self, object_id: bytes) -> List[bytes]: if not self.has_object(object_id): - return b"" + return list() return self._object_storage.get_object(object_id).object_bytes def get_status(self) -> ObjectManagerStatus: return ObjectManagerStatus.new_msg( - self._object_storage.object_count(), sum(len(v.object_bytes) for _, v in self._object_storage.items()) + self._object_storage.object_count(), + sum(sum(map(len, v.object_bytes)) for _, v in self._object_storage.items()), ) async def __process_get_request(self, source: bytes, request: ObjectRequest): @@ -139,12 +140,12 @@ def __on_object_create(self, source: bytes, instruction: ObjectInstruction): logging.error(f"received object creation from {source!r} for unknown client {instruction.object_user!r}") return - for object_id, object_name, object_content in zip( + for object_id, object_name, object_bytes in zip( instruction.object_content.object_ids, instruction.object_content.object_names, instruction.object_content.object_bytes, ): - self.on_add_object(instruction.object_user, object_id, object_name, object_content) + self.on_add_object(instruction.object_user, object_id, object_name, object_bytes) def __finished_object_storage(self, creation: _ObjectCreation): logging.debug( diff --git a/scaler/utility/graph/topological_sorter_graphblas.py b/scaler/utility/graph/topological_sorter_graphblas.py index 717608f..60e39a8 100644 --- a/scaler/utility/graph/topological_sorter_graphblas.py +++ b/scaler/utility/graph/topological_sorter_graphblas.py @@ -1,7 +1,7 @@ import collections import graphlib import itertools -from typing import Hashable, Iterable, List, Optional, Tuple, TypeVar, Generic, Mapping +from typing import Generic, Hashable, Iterable, List, Mapping, Optional, Tuple, TypeVar from bidict import bidict diff --git a/scaler/utility/queues/async_priority_queue.py b/scaler/utility/queues/async_priority_queue.py index f02bf57..6f58ed4 100644 --- a/scaler/utility/queues/async_priority_queue.py +++ b/scaler/utility/queues/async_priority_queue.py @@ -2,7 +2,6 @@ from asyncio import Queue from typing import Dict, List, Tuple, Union - PriorityType = Union[int, Tuple["PriorityType", ...]] diff --git a/scaler/worker/agent/heartbeat_manager.py b/scaler/worker/agent/heartbeat_manager.py index 04823f3..98b51be 100644 --- a/scaler/worker/agent/heartbeat_manager.py +++ b/scaler/worker/agent/heartbeat_manager.py @@ -4,7 +4,7 @@ import psutil from scaler.io.async_connector import AsyncConnector -from scaler.protocol.python.message import WorkerHeartbeat, WorkerHeartbeatEcho, Resource +from scaler.protocol.python.message import Resource, WorkerHeartbeat, WorkerHeartbeatEcho from scaler.protocol.python.status import ProcessorStatus from scaler.utility.mixins import Looper from scaler.worker.agent.mixins import HeartbeatManager, ProcessorManager, TaskManager, TimeoutManager diff --git a/scaler/worker/agent/processor/object_cache.py b/scaler/worker/agent/processor/object_cache.py index ee8e42e..4cbbb8a 100644 --- a/scaler/worker/agent/processor/object_cache.py +++ b/scaler/worker/agent/processor/object_cache.py @@ -5,13 +5,14 @@ import platform import threading import time -from typing import Any, Dict, Optional +from typing import Any, Dict, List, Optional import cloudpickle import psutil from scaler.client.serializer.mixins import Serializer from scaler.io.config import CLEANUP_INTERVAL_SECONDS +from scaler.io.utility import concat_list_of_bytes from scaler.protocol.python.common import ObjectContent from scaler.protocol.python.message import Task from scaler.utility.exceptions import DeserializeObjectError @@ -51,8 +52,8 @@ def add_serializer(self, client: bytes, serializer: Serializer): def serialize(self, client: bytes, obj: Any) -> bytes: return self.get_serializer(client).serialize(obj) - def deserialize(self, client: bytes, payload: bytes) -> Any: - return self.get_serializer(client).deserialize(payload) + def deserialize(self, client: bytes, payload: List[bytes]) -> Any: + return self.get_serializer(client).deserialize(concat_list_of_bytes(payload)) def add_objects(self, object_content: ObjectContent, task: Task): zipped = list(zip(object_content.object_ids, object_content.object_names, object_content.object_bytes)) @@ -60,7 +61,7 @@ def add_objects(self, object_content: ObjectContent, task: Task): others = filter(lambda o: not is_object_id_serializer(o[0]), zipped) for object_id, object_name, object_bytes in serializers: - self.add_serializer(object_id, cloudpickle.loads(object_bytes)) + self.add_serializer(object_id, cloudpickle.loads(concat_list_of_bytes(object_bytes))) for object_id, object_name, object_bytes in others: try: diff --git a/scaler/worker/agent/processor/processor.py b/scaler/worker/agent/processor/processor.py index 95aad7a..b331bee 100644 --- a/scaler/worker/agent/processor/processor.py +++ b/scaler/worker/agent/processor/processor.py @@ -13,7 +13,8 @@ from scaler.io.config import DUMMY_CLIENT from scaler.io.sync_connector import SyncConnector -from scaler.protocol.python.common import TaskStatus, ObjectContent +from scaler.io.utility import chunk_to_list_of_bytes +from scaler.protocol.python.common import ObjectContent, TaskStatus from scaler.protocol.python.message import ( ObjectInstruction, ObjectRequest, @@ -263,7 +264,9 @@ def __send_result(self, source: bytes, task_id: bytes, status: TaskStatus, resul ObjectInstruction.ObjectInstructionType.Create, source, ObjectContent.new_msg( - (result_object_id,), (f"".encode(),), (result_bytes,) + (result_object_id,), + (f"".encode(),), + (chunk_to_list_of_bytes(result_bytes),), ), ) ) diff --git a/scaler/worker/agent/processor_holder.py b/scaler/worker/agent/processor_holder.py index 97a829d..48ea1c2 100644 --- a/scaler/worker/agent/processor_holder.py +++ b/scaler/worker/agent/processor_holder.py @@ -10,7 +10,7 @@ from scaler.io.config import DEFAULT_PROCESSOR_KILL_DELAY_SECONDS from scaler.protocol.python.message import Task from scaler.utility.zmq_config import ZMQConfig -from scaler.worker.agent.processor.processor import Processor, SUSPEND_SIGNAL +from scaler.worker.agent.processor.processor import SUSPEND_SIGNAL, Processor class ProcessorHolder: diff --git a/scaler/worker/agent/processor_manager.py b/scaler/worker/agent/processor_manager.py index 3d467aa..10e9fc8 100644 --- a/scaler/worker/agent/processor_manager.py +++ b/scaler/worker/agent/processor_manager.py @@ -11,7 +11,8 @@ # from scaler.utility.logging.utility import setup_logger from scaler.io.async_binder import AsyncBinder from scaler.io.async_connector import AsyncConnector -from scaler.protocol.python.common import TaskStatus, ObjectContent +from scaler.io.utility import chunk_to_list_of_bytes +from scaler.protocol.python.common import ObjectContent, TaskStatus from scaler.protocol.python.message import ( ObjectInstruction, ObjectRequest, @@ -26,7 +27,7 @@ from scaler.utility.mixins import Looper from scaler.utility.object_utility import generate_object_id, serialize_failure from scaler.utility.zmq_config import ZMQConfig, ZMQType -from scaler.worker.agent.mixins import HeartbeatManager, ProcessorManager, ProfilingManager, TaskManager, ObjectTracker +from scaler.worker.agent.mixins import HeartbeatManager, ObjectTracker, ProcessorManager, ProfilingManager, TaskManager from scaler.worker.agent.processor_holder import ProcessorHolder @@ -148,7 +149,7 @@ async def on_failing_task(self, process_status: str): profile_result = self.__end_task(self._current_holder) - result_object_bytes = serialize_failure(ProcessorDiedError(f"{process_status=}")) + result_object_bytes = chunk_to_list_of_bytes(serialize_failure(ProcessorDiedError(f"{process_status=}"))) result_object_id = generate_object_id(source, uuid.uuid4().bytes) await self._connector_external.send( diff --git a/tests/test_async_sorted_priority_queue.py b/tests/test_async_sorted_priority_queue.py index e1cb4d0..d124f21 100644 --- a/tests/test_async_sorted_priority_queue.py +++ b/tests/test_async_sorted_priority_queue.py @@ -2,7 +2,8 @@ import unittest from scaler.utility.logging.utility import setup_logger -from scaler.utility.queues.async_sorted_priority_queue import AsyncSortedPriorityQueue +from scaler.utility.queues.async_sorted_priority_queue import \ + AsyncSortedPriorityQueue from tests.utility import logging_test_name diff --git a/tests/test_client.py b/tests/test_client.py index fadd817..02e75ac 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -115,6 +115,18 @@ def test_heavy_function(self): expected = [task * size for task in tasks] self.assertEqual(results, expected) + def test_very_large_payload(self): + def func(data: bytes): + return data + + with Client(self.address) as client: + payload = os.urandom(2**29 + 300) # 512MB + 300B + future = client.submit(func, payload) + + result = future.result() + + self.assertTrue(payload == result) + def test_sleep(self): with Client(self.address) as client: time.sleep(5) diff --git a/tests/test_death_timeout.py b/tests/test_death_timeout.py index 58273d9..698cb4d 100644 --- a/tests/test_death_timeout.py +++ b/tests/test_death_timeout.py @@ -3,18 +3,14 @@ import unittest from scaler import Client, Cluster, SchedulerClusterCombo -from scaler.io.config import ( - DEFAULT_GARBAGE_COLLECT_INTERVAL_SECONDS, - DEFAULT_HEARTBEAT_INTERVAL_SECONDS, - DEFAULT_IO_THREADS, - DEFAULT_TASK_TIMEOUT_SECONDS, - DEFAULT_TRIM_MEMORY_THRESHOLD_BYTES, -) +from scaler.io.config import (DEFAULT_GARBAGE_COLLECT_INTERVAL_SECONDS, + DEFAULT_HEARTBEAT_INTERVAL_SECONDS, + DEFAULT_IO_THREADS, DEFAULT_TASK_TIMEOUT_SECONDS, + DEFAULT_TRIM_MEMORY_THRESHOLD_BYTES) from scaler.utility.logging.utility import setup_logger from scaler.utility.zmq_config import ZMQConfig from tests.utility import get_available_tcp_port, logging_test_name - # This is a manual test because it can loop infinitely if it fails diff --git a/tests/test_object_usage.py b/tests/test_object_usage.py index a062e9d..94a84d9 100644 --- a/tests/test_object_usage.py +++ b/tests/test_object_usage.py @@ -1,7 +1,8 @@ import dataclasses import unittest -from scaler.scheduler.object_usage.object_tracker import ObjectTracker, ObjectUsage +from scaler.scheduler.object_usage.object_tracker import (ObjectTracker, + ObjectUsage) from scaler.utility.logging.utility import setup_logger from tests.utility import logging_test_name diff --git a/tests/test_worker_object_tracker.py b/tests/test_worker_object_tracker.py index ef18eed..34bb2cc 100644 --- a/tests/test_worker_object_tracker.py +++ b/tests/test_worker_object_tracker.py @@ -1,7 +1,8 @@ import unittest from scaler.protocol.python.common import ObjectContent -from scaler.protocol.python.message import ObjectInstruction, ObjectRequest, ObjectResponse +from scaler.protocol.python.message import (ObjectInstruction, ObjectRequest, + ObjectResponse) from scaler.utility.logging.utility import setup_logger from scaler.worker.agent.object_tracker import VanillaObjectTracker from tests.utility import logging_test_name @@ -56,7 +57,7 @@ def test_object_tracker(self) -> None: ObjectContent.new_msg( (b"object_1", b"object_2", b"object_3"), (b"name_1", b"name_2", b"name_3"), - (b"content_1", b"content_2", b"content_3"), + ([b"content_1"], [b"content_2"], [b"content_3"]), ), ) )