From a0578342052bcf77c3b2bd1487c10effbb593490 Mon Sep 17 00:00:00 2001 From: Daniel M Date: Fri, 19 Jan 2024 13:03:23 -0500 Subject: [PATCH] fix:xread blocking (#275) fix #274 --- fakeredis/commands_mixins/streams_mixin.py | 27 +++++++++++----------- test/test_mixins/test_streams_commands.py | 26 +++++++++++++++++++++ 2 files changed, 39 insertions(+), 14 deletions(-) diff --git a/fakeredis/commands_mixins/streams_mixin.py b/fakeredis/commands_mixins/streams_mixin.py index e14d8116..4d37900a 100644 --- a/fakeredis/commands_mixins/streams_mixin.py +++ b/fakeredis/commands_mixins/streams_mixin.py @@ -104,15 +104,19 @@ def xrevrange(self, key, _min, _max, *args): (count,), _ = extract_args(args, ("+count",)) return self._xrange(key.value, _max, _min, True, count) - def _xread(self, stream_start_id_list: List, count: int, first_pass: bool): + def _xread(self, stream_start_id_list: List[Tuple[bytes, StreamRangeTest]], count: int, blocking: bool, + first_pass: bool): max_inf = StreamRangeTest.decode(b"+") res = list() - for item, start_id in stream_start_id_list: + for stream_name, start_id in stream_start_id_list: + item = CommandItem(stream_name, self._db, item=self._db.get(stream_name), default=None) stream_results = self._xrange(item.value, start_id, max_inf, False, count) if first_pass and (count is None): return None if len(stream_results) > 0: res.append([item.key, stream_results]) + if blocking and count and len(res) == 0: + return None return res def _xreadgroup( @@ -135,6 +139,8 @@ def _xreadgroup( @staticmethod def _parse_start_id(key: CommandItem, s: bytes) -> StreamRangeTest: if s == b"$": + if key.value is None: + return StreamRangeTest.decode(b"0-0") return StreamRangeTest.decode(key.value.last_item_key(), exclusive=True) return StreamRangeTest.decode(s, exclusive=True) @@ -146,23 +152,16 @@ def xread(self, *args): left_args = left_args[1:] num_streams = int(len(left_args) / 2) - stream_start_id_list = list() + stream_start_id_list: List[Tuple[bytes, StreamRangeTest]] = list() # (name, start_id) for i in range(num_streams): - item = CommandItem( - left_args[i], self._db, item=self._db.get(left_args[i]), default=None - ) + item = CommandItem(left_args[i], self._db, item=self._db.get(left_args[i]), default=None) start_id = self._parse_start_id(item, left_args[i + num_streams]) - stream_start_id_list.append( - ( - item, - start_id, - ) - ) + stream_start_id_list.append((left_args[i], start_id)) if timeout is None: - return self._xread(stream_start_id_list, count, False) + return self._xread(stream_start_id_list, count, blocking=False, first_pass=False) else: return self._blocking( - timeout / 1000.0, functools.partial(self._xread, stream_start_id_list, count) + timeout / 1000.0, functools.partial(self._xread, stream_start_id_list, count, True) ) @command(name="XREADGROUP", fixed=(bytes, bytes, bytes), repeat=(bytes,)) diff --git a/test/test_mixins/test_streams_commands.py b/test/test_mixins/test_streams_commands.py index b74b333f..cb768577 100644 --- a/test/test_mixins/test_streams_commands.py +++ b/test/test_mixins/test_streams_commands.py @@ -1,3 +1,4 @@ +import threading import time from typing import List @@ -744,3 +745,28 @@ def test_xclaim(r: redis.Redis): stream, group, consumer1, min_idle_time=0, message_ids=(message_id,), justid=True, ) == [message_id, ] + + +def test_xread_blocking(create_redis): + # thread with xread block 0 should hang + # putting data in the stream should unblock it + event = threading.Event() + event.clear() + + def thread_func(): + while not event.is_set(): + time.sleep(0.1) + r = create_redis(db=1) + r.xadd("stream", {"x": "1"}) + time.sleep(0.1) + + t = threading.Thread(target=thread_func) + t.start() + r1 = create_redis(db=1) + event.set() + result = r1.xread({"stream": "$"}, block=0, count=1) + event.clear() + t.join() + assert result[0][0] == b"stream" + assert result[0][1][0][1] == {b'x': b'1'} + pass