Skip to content

Commit

Permalink
mypy
Browse files Browse the repository at this point in the history
  • Loading branch information
cunla committed Dec 13, 2023
1 parent 6a5ae53 commit f20db14
Show file tree
Hide file tree
Showing 9 changed files with 57 additions and 45 deletions.
1 change: 1 addition & 0 deletions docs/about/changelog.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ description: Change log of all fakeredis releases
### 🧰 Maintenance

- Testing for python 3.12
- Dependencies update

## v2.20.0

Expand Down
17 changes: 9 additions & 8 deletions fakeredis/_basefakesocket.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,15 +25,15 @@


def _extract_command(fields: List[bytes]) -> Tuple[Any, List[Any]]:
"""Extracts the command and command arguments from a list of bytes fields.
"""Extracts the command and command arguments from a list of `bytes` fields.
:param fields: A list of bytes fields containing the command and command arguments.
:param fields: A list of `bytes` fields containing the command and command arguments.
:return: A tuple of the command and command arguments.
Example:
fields = [b'GET', b'key1']
result = _extract_command(fields)
print(result) # ('GET', ['key1'])
print(result) # ('GET', ['key1'])
"""
cmd = encode_command(fields[0])
if cmd in COMMANDS_WITH_SUB and len(fields) >= 2:
Expand Down Expand Up @@ -71,7 +71,8 @@ class BaseFakeSocket:

def __init__(self, server, db, *args, **kwargs):
super(BaseFakeSocket, self).__init__(*args, **kwargs)
self._server = server
from fakeredis import FakeServer
self._server: FakeServer = server
self._db_num = db
self._db = server.dbs[self._db_num]
self.responses: Optional[queue.Queue] = queue.Queue()
Expand All @@ -84,7 +85,7 @@ def __init__(self, server, db, *args, **kwargs):
self.version = server.version

def put_response(self, msg: Any) -> None:
"""Put a response message into the responses queue.
"""Put a response message into the queue of responses.
:param msg: The response message.
"""
Expand Down Expand Up @@ -169,7 +170,7 @@ def _parse_commands(self):
buf = buf[length + 2:] # +2 to skip the CRLF
self._process_command(fields)

def _run_command(self, func: Callable[..., Any], sig: Signature, args: Tuple[Any], from_script: bool) -> Any:
def _run_command(self, func: Callable[..., Any], sig: Signature, args: List[Any], from_script: bool) -> Any:
command_items = {}
try:
ret = sig.apply(args, self._db, self.version)
Expand Down Expand Up @@ -320,8 +321,8 @@ def _scan(self, keys, cursor, *args):
it has the following drawbacks:
- A given element may be returned multiple times. It is up to the application to handle the case of duplicated
elements, for example, only using the returned elements in order to perform operations that are safe when
re-applied multiple times.
elements, for example, only using the returned elements to perform operations that are safe when re-applied
multiple times.
- Elements that were not constantly present in the collection during a full iteration may be returned or not:
it is undefined.
Expand Down
14 changes: 7 additions & 7 deletions fakeredis/_commands.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import functools
import math
import re
from typing import Tuple, Union, Optional, Any, Type, List, Callable
from typing import Tuple, Union, Optional, Any, Type, List, Callable, Sequence

from . import _msgs as msgs
from ._helpers import null_terminate, SimpleError, Database
Expand Down Expand Up @@ -179,7 +179,7 @@ class Float(RedisType):
Redis uses long double for some cases (INCRBYFLOAT, HINCRBYFLOAT)
and double for others (zset scores), but Python doesn't support
long double.
`long double`.
"""

DECODE_ERROR = msgs.INVALID_FLOAT_MSG
Expand All @@ -194,7 +194,7 @@ def decode(
crop_null: bool = False,
decode_error: Optional[str] = None,
) -> float:
# redis has some quirks in float parsing, with several variants.
# Redis has some quirks in float parsing, with several variants.
# See https://github.com/antirez/redis/issues/5706
try:
if crop_null:
Expand All @@ -209,7 +209,7 @@ def decode(
if math.isnan(out):
raise ValueError
if not allow_erange:
# Values that over- or underflow- are explicitly rejected by
# Values that over- or under-flow are explicitly rejected by
# redis. This is a crude hack to determine whether the input
# may have been such a value.
if out in (math.inf, -math.inf, 0.0) and re.match(b"^[^a-zA-Z]*[1-9]", value):
Expand All @@ -223,7 +223,7 @@ def encode(cls, value: float, humanfriendly: bool) -> bytes:
if math.isinf(value):
return str(value).encode()
elif humanfriendly:
# Algorithm from ld2string in redis
# Algorithm from `ld2string` in redis
out = "{:.17f}".format(value)
out = re.sub(r"\.?0+$", "", out)
return out.encode()
Expand Down Expand Up @@ -360,7 +360,7 @@ def __init__(
self.flags = set(flags)
self.command_args = args

def check_arity(self, args: Tuple[Any], version: Tuple[int]) -> None:
def check_arity(self, args: Sequence[Any], version: Tuple[int]) -> None:
if len(args) == len(self.fixed):
return
delta = len(args) - len(self.fixed)
Expand All @@ -376,7 +376,7 @@ def check_arity(self, args: Tuple[Any], version: Tuple[int]) -> None:
raise SimpleError(msg)

def apply(
self, args: Tuple[Any], db: Database, version: Tuple[int]
self, args: Sequence[Any], db: Database, version: Tuple[int]
) -> Union[Tuple[Any], Tuple[List[Any], List[CommandItem]]]:
"""Returns a tuple, which is either:
- transformed args and a dict of CommandItems; or
Expand Down
9 changes: 8 additions & 1 deletion fakeredis/_fakesocket.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,14 @@
from .commands_mixins.hash_mixin import HashCommandsMixin
from .commands_mixins.list_mixin import ListCommandsMixin
from .commands_mixins.pubsub_mixin import PubSubCommandsMixin
from .commands_mixins.scripting_mixin import ScriptingCommandsMixin

try:
from lupa import LuaRuntime
from .commands_mixins.scripting_mixin import ScriptingCommandsMixin
except ImportError:
class ScriptingCommandsMixin:
pass

from .commands_mixins.server_mixin import ServerCommandsMixin
from .commands_mixins.set_mixin import SetCommandsMixin
from .commands_mixins.sortedset_mixin import SortedSetCommandsMixin
Expand Down
2 changes: 1 addition & 1 deletion fakeredis/_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ def __init__(self, version: Tuple[int] = (7,)):
self.subscribers: Dict[bytes, weakref.WeakSet] = defaultdict(weakref.WeakSet)
self.psubscribers: Dict[bytes, weakref.WeakSet] = defaultdict(weakref.WeakSet)
self.ssubscribers: Dict[bytes, weakref.WeakSet] = defaultdict(weakref.WeakSet)
self.lastsave = int(time.time())
self.lastsave: int = int(time.time())
self.connected = True
# List of weakrefs to sockets that are being closed lazily
self.closed_sockets: List[Any] = []
Expand Down
4 changes: 2 additions & 2 deletions fakeredis/_stream.py
Original file line number Diff line number Diff line change
Expand Up @@ -380,7 +380,7 @@ def delete(self, lst: List[Union[str, bytes]]) -> int:
def add(self, fields: List, entry_key: str = "*") -> Union[None, bytes]:
"""Add entry to a stream.
If the entry_key can not be added (because its timestamp is before the last entry, etc.),
If the entry_key cannot be added (because its timestamp is before the last entry, etc.),
nothing is added.
:param fields: List of fields to add, must [key1, value1, key2, value2, ... ]
Expand Down Expand Up @@ -493,7 +493,7 @@ def trim(
"""Trim a stream
:param max_length: Max length of the resulting stream after trimming (number of last values to keep)
:param start_entry_key: Min entry-key to keep, can not be given together with max_length.
:param start_entry_key: Min entry-key to keep, cannot be given together with max_length.
:param limit: Number of entries to keep from minid.
:returns: The resulting stream after trimming.
:raises ValueError: When both max_length and start_entry_key are passed.
Expand Down
4 changes: 2 additions & 2 deletions fakeredis/aioredis.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ async def disconnect(self, **kwargs):
await super().disconnect(**kwargs)
self._sock = None

async def can_read(self, timeout: float = 0):
async def can_read(self, timeout: Optional[float] = 0):
if not self.is_connected:
await self.connect()
if timeout == 0:
Expand Down Expand Up @@ -154,7 +154,7 @@ async def read_response(self, **kwargs):
await self.disconnect()
raise redis_async.ConnectionError(msgs.CONNECTION_ERROR_MSG)
else:
timeout = kwargs.pop("timeout", None)
timeout: Optional[float] = kwargs.pop("timeout", None)
can_read = await self.can_read(timeout)
response = await self._reader.read(0) if can_read else None
if isinstance(response, redis_async.ResponseError):
Expand Down
20 changes: 10 additions & 10 deletions fakeredis/commands_mixins/scripting_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import itertools
import logging
from typing import Tuple, Callable, AnyStr, Set, Any

from lupa import LuaRuntime
from fakeredis import _msgs as msgs
from fakeredis._commands import command, Int
from fakeredis._helpers import (
Expand All @@ -29,15 +29,15 @@
}


def _ensure_str(s: AnyStr, encoding: str, replaceerr: str):
def _ensure_str(s: AnyStr, encoding: str, replaceerr: str) -> str:
if isinstance(s, bytes):
res = s.decode(encoding=encoding, errors=replaceerr)
else:
res = str(s).encode(encoding=encoding, errors=replaceerr)
res = str(s)
return res


def _check_for_lua_globals(lua_runtime, expected_globals):
def _check_for_lua_globals(lua_runtime: LuaRuntime, expected_globals: Set[Any]) -> None:
unexpected_globals = set(lua_runtime.globals().keys()) - expected_globals
if len(unexpected_globals) > 0:
unexpected = [
Expand All @@ -46,7 +46,7 @@ def _check_for_lua_globals(lua_runtime, expected_globals):
raise SimpleError(msgs.GLOBAL_VARIABLE_MSG.format(", ".join(unexpected)))


def _lua_redis_log(lua_runtime, expected_globals, lvl, *args):
def _lua_redis_log(lua_runtime: LuaRuntime, expected_globals: Set[Any], lvl, *args) -> None:
_check_for_lua_globals(lua_runtime, expected_globals)
if len(args) < 1:
raise SimpleError(msgs.REQUIRES_MORE_ARGS_MSG.format("redis.log()", "two"))
Expand All @@ -72,7 +72,7 @@ def __init__(self, *args, **kwargs):
# Maps SHA1 to the script source
self.script_cache = {}

def _convert_redis_arg(self, lua_runtime, value):
def _convert_redis_arg(self, lua_runtime: LuaRuntime, value):
# Type checks are exact to avoid issues like bool being a subclass of int.
if type(value) is bytes:
return value
Expand All @@ -87,7 +87,7 @@ def _convert_redis_arg(self, lua_runtime, value):
)
raise SimpleError(msg)

def _convert_redis_result(self, lua_runtime, result):
def _convert_redis_result(self, lua_runtime: LuaRuntime, result):
if isinstance(result, (bytes, int)):
return result
elif isinstance(result, SimpleString):
Expand Down Expand Up @@ -139,15 +139,15 @@ def _convert_lua_result(self, result, nested=True):
return 1 if result else None
return result

def _lua_redis_call(self, lua_runtime, expected_globals, op, *args):
def _lua_redis_call(self, lua_runtime: LuaRuntime, expected_globals: Set[Any], op, *args):
# Check if we've set any global variables before making any change.
_check_for_lua_globals(lua_runtime, expected_globals)
func, sig = self._name_to_func(encode_command(op))
new_args = [self._convert_redis_arg(lua_runtime, arg) for arg in args]
result = self._run_command(func, sig, new_args, True)
return self._convert_redis_result(lua_runtime, result)

def _lua_redis_pcall(self, lua_runtime, expected_globals, op, *args):
def _lua_redis_pcall(self, lua_runtime: LuaRuntime, expected_globals: Set[Any], op, *args):
try:
return self._lua_redis_call(lua_runtime, expected_globals, op, *args)
except Exception as ex:
Expand All @@ -163,7 +163,7 @@ def eval(self, script, numkeys, *keys_and_args):
raise SimpleError(msgs.NEGATIVE_KEYS_MSG)
sha1 = hashlib.sha1(script).hexdigest().encode()
self.script_cache[sha1] = script
lua_runtime = LuaRuntime(encoding=None, unpack_returned_tuples=True)
lua_runtime: LuaRuntime = LuaRuntime(encoding=None, unpack_returned_tuples=True)

set_globals = lua_runtime.eval(
"""
Expand Down
31 changes: 17 additions & 14 deletions fakeredis/commands_mixins/server_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

from fakeredis import _msgs as msgs
from fakeredis._commands import command, DbIndex
from fakeredis._helpers import OK, SimpleError, casematch, BGSAVE_STARTED, Database
from fakeredis._helpers import OK, SimpleError, casematch, BGSAVE_STARTED, Database, SimpleString

_COMMAND_INFO: Optional[Dict[bytes, List[Any]]] = None

Expand All @@ -28,8 +28,11 @@ def _load_command_info() -> None:


class ServerCommandsMixin:
_server: Any
_db: Database
def __init__(self, *args: Any, **kwargs: Any) -> None:
super().__init__(*args, **kwargs)
from fakeredis._server import FakeServer
self._server: "FakeServer"
self._db: Database

@staticmethod
def _get_command_info(cmd: bytes) -> Optional[List[Any]]:
Expand All @@ -39,25 +42,25 @@ def _get_command_info(cmd: bytes) -> Optional[List[Any]]:
return _COMMAND_INFO.get(cmd, None)

@command((), (bytes,), flags=msgs.FLAG_NO_SCRIPT)
def bgsave(self, *args):
def bgsave(self, *args: bytes) -> SimpleString:
if len(args) > 1 or (len(args) == 1 and not casematch(args[0], b"schedule")):
raise SimpleError(msgs.SYNTAX_ERROR_MSG)
self._server.lastsave = int(time.time())
return BGSAVE_STARTED

@command(())
def dbsize(self):
def dbsize(self) -> int:
return len(self._db)

@command((), (bytes,))
def flushdb(self, *args):
def flushdb(self, *args: bytes) -> SimpleString:
if len(args) > 0 and (len(args) != 1 or not casematch(args[0], b"async")):
raise SimpleError(msgs.SYNTAX_ERROR_MSG)
self._db.clear()
return OK

@command((), (bytes,))
def flushall(self, *args):
def flushall(self, *args: bytes) -> SimpleString:
if len(args) > 0 and (len(args) != 1 or not casematch(args[0], b"async")):
raise SimpleError(msgs.SYNTAX_ERROR_MSG)
for db in self._server.dbs.values():
Expand All @@ -66,41 +69,41 @@ def flushall(self, *args):
return OK

@command(())
def lastsave(self):
def lastsave(self) -> int:
return self._server.lastsave

@command((), flags=msgs.FLAG_NO_SCRIPT)
def save(self):
def save(self) -> SimpleString:
self._server.lastsave = int(time.time())
return OK

@command(())
def time(self):
def time(self) -> List[bytes]:
now_us = round(time.time() * 1_000_000)
now_s = now_us // 1_000_000
now_us %= 1_000_000
return [str(now_s).encode(), str(now_us).encode()]

@command((DbIndex, DbIndex))
def swapdb(self, index1, index2):
def swapdb(self, index1: int, index2: int) -> SimpleString:
if index1 != index2:
db1 = self._server.dbs[index1]
db2 = self._server.dbs[index2]
db1.swap(db2)
return OK

@command(name="COMMAND INFO", fixed=(), repeat=(bytes,))
def command_info(self, *commands):
def command_info(self, *commands: bytes) -> List[Any]:
res = [self._get_command_info(cmd) for cmd in commands]
return res

@command(name="COMMAND COUNT", fixed=(), repeat=())
def command_count(self):
def command_count(self) -> int:
_load_command_info()
return len(_COMMAND_INFO) if _COMMAND_INFO is not None else 0

@command(name="COMMAND", fixed=(), repeat=())
def command_(self):
def command_(self) -> List[Any]:
_load_command_info()
if _COMMAND_INFO is None:
return []
Expand Down

0 comments on commit f20db14

Please sign in to comment.