Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] codebase refactor. #332

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 3 additions & 5 deletions aredis_om/checks.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,20 +7,18 @@
@lru_cache(maxsize=None)
async def check_for_command(conn, cmd):
cmd_info = await conn.execute_command("command", "info", cmd)
return None not in cmd_info
return all(cmd_info)


@lru_cache(maxsize=None)
async def has_redis_json(conn=None):
if conn is None:
conn = get_redis_connection()
command_exists = await check_for_command(conn, "json.set")
command_exists = await check_for_command(conn or get_redis_connection(), "json.set")
return command_exists


@lru_cache(maxsize=None)
async def has_redisearch(conn=None):
if conn is None:
if not conn:
conn = get_redis_connection()
if has_redis_json(conn):
return True
Expand Down
10 changes: 4 additions & 6 deletions aredis_om/connections.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,9 @@
def get_redis_connection(**kwargs) -> aioredis.Redis:
# If someone passed in a 'url' parameter, or specified a REDIS_OM_URL
# environment variable, we'll create the Redis client from the URL.
url = kwargs.pop("url", URL)
if url:
return aioredis.Redis.from_url(url, **kwargs)

if not kwargs.get("url", None) and URL:
kwargs["url"] = URL
# Decode from UTF-8 by default
if "decode_responses" not in kwargs:
if not kwargs.get("decode_responses", None):
kwargs["decode_responses"] = True
return aioredis.Redis(**kwargs)
return aioredis.from_url(**kwargs)
6 changes: 3 additions & 3 deletions aredis_om/model/encoders.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,9 +64,9 @@ def jsonable_encoder(
custom_encoder: Dict[Any, Callable[[Any], Any]] = {},
sqlalchemy_safe: bool = True,
) -> Any:
if include is not None and not isinstance(include, (set, dict)):
if include and not isinstance(include, (set, dict)):
include = set(include)
if exclude is not None and not isinstance(exclude, (set, dict)):
if exclude and not isinstance(exclude, (set, dict)):
exclude = set(exclude)

if isinstance(obj, BaseModel):
Expand Down Expand Up @@ -107,7 +107,7 @@ def jsonable_encoder(
or (not isinstance(key, str))
or (not key.startswith("_sa"))
)
and (value is not None or not exclude_none)
and (value or not exclude_none)
and ((include and key in include) or not exclude or key not in exclude)
):
encoded_key = jsonable_encoder(
Expand Down
54 changes: 25 additions & 29 deletions aredis_om/model/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ def is_supported_container_type(typ: Optional[type]) -> bool:


def validate_model_fields(model: Type["RedisModel"], field_values: Dict[str, Any]):
for field_name in field_values.keys():
for field_name in field_values:
if "__" in field_name:
obj = model
for sub_field in field_name.split("__"):
Expand Down Expand Up @@ -432,11 +432,11 @@ def validate_sort_fields(self, sort_fields: List[str]):

@staticmethod
def resolve_field_type(field: ModelField, op: Operators) -> RediSearchFieldTypes:
if getattr(field.field_info, "primary_key", None) is True:
if getattr(field.field_info, "primary_key", None):
return RediSearchFieldTypes.TAG
elif op is Operators.LIKE:
fts = getattr(field.field_info, "full_text_search", None)
if fts is not True: # Could be PydanticUndefined
if not fts: # Could be PydanticUndefined
raise QuerySyntaxError(
f"You tried to do a full-text search on the field '{field.name}', "
f"but the field is not indexed for full-text search. Use the "
Expand Down Expand Up @@ -464,7 +464,7 @@ def resolve_field_type(field: ModelField, op: Operators) -> RediSearchFieldTypes
# is not itself directly indexed, but instead, we index any fields
# within the model inside the list marked as `index=True`.
return RediSearchFieldTypes.TAG
elif container_type is not None:
elif container_type:
raise QuerySyntaxError(
"Only lists and tuples are supported for multi-value fields. "
f"Docs: {ERRORS_URL}#E4"
Expand Down Expand Up @@ -567,7 +567,7 @@ def resolve_value(
# The value contains the TAG field separator. We can work
# around this by breaking apart the values and unioning them
# with multiple field:{} queries.
values: filter = filter(None, value.split(separator_char))
values: List[str] = [val for val in value.split(separator_char) if val]
for value in values:
value = escaper.escape(value)
result += f"@{field_name}:{{{value}}}"
Expand Down Expand Up @@ -1131,7 +1131,7 @@ async def save(self, pipeline: Optional[Pipeline] = None) -> "RedisModel":
raise NotImplementedError

async def expire(self, num_seconds: int, pipeline: Optional[Pipeline] = None):
if pipeline is None:
if not pipeline:
db = self.db()
else:
db = pipeline
Expand Down Expand Up @@ -1195,15 +1195,11 @@ def to_string(s):
step = 2 # Because the result has content
offset = 1 # The first item is the count of total matches.

for i in xrange(1, len(res), step):
fields_offset = offset

for i in range(1, len(res), step):
fields = dict(
dict(
izip(
map(to_string, res[i + fields_offset][::2]),
map(to_string, res[i + fields_offset][1::2]),
)
zip(
map(to_string, res[i + offset][::2]),
map(to_string, res[i + offset][1::2]),
)
)

Expand Down Expand Up @@ -1244,7 +1240,7 @@ async def add(
pipeline: Optional[Pipeline] = None,
pipeline_verifier: Callable[..., Any] = verify_pipeline_response,
) -> Sequence["RedisModel"]:
if pipeline is None:
if not pipeline:
# By default, send commands in a pipeline. Saving each model will
# be atomic, but Redis may process other commands in between
# these saves.
Expand All @@ -1261,7 +1257,7 @@ async def add(

# If the user didn't give us a pipeline, then we need to execute
# the one we just created.
if pipeline is None:
if not pipeline:
result = await db.execute()
pipeline_verifier(result, expected_responses=len(models))

Expand Down Expand Up @@ -1303,7 +1299,7 @@ def __init_subclass__(cls, **kwargs):

async def save(self, pipeline: Optional[Pipeline] = None) -> "HashModel":
self.check()
if pipeline is None:
if not pipeline:
db = self.db()
else:
db = pipeline
Expand Down Expand Up @@ -1356,7 +1352,7 @@ def _get_value(cls, *args, **kwargs) -> Any:
values. Is there a better way?
"""
val = super()._get_value(*args, **kwargs)
if val is None:
if not val:
return ""
return val

Expand Down Expand Up @@ -1392,7 +1388,7 @@ def schema_for_fields(cls):
name, _type, field.field_info
)
schema_parts.append(redisearch_field)
elif getattr(field.field_info, "index", None) is True:
elif getattr(field.field_info, "index", None):
schema_parts.append(cls.schema_for_type(name, _type, field.field_info))
elif is_subscripted_type:
# Ignore subscripted types (usually containers!) that we don't
Expand Down Expand Up @@ -1437,7 +1433,7 @@ def schema_for_type(cls, name, typ: Any, field_info: PydanticFieldInfo):
elif any(issubclass(typ, t) for t in NUMERIC_TYPES):
schema = f"{name} NUMERIC"
elif issubclass(typ, str):
if getattr(field_info, "full_text_search", False) is True:
if getattr(field_info, "full_text_search", False):
schema = (
f"{name} TAG SEPARATOR {SINGLE_VALUE_TAG_FIELD_SEPARATOR} "
f"{name} AS {name}_fts TEXT"
Expand All @@ -1455,7 +1451,7 @@ def schema_for_type(cls, name, typ: Any, field_info: PydanticFieldInfo):
schema = " ".join(sub_fields)
else:
schema = f"{name} TAG SEPARATOR {SINGLE_VALUE_TAG_FIELD_SEPARATOR}"
if schema and sortable is True:
if schema and sortable:
schema += " SORTABLE"
return schema

Expand All @@ -1475,7 +1471,7 @@ def __init__(self, *args, **kwargs):

async def save(self, pipeline: Optional[Pipeline] = None) -> "JsonModel":
self.check()
if pipeline is None:
if not pipeline:
db = self.db()
else:
db = pipeline
Expand Down Expand Up @@ -1633,7 +1629,7 @@ def schema_for_type(
parent_type=typ,
)
)
return " ".join(filter(None, sub_fields))
return " ".join([sub_field for sub_field in sub_fields if sub_field])
# NOTE: This is the termination point for recursion. We've descended
# into models and lists until we found an actual value to index.
elif should_index:
Expand All @@ -1660,23 +1656,23 @@ def schema_for_type(
"In this Preview release, list and tuple fields can only "
f"contain strings. Problem field: {name}. See docs: TODO"
)
if full_text_search is True:
if full_text_search:
raise RedisModelError(
"List and tuple fields cannot be indexed for full-text "
f"search. Problem field: {name}. See docs: TODO"
)
schema = f"{path} AS {index_field_name} TAG SEPARATOR {SINGLE_VALUE_TAG_FIELD_SEPARATOR}"
if sortable is True:
if sortable:
raise sortable_tag_error
elif any(issubclass(typ, t) for t in NUMERIC_TYPES):
schema = f"{path} AS {index_field_name} NUMERIC"
elif issubclass(typ, str):
if full_text_search is True:
if full_text_search:
schema = (
f"{path} AS {index_field_name} TAG SEPARATOR {SINGLE_VALUE_TAG_FIELD_SEPARATOR} "
f"{path} AS {index_field_name}_fts TEXT"
)
if sortable is True:
if sortable:
# NOTE: With the current preview release, making a field
# full-text searchable and sortable only makes the TEXT
# field sortable. This means that results for full-text
Expand All @@ -1685,11 +1681,11 @@ def schema_for_type(
schema += " SORTABLE"
else:
schema = f"{path} AS {index_field_name} TAG SEPARATOR {SINGLE_VALUE_TAG_FIELD_SEPARATOR}"
if sortable is True:
if sortable:
raise sortable_tag_error
else:
schema = f"{path} AS {index_field_name} TAG SEPARATOR {SINGLE_VALUE_TAG_FIELD_SEPARATOR}"
if sortable is True:
if sortable:
raise sortable_tag_error
return schema
return ""
Expand Down
18 changes: 8 additions & 10 deletions aredis_om/model/render_tree.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ def render_tree(
write to a StringIO buffer, then use that buffer to accumulate written lines
during recursive calls to render_tree().
"""
if buffer is None:
if not buffer:
buffer = io.StringIO()
if hasattr(current_node, nameattr):
name = lambda node: getattr(node, nameattr) # noqa: E731
Expand All @@ -31,11 +31,9 @@ def render_tree(
up = getattr(current_node, left_child, None)
down = getattr(current_node, right_child, None)

if up is not None:
if up:
next_last = "up"
next_indent = "{0}{1}{2}".format(
indent, " " if "up" in last else "|", " " * len(str(name(current_node)))
)
next_indent = f'{indent}{" " if "up" in last else "|"}{" " * len(str(name(current_node)))}'
render_tree(
up, nameattr, left_child, right_child, next_indent, next_last, buffer
)
Expand All @@ -49,7 +47,7 @@ def render_tree(
else:
start_shape = "├"

if up is not None and down is not None:
if up and down:
end_shape = "┤"
elif up:
end_shape = "┘"
Expand All @@ -59,14 +57,14 @@ def render_tree(
end_shape = ""

print(
"{0}{1}{2}{3}".format(indent, start_shape, name(current_node), end_shape),
f"{indent}{start_shape}{name(current_node)}{end_shape}",
file=buffer,
)

if down is not None:
if down:
next_last = "down"
next_indent = "{0}{1}{2}".format(
indent, " " if "down" in last else "|", " " * len(str(name(current_node)))
next_indent = (
f'{indent}{" " if "down" in last else "|"}{len(str(name(current_node)))}'
)
render_tree(
down, nameattr, left_child, right_child, next_indent, next_last, buffer
Expand Down
2 changes: 1 addition & 1 deletion aredis_om/unasync_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ async def f():
return None

obj = f()
if obj is None:
if not obj:
return False
else:
obj.close() # prevent unawaited coroutine warning
Expand Down
2 changes: 1 addition & 1 deletion tests/test_oss_redis_features.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ async def members(m):
async def test_all_keys(members, m):
pks = sorted([pk async for pk in await m.Member.all_pks()])
assert len(pks) == 3
assert pks == sorted([m.pk for m in members])
assert pks == sorted(m.pk for m in members)


@py_test_mark_asyncio
Expand Down