Skip to content

Commit

Permalink
support for server type specific commands (#340)
Browse files Browse the repository at this point in the history
  • Loading branch information
cunla authored Oct 19, 2024
1 parent 018ecb6 commit 375c1ef
Show file tree
Hide file tree
Showing 15 changed files with 140 additions and 44 deletions.
7 changes: 6 additions & 1 deletion docs/about/changelog.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,12 @@ tags:
toc_depth: 2
---

## v2.25.2
## v2.26.0

### 🚀 Features

- Support for server-type specific commands #340
- Support for Dragonfly `SADDEX` command #340

### 🐛 Bug Fixes

Expand Down
15 changes: 9 additions & 6 deletions docs/index.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
---
from test.test_hypothesis import server_typefrom test.test_hypothesis import server_type---
toc:
toc_depth: 3
toc_depth: 3
---

fakeredis: A python implementation of redis server
Expand Down Expand Up @@ -46,7 +46,7 @@ from threading import Thread
from fakeredis import TcpFakeServer

server_address = ("127.0.0.1", 6379)
server = TcpFakeServer(server_address)
server = TcpFakeServer(server_address, server_type="redis")
t = Thread(target=server.serve_forever, daemon=True)
t.start()

Expand All @@ -73,14 +73,15 @@ def redis_client(request):

### General usage

FakeRedis can imitate Redis server version 6.x or 7.x. Version 7 is used by default.
FakeRedis can imitate Redis server version 6.x or 7.x, [Valkey server](./valkey-support),
and [dragonfly server][dragonfly]. Redis version 7 is used by default.

The intent is for fakeredis to act as though you're talking to a real redis server.
It does this by storing the state internally. For example:

```pycon
>>> import fakeredis
>>> r = fakeredis.FakeStrictRedis(version=6)
>>> r = fakeredis.FakeStrictRedis(server_type="redis")
>>> r.set('foo', 'bar')
True
>>> r.get('foo')
Expand Down Expand Up @@ -391,4 +392,6 @@ You can support this project by becoming a sponsor using [this link][2].

[8]:https://github.com/jazzband/django-redis

[9]:https://docs.djangoproject.com/en/4.1/topics/testing/tools/#django.test.override_settings
[9]:https://docs.djangoproject.com/en/4.1/topics/testing/tools/#django.test.override_settings

[dragonfly]:https://www.dragonflydb.io/
19 changes: 19 additions & 0 deletions docs/valkey-support.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,4 +18,23 @@ valkey.set("key", "value")
print(valkey.get("key"))
```

Alternatively, you can start a thread with a Fake Valkey server.

```python
from threading import Thread
from fakeredis import TcpFakeServer

server_address = ("127.0.0.1", 6379)
server = TcpFakeServer(server_address, server_type="valkey")
t = Thread(target=server.serve_forever, daemon=True)
t.start()

import valkey

r = valkey.Valkey(host=server_address[0], port=server_address[1])
r.set("foo", "bar")
assert r.get("foo") == b"bar"

```

[1]: https://github.com/valkey-io/valkey
4 changes: 4 additions & 0 deletions fakeredis/_basefakesocket.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,6 +251,10 @@ def _name_to_func(self, cmd_name: str) -> Tuple[Optional[Callable[[Any], Any]],
clean_name = cmd_name.replace("\r", " ").replace("\n", " ")
raise SimpleError(msgs.UNKNOWN_COMMAND_MSG.format(clean_name))
sig = SUPPORTED_COMMANDS[cmd_name]
if self._server.server_type not in sig.server_types:
# redis remaps \r or \n in an error to ' ' to make it legal protocol
clean_name = cmd_name.replace("\r", " ").replace("\n", " ")
raise SimpleError(msgs.UNKNOWN_COMMAND_MSG.format(clean_name))
func = getattr(self, sig.func_name, None)
return func, sig

Expand Down
2 changes: 2 additions & 0 deletions fakeredis/_commands.py
Original file line number Diff line number Diff line change
Expand Up @@ -425,13 +425,15 @@ def __init__(
repeat: Tuple[Type[Union[RedisType, bytes]]] = (), # type:ignore
args: Tuple[str] = (), # type:ignore
flags: str = "",
server_types: Tuple[str] = ("redis", "valkey", "dragonfly"), # supported server types: redis, dragonfly, valkey
):
self.name = name
self.func_name = func_name
self.fixed = fixed
self.repeat = repeat
self.flags = set(flags)
self.command_args = args
self.server_types: Set[str] = set(server_types)

def check_arity(self, args: Sequence[Any], version: Tuple[int]) -> None:
if len(args) == len(self.fixed):
Expand Down
42 changes: 18 additions & 24 deletions fakeredis/_fakesocket.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,20 @@
from typing import Optional, Set, Any
from typing import Optional, Set

from fakeredis.commands_mixins import (
BitmapCommandsMixin,
ConnectionCommandsMixin,
GenericCommandsMixin,
GeoCommandsMixin,
HashCommandsMixin,
ListCommandsMixin,
PubSubCommandsMixin,
ScriptingCommandsMixin,
ServerCommandsMixin,
StringCommandsMixin,
TransactionsCommandsMixin,
SetCommandsMixin,
StreamsCommandsMixin,
)
from fakeredis.stack import (
JSONCommandsMixin,
BFCommandsMixin,
Expand All @@ -11,30 +26,8 @@
)
from ._basefakesocket import BaseFakeSocket
from ._server import FakeServer
from .commands_mixins.bitmap_mixin import BitmapCommandsMixin
from .commands_mixins.connection_mixin import ConnectionCommandsMixin
from .commands_mixins.generic_mixin import GenericCommandsMixin
from .commands_mixins.geo_mixin import GeoCommandsMixin
from .commands_mixins.hash_mixin import HashCommandsMixin
from .commands_mixins.list_mixin import ListCommandsMixin
from .commands_mixins.pubsub_mixin import PubSubCommandsMixin

try:
from .commands_mixins.scripting_mixin import ScriptingCommandsMixin
except ImportError:

class ScriptingCommandsMixin: # type: ignore # noqa: E303
def __init__(self, *args: Any, **kwargs: Any) -> None:
kwargs.pop("lua_modules", None)
super(ScriptingCommandsMixin, self).__init__(*args, **kwargs) # type: ignore


from .commands_mixins.server_mixin import ServerCommandsMixin
from .commands_mixins.set_mixin import SetCommandsMixin
from .commands_mixins.sortedset_mixin import SortedSetCommandsMixin
from .commands_mixins.streams_mixin import StreamsCommandsMixin
from .commands_mixins.string_mixin import StringCommandsMixin
from .commands_mixins.transactions_mixin import TransactionsCommandsMixin
from .server_specific_commands import DragonflyCommandsMixin


class FakeSocket(
Expand All @@ -60,6 +53,7 @@ class FakeSocket(
TopkCommandsMixin,
TDigestCommandsMixin,
TimeSeriesCommandsMixin,
DragonflyCommandsMixin,
):
def __init__(
self,
Expand Down
6 changes: 4 additions & 2 deletions fakeredis/_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,16 @@
import time
import weakref
from collections import defaultdict
from typing import Dict, Tuple, Any, List, Optional, Union
from typing import Dict, Tuple, Any, List, Optional, Union, Literal

from fakeredis._helpers import Database, FakeSelector

LOGGER = logging.getLogger("fakeredis")

VersionType = Union[Tuple[int, ...], int, str]

ServerType = Literal["redis", "dragonfly", "valkey"]


def _create_version(v: VersionType) -> Tuple[int, ...]:
if isinstance(v, tuple):
Expand All @@ -26,7 +28,7 @@ def _create_version(v: VersionType) -> Tuple[int, ...]:
class FakeServer:
_servers_map: Dict[str, "FakeServer"] = dict()

def __init__(self, version: VersionType = (7,), server_type: str = "redis") -> None:
def __init__(self, version: VersionType = (7,), server_type: ServerType = "redis") -> None:
self.lock = threading.Lock()
self.dbs: Dict[int, Database] = defaultdict(lambda: Database(self.lock))
# Maps channel/pattern to a weak set of sockets
Expand Down
3 changes: 2 additions & 1 deletion fakeredis/_tcp_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

from fakeredis import FakeRedis
from fakeredis import FakeServer
from fakeredis._server import ServerType

LOGGER = logging.getLogger("fakeredis")

Expand Down Expand Up @@ -113,7 +114,7 @@ def __init__(
self,
server_address: Tuple[str | bytes | bytearray, int],
bind_and_activate: bool = True,
server_type: str = "redis",
server_type: ServerType = "redis",
server_version: Tuple[int, ...] = (7, 4),
):
super().__init__(server_address, TCPFakeRequestHandler, bind_and_activate)
Expand Down
41 changes: 41 additions & 0 deletions fakeredis/commands_mixins/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
from typing import Any

from .bitmap_mixin import BitmapCommandsMixin
from .connection_mixin import ConnectionCommandsMixin
from .generic_mixin import GenericCommandsMixin
from .geo_mixin import GeoCommandsMixin
from .hash_mixin import HashCommandsMixin
from .list_mixin import ListCommandsMixin
from .pubsub_mixin import PubSubCommandsMixin
from .server_mixin import ServerCommandsMixin
from .set_mixin import SetCommandsMixin
from .streams_mixin import StreamsCommandsMixin
from .string_mixin import StringCommandsMixin

try:
from .scripting_mixin import ScriptingCommandsMixin
except ImportError:

class ScriptingCommandsMixin: # type: ignore # noqa: E303
def __init__(self, *args: Any, **kwargs: Any) -> None:
kwargs.pop("lua_modules", None)
super(ScriptingCommandsMixin, self).__init__(*args, **kwargs) # type: ignore


from .transactions_mixin import TransactionsCommandsMixin

__all__ = [
"BitmapCommandsMixin",
"ConnectionCommandsMixin",
"GenericCommandsMixin",
"GeoCommandsMixin",
"HashCommandsMixin",
"ListCommandsMixin",
"PubSubCommandsMixin",
"ScriptingCommandsMixin",
"TransactionsCommandsMixin",
"ServerCommandsMixin",
"SetCommandsMixin",
"StreamsCommandsMixin",
"StringCommandsMixin",
]
2 changes: 1 addition & 1 deletion fakeredis/commands_mixins/connection_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ def ping(self, *args: bytes) -> Union[List[bytes], bytes, SimpleString]:
else:
return args[0] if args else PONG

@command((DbIndex,))
@command(name="SELECT", fixed=(DbIndex,))
def select(self, index: DbIndex) -> SimpleString:
self._db = self._server.dbs[index]
self._db_num = index # type: ignore
Expand Down
6 changes: 1 addition & 5 deletions fakeredis/commands_mixins/generic_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,11 +110,7 @@ def exists(self, *keys):
ret += 1
return ret

@command(
name="expire",
fixed=(Key(), Int),
repeat=(bytes,),
)
@command(name="EXPIRE", fixed=(Key(), Int), repeat=(bytes,))
def expire(self, key: CommandItem, seconds: int, *args: bytes) -> int:
res = self._expireat(key, self._db.time + seconds, *args)
return res
Expand Down
5 changes: 5 additions & 0 deletions fakeredis/server_specific_commands/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
from fakeredis.server_specific_commands.dragonfly_mixin import DragonflyCommandsMixin

__all__ = [
"DragonflyCommandsMixin",
]
20 changes: 20 additions & 0 deletions fakeredis/server_specific_commands/dragonfly_mixin.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
from typing import Callable

from fakeredis._commands import command, Key, Int, CommandItem
from fakeredis._helpers import Database


class DragonflyCommandsMixin(object):
_expireat: Callable[[CommandItem, int], int]

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self._db: Database

@command(name="SADDEX", fixed=(Key(set), Int, bytes), repeat=(bytes,), server_types=("dragonfly",))
def saddex(self, key: CommandItem, seconds: int, *members: bytes) -> int:
old_size = len(key.value)
key.value.update(members)
key.updated()
self._expireat(key, self._db.time + seconds)
return len(key.value) - old_size
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ packages = [
{ include = "fakeredis" },
{ include = "LICENSE", to = "fakeredis" },
]
version = "2.25.2"
version = "2.26.0"
description = "Python implementation of redis API, can be used for testing purposes."
readme = "README.md"
keywords = ["redis", "RedisJson", "RedisBloom", "tests", "redis-stack"]
Expand Down
10 changes: 7 additions & 3 deletions test/test_mixins/test_server_commands.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,13 +50,17 @@ def test_lastsave(r: redis.Redis):
@fake_only
def test_command(r: redis.Redis):
commands_dict = r.command()
one_word_commands = {cmd for cmd in SUPPORTED_COMMANDS if " " not in cmd}
assert one_word_commands - set(commands_dict.keys()) == set()
one_word_commands = {cmd for cmd in SUPPORTED_COMMANDS if " " not in cmd and SUPPORTED_COMMANDS[cmd].server_types}
server_unsupported_commands = one_word_commands - set(commands_dict.keys())
for command in server_unsupported_commands:
assert "redis" not in SUPPORTED_COMMANDS[command].server_types


@fake_only
def test_command_count(r: redis.Redis):
assert r.command_count() >= len([cmd for cmd in SUPPORTED_COMMANDS if " " not in cmd])
assert r.command_count() >= len(
[cmd for (cmd, cmd_info) in SUPPORTED_COMMANDS.items() if " " not in cmd and "redis" in cmd_info.server_types]
)


@pytest.mark.unsupported_server_types("dragonfly")
Expand Down

0 comments on commit 375c1ef

Please sign in to comment.