diff --git a/fakeredis/_basefakesocket.py b/fakeredis/_basefakesocket.py index 599717c4..c0ef4f55 100644 --- a/fakeredis/_basefakesocket.py +++ b/fakeredis/_basefakesocket.py @@ -8,6 +8,8 @@ import redis from redis.connection import DefaultParser +from fakeredis.model import XStream +from fakeredis.model import ZSet from . import _msgs as msgs from ._command_args_parsing import extract_args from ._commands import Int, Float, SUPPORTED_COMMANDS, COMMANDS_WITH_SUB, Signature, CommandItem, Hash @@ -21,8 +23,6 @@ QUEUED, decode_command_bytes, ) -from ._stream import XStream -from ._zset import ZSet def _extract_command(fields: List[bytes]) -> Tuple[Any, List[Any]]: diff --git a/fakeredis/commands_mixins/generic_mixin.py b/fakeredis/commands_mixins/generic_mixin.py index 7cfb9ed6..b0b761bb 100644 --- a/fakeredis/commands_mixins/generic_mixin.py +++ b/fakeredis/commands_mixins/generic_mixin.py @@ -17,7 +17,7 @@ Hash, ) from fakeredis._helpers import compile_pattern, SimpleError, OK, casematch, Database, SimpleString -from fakeredis._zset import ZSet +from fakeredis.model import ZSet class GenericCommandsMixin: diff --git a/fakeredis/commands_mixins/geo_mixin.py b/fakeredis/commands_mixins/geo_mixin.py index 3fcfebad..52471b2f 100644 --- a/fakeredis/commands_mixins/geo_mixin.py +++ b/fakeredis/commands_mixins/geo_mixin.py @@ -6,9 +6,9 @@ from fakeredis._command_args_parsing import extract_args from fakeredis._commands import command, Key, Float, CommandItem from fakeredis._helpers import SimpleError, Database -from fakeredis._zset import ZSet -from fakeredis.geo import geohash -from fakeredis.geo.haversine import distance +from fakeredis.model import ZSet +from fakeredis.geo import distance, geo_encode, geo_decode + UNIT_TO_M = {"km": 0.001, "mi": 0.000621371, "ft": 3.28084, "m": 1} @@ -71,7 +71,7 @@ def _find_near( """ results = list() for name, _hash in zset.items(): - p_lat, p_long, _, _ = geohash.decode(_hash) + p_lat, p_long, _, _ = geo_decode(_hash) dist = distance((p_lat, p_long), (lat, long)) * conv if dist < radius: results.append(GeoResult(name, p_long, p_lat, _hash, dist)) @@ -120,7 +120,7 @@ def geoadd(self, key: CommandItem, *args: bytes) -> int: data[i + 2], ) if (name in zset and not xx) or (name not in zset and not nx): - if zset.add(name, geohash.encode(lat, long, 10)): + if zset.add(name, geo_encode(lat, long, 10)): changed_items += 1 if changed_items: key.updated() @@ -137,7 +137,7 @@ def geohash(self, key: CommandItem, *members: bytes) -> List[bytes]: @command(name="GEOPOS", fixed=(Key(ZSet), bytes), repeat=(bytes,)) def geopos(self, key: CommandItem, *members: bytes) -> List[Optional[List[bytes]]]: gospositions = map( - lambda x: geohash.decode(x) if x is not None else x, + lambda x: geo_decode(x) if x is not None else x, map(key.value.get, members), ) res = [ @@ -158,7 +158,7 @@ def geodist(self, key: CommandItem, m1: bytes, m2: bytes, *args: bytes) -> Optio geohashes = [key.value.get(m1), key.value.get(m2)] if any(elem is None for elem in geohashes): return None - geo_locs = [geohash.decode(x) for x in geohashes] + geo_locs = [geo_decode(x) for x in geohashes] res = distance((geo_locs[0][0], geo_locs[0][1]), (geo_locs[1][0], geo_locs[1][1])) unit = translate_meters_to_unit(args[0]) if len(args) == 1 else 1 return res * unit @@ -236,13 +236,13 @@ def georadius(self, key: CommandItem, long: float, lat: float, radius: float, *a @command(name="GEORADIUSBYMEMBER", fixed=(Key(ZSet), bytes, Float), repeat=(bytes,)) def georadiusbymember(self, key: CommandItem, member_name: bytes, radius: float, *args: bytes): member_score = key.value.get(member_name) - lat, long, _, _ = geohash.decode(member_score) + lat, long, _, _ = geo_decode(member_score) return self.georadius(key, long, lat, radius, *args) @command(name="GEORADIUSBYMEMBER_RO", fixed=(Key(ZSet), bytes, Float), repeat=(bytes,)) def georadiusbymember_ro(self, key: CommandItem, member_name: bytes, radius: float, *args: float) -> List[Any]: member_score = key.value.get(member_name) - lat, long, _, _ = geohash.decode(member_score) + lat, long, _, _ = geo_decode(member_score) return self.georadius_ro(key, long, lat, radius, *args) @command(name="GEOSEARCH", fixed=(Key(ZSet),), repeat=(bytes,)) diff --git a/fakeredis/commands_mixins/sortedset_mixin.py b/fakeredis/commands_mixins/sortedset_mixin.py index 6c4c8036..6cb8da13 100644 --- a/fakeredis/commands_mixins/sortedset_mixin.py +++ b/fakeredis/commands_mixins/sortedset_mixin.py @@ -26,7 +26,7 @@ null_terminate, Database, ) -from fakeredis._zset import ZSet +from fakeredis.model import ZSet SORTED_SET_METHODS = { "ZUNIONSTORE": lambda s1, s2: s1 | s2, diff --git a/fakeredis/commands_mixins/streams_mixin.py b/fakeredis/commands_mixins/streams_mixin.py index 081c2074..ff3a34d2 100644 --- a/fakeredis/commands_mixins/streams_mixin.py +++ b/fakeredis/commands_mixins/streams_mixin.py @@ -5,7 +5,7 @@ from fakeredis._command_args_parsing import extract_args from fakeredis._commands import Key, command, CommandItem, Int from fakeredis._helpers import SimpleError, casematch, OK, current_time, Database, SimpleString -from fakeredis._stream import XStream, StreamRangeTest, StreamGroup, StreamEntryKey +from fakeredis.model import XStream, StreamRangeTest, StreamGroup, StreamEntryKey class StreamsCommandsMixin: diff --git a/fakeredis/geo/__init__.py b/fakeredis/geo/__init__.py index e69de29b..745564c7 100644 --- a/fakeredis/geo/__init__.py +++ b/fakeredis/geo/__init__.py @@ -0,0 +1,8 @@ +from .geohash import geo_encode, geo_decode +from .haversine import distance + +__all__ = [ + "geo_encode", + "geo_decode", + "distance", +] diff --git a/fakeredis/geo/geohash.py b/fakeredis/geo/geohash.py index c7a0d2be..a480ea01 100644 --- a/fakeredis/geo/geohash.py +++ b/fakeredis/geo/geohash.py @@ -7,13 +7,11 @@ decodemap = {base32[i]: i for i in range(len(base32))} -def decode(geohash: str) -> Tuple[float, float, float, float]: +def geo_decode(geohash: str) -> Tuple[float, float, float, float]: """ - Decode the geohash to its exact values, including the error - margins of the result. Returns four float values: latitude, - longitude, the plus/minus error for latitude (as a positive - number) and the plus/minus error for longitude (as a positive - number). + Decode the geohash to its exact values, including the error margins of the result. Returns four float values: + latitude, longitude, the plus/minus error for latitude (as a positive number) and the plus/minus error for longitude + (as a positive number). """ lat_interval, lon_interval = (-90.0, 90.0), (-180.0, 180.0) lat_err, lon_err = 90.0, 180.0 @@ -51,10 +49,10 @@ def decode(geohash: str) -> Tuple[float, float, float, float]: return lat, lon, lat_err, lon_err -def encode(latitude: float, longitude: float, precision: int = 12) -> str: +def geo_encode(latitude: float, longitude: float, precision: int = 12) -> str: """ - Encode a position given in float arguments latitude, longitude to - a geohash which will have the character count precision. + Encode a position given in float arguments latitude, longitude to a geohash which will have the character count + precision. """ lat_interval, lon_interval = (-90.0, 90.0), (-180.0, 180.0) geohash, bits = [], [16, 8, 4, 2, 1] # type: ignore diff --git a/fakeredis/geo/haversine.py b/fakeredis/geo/haversine.py index 6e5aed6b..b2096271 100644 --- a/fakeredis/geo/haversine.py +++ b/fakeredis/geo/haversine.py @@ -2,23 +2,6 @@ from typing import Tuple -# class GeoMember: -# def __init__(self, name: bytes, lat: float, long: float): -# self.name = name -# self.long = long -# self.lat = lat -# -# @staticmethod -# def from_bytes_tuple(t: Tuple[bytes, bytes, bytes]) -> 'GeoMember': -# long = Float.decode(t[0]) -# lat = Float.decode(t[1]) -# name = t[2] -# return GeoMember(name, lat, long) -# -# def geohash(self): -# return geohash.encode(self.lat, self.long) - - def distance(origin: Tuple[float, float], destination: Tuple[float, float]) -> float: """Calculate the Haversine distance in meters.""" radius = 6372797.560856 # Earth's quatratic mean radius for WGS-84 diff --git a/fakeredis/model/__init__.py b/fakeredis/model/__init__.py new file mode 100644 index 00000000..d7e1d45b --- /dev/null +++ b/fakeredis/model/__init__.py @@ -0,0 +1,16 @@ +from ._stream import XStream, StreamEntryKey, StreamGroup, StreamRangeTest +from ._timeseries_model import TimeSeries, TimeSeriesRule, AGGREGATORS +from ._topk import HeavyKeeper +from ._zset import ZSet + +__all__ = [ + "XStream", + "StreamRangeTest", + "StreamGroup", + "StreamEntryKey", + "ZSet", + "TimeSeries", + "TimeSeriesRule", + "AGGREGATORS", + "HeavyKeeper", +] diff --git a/fakeredis/_stream.py b/fakeredis/model/_stream.py similarity index 100% rename from fakeredis/_stream.py rename to fakeredis/model/_stream.py diff --git a/fakeredis/stack/_timeseries_model.py b/fakeredis/model/_timeseries_model.py similarity index 100% rename from fakeredis/stack/_timeseries_model.py rename to fakeredis/model/_timeseries_model.py diff --git a/fakeredis/model/_topk.py b/fakeredis/model/_topk.py new file mode 100644 index 00000000..024e9df7 --- /dev/null +++ b/fakeredis/model/_topk.py @@ -0,0 +1,103 @@ +import heapq +import random +import time +from typing import List, Optional, Tuple + + +class Bucket(object): + def __init__(self, counter: int, fingerprint: int): + self.counter = counter + self.fingerprint = fingerprint + + def add(self, fingerprint: int, incr: int, decay: float) -> int: + if self.fingerprint == fingerprint: + self.counter += incr + return self.counter + elif self._decay(decay): + self.counter += incr + self.fingerprint = fingerprint + return self.counter + return 0 + + def count(self, fingerprint: int) -> int: + if self.fingerprint == fingerprint: + return self.counter + return 0 + + def _decay(self, decay: float) -> bool: + if self.counter > 0: + probability = decay**self.counter + if probability >= 1 or random.random() < probability: + self.counter -= 1 + return self.counter == 0 + + +class HashArray(object): + def __init__(self, width: int, decay: float): + self.width = width + self.decay = decay + self.array = [Bucket(0, 0) for _ in range(width)] + self._seed = random.getrandbits(32) + + def count(self, item: bytes) -> int: + return self.get_bucket(item).count(self._hash(item)) + + def add(self, item: bytes, incr: int) -> int: + bucket = self.get_bucket(item) + return bucket.add(self._hash(item), incr, self.decay) + + def get_bucket(self, item: bytes) -> Bucket: + return self.array[self._hash(item) % self.width] + + def _hash(self, item: bytes) -> int: + return hash(item) ^ self._seed + + +class HeavyKeeper(object): + is_topk_initialized = False + + def __init__(self, k: int, width: int = 1024, depth: int = 5, decay: float = 0.9) -> None: + if not HeavyKeeper.is_topk_initialized: + random.seed(time.time()) + self.k = k + self.width = width + self.depth = depth + self.decay = decay + self.hash_arrays = [HashArray(width, decay) for _ in range(depth)] + self.min_heap: List[Tuple[int, bytes]] = list() + + def _index(self, val: bytes) -> int: + for ind, item in enumerate(self.min_heap): + if item[1] == val: + return ind + return -1 + + def add(self, item: bytes, incr: int) -> Optional[bytes]: + max_count = 0 + for i in range(self.depth): + count = self.hash_arrays[i].add(item, incr) + max_count = max(max_count, count) + if len(self.min_heap) < self.k: + heapq.heappush(self.min_heap, (max_count, item)) + return None + ind = self._index(item) + if ind >= 0: + self.min_heap[ind] = (max_count, item) + heapq.heapify(self.min_heap) + return None + if max_count > self.min_heap[0][0]: + expelled = heapq.heapreplace(self.min_heap, (max_count, item)) + return expelled[1] + return None + + def count(self, item: bytes) -> int: + ind = self._index(item) + if ind > 0: + return self.min_heap[ind][0] + return max([ha.count(item) for ha in self.hash_arrays]) + + def list(self, k: Optional[int] = None) -> List[Tuple[int, bytes]]: + sorted_list = sorted(self.min_heap, key=lambda x: x[0], reverse=True) + if k is None: + return sorted_list + return sorted_list[:k] diff --git a/fakeredis/_zset.py b/fakeredis/model/_zset.py similarity index 100% rename from fakeredis/_zset.py rename to fakeredis/model/_zset.py diff --git a/fakeredis/stack/_json_mixin.py b/fakeredis/stack/_json_mixin.py index cf41ff1c..6d4b38f6 100644 --- a/fakeredis/stack/_json_mixin.py +++ b/fakeredis/stack/_json_mixin.py @@ -14,7 +14,7 @@ from fakeredis._command_args_parsing import extract_args from fakeredis._commands import Key, command, delete_keys, CommandItem, Int, Float from fakeredis._helpers import SimpleString -from fakeredis._zset import ZSet +from fakeredis.model import ZSet JsonType = Union[str, int, float, bool, None, Dict[str, Any], List[Any]] diff --git a/fakeredis/stack/_timeseries_mixin.py b/fakeredis/stack/_timeseries_mixin.py index 1444f803..1911db5c 100644 --- a/fakeredis/stack/_timeseries_mixin.py +++ b/fakeredis/stack/_timeseries_mixin.py @@ -5,7 +5,7 @@ from fakeredis._command_args_parsing import extract_args from fakeredis._commands import command, Key, CommandItem, Int, Float, Timestamp from fakeredis._helpers import Database, SimpleString, OK, SimpleError, casematch -from ._timeseries_model import TimeSeries, TimeSeriesRule, AGGREGATORS +from fakeredis.model import TimeSeries, TimeSeriesRule, AGGREGATORS class TimeSeriesCommandsMixin: # TimeSeries commands diff --git a/fakeredis/stack/_topk_mixin.py b/fakeredis/stack/_topk_mixin.py index f0ab44d5..17b9deb5 100644 --- a/fakeredis/stack/_topk_mixin.py +++ b/fakeredis/stack/_topk_mixin.py @@ -1,113 +1,10 @@ -"""Command mixin for emulating `redis-py`'s top-k functionality.""" - -import heapq -import random -import time from typing import Any, List, Optional, Tuple from fakeredis import _msgs as msgs from fakeredis._command_args_parsing import extract_args from fakeredis._commands import Key, Int, Float, command, CommandItem from fakeredis._helpers import OK, SimpleError, SimpleString - - -class Bucket(object): - def __init__(self, counter: int, fingerprint: int): - self.counter = counter - self.fingerprint = fingerprint - - def add(self, fingerprint: int, incr: int, decay: float) -> int: - if self.fingerprint == fingerprint: - self.counter += incr - return self.counter - elif self._decay(decay): - self.counter += incr - self.fingerprint = fingerprint - return self.counter - return 0 - - def count(self, fingerprint: int) -> int: - if self.fingerprint == fingerprint: - return self.counter - return 0 - - def _decay(self, decay: float) -> bool: - if self.counter > 0: - probability = decay**self.counter - if probability >= 1 or random.random() < probability: - self.counter -= 1 - return self.counter == 0 - - -class HashArray(object): - def __init__(self, width: int, decay: float): - self.width = width - self.decay = decay - self.array = [Bucket(0, 0) for _ in range(width)] - self._seed = random.getrandbits(32) - - def count(self, item: bytes) -> int: - return self.get_bucket(item).count(self._hash(item)) - - def add(self, item: bytes, incr: int) -> int: - bucket = self.get_bucket(item) - return bucket.add(self._hash(item), incr, self.decay) - - def get_bucket(self, item: bytes) -> Bucket: - return self.array[self._hash(item) % self.width] - - def _hash(self, item: bytes) -> int: - return hash(item) ^ self._seed - - -class HeavyKeeper(object): - is_topk_initialized = False - - def __init__(self, k: int, width: int = 1024, depth: int = 5, decay: float = 0.9) -> None: - if not HeavyKeeper.is_topk_initialized: - random.seed(time.time()) - self.k = k - self.width = width - self.depth = depth - self.decay = decay - self.hash_arrays = [HashArray(width, decay) for _ in range(depth)] - self.min_heap: List[Tuple[int, bytes]] = list() - - def _index(self, val: bytes) -> int: - for ind, item in enumerate(self.min_heap): - if item[1] == val: - return ind - return -1 - - def add(self, item: bytes, incr: int) -> Optional[bytes]: - max_count = 0 - for i in range(self.depth): - count = self.hash_arrays[i].add(item, incr) - max_count = max(max_count, count) - if len(self.min_heap) < self.k: - heapq.heappush(self.min_heap, (max_count, item)) - return None - ind = self._index(item) - if ind >= 0: - self.min_heap[ind] = (max_count, item) - heapq.heapify(self.min_heap) - return None - if max_count > self.min_heap[0][0]: - expelled = heapq.heapreplace(self.min_heap, (max_count, item)) - return expelled[1] - return None - - def count(self, item: bytes) -> int: - ind = self._index(item) - if ind > 0: - return self.min_heap[ind][0] - return max([ha.count(item) for ha in self.hash_arrays]) - - def list(self, k: Optional[int] = None) -> List[Tuple[int, bytes]]: - sorted_list = sorted(self.min_heap, key=lambda x: x[0], reverse=True) - if k is None: - return sorted_list - return sorted_list[:k] +from fakeredis.model import HeavyKeeper class TopkCommandsMixin: diff --git a/test/test_internals/test_xstream.py b/test/test_internals/test_xstream.py index 0af6524f..56ac1434 100644 --- a/test/test_internals/test_xstream.py +++ b/test/test_internals/test_xstream.py @@ -1,6 +1,6 @@ import pytest -from fakeredis._stream import XStream, StreamRangeTest +from fakeredis.model import XStream, StreamRangeTest @pytest.mark.fake