Skip to content

Commit

Permalink
enable dofn state support
Browse files Browse the repository at this point in the history
Fixes #12
  • Loading branch information
iasoon committed Jun 16, 2022
1 parent f1b8fde commit 615e351
Show file tree
Hide file tree
Showing 3 changed files with 55 additions and 18 deletions.
7 changes: 7 additions & 0 deletions ray_beam_runner/portability/execution_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,3 +57,10 @@ def test_data_stored_properly(self):
hc.assert_that(
all_data, hc.contains_exactly(*StateHandlerTest.SAMPLE_INPUT_DATA)
)

def test_fresh_key(self):
sh = RayStateManager()
with sh.process_instruction_id("anyinstruction"):
data, continuation_token = sh.get_raw(StateHandlerTest.SAMPLE_STATE_KEY)
hc.assert_that(continuation_token, hc.equal_to(None))
hc.assert_that(data, hc.equal_to(b""))
1 change: 0 additions & 1 deletion ray_beam_runner/portability/ray_runner_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -368,7 +368,6 @@ def cross_product(elem, sides):
equal_to([("a", "a"), ("a", "b"), ("b", "a"), ("b", "b")]),
)

@unittest.skip("State not yet supported")
def test_pardo_state_only(self):
index_state_spec = userstate.CombiningValueStateSpec("index", sum)
value_and_index_state_spec = userstate.ReadModifyWriteStateSpec(
Expand Down
65 changes: 48 additions & 17 deletions ray_beam_runner/portability/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,40 @@

import collections
import contextlib
from typing import Optional, Tuple, Iterator
from typing import Optional, Tuple, Iterator, TypeVar

import ray
from ray import ObjectRef
from apache_beam.portability.api import beam_fn_api_pb2
from apache_beam.runners.worker import sdk_worker

T = TypeVar("T")


class RayFuture(sdk_worker._Future[T]):
"""Wraps a ray ObjectRef in a beam sdk_worker._Future"""

def __init__(self, object_ref):
# type: (ObjectRef[T]) -> None
self._object_ref: ObjectRef[T] = object_ref

def wait(self, timeout=None):
# type: (Optional[float]) -> bool
try:
ray.get(self._object_ref, timeout=timeout)
#
return True
except ray.GetTimeoutError:
return False

def get(self, timeout=None):
# type: (Optional[float]) -> T
return ray.get(self._object_ref, timeout=timeout)

def set(self, _value):
# type: (T) -> sdk_worker._Future[T]
raise NotImplementedError()


@ray.remote
class _ActorStateManager:
Expand All @@ -42,14 +70,16 @@ def get_raw(
else:
continuation_token = 0

new_cont_token = continuation_token + 1
if len(self._data[(bundle_id, state_key)]) == new_cont_token:
return self._data[(bundle_id, state_key)][continuation_token], None
full_state = self._data[(bundle_id, state_key)]
if len(full_state) == continuation_token:
return b"", None

if continuation_token + 1 == len(full_state):
next_cont_token = None
else:
return (
self._data[(bundle_id, state_key)][continuation_token],
str(continuation_token + 1).encode("utf8"),
)
next_cont_token = str(continuation_token + 1).encode("utf8")

return full_state[continuation_token], next_cont_token

def append_raw(self, bundle_id: str, state_key: str, data: bytes):
self._data[(bundle_id, state_key)].append(data)
Expand Down Expand Up @@ -81,19 +111,20 @@ def get_raw(
)
)

def append_raw(
self, state_key: beam_fn_api_pb2.StateKey, data: bytes
) -> sdk_worker._Future:
def append_raw(self, state_key: beam_fn_api_pb2.StateKey, data: bytes) -> RayFuture:
assert self._instruction_id is not None
return self._state_actor.append_raw.remote(
self._instruction_id, RayStateManager._to_key(state_key), data
return RayFuture(
self._state_actor.append_raw.remote(
self._instruction_id, RayStateManager._to_key(state_key), data
)
)

def clear(self, state_key: beam_fn_api_pb2.StateKey) -> sdk_worker._Future:
# TODO(pabloem): Does the ray future work as a replacement of Beam _Future?
def clear(self, state_key: beam_fn_api_pb2.StateKey) -> RayFuture:
assert self._instruction_id is not None
return self._state_actor.clear.remote(
self._instruction_id, RayStateManager._to_key(state_key)
return RayFuture(
self._state_actor.clear.remote(
self._instruction_id, RayStateManager._to_key(state_key)
)
)

@contextlib.contextmanager
Expand Down

0 comments on commit 615e351

Please sign in to comment.