From f20db1479b9df1bb4c50bdff197a0941c929c35f Mon Sep 17 00:00:00 2001 From: Daniel M Date: Wed, 13 Dec 2023 14:45:58 +0100 Subject: [PATCH] mypy --- docs/about/changelog.md | 1 + fakeredis/_basefakesocket.py | 17 ++++++----- fakeredis/_commands.py | 14 ++++----- fakeredis/_fakesocket.py | 9 +++++- fakeredis/_server.py | 2 +- fakeredis/_stream.py | 4 +-- fakeredis/aioredis.py | 4 +-- fakeredis/commands_mixins/scripting_mixin.py | 20 ++++++------- fakeredis/commands_mixins/server_mixin.py | 31 +++++++++++--------- 9 files changed, 57 insertions(+), 45 deletions(-) diff --git a/docs/about/changelog.md b/docs/about/changelog.md index 213a7f4a..f62158a2 100644 --- a/docs/about/changelog.md +++ b/docs/about/changelog.md @@ -14,6 +14,7 @@ description: Change log of all fakeredis releases ### 🧰 Maintenance - Testing for python 3.12 +- Dependencies update ## v2.20.0 diff --git a/fakeredis/_basefakesocket.py b/fakeredis/_basefakesocket.py index f068780f..9a0aef76 100644 --- a/fakeredis/_basefakesocket.py +++ b/fakeredis/_basefakesocket.py @@ -25,15 +25,15 @@ def _extract_command(fields: List[bytes]) -> Tuple[Any, List[Any]]: - """Extracts the command and command arguments from a list of bytes fields. + """Extracts the command and command arguments from a list of `bytes` fields. - :param fields: A list of bytes fields containing the command and command arguments. + :param fields: A list of `bytes` fields containing the command and command arguments. :return: A tuple of the command and command arguments. Example: fields = [b'GET', b'key1'] result = _extract_command(fields) - print(result) # ('GET', ['key1']) + print(result) # ('GET', ['key1']) """ cmd = encode_command(fields[0]) if cmd in COMMANDS_WITH_SUB and len(fields) >= 2: @@ -71,7 +71,8 @@ class BaseFakeSocket: def __init__(self, server, db, *args, **kwargs): super(BaseFakeSocket, self).__init__(*args, **kwargs) - self._server = server + from fakeredis import FakeServer + self._server: FakeServer = server self._db_num = db self._db = server.dbs[self._db_num] self.responses: Optional[queue.Queue] = queue.Queue() @@ -84,7 +85,7 @@ def __init__(self, server, db, *args, **kwargs): self.version = server.version def put_response(self, msg: Any) -> None: - """Put a response message into the responses queue. + """Put a response message into the queue of responses. :param msg: The response message. """ @@ -169,7 +170,7 @@ def _parse_commands(self): buf = buf[length + 2:] # +2 to skip the CRLF self._process_command(fields) - def _run_command(self, func: Callable[..., Any], sig: Signature, args: Tuple[Any], from_script: bool) -> Any: + def _run_command(self, func: Callable[..., Any], sig: Signature, args: List[Any], from_script: bool) -> Any: command_items = {} try: ret = sig.apply(args, self._db, self.version) @@ -320,8 +321,8 @@ def _scan(self, keys, cursor, *args): it has the following drawbacks: - A given element may be returned multiple times. It is up to the application to handle the case of duplicated - elements, for example, only using the returned elements in order to perform operations that are safe when - re-applied multiple times. + elements, for example, only using the returned elements to perform operations that are safe when re-applied + multiple times. - Elements that were not constantly present in the collection during a full iteration may be returned or not: it is undefined. diff --git a/fakeredis/_commands.py b/fakeredis/_commands.py index 3cd68523..98342e35 100644 --- a/fakeredis/_commands.py +++ b/fakeredis/_commands.py @@ -5,7 +5,7 @@ import functools import math import re -from typing import Tuple, Union, Optional, Any, Type, List, Callable +from typing import Tuple, Union, Optional, Any, Type, List, Callable, Sequence from . import _msgs as msgs from ._helpers import null_terminate, SimpleError, Database @@ -179,7 +179,7 @@ class Float(RedisType): Redis uses long double for some cases (INCRBYFLOAT, HINCRBYFLOAT) and double for others (zset scores), but Python doesn't support - long double. + `long double`. """ DECODE_ERROR = msgs.INVALID_FLOAT_MSG @@ -194,7 +194,7 @@ def decode( crop_null: bool = False, decode_error: Optional[str] = None, ) -> float: - # redis has some quirks in float parsing, with several variants. + # Redis has some quirks in float parsing, with several variants. # See https://github.com/antirez/redis/issues/5706 try: if crop_null: @@ -209,7 +209,7 @@ def decode( if math.isnan(out): raise ValueError if not allow_erange: - # Values that over- or underflow- are explicitly rejected by + # Values that over- or under-flow are explicitly rejected by # redis. This is a crude hack to determine whether the input # may have been such a value. if out in (math.inf, -math.inf, 0.0) and re.match(b"^[^a-zA-Z]*[1-9]", value): @@ -223,7 +223,7 @@ def encode(cls, value: float, humanfriendly: bool) -> bytes: if math.isinf(value): return str(value).encode() elif humanfriendly: - # Algorithm from ld2string in redis + # Algorithm from `ld2string` in redis out = "{:.17f}".format(value) out = re.sub(r"\.?0+$", "", out) return out.encode() @@ -360,7 +360,7 @@ def __init__( self.flags = set(flags) self.command_args = args - def check_arity(self, args: Tuple[Any], version: Tuple[int]) -> None: + def check_arity(self, args: Sequence[Any], version: Tuple[int]) -> None: if len(args) == len(self.fixed): return delta = len(args) - len(self.fixed) @@ -376,7 +376,7 @@ def check_arity(self, args: Tuple[Any], version: Tuple[int]) -> None: raise SimpleError(msg) def apply( - self, args: Tuple[Any], db: Database, version: Tuple[int] + self, args: Sequence[Any], db: Database, version: Tuple[int] ) -> Union[Tuple[Any], Tuple[List[Any], List[CommandItem]]]: """Returns a tuple, which is either: - transformed args and a dict of CommandItems; or diff --git a/fakeredis/_fakesocket.py b/fakeredis/_fakesocket.py index 8bd4f708..18a58733 100644 --- a/fakeredis/_fakesocket.py +++ b/fakeredis/_fakesocket.py @@ -7,7 +7,14 @@ from .commands_mixins.hash_mixin import HashCommandsMixin from .commands_mixins.list_mixin import ListCommandsMixin from .commands_mixins.pubsub_mixin import PubSubCommandsMixin -from .commands_mixins.scripting_mixin import ScriptingCommandsMixin + +try: + from lupa import LuaRuntime + from .commands_mixins.scripting_mixin import ScriptingCommandsMixin +except ImportError: + class ScriptingCommandsMixin: + pass + from .commands_mixins.server_mixin import ServerCommandsMixin from .commands_mixins.set_mixin import SetCommandsMixin from .commands_mixins.sortedset_mixin import SortedSetCommandsMixin diff --git a/fakeredis/_server.py b/fakeredis/_server.py index 5f949a3b..0b5f54d2 100644 --- a/fakeredis/_server.py +++ b/fakeredis/_server.py @@ -39,7 +39,7 @@ def __init__(self, version: Tuple[int] = (7,)): self.subscribers: Dict[bytes, weakref.WeakSet] = defaultdict(weakref.WeakSet) self.psubscribers: Dict[bytes, weakref.WeakSet] = defaultdict(weakref.WeakSet) self.ssubscribers: Dict[bytes, weakref.WeakSet] = defaultdict(weakref.WeakSet) - self.lastsave = int(time.time()) + self.lastsave: int = int(time.time()) self.connected = True # List of weakrefs to sockets that are being closed lazily self.closed_sockets: List[Any] = [] diff --git a/fakeredis/_stream.py b/fakeredis/_stream.py index 86a3c769..e96fdfc5 100644 --- a/fakeredis/_stream.py +++ b/fakeredis/_stream.py @@ -380,7 +380,7 @@ def delete(self, lst: List[Union[str, bytes]]) -> int: def add(self, fields: List, entry_key: str = "*") -> Union[None, bytes]: """Add entry to a stream. - If the entry_key can not be added (because its timestamp is before the last entry, etc.), + If the entry_key cannot be added (because its timestamp is before the last entry, etc.), nothing is added. :param fields: List of fields to add, must [key1, value1, key2, value2, ... ] @@ -493,7 +493,7 @@ def trim( """Trim a stream :param max_length: Max length of the resulting stream after trimming (number of last values to keep) - :param start_entry_key: Min entry-key to keep, can not be given together with max_length. + :param start_entry_key: Min entry-key to keep, cannot be given together with max_length. :param limit: Number of entries to keep from minid. :returns: The resulting stream after trimming. :raises ValueError: When both max_length and start_entry_key are passed. diff --git a/fakeredis/aioredis.py b/fakeredis/aioredis.py index d8a2a4ce..868ea386 100644 --- a/fakeredis/aioredis.py +++ b/fakeredis/aioredis.py @@ -115,7 +115,7 @@ async def disconnect(self, **kwargs): await super().disconnect(**kwargs) self._sock = None - async def can_read(self, timeout: float = 0): + async def can_read(self, timeout: Optional[float] = 0): if not self.is_connected: await self.connect() if timeout == 0: @@ -154,7 +154,7 @@ async def read_response(self, **kwargs): await self.disconnect() raise redis_async.ConnectionError(msgs.CONNECTION_ERROR_MSG) else: - timeout = kwargs.pop("timeout", None) + timeout: Optional[float] = kwargs.pop("timeout", None) can_read = await self.can_read(timeout) response = await self._reader.read(0) if can_read else None if isinstance(response, redis_async.ResponseError): diff --git a/fakeredis/commands_mixins/scripting_mixin.py b/fakeredis/commands_mixins/scripting_mixin.py index 3990105b..17296015 100644 --- a/fakeredis/commands_mixins/scripting_mixin.py +++ b/fakeredis/commands_mixins/scripting_mixin.py @@ -3,7 +3,7 @@ import itertools import logging from typing import Tuple, Callable, AnyStr, Set, Any - +from lupa import LuaRuntime from fakeredis import _msgs as msgs from fakeredis._commands import command, Int from fakeredis._helpers import ( @@ -29,15 +29,15 @@ } -def _ensure_str(s: AnyStr, encoding: str, replaceerr: str): +def _ensure_str(s: AnyStr, encoding: str, replaceerr: str) -> str: if isinstance(s, bytes): res = s.decode(encoding=encoding, errors=replaceerr) else: - res = str(s).encode(encoding=encoding, errors=replaceerr) + res = str(s) return res -def _check_for_lua_globals(lua_runtime, expected_globals): +def _check_for_lua_globals(lua_runtime: LuaRuntime, expected_globals: Set[Any]) -> None: unexpected_globals = set(lua_runtime.globals().keys()) - expected_globals if len(unexpected_globals) > 0: unexpected = [ @@ -46,7 +46,7 @@ def _check_for_lua_globals(lua_runtime, expected_globals): raise SimpleError(msgs.GLOBAL_VARIABLE_MSG.format(", ".join(unexpected))) -def _lua_redis_log(lua_runtime, expected_globals, lvl, *args): +def _lua_redis_log(lua_runtime: LuaRuntime, expected_globals: Set[Any], lvl, *args) -> None: _check_for_lua_globals(lua_runtime, expected_globals) if len(args) < 1: raise SimpleError(msgs.REQUIRES_MORE_ARGS_MSG.format("redis.log()", "two")) @@ -72,7 +72,7 @@ def __init__(self, *args, **kwargs): # Maps SHA1 to the script source self.script_cache = {} - def _convert_redis_arg(self, lua_runtime, value): + def _convert_redis_arg(self, lua_runtime: LuaRuntime, value): # Type checks are exact to avoid issues like bool being a subclass of int. if type(value) is bytes: return value @@ -87,7 +87,7 @@ def _convert_redis_arg(self, lua_runtime, value): ) raise SimpleError(msg) - def _convert_redis_result(self, lua_runtime, result): + def _convert_redis_result(self, lua_runtime: LuaRuntime, result): if isinstance(result, (bytes, int)): return result elif isinstance(result, SimpleString): @@ -139,7 +139,7 @@ def _convert_lua_result(self, result, nested=True): return 1 if result else None return result - def _lua_redis_call(self, lua_runtime, expected_globals, op, *args): + def _lua_redis_call(self, lua_runtime: LuaRuntime, expected_globals: Set[Any], op, *args): # Check if we've set any global variables before making any change. _check_for_lua_globals(lua_runtime, expected_globals) func, sig = self._name_to_func(encode_command(op)) @@ -147,7 +147,7 @@ def _lua_redis_call(self, lua_runtime, expected_globals, op, *args): result = self._run_command(func, sig, new_args, True) return self._convert_redis_result(lua_runtime, result) - def _lua_redis_pcall(self, lua_runtime, expected_globals, op, *args): + def _lua_redis_pcall(self, lua_runtime: LuaRuntime, expected_globals: Set[Any], op, *args): try: return self._lua_redis_call(lua_runtime, expected_globals, op, *args) except Exception as ex: @@ -163,7 +163,7 @@ def eval(self, script, numkeys, *keys_and_args): raise SimpleError(msgs.NEGATIVE_KEYS_MSG) sha1 = hashlib.sha1(script).hexdigest().encode() self.script_cache[sha1] = script - lua_runtime = LuaRuntime(encoding=None, unpack_returned_tuples=True) + lua_runtime: LuaRuntime = LuaRuntime(encoding=None, unpack_returned_tuples=True) set_globals = lua_runtime.eval( """ diff --git a/fakeredis/commands_mixins/server_mixin.py b/fakeredis/commands_mixins/server_mixin.py index 1ef1f91a..dd9e7f1d 100644 --- a/fakeredis/commands_mixins/server_mixin.py +++ b/fakeredis/commands_mixins/server_mixin.py @@ -5,7 +5,7 @@ from fakeredis import _msgs as msgs from fakeredis._commands import command, DbIndex -from fakeredis._helpers import OK, SimpleError, casematch, BGSAVE_STARTED, Database +from fakeredis._helpers import OK, SimpleError, casematch, BGSAVE_STARTED, Database, SimpleString _COMMAND_INFO: Optional[Dict[bytes, List[Any]]] = None @@ -28,8 +28,11 @@ def _load_command_info() -> None: class ServerCommandsMixin: - _server: Any - _db: Database + def __init__(self, *args: Any, **kwargs: Any) -> None: + super().__init__(*args, **kwargs) + from fakeredis._server import FakeServer + self._server: "FakeServer" + self._db: Database @staticmethod def _get_command_info(cmd: bytes) -> Optional[List[Any]]: @@ -39,25 +42,25 @@ def _get_command_info(cmd: bytes) -> Optional[List[Any]]: return _COMMAND_INFO.get(cmd, None) @command((), (bytes,), flags=msgs.FLAG_NO_SCRIPT) - def bgsave(self, *args): + def bgsave(self, *args: bytes) -> SimpleString: if len(args) > 1 or (len(args) == 1 and not casematch(args[0], b"schedule")): raise SimpleError(msgs.SYNTAX_ERROR_MSG) self._server.lastsave = int(time.time()) return BGSAVE_STARTED @command(()) - def dbsize(self): + def dbsize(self) -> int: return len(self._db) @command((), (bytes,)) - def flushdb(self, *args): + def flushdb(self, *args: bytes) -> SimpleString: if len(args) > 0 and (len(args) != 1 or not casematch(args[0], b"async")): raise SimpleError(msgs.SYNTAX_ERROR_MSG) self._db.clear() return OK @command((), (bytes,)) - def flushall(self, *args): + def flushall(self, *args: bytes) -> SimpleString: if len(args) > 0 and (len(args) != 1 or not casematch(args[0], b"async")): raise SimpleError(msgs.SYNTAX_ERROR_MSG) for db in self._server.dbs.values(): @@ -66,23 +69,23 @@ def flushall(self, *args): return OK @command(()) - def lastsave(self): + def lastsave(self) -> int: return self._server.lastsave @command((), flags=msgs.FLAG_NO_SCRIPT) - def save(self): + def save(self) -> SimpleString: self._server.lastsave = int(time.time()) return OK @command(()) - def time(self): + def time(self) -> List[bytes]: now_us = round(time.time() * 1_000_000) now_s = now_us // 1_000_000 now_us %= 1_000_000 return [str(now_s).encode(), str(now_us).encode()] @command((DbIndex, DbIndex)) - def swapdb(self, index1, index2): + def swapdb(self, index1: int, index2: int) -> SimpleString: if index1 != index2: db1 = self._server.dbs[index1] db2 = self._server.dbs[index2] @@ -90,17 +93,17 @@ def swapdb(self, index1, index2): return OK @command(name="COMMAND INFO", fixed=(), repeat=(bytes,)) - def command_info(self, *commands): + def command_info(self, *commands: bytes) -> List[Any]: res = [self._get_command_info(cmd) for cmd in commands] return res @command(name="COMMAND COUNT", fixed=(), repeat=()) - def command_count(self): + def command_count(self) -> int: _load_command_info() return len(_COMMAND_INFO) if _COMMAND_INFO is not None else 0 @command(name="COMMAND", fixed=(), repeat=()) - def command_(self): + def command_(self) -> List[Any]: _load_command_info() if _COMMAND_INFO is None: return []