Skip to content

Commit

Permalink
adding xfix style queries (#610)
Browse files Browse the repository at this point in the history
  • Loading branch information
slorello89 authored May 3, 2024
1 parent 5ef3d27 commit 57fe8a2
Show file tree
Hide file tree
Showing 3 changed files with 95 additions and 1 deletion.
36 changes: 35 additions & 1 deletion aredis_om/model/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,8 @@
TypeVar,
Union,
)
from typing import get_args as typing_get_args, no_type_check
from typing import get_args as typing_get_args
from typing import no_type_check

from more_itertools import ichunked
from redis.commands.json.path import Path
Expand Down Expand Up @@ -112,6 +113,9 @@ class Operators(Enum):
NOT_IN = 11
LIKE = 12
ALL = 13
STARTSWITH = 14
ENDSWITH = 15
CONTAINS = 16

def __str__(self):
return str(self.name)
Expand Down Expand Up @@ -346,6 +350,21 @@ def __rshift__(self, other: Any) -> Expression:
left=self.field, op=Operators.NOT_IN, right=other, parents=self.parents
)

def startswith(self, other: Any) -> Expression:
return Expression(
left=self.field, op=Operators.STARTSWITH, right=other, parents=self.parents
)

def endswith(self, other: Any) -> Expression:
return Expression(
left=self.field, op=Operators.ENDSWITH, right=other, parents=self.parents
)

def contains(self, other: Any) -> Expression:
return Expression(
left=self.field, op=Operators.CONTAINS, right=other, parents=self.parents
)

def __getattr__(self, item):
if item.startswith("__"):
raise AttributeError("cannot invoke __getattr__ with reserved field")
Expand Down Expand Up @@ -691,6 +710,21 @@ def resolve_value(
result += "-(@{field_name}:{{{expanded_value}}})".format(
field_name=field_name, expanded_value=expanded_value
)
elif op is Operators.STARTSWITH:
expanded_value = cls.expand_tag_value(value)
result += "(@{field_name}:{{{expanded_value}*}})".format(
field_name=field_name, expanded_value=expanded_value
)
elif op is Operators.ENDSWITH:
expanded_value = cls.expand_tag_value(value)
result += "(@{field_name}:{{*{expanded_value}}})".format(
field_name=field_name, expanded_value=expanded_value
)
elif op is Operators.CONTAINS:
expanded_value = cls.expand_tag_value(value)
result += "(@{field_name}:{{*{expanded_value}*}})".format(
field_name=field_name, expanded_value=expanded_value
)

return result

Expand Down
23 changes: 23 additions & 0 deletions tests/test_hash_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -852,3 +852,26 @@ class TypeWithUuid(HashModel):
item = TypeWithUuid(uuid=uuid.uuid4())

await item.save()


@py_test_mark_asyncio
async def test_xfix_queries(members, m):
member1, member2, member3 = members

result = await m.Member.find(m.Member.first_name.startswith("And")).first()
assert result.first_name == "Andrew"

result = await m.Member.find(m.Member.last_name.endswith("ins")).first()
assert result.first_name == "Andrew"

result = await m.Member.find(m.Member.last_name.contains("ook")).first()
assert result.first_name == "Andrew"

result = await m.Member.find(m.Member.bio % "great*").first()
assert result.first_name == "Andrew"

result = await m.Member.find(m.Member.bio % "*rty").first()
assert result.first_name == "Andrew"

result = await m.Member.find(m.Member.bio % "*eat*").first()
assert result.first_name == "Andrew"
37 changes: 37 additions & 0 deletions tests/test_json_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -934,3 +934,40 @@ class TypeWithUuid(JsonModel):
item = TypeWithUuid(uuid=uuid.uuid4())

await item.save()


@py_test_mark_asyncio
async def test_xfix_queries(m):
await m.Member(
first_name="Steve",
last_name="Lorello",
email="s@example.com",
join_date=today,
bio="Steve is a two-bit hacker who loves Redis.",
address=m.Address(
address_line_1="42 foo bar lane",
city="Satellite Beach",
state="FL",
country="USA",
postal_code="32999",
),
age=34,
).save()

result = await m.Member.find(m.Member.first_name.startswith("Ste")).first()
assert result.first_name == "Steve"

result = await m.Member.find(m.Member.last_name.endswith("llo")).first()
assert result.first_name == "Steve"

result = await m.Member.find(m.Member.address.city.contains("llite")).first()
assert result.first_name == "Steve"

result = await m.Member.find(m.Member.bio % "tw*").first()
assert result.first_name == "Steve"

result = await m.Member.find(m.Member.bio % "*cker").first()
assert result.first_name == "Steve"

result = await m.Member.find(m.Member.bio % "*ack*").first()
assert result.first_name == "Steve"

0 comments on commit 57fe8a2

Please sign in to comment.