Skip to content

Commit

Permalink
fix:improve json_mixin code
Browse files Browse the repository at this point in the history
  • Loading branch information
cunla committed Aug 7, 2023
1 parent 46607cf commit 4cf433f
Showing 1 changed file with 29 additions and 30 deletions.
59 changes: 29 additions & 30 deletions fakeredis/stack/_json_mixin.py
Original file line number Diff line number Diff line change
@@ -1,25 +1,22 @@
"""Command mixin for emulating `redis-py`'s JSON functionality."""

# Future Imports
from __future__ import annotations

import copy
# Standard Library Imports
import json
from json import JSONDecodeError
from typing import Any, Union, Dict
from typing import Any, Union, Dict, List, Optional

from jsonpath_ng import Root, JSONPath
from jsonpath_ng.exceptions import JsonPathParserError
from jsonpath_ng.ext import parse
from redis.commands.json.commands import JsonType

from fakeredis import _helpers as helpers, _msgs as msgs
from fakeredis import _helpers as helpers
from fakeredis import _msgs as msgs
from fakeredis._command_args_parsing import extract_args
from fakeredis._commands import Key, command, delete_keys, CommandItem, Int, Float
from fakeredis._helpers import SimpleError, casematch
from fakeredis._zset import ZSet

JsonType = Union[str, int, float, bool, None, Dict[str, Any], List[Any]]


def _format_path(path) -> str:
if isinstance(path, bytes):
Expand All @@ -39,7 +36,7 @@ def _parse_jsonpath(path: Union[str, bytes]):
try:
return parse(path)
except JsonPathParserError:
raise SimpleError(msgs.JSON_PATH_DOES_NOT_EXIST.format(path))
raise helpers.SimpleError(msgs.JSON_PATH_DOES_NOT_EXIST.format(path))


def _path_is_root(path: JSONPath) -> bool:
Expand Down Expand Up @@ -73,7 +70,7 @@ def decode(cls, value: bytes) -> Any:
try:
return json.loads(value)
except JSONDecodeError:
raise SimpleError(cls.DECODE_ERROR)
raise helpers.SimpleError(cls.DECODE_ERROR)

@classmethod
def encode(cls, value: Any) -> bytes:
Expand All @@ -87,11 +84,11 @@ def _json_write_iterate(method, key, path_str, **kwargs):
Iterate over values with path_str in key and running method to get new value for path item.
"""
if key.value is None:
raise SimpleError(msgs.JSON_KEY_NOT_FOUND)
raise helpers.SimpleError(msgs.JSON_KEY_NOT_FOUND)
path = _parse_jsonpath(path_str)
found_matches = path.find(key.value)
if len(found_matches) == 0:
raise SimpleError(msgs.JSON_PATH_NOT_FOUND_OR_NOT_STRING.format(path_str))
raise helpers.SimpleError(msgs.JSON_PATH_NOT_FOUND_OR_NOT_STRING.format(path_str))

curr_value = copy.deepcopy(key.value)
res = list()
Expand All @@ -117,14 +114,14 @@ def _json_read_iterate(method, key, *args, error_on_zero_matches=False):
path_str = args[0] if len(args) > 0 else '$'
if key.value is None:
if path_str[0] == 36:
raise SimpleError(msgs.JSON_KEY_NOT_FOUND)
raise helpers.SimpleError(msgs.JSON_KEY_NOT_FOUND)
else:
return None

path = _parse_jsonpath(path_str)
found_matches = path.find(key.value)
if error_on_zero_matches and len(found_matches) == 0 and path_str[0] != 36:
raise SimpleError(msgs.JSON_PATH_NOT_FOUND_OR_NOT_STRING.format(path_str))
raise helpers.SimpleError(msgs.JSON_PATH_NOT_FOUND_OR_NOT_STRING.format(path_str))
res = list()
for item in found_matches:
res.append(method(item.value))
Expand Down Expand Up @@ -159,6 +156,8 @@ class JSONCommandsMixin:
ZSet: 'zset'
}

_db: helpers.Database

def __init__(self, *args: Any, **kwargs: Any) -> None:
super().__init__(*args, **kwargs)

Expand Down Expand Up @@ -199,11 +198,11 @@ def json_del(self, key, path_str) -> int:
def _json_set(key: CommandItem, path_str: bytes, value: JsonType, *args):
path = _parse_jsonpath(path_str)
if key.value is not None and (type(key.value) is not dict) and not _path_is_root(path):
raise SimpleError(msgs.JSON_WRONG_REDIS_TYPE)
raise helpers.SimpleError(msgs.JSON_WRONG_REDIS_TYPE)
old_value = path.find(key.value)
(nx, xx), _ = extract_args(args, ('nx', 'xx'))
if xx and nx:
raise SimpleError(msgs.SYNTAX_ERROR_MSG)
raise helpers.SimpleError(msgs.SYNTAX_ERROR_MSG)
if (nx and old_value) or (xx and not old_value):
return None
new_value = path.update_or_create(key.value, value)
Expand All @@ -219,15 +218,15 @@ def json_set(self, key, path_str: bytes, value: JsonType, *args):
return JSONCommandsMixin._json_set(key, path_str, value, *args)

@command(name="JSON.GET", fixed=(Key(),), repeat=(bytes,), flags=msgs.FLAG_LEAVE_EMPTY_VAL)
def json_get(self, key, *args) -> bytes:
def json_get(self, key, *args) -> Optional[bytes]:
if key.value is None:
return None
paths = [arg for arg in args if not casematch(b'noescape', arg)]
paths = [arg for arg in args if not helpers.casematch(b'noescape', arg)]
no_wrapping_array = (len(paths) == 1 and paths[0][0] == ord(b'.'))

formatted_paths = [
_format_path(arg) for arg in args
if not casematch(b'noescape', arg)
if not helpers.casematch(b'noescape', arg)
]
path_values = [self._get_single(key, path, len(formatted_paths) > 1) for path in formatted_paths]

Expand All @@ -244,7 +243,7 @@ def json_get(self, key, *args) -> bytes:
@command(name="JSON.MGET", fixed=(bytes,), repeat=(bytes,), flags=msgs.FLAG_LEAVE_EMPTY_VAL)
def json_mget(self, *args):
if len(args) < 2:
raise SimpleError(msgs.WRONG_ARGS_MSG6.format('json.mget'))
raise helpers.SimpleError(msgs.WRONG_ARGS_MSG6.format('json.mget'))
path_str = args[-1]
keys = [CommandItem(key, self._db, item=self._db.get(key), default=[])
for key in args[:-1]]
Expand All @@ -255,7 +254,7 @@ def json_mget(self, *args):
@command(name="JSON.TOGGLE", fixed=(Key(),), repeat=(bytes,), flags=msgs.FLAG_LEAVE_EMPTY_VAL)
def json_toggle(self, key, *args):
if key.value is None:
raise SimpleError(msgs.JSON_KEY_NOT_FOUND)
raise helpers.SimpleError(msgs.JSON_KEY_NOT_FOUND)
path_str = args[0] if len(args) > 0 else '$'
path = _parse_jsonpath(path_str)
found_matches = path.find(key.value)
Expand All @@ -269,7 +268,7 @@ def json_toggle(self, key, *args):
else:
res.append(None)
if all([x is None for x in res]):
raise SimpleError(msgs.JSON_KEY_NOT_FOUND)
raise helpers.SimpleError(msgs.JSON_KEY_NOT_FOUND)
key.update(curr_value)

if len(res) == 1 and (len(args) == 0 or (len(args) == 1 and args[0] == b'.')):
Expand All @@ -280,7 +279,7 @@ def json_toggle(self, key, *args):
@command(name="JSON.CLEAR", fixed=(Key(),), repeat=(bytes,), flags=msgs.FLAG_LEAVE_EMPTY_VAL)
def json_clear(self, key, *args, ):
if key.value is None:
raise SimpleError(msgs.JSON_KEY_NOT_FOUND)
raise helpers.SimpleError(msgs.JSON_KEY_NOT_FOUND)
path_str = args[0] if len(args) > 0 else '$'
path = _parse_jsonpath(path_str)
found_matches = path.find(key.value)
Expand All @@ -298,7 +297,7 @@ def json_clear(self, key, *args, ):
@command(name="JSON.STRAPPEND", fixed=(Key(), bytes), repeat=(bytes,), flags=msgs.FLAG_LEAVE_EMPTY_VAL)
def json_strappend(self, key, path_str, *args):
if len(args) == 0:
raise SimpleError(msgs.WRONG_ARGS_MSG6.format('json.strappend'))
raise helpers.SimpleError(msgs.WRONG_ARGS_MSG6.format('json.strappend'))
addition = JSONObject.decode(args[0])

def strappend(val):
Expand All @@ -313,7 +312,7 @@ def strappend(val):
@command(name="JSON.ARRAPPEND", fixed=(Key(), bytes,), repeat=(bytes,), flags=msgs.FLAG_LEAVE_EMPTY_VAL)
def json_arrappend(self, key, path_str, *args):
if len(args) == 0:
raise SimpleError(msgs.WRONG_ARGS_MSG6.format('json.arrappend'))
raise helpers.SimpleError(msgs.WRONG_ARGS_MSG6.format('json.arrappend'))

addition = [JSONObject.decode(item) for item in args]

Expand All @@ -329,7 +328,7 @@ def arrappend(val):
@command(name="JSON.ARRINSERT", fixed=(Key(), bytes, Int), repeat=(bytes,), flags=msgs.FLAG_LEAVE_EMPTY_VAL)
def json_arrinsert(self, key, path_str, index, *args):
if len(args) == 0:
raise SimpleError(msgs.WRONG_ARGS_MSG6.format('json.arrinsert'))
raise helpers.SimpleError(msgs.WRONG_ARGS_MSG6.format('json.arrinsert'))

addition = [JSONObject.decode(item) for item in args]

Expand Down Expand Up @@ -377,7 +376,7 @@ def arrtrim(val):
return _json_write_iterate(arrtrim, key, path_str)

@command(name="JSON.NUMINCRBY", fixed=(Key(), bytes, Float), repeat=(bytes,), flags=msgs.FLAG_LEAVE_EMPTY_VAL)
def json_numincrby(self, key, path_str, inc_by, *args):
def json_numincrby(self, key, path_str, inc_by, *_):

def numincrby(val):
if type(val) in {int, float}:
Expand All @@ -389,7 +388,7 @@ def numincrby(val):
return _json_write_iterate(numincrby, key, path_str)

@command(name="JSON.NUMMULTBY", fixed=(Key(), bytes, Float), repeat=(bytes,), flags=msgs.FLAG_LEAVE_EMPTY_VAL)
def json_nummultby(self, key, path_str, mult_by, *args):
def json_nummultby(self, key, path_str, mult_by, *_):

def nummultby(val):
if type(val) in {int, float}:
Expand Down Expand Up @@ -449,7 +448,7 @@ def json_objkeys(self, key, *args):
@command(name="JSON.MSET", fixed=(), repeat=(Key(), bytes, JSONObject), flags=msgs.FLAG_LEAVE_EMPTY_VAL)
def json_mset(self, *args):
if len(args) < 3 or len(args) % 3 != 0:
raise SimpleError(msgs.WRONG_ARGS_MSG6.format('json.mset'))
raise helpers.SimpleError(msgs.WRONG_ARGS_MSG6.format('json.mset'))
for i in range(0, len(args), 3):
key, path_str, value = args[i], args[i + 1], args[i + 2]
JSONCommandsMixin._json_set(key, path_str, value)
Expand All @@ -459,7 +458,7 @@ def json_mset(self, *args):
def json_merge(self, key, path_str: bytes, value: JsonType):
path: JSONPath = _parse_jsonpath(path_str)
if key.value is not None and (type(key.value) is not dict) and not _path_is_root(path):
raise SimpleError(msgs.JSON_WRONG_REDIS_TYPE)
raise helpers.SimpleError(msgs.JSON_WRONG_REDIS_TYPE)
matching = path.find(key.value)
for item in matching:
prev_value = item.value if item is not None else dict()
Expand Down

0 comments on commit 4cf433f

Please sign in to comment.