Skip to content

Commit

Permalink
Merge pull request #38 from sharpener6/main
Browse files Browse the repository at this point in the history
Fix serialized big data issue
  • Loading branch information
rafa-be authored Oct 24, 2024
2 parents ab333d0 + a41984d commit 4d7b85d
Show file tree
Hide file tree
Showing 29 changed files with 100 additions and 67 deletions.
2 changes: 1 addition & 1 deletion scaler/about.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = "1.8.10"
__version__ = "1.8.11"
7 changes: 4 additions & 3 deletions scaler/client/agent/future_manager.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
import logging
import threading
from concurrent.futures import InvalidStateError, Future
from concurrent.futures import Future, InvalidStateError
from typing import Dict, Tuple

from scaler.client.agent.mixins import FutureManager
from scaler.client.future import ScalerFuture
from scaler.client.serializer.mixins import Serializer
from scaler.io.utility import concat_list_of_bytes
from scaler.protocol.python.common import TaskStatus
from scaler.protocol.python.message import ObjectResponse, TaskResult
from scaler.utility.exceptions import DisconnectedError, NoWorkerError, TaskNotFoundError, WorkerDiedError
Expand Down Expand Up @@ -106,9 +107,9 @@ def on_object_response(self, response: ObjectResponse):

try:
if status == TaskStatus.Success:
future.set_result(self._serializer.deserialize(object_bytes))
future.set_result(self._serializer.deserialize(concat_list_of_bytes(object_bytes)))

elif status == TaskStatus.Failed:
future.set_exception(deserialize_failure(object_bytes))
future.set_exception(deserialize_failure(concat_list_of_bytes(object_bytes)))
except InvalidStateError:
continue # future got canceled
7 changes: 2 additions & 5 deletions scaler/client/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,10 +91,7 @@ def __initialize__(
self._stop_event = threading.Event()
self._context = zmq.Context()
self._connector = SyncConnector(
context=self._context,
socket_type=zmq.PAIR,
address=self._client_agent_address,
identity=self._identity,
context=self._context, socket_type=zmq.PAIR, address=self._client_agent_address, identity=self._identity
)

self._future_manager = ClientFutureManager(self._serializer)
Expand Down Expand Up @@ -309,7 +306,7 @@ def send_object(self, obj: Any, name: Optional[str] = None) -> ObjectReference:
self.__assert_client_not_stopped()

cache = self._object_buffer.buffer_send_object(obj, name)
return ObjectReference(cache.object_name, cache.object_id, len(cache.object_bytes))
return ObjectReference(cache.object_name, cache.object_id, sum(map(len, cache.object_bytes)))

def disconnect(self):
"""
Expand Down
11 changes: 7 additions & 4 deletions scaler/client/object_buffer.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

from scaler.client.serializer.mixins import Serializer
from scaler.io.sync_connector import SyncConnector
from scaler.io.utility import chunk_to_list_of_bytes
from scaler.protocol.python.common import ObjectContent
from scaler.protocol.python.message import ObjectInstruction
from scaler.utility.object_utility import generate_object_id, generate_serializer_object_id
Expand All @@ -15,7 +16,7 @@
class ObjectCache:
object_id: bytes
object_name: bytes
object_bytes: bytes
object_bytes: List[bytes]


class ObjectBuffer:
Expand Down Expand Up @@ -83,18 +84,20 @@ def commit_delete_objects(self):
def __construct_serializer(self) -> ObjectCache:
serializer_bytes = cloudpickle.dumps(self._serializer, protocol=pickle.HIGHEST_PROTOCOL)
object_id = generate_serializer_object_id(self._identity)
return ObjectCache(object_id, b"serializer", serializer_bytes)
return ObjectCache(object_id, b"serializer", chunk_to_list_of_bytes(serializer_bytes))

def __construct_function(self, fn: Callable) -> ObjectCache:
function_bytes = self._serializer.serialize(fn)
object_id = generate_object_id(self._identity, function_bytes)
function_cache = ObjectCache(
object_id, getattr(fn, "__name__", f"<func {object_id.hex()[:6]}>").encode(), function_bytes
object_id,
getattr(fn, "__name__", f"<func {object_id.hex()[:6]}>").encode(),
chunk_to_list_of_bytes(function_bytes),
)
return function_cache

def __construct_object(self, obj: Any, name: Optional[str] = None) -> ObjectCache:
object_payload = self._serializer.serialize(obj)
object_id = generate_object_id(self._identity, object_payload)
name_bytes = name.encode() if name else f"<obj {object_id.hex()[-6:]}>".encode()
return ObjectCache(object_id, name_bytes, object_payload)
return ObjectCache(object_id, name_bytes, chunk_to_list_of_bytes(object_payload))
2 changes: 1 addition & 1 deletion scaler/cluster/combo.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from scaler.io.config import (
DEFAULT_CLIENT_TIMEOUT_SECONDS,
DEFAULT_GARBAGE_COLLECT_INTERVAL_SECONDS,
DEFAULT_HARD_PROCESSOR_SUSPEND,
DEFAULT_HEARTBEAT_INTERVAL_SECONDS,
DEFAULT_IO_THREADS,
DEFAULT_LOAD_BALANCE_SECONDS,
Expand All @@ -18,7 +19,6 @@
DEFAULT_TRIM_MEMORY_THRESHOLD_BYTES,
DEFAULT_WORKER_DEATH_TIMEOUT,
DEFAULT_WORKER_TIMEOUT_SECONDS,
DEFAULT_HARD_PROCESSOR_SUSPEND,
)
from scaler.utility.zmq_config import ZMQConfig

Expand Down
2 changes: 1 addition & 1 deletion scaler/cluster/scheduler.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import asyncio
import multiprocessing
from asyncio import AbstractEventLoop, Task
from typing import Optional, Tuple, Any
from typing import Any, Optional, Tuple

from scaler.scheduler.config import SchedulerConfig
from scaler.scheduler.scheduler import Scheduler, scheduler_main
Expand Down
2 changes: 1 addition & 1 deletion scaler/entry_points/cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,13 @@
from scaler.cluster.cluster import Cluster
from scaler.io.config import (
DEFAULT_GARBAGE_COLLECT_INTERVAL_SECONDS,
DEFAULT_HARD_PROCESSOR_SUSPEND,
DEFAULT_HEARTBEAT_INTERVAL_SECONDS,
DEFAULT_IO_THREADS,
DEFAULT_NUMBER_OF_WORKER,
DEFAULT_TASK_TIMEOUT_SECONDS,
DEFAULT_TRIM_MEMORY_THRESHOLD_BYTES,
DEFAULT_WORKER_DEATH_TIMEOUT,
DEFAULT_HARD_PROCESSOR_SUSPEND,
)
from scaler.utility.event_loop import EventLoopType, register_event_loop
from scaler.utility.zmq_config import ZMQConfig
Expand Down
2 changes: 1 addition & 1 deletion scaler/entry_points/top.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import argparse
import curses
import functools
from typing import List, Literal, Dict, Union
from typing import Dict, List, Literal, Union

from scaler.io.sync_subscriber import SyncSubscriber
from scaler.protocol.python.message import StateScheduler
Expand Down
2 changes: 1 addition & 1 deletion scaler/io/async_binder.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import os
import uuid
from collections import defaultdict
from typing import Awaitable, Callable, List, Optional, Dict
from typing import Awaitable, Callable, Dict, List, Optional

import zmq.asyncio
from zmq import Frame
Expand Down
5 changes: 4 additions & 1 deletion scaler/io/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,11 @@
# number of seconds for profiling
PROFILING_INTERVAL_SECONDS = 1

# cap'n proto only allow Data/Text/Blob size to be as big as 500MB
CAPNP_DATA_SIZE_LIMIT = 2**29 - 1

# message size limitation, max can be 2**64
MESSAGE_SIZE_LIMIT = 2**64 - 1
CAPNP_MESSAGE_SIZE_LIMIT = 2**64 - 1

# ==========================
# SCHEDULER SPECIFIC OPTIONS
Expand Down
18 changes: 15 additions & 3 deletions scaler/io/utility.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
import logging
from typing import Optional
from typing import List, Optional

from scaler.io.config import MESSAGE_SIZE_LIMIT
from scaler.io.config import CAPNP_DATA_SIZE_LIMIT, CAPNP_MESSAGE_SIZE_LIMIT
from scaler.protocol.capnp._python import _message # noqa
from scaler.protocol.python.message import PROTOCOL
from scaler.protocol.python.mixins import Message


def deserialize(data: bytes) -> Optional[Message]:
with _message.Message.from_bytes(data, traversal_limit_in_words=MESSAGE_SIZE_LIMIT) as payload:
with _message.Message.from_bytes(data, traversal_limit_in_words=CAPNP_MESSAGE_SIZE_LIMIT) as payload:
if not hasattr(payload, payload.which()):
logging.error(f"unknown message type: {payload.which()}")
return None
Expand All @@ -20,3 +20,15 @@ def deserialize(data: bytes) -> Optional[Message]:
def serialize(message: Message) -> bytes:
payload = _message.Message(**{PROTOCOL.inverse[type(message)]: message.get_message()})
return payload.to_bytes()


def chunk_to_list_of_bytes(data: bytes) -> List[bytes]:
# TODO: change to list of memoryview when capnp can support memoryview
return [data[i : i + CAPNP_DATA_SIZE_LIMIT] for i in range(0, len(data), CAPNP_DATA_SIZE_LIMIT)]


def concat_list_of_bytes(data: List[bytes]) -> bytes:
one_object_bytes = bytearray()
for chunk in data:
one_object_bytes.extend(chunk)
return one_object_bytes
1 change: 1 addition & 0 deletions scaler/protocol/capnp/_python.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import capnp # noqa

import scaler.protocol.capnp.common_capnp as _common # noqa
import scaler.protocol.capnp.message_capnp as _message # noqa
import scaler.protocol.capnp.status_capnp as _status # noqa
2 changes: 1 addition & 1 deletion scaler/protocol/capnp/common.capnp
Original file line number Diff line number Diff line change
Expand Up @@ -18,5 +18,5 @@ enum TaskStatus {
struct ObjectContent {
objectIds @0 :List(Data);
objectNames @1 :List(Data);
objectBytes @2 :List(Data);
objectBytes @2 :List(List(Data));
}
8 changes: 4 additions & 4 deletions scaler/protocol/python/common.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import dataclasses
import enum
from typing import Tuple
from typing import List, Tuple

from scaler.protocol.capnp._python import _common # noqa
from scaler.protocol.python.mixins import Message
Expand All @@ -26,7 +26,7 @@ class TaskStatus(enum.Enum):
@dataclasses.dataclass
class ObjectContent(Message):
def __init__(self, msg):
self._msg = msg
super().__init__(msg)

@property
def object_ids(self) -> Tuple[bytes, ...]:
Expand All @@ -37,14 +37,14 @@ def object_names(self) -> Tuple[bytes, ...]:
return tuple(self._msg.objectNames)

@property
def object_bytes(self) -> Tuple[bytes, ...]:
def object_bytes(self) -> Tuple[List[bytes], ...]:
return tuple(self._msg.objectBytes)

@staticmethod
def new_msg(
object_ids: Tuple[bytes, ...],
object_names: Tuple[bytes, ...] = tuple(),
object_bytes: Tuple[bytes, ...] = tuple(),
object_bytes: Tuple[List[bytes], ...] = tuple(),
) -> "ObjectContent":
return ObjectContent(
_common.ObjectContent(
Expand Down
6 changes: 3 additions & 3 deletions scaler/protocol/python/message.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,19 @@
import dataclasses
import enum
import os
from typing import List, Set, Tuple, Optional, Type
from typing import List, Optional, Set, Tuple, Type

import bidict

from scaler.protocol.capnp._python import _message # noqa
from scaler.protocol.python.common import TaskStatus, ObjectContent
from scaler.protocol.python.common import ObjectContent, TaskStatus
from scaler.protocol.python.mixins import Message
from scaler.protocol.python.status import (
BinderStatus,
ClientManagerStatus,
ObjectManagerStatus,
Resource,
ProcessorStatus,
Resource,
TaskManagerStatus,
WorkerManagerStatus,
)
Expand Down
8 changes: 4 additions & 4 deletions scaler/scheduler/mixins.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,18 @@
import abc
from typing import Optional, Set
from typing import List, Optional, Set

from scaler.protocol.python.message import (
ClientDisconnect,
ClientHeartbeat,
DisconnectRequest,
GraphTask,
GraphTaskCancel,
ObjectInstruction,
ObjectRequest,
Task,
TaskCancel,
TaskResult,
WorkerHeartbeat,
ObjectInstruction,
)
from scaler.utility.mixins import Reporter

Expand All @@ -27,7 +27,7 @@ async def on_object_request(self, source: bytes, request: ObjectRequest):
raise NotImplementedError()

@abc.abstractmethod
def on_add_object(self, object_user: bytes, object_id: bytes, object_name: bytes, object_bytes: bytes):
def on_add_object(self, object_user: bytes, object_id: bytes, object_name: bytes, object_bytes: List[bytes]):
raise NotImplementedError()

@abc.abstractmethod
Expand All @@ -47,7 +47,7 @@ def get_object_name(self, object_id: bytes) -> bytes:
raise NotImplementedError()

@abc.abstractmethod
def get_object_content(self, object_id: bytes) -> bytes:
def get_object_content(self, object_id: bytes) -> List[bytes]:
raise NotImplementedError()


Expand Down
17 changes: 9 additions & 8 deletions scaler/scheduler/object_manager.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import dataclasses
import logging
from asyncio import Queue
from typing import Optional, Set
from typing import List, Optional, Set

from scaler.io.async_binder import AsyncBinder
from scaler.io.async_connector import AsyncConnector
Expand All @@ -19,7 +19,7 @@ class _ObjectCreation(ObjectUsage):
object_id: bytes
object_creator: bytes
object_name: bytes
object_bytes: bytes
object_bytes: List[bytes]

def get_object_key(self) -> bytes:
return self.object_id
Expand Down Expand Up @@ -71,7 +71,7 @@ async def on_object_request(self, source: bytes, request: ObjectRequest):

logging.error(f"received unknown object request type {request=} from {source=!r}")

def on_add_object(self, object_user: bytes, object_id: bytes, object_name: bytes, object_bytes: bytes):
def on_add_object(self, object_user: bytes, object_id: bytes, object_name: bytes, object_bytes: List[bytes]):
creation = _ObjectCreation(object_id, object_user, object_name, object_bytes)
logging.debug(
f"add object cache "
Expand Down Expand Up @@ -102,15 +102,16 @@ def get_object_name(self, object_id: bytes) -> bytes:

return self._object_storage.get_object(object_id).object_name

def get_object_content(self, object_id: bytes) -> bytes:
def get_object_content(self, object_id: bytes) -> List[bytes]:
if not self.has_object(object_id):
return b""
return list()

return self._object_storage.get_object(object_id).object_bytes

def get_status(self) -> ObjectManagerStatus:
return ObjectManagerStatus.new_msg(
self._object_storage.object_count(), sum(len(v.object_bytes) for _, v in self._object_storage.items())
self._object_storage.object_count(),
sum(sum(map(len, v.object_bytes)) for _, v in self._object_storage.items()),
)

async def __process_get_request(self, source: bytes, request: ObjectRequest):
Expand Down Expand Up @@ -139,12 +140,12 @@ def __on_object_create(self, source: bytes, instruction: ObjectInstruction):
logging.error(f"received object creation from {source!r} for unknown client {instruction.object_user!r}")
return

for object_id, object_name, object_content in zip(
for object_id, object_name, object_bytes in zip(
instruction.object_content.object_ids,
instruction.object_content.object_names,
instruction.object_content.object_bytes,
):
self.on_add_object(instruction.object_user, object_id, object_name, object_content)
self.on_add_object(instruction.object_user, object_id, object_name, object_bytes)

def __finished_object_storage(self, creation: _ObjectCreation):
logging.debug(
Expand Down
2 changes: 1 addition & 1 deletion scaler/utility/graph/topological_sorter_graphblas.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import collections
import graphlib
import itertools
from typing import Hashable, Iterable, List, Optional, Tuple, TypeVar, Generic, Mapping
from typing import Generic, Hashable, Iterable, List, Mapping, Optional, Tuple, TypeVar

from bidict import bidict

Expand Down
1 change: 0 additions & 1 deletion scaler/utility/queues/async_priority_queue.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
from asyncio import Queue
from typing import Dict, List, Tuple, Union


PriorityType = Union[int, Tuple["PriorityType", ...]]


Expand Down
2 changes: 1 addition & 1 deletion scaler/worker/agent/heartbeat_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import psutil

from scaler.io.async_connector import AsyncConnector
from scaler.protocol.python.message import WorkerHeartbeat, WorkerHeartbeatEcho, Resource
from scaler.protocol.python.message import Resource, WorkerHeartbeat, WorkerHeartbeatEcho
from scaler.protocol.python.status import ProcessorStatus
from scaler.utility.mixins import Looper
from scaler.worker.agent.mixins import HeartbeatManager, ProcessorManager, TaskManager, TimeoutManager
Expand Down
Loading

0 comments on commit 4d7b85d

Please sign in to comment.