Skip to content

Feature: Join PoC #833

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 14 commits into from
58 changes: 35 additions & 23 deletions quixstreams/dataframe/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,8 @@
from quixstreams.models.serializers import DeserializerType, SerializerType
from quixstreams.sinks import BaseSink
from quixstreams.state.base import State
from quixstreams.state.base.transaction import PartitionTransaction
from quixstreams.state.rocksdb.timestamped import TimestampedStore
from quixstreams.utils.printing import (
DEFAULT_COLUMN_NAME,
DEFAULT_LIVE,
Expand Down Expand Up @@ -278,11 +280,7 @@ def func(d: dict, state: State):
cast(ApplyCallbackStateful, func)
)

stateful_func = _as_stateful(
func=with_metadata_func,
processing_context=self._processing_context,
stream_id=self.stream_id,
)
stateful_func = _as_stateful(with_metadata_func, self)
stream = self.stream.add_apply(stateful_func, expand=expand, metadata=True) # type: ignore[call-overload]
else:
stream = self.stream.add_apply(
Expand Down Expand Up @@ -387,11 +385,7 @@ def func(values: list, state: State):
cast(UpdateCallbackStateful, func)
)

stateful_func = _as_stateful(
func=with_metadata_func,
processing_context=self._processing_context,
stream_id=self.stream_id,
)
stateful_func = _as_stateful(with_metadata_func, self)
return self._add_update(stateful_func, metadata=True)
else:
return self._add_update(
Expand Down Expand Up @@ -489,11 +483,7 @@ def func(d: dict, state: State):
cast(FilterCallbackStateful, func)
)

stateful_func = _as_stateful(
func=with_metadata_func,
processing_context=self._processing_context,
stream_id=self.stream_id,
)
stateful_func = _as_stateful(with_metadata_func, self)
stream = self.stream.add_filter(stateful_func, metadata=True)
else:
stream = self.stream.add_filter( # type: ignore[call-overload]
Expand Down Expand Up @@ -1615,6 +1605,27 @@ def concat(self, other: "StreamingDataFrame") -> "StreamingDataFrame":
*self.topics, *other.topics, stream=merged_stream
)

def join(self, right: "StreamingDataFrame") -> "StreamingDataFrame":
# TODO: ensure copartitioning of left and right?
right.processing_context.state_manager.register_store(
stream_id=right.stream_id,
store_type=TimestampedStore,
changelog_config=self._topic_manager.derive_topic_config(right.topics),
)

def left_func(value, key, timestamp, headers):
right_tx = _get_transaction(right)
right_value = right_tx.get_last(timestamp=timestamp, prefix=key)
return {**value, **(right_value or {})}

def right_func(value, key, timestamp, headers):
right_tx = _get_transaction(right)
right_tx.set(timestamp=timestamp, value=value, prefix=key)

left = self.apply(left_func, metadata=True)
right = right.update(right_func, metadata=True).filter(lambda value: False)
return left.concat(right)

def ensure_topics_copartitioned(self):
partitions_counts = set(t.broker_config.num_partitions for t in self._topics)
if len(partitions_counts) > 1:
Expand Down Expand Up @@ -1804,19 +1815,20 @@ def wrapper(

def _as_stateful(
func: Callable[[Any, Any, int, Any, State], T],
processing_context: ProcessingContext,
stream_id: str,
sdf: StreamingDataFrame,
) -> Callable[[Any, Any, int, Any], T]:
@functools.wraps(func)
def wrapper(value: Any, key: Any, timestamp: int, headers: Any) -> Any:
ctx = message_context()
transaction = processing_context.checkpoint.get_store_transaction(
stream_id=stream_id,
partition=ctx.partition,
)
# Pass a State object with an interface limited to the key updates only
# and prefix all the state keys by the message key
state = transaction.as_state(prefix=key)
state = _get_transaction(sdf).as_state(prefix=key)
return func(value, key, timestamp, headers, state)

return wrapper


def _get_transaction(sdf: StreamingDataFrame) -> PartitionTransaction:
return sdf.processing_context.checkpoint.get_store_transaction(
stream_id=sdf.stream_id,
partition=message_context().partition,
)
4 changes: 3 additions & 1 deletion quixstreams/state/base/partition.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@ class StorePartition(ABC):
the persistent storage).
"""

partition_transaction_class = PartitionTransaction

def __init__(
self,
dumps: DumpsFunc,
Expand Down Expand Up @@ -112,7 +114,7 @@ def begin(self) -> PartitionTransaction:

Using `PartitionTransaction` is a recommended way for accessing the data.
"""
return PartitionTransaction(
return self.partition_transaction_class(
partition=self,
dumps=self._dumps,
loads=self._loads,
Expand Down
17 changes: 17 additions & 0 deletions quixstreams/state/base/transaction.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,23 @@ def get(
# UNDEFINED to signify that
return self._updated[cf_name][prefix].get(key, Marker.UNDEFINED)

def iter_items(
self, prefix: bytes, backwards: bool = False, cf_name: str = "default"
) -> list[tuple[bytes, bytes]]:
"""
Iterate over sorted, non-deleted items in the cache
for the given prefix and column family.
"""
deleted = self._deleted[cf_name]
return sorted(
(
(key, value)
for key, value in self._updated[cf_name][prefix].items()
if key not in deleted
),
reverse=backwards,
)

def set(self, key: bytes, value: bytes, prefix: bytes, cf_name: str = "default"):
"""
Set a value for the key.
Expand Down
17 changes: 13 additions & 4 deletions quixstreams/state/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from .memory import MemoryStore
from .recovery import ChangelogProducerFactory, RecoveryManager
from .rocksdb import RocksDBOptionsType, RocksDBStore
from .rocksdb.timestamped import TimestampedStore
from .rocksdb.windowed.store import WindowedRocksDBStore

__all__ = ("StateStoreManager", "DEFAULT_STATE_STORE_NAME", "StoreTypes")
Expand All @@ -24,7 +25,7 @@

DEFAULT_STATE_STORE_NAME = "default"

StoreTypes = Union[Type[RocksDBStore], Type[MemoryStore]]
StoreTypes = Union[Type[RocksDBStore], Type[MemoryStore], Type[TimestampedStore]]
SUPPORTED_STORES = [RocksDBStore, MemoryStore]


Expand Down Expand Up @@ -189,23 +190,31 @@ def register_store(

store_type = store_type or self.default_store_type
if store_type == RocksDBStore:
factory: Store = RocksDBStore(
store: Store = RocksDBStore(
name=store_name,
stream_id=stream_id,
base_dir=str(self._state_dir),
changelog_producer_factory=changelog_producer_factory,
options=self._rocksdb_options,
)
elif store_type == TimestampedStore:
store = TimestampedStore(
name=store_name,
stream_id=stream_id,
base_dir=str(self._state_dir),
changelog_producer_factory=changelog_producer_factory,
options=self._rocksdb_options,
)
elif store_type == MemoryStore:
factory = MemoryStore(
store = MemoryStore(
name=store_name,
stream_id=stream_id,
changelog_producer_factory=changelog_producer_factory,
)
else:
raise ValueError(f"invalid store type: {store_type}")

self._stores.setdefault(stream_id, {})[store_name] = factory
self._stores.setdefault(stream_id, {})[store_name] = store

def register_windowed_store(
self,
Expand Down
26 changes: 24 additions & 2 deletions quixstreams/state/rocksdb/partition.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
import logging
import time
from typing import Dict, List, Literal, Optional, Union, cast
from typing import Dict, Iterator, List, Literal, Optional, Union, cast

from rocksdict import AccessType, ColumnFamily, Rdict, WriteBatch
from rocksdict import AccessType, ColumnFamily, Rdict, ReadOptions, WriteBatch

from quixstreams.state.base import PartitionTransactionCache, StorePartition
from quixstreams.state.exceptions import ColumnFamilyDoesNotExist
Expand Down Expand Up @@ -139,6 +139,28 @@ def get(
# RDict accept Any type as value but we only write bytes so we should only get bytes back.
return cast(Union[bytes, Literal[Marker.UNDEFINED]], result)

def iter_items(
self,
lower_bound: bytes,
upper_bound: bytes,
backwards: bool = False,
cf_name: str = "default",
) -> Iterator[tuple[bytes, bytes]]:
cf = self.get_column_family(cf_name=cf_name)

# Set iterator bounds to reduce IO by limiting the range of keys fetched
read_opt = ReadOptions()
read_opt.set_iterate_lower_bound(lower_bound)
read_opt.set_iterate_upper_bound(upper_bound)

from_key = upper_bound if backwards else lower_bound

# RDict accept Any type as value but we only write bytes so we should only get bytes back.
return cast(
Iterator[tuple[bytes, bytes]],
cf.items(from_key=from_key, read_opt=read_opt, backwards=backwards),
)

def exists(self, key: bytes, cf_name: str = "default") -> bool:
"""
Check if a key is present in the DB.
Expand Down
4 changes: 3 additions & 1 deletion quixstreams/state/rocksdb/store.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@ class RocksDBStore(Store):
partitions' transactions.
"""

store_partition_class = RocksDBStorePartition

def __init__(
self,
name: str,
Expand Down Expand Up @@ -61,6 +63,6 @@ def create_new_partition(
self._changelog_producer_factory.get_partition_producer(partition)
)

return RocksDBStorePartition(
return self.store_partition_class(
path=path, options=self._options, changelog_producer=changelog_producer
)
127 changes: 127 additions & 0 deletions quixstreams/state/rocksdb/timestamped.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,127 @@
from typing import Any, Optional

from quixstreams.state.base.transaction import (
PartitionTransaction,
PartitionTransactionStatus,
validate_transaction_status,
)
from quixstreams.state.serialization import serialize

from .partition import RocksDBStorePartition
from .store import RocksDBStore

__all__ = (
"TimestampedStore",
"TimestampedStorePartition",
"TimestampedPartitionTransaction",
)


class TimestampedPartitionTransaction(PartitionTransaction):
"""
A partition-specific transaction handler for the `TimestampedStore`.

Provides timestamp-aware methods for querying key-value pairs
based on a timestamp, alongside standard transaction operations.
It interacts with both an in-memory update cache and the persistent RocksDB store.
"""

# Override the type hint from the parent class (`PartitionTransaction`).
# This informs type checkers like mypy that in this specific subclass,
# `_partition` is a `TimestampedStorePartition` (defined below),
# which has methods like `.iter_items()` that the base type might lack.
# The string quotes are necessary for the forward reference.
_partition: "TimestampedStorePartition"

@validate_transaction_status(PartitionTransactionStatus.STARTED)
def get_last(
self,
timestamp: int,
prefix: Any,
cf_name: str = "default",
) -> Optional[Any]:
"""Get the latest value for a prefix up to a given timestamp.

Searches both the transaction's update cache and the underlying RocksDB store
to find the value associated with the given `prefix` that has the highest
timestamp less than or equal to the provided `timestamp`.

The search prioritizes values from the update cache if their timestamps are
more recent than those found in the store.

:param timestamp: The upper bound timestamp (inclusive) in milliseconds.
:param prefix: The key prefix to search for.
:param cf_name: The column family name.
:return: The deserialized value if found, otherwise None.
"""
if not isinstance(prefix, bytes):
prefix = serialize(prefix, dumps=self._dumps)

# Add +1 because the storage `.iter_items()` is exclusive on the upper bound
key = self._serialize_key(timestamp + 1, prefix=prefix)
value: Optional[bytes] = None

cached = self._update_cache.iter_items(
prefix=prefix,
backwards=True,
cf_name=cf_name,
)
for cached_key, cached_value in cached:
if prefix < cached_key < key:
value = cached_value
break

stored = self._partition.iter_items(
lower_bound=prefix,
upper_bound=key,
backwards=True,
cf_name=cf_name,
)
for stored_key, stored_value in stored:
if value is None or cached_key < stored_key:
value = stored_value
# We only care about the first item found when iterating backwards
# from the upper bound, hence the break.
break

return self._deserialize_value(value) if value is not None else None

@validate_transaction_status(PartitionTransactionStatus.STARTED)
def set(self, timestamp: int, value: Any, prefix: Any, cf_name: str = "default"):
"""Set a value associated with a prefix and timestamp.

This method acts as a proxy, passing the provided `timestamp` and `prefix`
to the parent `set` method. The parent method internally serializes these
into a combined key before storing the value in the update cache.

:param timestamp: Timestamp associated with the value in milliseconds.
:param value: The value to store.
:param prefix: The key prefix.
:param cf_name: Column family name.
"""
if not isinstance(prefix, bytes):
prefix = serialize(prefix, dumps=self._dumps)

super().set(timestamp, value, prefix, cf_name)


class TimestampedStorePartition(RocksDBStorePartition):
"""
Represents a single partition within a `TimestampedStore`.

This class is responsible for managing the state of one partition and creating
`TimestampedPartitionTransaction` instances to handle atomic operations for that partition.
"""

partition_transaction_class = TimestampedPartitionTransaction


class TimestampedStore(RocksDBStore):
"""
A RocksDB-backed state store implementation that manages key-value pairs
associated with timestamps.

Uses `TimestampedStorePartition` to manage individual partitions.
"""

store_partition_class = TimestampedStorePartition
Loading