Skip to content

Commit

Permalink
refactor models (#339)
Browse files Browse the repository at this point in the history
  • Loading branch information
cunla authored Oct 10, 2024
1 parent fe71bd8 commit a00e698
Show file tree
Hide file tree
Showing 17 changed files with 152 additions and 147 deletions.
4 changes: 2 additions & 2 deletions fakeredis/_basefakesocket.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
import redis
from redis.connection import DefaultParser

from fakeredis.model import XStream
from fakeredis.model import ZSet
from . import _msgs as msgs
from ._command_args_parsing import extract_args
from ._commands import Int, Float, SUPPORTED_COMMANDS, COMMANDS_WITH_SUB, Signature, CommandItem, Hash
Expand All @@ -21,8 +23,6 @@
QUEUED,
decode_command_bytes,
)
from ._stream import XStream
from ._zset import ZSet


def _extract_command(fields: List[bytes]) -> Tuple[Any, List[Any]]:
Expand Down
2 changes: 1 addition & 1 deletion fakeredis/commands_mixins/generic_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
Hash,
)
from fakeredis._helpers import compile_pattern, SimpleError, OK, casematch, Database, SimpleString
from fakeredis._zset import ZSet
from fakeredis.model import ZSet


class GenericCommandsMixin:
Expand Down
18 changes: 9 additions & 9 deletions fakeredis/commands_mixins/geo_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,9 @@
from fakeredis._command_args_parsing import extract_args
from fakeredis._commands import command, Key, Float, CommandItem
from fakeredis._helpers import SimpleError, Database
from fakeredis._zset import ZSet
from fakeredis.geo import geohash
from fakeredis.geo.haversine import distance
from fakeredis.model import ZSet
from fakeredis.geo import distance, geo_encode, geo_decode


UNIT_TO_M = {"km": 0.001, "mi": 0.000621371, "ft": 3.28084, "m": 1}

Expand Down Expand Up @@ -71,7 +71,7 @@ def _find_near(
"""
results = list()
for name, _hash in zset.items():
p_lat, p_long, _, _ = geohash.decode(_hash)
p_lat, p_long, _, _ = geo_decode(_hash)
dist = distance((p_lat, p_long), (lat, long)) * conv
if dist < radius:
results.append(GeoResult(name, p_long, p_lat, _hash, dist))
Expand Down Expand Up @@ -120,7 +120,7 @@ def geoadd(self, key: CommandItem, *args: bytes) -> int:
data[i + 2],
)
if (name in zset and not xx) or (name not in zset and not nx):
if zset.add(name, geohash.encode(lat, long, 10)):
if zset.add(name, geo_encode(lat, long, 10)):
changed_items += 1
if changed_items:
key.updated()
Expand All @@ -137,7 +137,7 @@ def geohash(self, key: CommandItem, *members: bytes) -> List[bytes]:
@command(name="GEOPOS", fixed=(Key(ZSet), bytes), repeat=(bytes,))
def geopos(self, key: CommandItem, *members: bytes) -> List[Optional[List[bytes]]]:
gospositions = map(
lambda x: geohash.decode(x) if x is not None else x,
lambda x: geo_decode(x) if x is not None else x,
map(key.value.get, members),
)
res = [
Expand All @@ -158,7 +158,7 @@ def geodist(self, key: CommandItem, m1: bytes, m2: bytes, *args: bytes) -> Optio
geohashes = [key.value.get(m1), key.value.get(m2)]
if any(elem is None for elem in geohashes):
return None
geo_locs = [geohash.decode(x) for x in geohashes]
geo_locs = [geo_decode(x) for x in geohashes]
res = distance((geo_locs[0][0], geo_locs[0][1]), (geo_locs[1][0], geo_locs[1][1]))
unit = translate_meters_to_unit(args[0]) if len(args) == 1 else 1
return res * unit
Expand Down Expand Up @@ -236,13 +236,13 @@ def georadius(self, key: CommandItem, long: float, lat: float, radius: float, *a
@command(name="GEORADIUSBYMEMBER", fixed=(Key(ZSet), bytes, Float), repeat=(bytes,))
def georadiusbymember(self, key: CommandItem, member_name: bytes, radius: float, *args: bytes):
member_score = key.value.get(member_name)
lat, long, _, _ = geohash.decode(member_score)
lat, long, _, _ = geo_decode(member_score)
return self.georadius(key, long, lat, radius, *args)

@command(name="GEORADIUSBYMEMBER_RO", fixed=(Key(ZSet), bytes, Float), repeat=(bytes,))
def georadiusbymember_ro(self, key: CommandItem, member_name: bytes, radius: float, *args: float) -> List[Any]:
member_score = key.value.get(member_name)
lat, long, _, _ = geohash.decode(member_score)
lat, long, _, _ = geo_decode(member_score)
return self.georadius_ro(key, long, lat, radius, *args)

@command(name="GEOSEARCH", fixed=(Key(ZSet),), repeat=(bytes,))
Expand Down
2 changes: 1 addition & 1 deletion fakeredis/commands_mixins/sortedset_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
null_terminate,
Database,
)
from fakeredis._zset import ZSet
from fakeredis.model import ZSet

SORTED_SET_METHODS = {
"ZUNIONSTORE": lambda s1, s2: s1 | s2,
Expand Down
2 changes: 1 addition & 1 deletion fakeredis/commands_mixins/streams_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from fakeredis._command_args_parsing import extract_args
from fakeredis._commands import Key, command, CommandItem, Int
from fakeredis._helpers import SimpleError, casematch, OK, current_time, Database, SimpleString
from fakeredis._stream import XStream, StreamRangeTest, StreamGroup, StreamEntryKey
from fakeredis.model import XStream, StreamRangeTest, StreamGroup, StreamEntryKey


class StreamsCommandsMixin:
Expand Down
8 changes: 8 additions & 0 deletions fakeredis/geo/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
from .geohash import geo_encode, geo_decode
from .haversine import distance

__all__ = [
"geo_encode",
"geo_decode",
"distance",
]
16 changes: 7 additions & 9 deletions fakeredis/geo/geohash.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,11 @@
decodemap = {base32[i]: i for i in range(len(base32))}


def decode(geohash: str) -> Tuple[float, float, float, float]:
def geo_decode(geohash: str) -> Tuple[float, float, float, float]:
"""
Decode the geohash to its exact values, including the error
margins of the result. Returns four float values: latitude,
longitude, the plus/minus error for latitude (as a positive
number) and the plus/minus error for longitude (as a positive
number).
Decode the geohash to its exact values, including the error margins of the result. Returns four float values:
latitude, longitude, the plus/minus error for latitude (as a positive number) and the plus/minus error for longitude
(as a positive number).
"""
lat_interval, lon_interval = (-90.0, 90.0), (-180.0, 180.0)
lat_err, lon_err = 90.0, 180.0
Expand Down Expand Up @@ -51,10 +49,10 @@ def decode(geohash: str) -> Tuple[float, float, float, float]:
return lat, lon, lat_err, lon_err


def encode(latitude: float, longitude: float, precision: int = 12) -> str:
def geo_encode(latitude: float, longitude: float, precision: int = 12) -> str:
"""
Encode a position given in float arguments latitude, longitude to
a geohash which will have the character count precision.
Encode a position given in float arguments latitude, longitude to a geohash which will have the character count
precision.
"""
lat_interval, lon_interval = (-90.0, 90.0), (-180.0, 180.0)
geohash, bits = [], [16, 8, 4, 2, 1] # type: ignore
Expand Down
17 changes: 0 additions & 17 deletions fakeredis/geo/haversine.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,23 +2,6 @@
from typing import Tuple


# class GeoMember:
# def __init__(self, name: bytes, lat: float, long: float):
# self.name = name
# self.long = long
# self.lat = lat
#
# @staticmethod
# def from_bytes_tuple(t: Tuple[bytes, bytes, bytes]) -> 'GeoMember':
# long = Float.decode(t[0])
# lat = Float.decode(t[1])
# name = t[2]
# return GeoMember(name, lat, long)
#
# def geohash(self):
# return geohash.encode(self.lat, self.long)


def distance(origin: Tuple[float, float], destination: Tuple[float, float]) -> float:
"""Calculate the Haversine distance in meters."""
radius = 6372797.560856 # Earth's quatratic mean radius for WGS-84
Expand Down
16 changes: 16 additions & 0 deletions fakeredis/model/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
from ._stream import XStream, StreamEntryKey, StreamGroup, StreamRangeTest
from ._timeseries_model import TimeSeries, TimeSeriesRule, AGGREGATORS
from ._topk import HeavyKeeper
from ._zset import ZSet

__all__ = [
"XStream",
"StreamRangeTest",
"StreamGroup",
"StreamEntryKey",
"ZSet",
"TimeSeries",
"TimeSeriesRule",
"AGGREGATORS",
"HeavyKeeper",
]
File renamed without changes.
File renamed without changes.
103 changes: 103 additions & 0 deletions fakeredis/model/_topk.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
import heapq
import random
import time
from typing import List, Optional, Tuple


class Bucket(object):
def __init__(self, counter: int, fingerprint: int):
self.counter = counter
self.fingerprint = fingerprint

def add(self, fingerprint: int, incr: int, decay: float) -> int:
if self.fingerprint == fingerprint:
self.counter += incr
return self.counter
elif self._decay(decay):
self.counter += incr
self.fingerprint = fingerprint
return self.counter
return 0

def count(self, fingerprint: int) -> int:
if self.fingerprint == fingerprint:
return self.counter
return 0

def _decay(self, decay: float) -> bool:
if self.counter > 0:
probability = decay**self.counter
if probability >= 1 or random.random() < probability:
self.counter -= 1
return self.counter == 0


class HashArray(object):
def __init__(self, width: int, decay: float):
self.width = width
self.decay = decay
self.array = [Bucket(0, 0) for _ in range(width)]
self._seed = random.getrandbits(32)

def count(self, item: bytes) -> int:
return self.get_bucket(item).count(self._hash(item))

def add(self, item: bytes, incr: int) -> int:
bucket = self.get_bucket(item)
return bucket.add(self._hash(item), incr, self.decay)

def get_bucket(self, item: bytes) -> Bucket:
return self.array[self._hash(item) % self.width]

def _hash(self, item: bytes) -> int:
return hash(item) ^ self._seed


class HeavyKeeper(object):
is_topk_initialized = False

def __init__(self, k: int, width: int = 1024, depth: int = 5, decay: float = 0.9) -> None:
if not HeavyKeeper.is_topk_initialized:
random.seed(time.time())
self.k = k
self.width = width
self.depth = depth
self.decay = decay
self.hash_arrays = [HashArray(width, decay) for _ in range(depth)]
self.min_heap: List[Tuple[int, bytes]] = list()

def _index(self, val: bytes) -> int:
for ind, item in enumerate(self.min_heap):
if item[1] == val:
return ind
return -1

def add(self, item: bytes, incr: int) -> Optional[bytes]:
max_count = 0
for i in range(self.depth):
count = self.hash_arrays[i].add(item, incr)
max_count = max(max_count, count)
if len(self.min_heap) < self.k:
heapq.heappush(self.min_heap, (max_count, item))
return None
ind = self._index(item)
if ind >= 0:
self.min_heap[ind] = (max_count, item)
heapq.heapify(self.min_heap)
return None
if max_count > self.min_heap[0][0]:
expelled = heapq.heapreplace(self.min_heap, (max_count, item))
return expelled[1]
return None

def count(self, item: bytes) -> int:
ind = self._index(item)
if ind > 0:
return self.min_heap[ind][0]
return max([ha.count(item) for ha in self.hash_arrays])

def list(self, k: Optional[int] = None) -> List[Tuple[int, bytes]]:
sorted_list = sorted(self.min_heap, key=lambda x: x[0], reverse=True)
if k is None:
return sorted_list
return sorted_list[:k]
File renamed without changes.
2 changes: 1 addition & 1 deletion fakeredis/stack/_json_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from fakeredis._command_args_parsing import extract_args
from fakeredis._commands import Key, command, delete_keys, CommandItem, Int, Float
from fakeredis._helpers import SimpleString
from fakeredis._zset import ZSet
from fakeredis.model import ZSet

JsonType = Union[str, int, float, bool, None, Dict[str, Any], List[Any]]

Expand Down
2 changes: 1 addition & 1 deletion fakeredis/stack/_timeseries_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from fakeredis._command_args_parsing import extract_args
from fakeredis._commands import command, Key, CommandItem, Int, Float, Timestamp
from fakeredis._helpers import Database, SimpleString, OK, SimpleError, casematch
from ._timeseries_model import TimeSeries, TimeSeriesRule, AGGREGATORS
from fakeredis.model import TimeSeries, TimeSeriesRule, AGGREGATORS


class TimeSeriesCommandsMixin: # TimeSeries commands
Expand Down
Loading

0 comments on commit a00e698

Please sign in to comment.