Skip to content

Commit

Permalink
Separate LUA support to a different file (#55)
Browse files Browse the repository at this point in the history
* Separate LUA support to a different file
  • Loading branch information
cunla authored Oct 15, 2022
1 parent cc8c2e7 commit ac471ca
Show file tree
Hide file tree
Showing 4 changed files with 189 additions and 174 deletions.
167 changes: 167 additions & 0 deletions fakeredis/_basefakeluasupport.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,167 @@
import functools
import hashlib
import itertools

from fakeredis._helpers import SimpleError, SimpleString
from . import _msgs as msgs
from ._commands import command, Int
from ._helpers import REDIS_LOG_LEVELS, REDIS_LOG_LEVELS_TO_LOGGING, LOGGER


def _ensure_str(s, encoding, replaceerr):
if isinstance(s, bytes):
res = s.decode(encoding=encoding, errors=replaceerr)
else:
res = str(s).encode(encoding=encoding, errors=replaceerr)
return res


def _check_for_lua_globals(lua_runtime, expected_globals):
unexpected_globals = set(lua_runtime.globals().keys()) - expected_globals
if len(unexpected_globals) > 0:
unexpected = [_ensure_str(var, 'utf-8', 'replace') for var in unexpected_globals]
raise SimpleError(msgs.GLOBAL_VARIABLE_MSG.format(", ".join(unexpected)))


def _lua_redis_log(lua_runtime, expected_globals, lvl, *args):
_check_for_lua_globals(lua_runtime, expected_globals)
if len(args) < 1:
raise SimpleError(msgs.REQUIRES_MORE_ARGS_MSG.format("redis.log()", "two"))
if lvl not in REDIS_LOG_LEVELS.values():
raise SimpleError(msgs.LOG_INVALID_DEBUG_LEVEL_MSG)
msg = ' '.join([x.decode('utf-8')
if isinstance(x, bytes) else str(x)
for x in args if not isinstance(x, bool)])
LOGGER.log(REDIS_LOG_LEVELS_TO_LOGGING[lvl], msg)


class BaseFakeLuaSocket:

# Script commands
# script debug and script kill will probably not be supported

def _convert_redis_arg(self, lua_runtime, value):
# Type checks are exact to avoid issues like bool being a subclass of int.
if type(value) is bytes:
return value
elif type(value) in {int, float}:
return '{:.17g}'.format(value).encode()
else:
# TODO: add the context
msg = msgs.LUA_COMMAND_ARG_MSG6 if self.version < 7 else msgs.LUA_COMMAND_ARG_MSG
raise SimpleError(msg)

def _convert_redis_result(self, lua_runtime, result):
if isinstance(result, (bytes, int)):
return result
elif isinstance(result, SimpleString):
return lua_runtime.table_from({b"ok": result.value})
elif result is None:
return False
elif isinstance(result, list):
converted = [
self._convert_redis_result(lua_runtime, item)
for item in result
]
return lua_runtime.table_from(converted)
elif isinstance(result, SimpleError):
raise result
else:
raise RuntimeError("Unexpected return type from redis: {}".format(type(result)))

def _convert_lua_result(self, result, nested=True):
from lupa import lua_type
if lua_type(result) == 'table':
for key in (b'ok', b'err'):
if key in result:
msg = self._convert_lua_result(result[key])
if not isinstance(msg, bytes):
raise SimpleError(msgs.LUA_WRONG_NUMBER_ARGS_MSG)
if key == b'ok':
return SimpleString(msg)
elif nested:
return SimpleError(msg.decode('utf-8', 'replace'))
else:
raise SimpleError(msg.decode('utf-8', 'replace'))
# Convert Lua tables into lists, starting from index 1, mimicking the behavior of StrictRedis.
result_list = []
for index in itertools.count(1):
if index not in result:
break
item = result[index]
result_list.append(self._convert_lua_result(item))
return result_list
elif isinstance(result, str):
return result.encode()
elif isinstance(result, float):
return int(result)
elif isinstance(result, bool):
return 1 if result else None
return result

def _lua_redis_call(self, lua_runtime, expected_globals, op, *args):
# Check if we've set any global variables before making any change.
_check_for_lua_globals(lua_runtime, expected_globals)
func, func_name = self._name_to_func(op)
args = [self._convert_redis_arg(lua_runtime, arg) for arg in args]
result = self._run_command(func, func._fakeredis_sig, args, True)
return self._convert_redis_result(lua_runtime, result)

def _lua_redis_pcall(self, lua_runtime, expected_globals, op, *args):
try:
return self._lua_redis_call(lua_runtime, expected_globals, op, *args)
except Exception as ex:
return lua_runtime.table_from({b"err": str(ex)})

@command((bytes, Int), (bytes,), flags='s')
def eval(self, script, numkeys, *keys_and_args):
from lupa import LuaError, LuaRuntime, as_attrgetter

if numkeys > len(keys_and_args):
raise SimpleError(msgs.TOO_MANY_KEYS_MSG)
if numkeys < 0:
raise SimpleError(msgs.NEGATIVE_KEYS_MSG)
sha1 = hashlib.sha1(script).hexdigest().encode()
self._server.script_cache[sha1] = script
lua_runtime = LuaRuntime(encoding=None, unpack_returned_tuples=True)

set_globals = lua_runtime.eval(
"""
function(keys, argv, redis_call, redis_pcall, redis_log, redis_log_levels)
redis = {}
redis.call = redis_call
redis.pcall = redis_pcall
redis.log = redis_log
for level, pylevel in python.iterex(redis_log_levels.items()) do
redis[level] = pylevel
end
redis.error_reply = function(msg) return {err=msg} end
redis.status_reply = function(msg) return {ok=msg} end
KEYS = keys
ARGV = argv
end
"""
)
expected_globals = set()
set_globals(
lua_runtime.table_from(keys_and_args[:numkeys]),
lua_runtime.table_from(keys_and_args[numkeys:]),
functools.partial(self._lua_redis_call, lua_runtime, expected_globals),
functools.partial(self._lua_redis_pcall, lua_runtime, expected_globals),
functools.partial(_lua_redis_log, lua_runtime, expected_globals),
as_attrgetter(REDIS_LOG_LEVELS)
)
expected_globals.update(lua_runtime.globals().keys())

try:
result = lua_runtime.execute(script)
except SimpleError as ex:
if self.version == 6:
raise SimpleError(msgs.SCRIPT_ERROR_MSG.format(sha1.decode(), ex))
raise SimpleError(ex.value)
except LuaError as ex:
raise SimpleError(msgs.SCRIPT_ERROR_MSG.format(sha1.decode(), ex))

_check_for_lua_globals(lua_runtime, expected_globals)

return self._convert_lua_result(result, nested=False)
158 changes: 3 additions & 155 deletions fakeredis/_fakesocket.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,18 +9,18 @@
import redis

from . import _msgs as msgs
from ._basefakeluasupport import BaseFakeLuaSocket
from ._basefakesocket import BaseFakeSocket
from ._commands import (
Key, command, DbIndex, Int, CommandItem, BeforeAny, SortFloat, Float, BitOffset, BitValue, Hash,
StringTest, ScoreTest, Timeout)
from ._helpers import (
PONG, OK, MAX_STRING_SIZE, SimpleError, SimpleString, casematch,
BGSAVE_STARTED, REDIS_LOG_LEVELS_TO_LOGGING, LOGGER, REDIS_LOG_LEVELS, casenorm, compile_pattern)
from ._msgs import LUA_COMMAND_ARG_MSG, LUA_COMMAND_ARG_MSG6
BGSAVE_STARTED, casenorm, compile_pattern)
from ._zset import ZSet


class FakeSocket(BaseFakeSocket):
class FakeSocket(BaseFakeSocket, BaseFakeLuaSocket):
_connection_error_class = redis.ConnectionError

def __init__(self, server):
Expand Down Expand Up @@ -1455,158 +1455,6 @@ def time(self):
now_us %= 1000000
return [str(now_s).encode(), str(now_us).encode()]

# Script commands
# script debug and script kill will probably not be supported

def _convert_redis_arg(self, lua_runtime, value):
# Type checks are exact to avoid issues like bool being a subclass of int.
if type(value) is bytes:
return value
elif type(value) in {int, float}:
return '{:.17g}'.format(value).encode()
else:
# TODO: add the context
msg = LUA_COMMAND_ARG_MSG6 if self.version < 7 else LUA_COMMAND_ARG_MSG
raise SimpleError(msg)

def _convert_redis_result(self, lua_runtime, result):
if isinstance(result, (bytes, int)):
return result
elif isinstance(result, SimpleString):
return lua_runtime.table_from({b"ok": result.value})
elif result is None:
return False
elif isinstance(result, list):
converted = [
self._convert_redis_result(lua_runtime, item)
for item in result
]
return lua_runtime.table_from(converted)
elif isinstance(result, SimpleError):
raise result
else:
raise RuntimeError("Unexpected return type from redis: {}".format(type(result)))

def _convert_lua_result(self, result, nested=True):
from lupa import lua_type
if lua_type(result) == 'table':
for key in (b'ok', b'err'):
if key in result:
msg = self._convert_lua_result(result[key])
if not isinstance(msg, bytes):
raise SimpleError(msgs.LUA_WRONG_NUMBER_ARGS_MSG)
if key == b'ok':
return SimpleString(msg)
elif nested:
return SimpleError(msg.decode('utf-8', 'replace'))
else:
raise SimpleError(msg.decode('utf-8', 'replace'))
# Convert Lua tables into lists, starting from index 1, mimicking the behavior of StrictRedis.
result_list = []
for index in itertools.count(1):
if index not in result:
break
item = result[index]
result_list.append(self._convert_lua_result(item))
return result_list
elif isinstance(result, str):
return result.encode()
elif isinstance(result, float):
return int(result)
elif isinstance(result, bool):
return 1 if result else None
return result

def ensure_str(self, s):
return (s.decode(encoding='utf-8', errors='replace')
if isinstance(s, bytes)
else str(s).encode(encoding='utf-8', errors='replace'))

def _check_for_lua_globals(self, lua_runtime, expected_globals):
actual_globals = set(lua_runtime.globals().keys())
if actual_globals != expected_globals:
unexpected = [self.ensure_str(var, 'utf-8', 'replace')
for var in actual_globals - expected_globals]
raise SimpleError(msgs.GLOBAL_VARIABLE_MSG.format(", ".join(unexpected)))

def _lua_redis_call(self, lua_runtime, expected_globals, op, *args):
# Check if we've set any global variables before making any change.
self._check_for_lua_globals(lua_runtime, expected_globals)
func, func_name = self._name_to_func(op)
args = [self._convert_redis_arg(lua_runtime, arg) for arg in args]
result = self._run_command(func, func._fakeredis_sig, args, True)
return self._convert_redis_result(lua_runtime, result)

def _lua_redis_pcall(self, lua_runtime, expected_globals, op, *args):
try:
return self._lua_redis_call(lua_runtime, expected_globals, op, *args)
except Exception as ex:
return lua_runtime.table_from({b"err": str(ex)})

def _lua_redis_log(self, lua_runtime, expected_globals, lvl, *args):
self._check_for_lua_globals(lua_runtime, expected_globals)
if len(args) < 1:
raise SimpleError(msgs.REQUIRES_MORE_ARGS_MSG.format("redis.log()", "two"))
if lvl not in REDIS_LOG_LEVELS.values():
raise SimpleError(msgs.LOG_INVALID_DEBUG_LEVEL_MSG)
msg = ' '.join([x.decode('utf-8')
if isinstance(x, bytes) else str(x)
for x in args if not isinstance(x, bool)])
LOGGER.log(REDIS_LOG_LEVELS_TO_LOGGING[lvl], msg)

@command((bytes, Int), (bytes,), flags='s')
def eval(self, script, numkeys, *keys_and_args):
from lupa import LuaError, LuaRuntime, as_attrgetter

if numkeys > len(keys_and_args):
raise SimpleError(msgs.TOO_MANY_KEYS_MSG)
if numkeys < 0:
raise SimpleError(msgs.NEGATIVE_KEYS_MSG)
sha1 = hashlib.sha1(script).hexdigest().encode()
self._server.script_cache[sha1] = script
lua_runtime = LuaRuntime(encoding=None, unpack_returned_tuples=True)

set_globals = lua_runtime.eval(
"""
function(keys, argv, redis_call, redis_pcall, redis_log, redis_log_levels)
redis = {}
redis.call = redis_call
redis.pcall = redis_pcall
redis.log = redis_log
for level, pylevel in python.iterex(redis_log_levels.items()) do
redis[level] = pylevel
end
redis.error_reply = function(msg) return {err=msg} end
redis.status_reply = function(msg) return {ok=msg} end
KEYS = keys
ARGV = argv
end
"""
)
expected_globals = set()
set_globals(
lua_runtime.table_from(keys_and_args[:numkeys]),
lua_runtime.table_from(keys_and_args[numkeys:]),
functools.partial(self._lua_redis_call, lua_runtime, expected_globals),
functools.partial(self._lua_redis_pcall, lua_runtime, expected_globals),
functools.partial(self._lua_redis_log, lua_runtime, expected_globals),
as_attrgetter(REDIS_LOG_LEVELS)
)
expected_globals.update(lua_runtime.globals().keys())

try:
result = lua_runtime.execute(script)
except SimpleError as ex:
if self.version == 6:
raise SimpleError(msgs.SCRIPT_ERROR_MSG.format(sha1.decode(), ex))
raise SimpleError(ex.value)
except LuaError as ex:
raise SimpleError(msgs.SCRIPT_ERROR_MSG.format(sha1.decode(), ex))

self._check_for_lua_globals(lua_runtime, expected_globals)

return self._convert_lua_result(result, nested=False)

@command((bytes, Int), (bytes,), flags='s')
def evalsha(self, sha1, numkeys, *keys_and_args):
try:
Expand Down
Loading

0 comments on commit ac471ca

Please sign in to comment.