Skip to content

Commit

Permalink
rewrite
Browse files Browse the repository at this point in the history
  • Loading branch information
lynnagara committed Dec 18, 2024
1 parent 7cfbeb5 commit 1b733e5
Show file tree
Hide file tree
Showing 2 changed files with 52 additions and 38 deletions.
61 changes: 39 additions & 22 deletions src/sentry/utils/arroyo_router.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@
import time
from collections import deque
from collections.abc import Callable, Mapping, MutableMapping, Sequence
from typing import Deque, Generic, TypeVar
from dataclasses import dataclass
from typing import Any, Deque, Generic, TypeVar

from arroyo.processing.strategies import MessageRejected
from arroyo.processing.strategies.abstract import ProcessingStrategy
Expand Down Expand Up @@ -96,6 +97,12 @@ def terminate(self) -> None:
self.next_step.terminate()


@dataclass
class PartitionOffsets:
last_started: int
last_finished: int | None


class MessageBuffer(Generic[TResult]):
"""
Keeps track of the in-flight offsets for all routes, and buffers messages for
Expand All @@ -104,41 +111,51 @@ class MessageBuffer(Generic[TResult]):
"""

def __init__(self, routes: Sequence[str]) -> None:
# Maintains the last committable offsets for each route and partition
self.committable_offsets: Mapping[str, MutableMapping[Partition, int]] = {
r: {} for r in routes
}

# the low watermarks?
self.committable_offsets: Mapping[str, MutableMapping[Partition, int]] = {
# Keeps track of the in-flight offset ranges in each route
self.committable_offsets: Mapping[str, MutableMapping[Partition, PartitionOffsets]] = {
r: {} for r in routes
}

# Keeps track of all completed messages together with the route on which it was sent
self.messages: Deque[tuple[str, Message[TResult]]] = deque()

def add(self, message: Message[TResult], routing_key: str) -> None:
self.messages.append((routing_key, message))

def start(self, message: Message[Any], routing_key: str) -> None:
for partition, committable_offset in message.committable.items():
if self.committable_offsets[routing_key].get(partition):
if committable_offset > self.committable_offsets[routing_key][partition]:
self.committable_offsets[routing_key][partition] = committable_offset
self.committable_offsets[routing_key][partition].last_started = committable_offset
else:
self.committable_offsets[routing_key][partition] = committable_offset
self.committable_offsets[routing_key][partition] = PartitionOffsets(
committable_offset, None
)

def submit(self, message: Message[TResult], routing_key: str) -> None:
self.messages.append((routing_key, message))

for partition, committable_offset in message.committable.items():
self.committable_offsets[routing_key][partition].last_finished = committable_offset

def poll(self) -> Message[TResult] | None:
if not self.messages:
return None

(route, message) = self.messages[0]

# Ensure the message isn't returned if it's not completed yet
for partition, committable_offset in message.committable.items():
partition_offset = self.committable_offsets[route].get(partition)

if partition_offset is None or committable_offset > partition_offset:
return None
(message_route, message) = self.messages[0]

# Ensure the message isn't returned if its offsets are above those finished on all routes
for route, route_offsets in self.committable_offsets.items():
if message_route == route:
continue

for partition, committable_offset in message.committable.items():
if route_offsets.get(partition):
partition_offsets = route_offsets[partition]
if partition_offsets.last_finished is None:
return None
else:
offset_range = range(
partition_offsets.last_started, partition_offsets.last_finished
)
if committable_offset in offset_range:
return None

self.messages.popleft()
return message
Expand Down
29 changes: 13 additions & 16 deletions tests/sentry/utils/test_arroyo_router.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,35 +16,32 @@ def test_message_buffer() -> None:
buffer: MessageBuffer[int] = MessageBuffer(["route_a", "route_b"])

# Add two messages to each route
buffer.add(messages[0], "route_a")
buffer.add(messages[1], "route_b")
buffer.add(messages[2], "route_a")
buffer.add(messages[3], "route_b")
buffer.start(messages[0], "route_a")
buffer.start(messages[1], "route_b")
buffer.start(messages[2], "route_a")
buffer.start(messages[3], "route_b")

# All messages are in-flight, poll returns nothing
assert buffer.poll() is None
assert len(buffer) == 4
assert len(buffer) == 0

# The first message was completed, now it can be polled
buffer.remove(messages[0], "route_a")
buffer.submit(messages[0], "route_a")
assert len(buffer) == 1
msg = buffer.poll()
assert msg is not None
assert isinstance(msg.value, BrokerValue)
assert len(buffer) == 0
assert msg.value.offset == 0
assert buffer.poll() is None
assert len(buffer) == 3

# Still waiting for route_b
buffer.remove(messages[2], "route_a")
assert buffer.poll() is None

# All done, now we can poll the last 3 messages
buffer.remove(messages[3], "route_b")
assert buffer.poll() is not None
# Second message is done
buffer.submit(messages[2], "route_a")
assert buffer.poll() is not None

# Third message was filtered, but the last message is done
buffer.submit(messages[3], "route_b")
assert buffer.poll() is not None
assert buffer.poll() is None
assert len(buffer) == 0


def test_router() -> None:
Expand Down

0 comments on commit 1b733e5

Please sign in to comment.