Skip to content

Commit

Permalink
Implement json.mget (#85)
Browse files Browse the repository at this point in the history
* Implement json.mget

Resolves #81
  • Loading branch information
cunla authored Nov 28, 2022
1 parent 2c13b7b commit 97303fc
Show file tree
Hide file tree
Showing 7 changed files with 61 additions and 90 deletions.
4 changes: 2 additions & 2 deletions .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -95,11 +95,11 @@ jobs:
- name: Install lupa
if: ${{ matrix.lupa }}
run: |
poetry run pip install fakeredis[lupa]
poetry run pip install fakeredis[lua]
- name: Install json
if: ${{ matrix.json }}
run: |
poetry run pip install fakeredis[jsonpath]
poetry run pip install fakeredis[json]
- name: Get version
id: getVersion
shell: bash
Expand Down
5 changes: 1 addition & 4 deletions fakeredis/_basefakesocket.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,10 @@
from ._commands import (Int, Float, SUPPORTED_COMMANDS, COMMANDS_WITH_SUB)
from ._helpers import (
SimpleError, valid_response_type, SimpleString, NoResponse, casematch,
compile_pattern, QUEUED)
compile_pattern, QUEUED, encode_command)


def _extract_command(fields):
def encode_command(s):
return s.decode(encoding='utf-8', errors='replace').lower()

cmd = encode_command(fields[0])
if cmd in COMMANDS_WITH_SUB and len(fields) >= 2:
cmd += ' ' + encode_command(fields[1])
Expand Down
4 changes: 4 additions & 0 deletions fakeredis/_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,10 @@ def casematch(a, b):
return casenorm(a) == casenorm(b)


def encode_command(s):
return s.decode(encoding='utf-8', errors='replace').lower()


def compile_pattern(pattern):
"""Compile a glob pattern (e.g. for keys) to a bytes regex.
Expand Down
4 changes: 2 additions & 2 deletions fakeredis/commands_mixins/scripting_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

from fakeredis import _msgs as msgs
from fakeredis._commands import command, Int
from fakeredis._helpers import SimpleError, SimpleString, casenorm, OK
from fakeredis._helpers import SimpleError, SimpleString, casenorm, OK, encode_command

LOGGER = logging.getLogger('fakeredis')
REDIS_LOG_LEVELS = {
Expand Down Expand Up @@ -121,7 +121,7 @@ def _convert_lua_result(self, result, nested=True):
def _lua_redis_call(self, lua_runtime, expected_globals, op, *args):
# Check if we've set any global variables before making any change.
_check_for_lua_globals(lua_runtime, expected_globals)
func, sig = self._name_to_func(self._encode_command(op))
func, sig = self._name_to_func(encode_command(op))
args = [self._convert_redis_arg(lua_runtime, arg) for arg in args]
result = self._run_command(func, sig, args, True)
return self._convert_redis_result(lua_runtime, result)
Expand Down
35 changes: 19 additions & 16 deletions fakeredis/stack/_json_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from redis.commands.json.commands import JsonType

from fakeredis import _helpers as helpers, _msgs as msgs
from fakeredis._commands import Key, command, delete_keys
from fakeredis._commands import Key, command, delete_keys, CommandItem
from fakeredis._helpers import SimpleError, casematch

path_pattern: re.Pattern = re.compile(r"^((?<!\$)\.|(\$\.$))")
Expand Down Expand Up @@ -87,10 +87,15 @@ def json_del(self, key, path_str) -> int:
return len(found_matches)

@staticmethod
def _get_single(key, path_str):
def _get_single(key, path_str: str, always_return_list: bool = False, empty_list_as_none: bool = False):
path = _parse_jsonpath(path_str)
res = path.find(key.value)
return res
path_value = path.find(key.value)
val = [i.value for i in path_value]
if empty_list_as_none and len(val) == 0:
val = None
elif len(val) == 1 and not always_return_list:
val = val[0]
return val

@command(name="JSON.GET", fixed=(Key(),), repeat=(bytes,), )
def json_get(self, key, *args) -> bytes:
Expand All @@ -107,17 +112,7 @@ def json_get(self, key, *args) -> bytes:
for arg in args
if not casematch(b'noescape', arg)
]
resolved_paths = [self._get_single(key, path) for path in formatted_paths]

path_values = list()
for lst in resolved_paths:
if len(lst) == 0:
val = []
elif len(lst) > 1 or len(resolved_paths) > 1:
val = [i.value for i in lst]
else:
val = lst[0].value
path_values.append(val)
path_values = [self._get_single(key, path, len(formatted_paths) > 1) for path in formatted_paths]

# Emulate the behavior of `redis-py`:
# - if only one path was supplied => return a single value
Expand Down Expand Up @@ -169,5 +164,13 @@ def json_set(

return helpers.OK

@command(name="JSON.MGET", fixed=(bytes,), repeat=(bytes,), )
def json_mget(self, *args):
pass
if len(args) < 2:
raise SimpleError(msgs.WRONG_ARGS_MSG6.format('json.mget'))
path_str = args[-1]
keys = [CommandItem(key, self._db, item=self._db.get(key), default=[])
for key in args[:-1]]

result = [self._get_single(key, path_str, empty_list_as_none=True) for key in keys]
return result
29 changes: 29 additions & 0 deletions test/json/test_json.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,3 +82,32 @@ def test_json_set_flags_should_be_mutually_exclusive(r: redis.Redis):
def test_json_unknown_param(r: redis.Redis):
with pytest.raises(redis.ResponseError):
raw_command(r, 'json.set', 'obj', '$', json.dumps({"foo": "bar"}), 'unknown')


def test_mget(r: redis.Redis):
# Test mget with multi paths
r.json().set("doc1", "$", {"a": 1, "b": 2, "nested": {"a": 3}, "c": None, "nested2": {"a": None}})
r.json().set("doc2", "$", {"a": 4, "b": 5, "nested": {"a": 6}, "c": None, "nested2": {"a": [None]}})
# Compare also to single JSON.GET
assert r.json().get("doc1", Path("$..a")) == [1, 3, None]
assert r.json().get("doc2", "$..a") == [4, 6, [None]]

# Test mget with single path
assert r.json().mget(["doc1"], "$..a") == [[1, 3, None]]

# Test mget with multi path
assert r.json().mget(["doc1", "doc2"], "$..a") == [[1, 3, None], [4, 6, [None]]]

# Test missing key
assert r.json().mget(["doc1", "missing_doc"], "$..a") == [[1, 3, None], None]

assert r.json().mget(["missing_doc1", "missing_doc2"], "$..a") == [None, None]


def test_mget_should_succeed(r: redis.Redis) -> None:
r.json().set("1", Path.root_path(), 1)
r.json().set("2", Path.root_path(), 2)

assert r.json().mget(["1"], Path.root_path()) == [1]

assert r.json().mget([1, 2], Path.root_path()) == [1, 2]
70 changes: 4 additions & 66 deletions test/json/test_json_commands.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,30 +105,6 @@ def test_nonascii_setgetdelete(r: redis.Redis) -> None:
assert r.exists("not-ascii") == 0


@pytest.mark.xfail
def test_mget_should_succeed(r: redis.Redis) -> None:
r.json().set(
"1",
Path.root_path(),
1,
)
r.json().set(
"2",
Path.root_path(),
2,
)

assert r.json().mget(
["1"],
Path.root_path(),
) == [1]

assert r.json().mget(
[1, 2],
Path.root_path(),
) == [1, 2]


@pytest.mark.xfail
def test_clear(r: redis.Redis) -> None:
r.json().set(
Expand All @@ -154,27 +130,11 @@ def test_type(r: redis.Redis) -> None:

@pytest.mark.xfail
def test_numincrby(r: redis.Redis) -> None:
r.json().set(
"num",
Path.root_path(),
1,
)
r.json().set("num", Path.root_path(), 1)

assert 2 == r.json().numincrby(
"num",
Path.root_path(),
1,
)
assert 2.5 == r.json().numincrby(
"num",
Path.root_path(),
0.5,
)
assert 1.25 == r.json().numincrby(
"num",
Path.root_path(),
-1.25,
)
assert 2 == r.json().numincrby("num", Path.root_path(), 1)
assert 2.5 == r.json().numincrby("num", Path.root_path(), 0.5)
assert 1.25 == r.json().numincrby("num", Path.root_path(), -1.25)


@pytest.mark.xfail
Expand Down Expand Up @@ -641,28 +601,6 @@ def test_json_forget_with_dollar(r: redis.Redis) -> None:
r.json().forget("not_a_document", "..a")


@pytest.mark.xfail
def test_json_mget_dollar(r: redis.Redis) -> None:
# Test mget with multi paths
r.json().set("doc1", "$", {"a": 1, "b": 2, "nested": {"a": 3}, "c": None, "nested2": {"a": None}})
r.json().set("doc2", "$", {"a": 4, "b": 5, "nested": {"a": 6}, "c": None, "nested2": {"a": [None]}})
# Compare also to single JSON.GET
assert r.json().get("doc1", "$..a") == [1, 3, None]
assert r.json().get("doc2", "$..a") == [4, 6, [None]]

# Test mget with single path
assert r.json().mget("doc1", "$..a") == [1, 3, None]

# Test mget with multi path
assert r.json().mget(["doc1", "doc2"], "$..a") == [[1, 3, None], [4, 6, [None]]]

# Test missing key
assert r.json().mget(["doc1", "missing_doc"], "$..a") == [[1, 3, None], None]
res = r.json().mget(["missing_doc1", "missing_doc2"], "$..a")

assert res == [None, None]


@pytest.mark.xfail
def test_numby_commands_dollar(r: redis.Redis) -> None:
# Test NUMINCRBY
Expand Down

0 comments on commit 97303fc

Please sign in to comment.