Skip to content

Commit

Permalink
fix: show a warning if the user uses star-expressions inside Lists
Browse files Browse the repository at this point in the history
  • Loading branch information
15r10nk committed Sep 22, 2024
1 parent 4b9f337 commit 58e0a00
Show file tree
Hide file tree
Showing 4 changed files with 106 additions and 33 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ dependencies=["cogapp","lxml","requests"]
scripts.update="cog -r docs/**.md"

[[tool.hatch.envs.hatch-test.matrix]]
python = ["3.13", "3.12", "3.11", "3.10", "3.9", "3.8"]
python = ["3.12", "3.11", "3.10", "3.9", "3.8"]

[tool.hatch.envs.hatch-test]
extra-dependencies = [
Expand Down
47 changes: 28 additions & 19 deletions src/inline_snapshot/_inline_snapshot.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import copy
import inspect
import tokenize
import warnings
from collections import defaultdict
from pathlib import Path
from typing import Any
Expand Down Expand Up @@ -170,34 +171,36 @@ def _new_code(self):

def _get_changes(self) -> Iterator[Change]:

def handle(node, value):
if isinstance(value, list):
def handle(node, obj):
if isinstance(obj, list):
if not isinstance(node, ast.List):
return
for n, v in zip(node.elts, value):
yield from handle(n, v)
elif isinstance(value, tuple):
for node_value, value in zip(node.elts, obj):
yield from handle(node_value, value)
elif isinstance(obj, tuple):
if not isinstance(node, ast.Tuple):
return
for n, v in zip(node.elts, value):
yield from handle(n, v)
for node_value, value in zip(node.elts, obj):
yield from handle(node_value, value)

elif isinstance(value, dict):
elif isinstance(obj, dict):
if not isinstance(node, ast.Dict):
return
for vk, nk, n in zip(value.keys(), node.keys, node.values):
for value_key, node_key, node_value in zip(
obj.keys(), node.keys, node.values
):
try:
# this is just a sanity check, dicts should be ordered
node_key = ast.literal_eval(nk)
node_key = ast.literal_eval(node_key)
except Exception:
assert False
pass
else:
assert node_key == vk
assert node_key == value_key

yield from handle(n, value[vk])
yield from handle(node_value, obj[value_key])
else:
if update_allowed(value):
new_token = value_to_token(value)
if update_allowed(obj):
new_token = value_to_token(obj)
if self._token_of_node(node) != new_token:
new_code = self._token_to_code(new_token)

Expand Down Expand Up @@ -345,6 +348,15 @@ def check(old_value, old_node, new_value):
and isinstance(new_value, tuple)
and isinstance(old_value, tuple)
):
for e in old_node.elts:
if isinstance(e, ast.Starred):
warnings.showwarning(
"starred-expressions are not supported inside snapshots",
filename=self._source.filename,
lineno=e.lineno,
category=SyntaxWarning,
)
return
diff = add_x(align(old_value, new_value))
old = zip(old_value, old_node.elts)
new = iter(new_value)
Expand Down Expand Up @@ -840,8 +852,5 @@ def _change(self):
@property
def _flags(self):

if self._value._old_value is undefined:
return {"create"}

changes = self._value._get_changes()
changes = self._changes()
return {change.flag for change in changes}
28 changes: 19 additions & 9 deletions src/inline_snapshot/testing/_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,13 @@

import inline_snapshot._external
import inline_snapshot._external as external
from inline_snapshot import _inline_snapshot
from inline_snapshot._inline_snapshot import Flags
from inline_snapshot._rewrite_code import ChangeRecorder
from inline_snapshot._types import Category
from inline_snapshot._types import Snapshot

from .. import _inline_snapshot
from .._change import apply_all
from .._inline_snapshot import Flags
from .._rewrite_code import ChangeRecorder
from .._types import Category
from .._types import Snapshot


@contextlib.contextmanager
Expand Down Expand Up @@ -160,11 +162,19 @@ def run_inline(
finally:
_inline_snapshot._active = False

snapshot_flags = set()

changes = []
for snapshot in _inline_snapshot.snapshots.values():
snapshot_flags |= snapshot._flags
snapshot._change()
changes += snapshot._changes()

snapshot_flags = {change.flag for change in changes}

apply_all(
[
change
for change in changes
if change.flag in _inline_snapshot._update_flags.to_set()
]
)

if reported_categories is not None:
assert sorted(snapshot_flags) == reported_categories
Expand Down
62 changes: 58 additions & 4 deletions tests/test_inline_snapshot.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,11 @@
import ast
import contextlib
import itertools
import warnings
from collections import namedtuple
from contextlib import nullcontext
from dataclasses import dataclass
from typing import Union

import pytest
from hypothesis import given
Expand Down Expand Up @@ -575,9 +579,7 @@ def test_assert(check_update):
def test_plain(check_update, executing_used):
assert check_update("s = snapshot(5)", flags="") == snapshot("s = snapshot(5)")

assert check_update(
"s = snapshot()", flags="", reported_flags="create"
) == snapshot("s = snapshot()")
assert check_update("s = snapshot()", flags="") == snapshot("s = snapshot()")


def test_string_update(check_update):
Expand Down Expand Up @@ -805,7 +807,10 @@ def test_format_value(check_update):


def test_unused_snapshot(check_update):
assert check_update("snapshot()\n", flags="create") == "snapshot()\n"
assert (
check_update("snapshot()\n", flags="create", reported_flags="")
== "snapshot()\n"
)


def test_type_error(check_update):
Expand Down Expand Up @@ -1094,6 +1099,7 @@ def test_dirty_equals_in_unused_snapshot() -> None:
snapshot([IsStr(),3])
snapshot((IsStr(),3))
snapshot({1:IsStr(),2:3})
snapshot({1+1:2})
t=(1,2)
d={1:2}
Expand All @@ -1104,3 +1110,51 @@ def test_dirty_equals_in_unused_snapshot() -> None:
["--inline-snapshot=fix"],
changed_files=snapshot({}),
)


@dataclass
class Warning:
category: type
message: str
filename: Union[str, None] = None
line: Union[int, None] = None


@contextlib.contextmanager
def warns(expected_warnings=[], include_line=False, include_file=False):
with warnings.catch_warnings(record=True) as result:
warnings.simplefilter("always")
yield

assert [
Warning(
category=w.category,
line=w.lineno if include_line else None,
message=str(w.message),
filename=w.filename if include_file else None,
)
for w in result
] == expected_warnings


def test_starred_warns():

with warns(
snapshot(
[
Warning(
category=SyntaxWarning,
message="starred-expressions are not supported inside snapshots",
line=4,
)
]
),
include_line=True,
):
Example(
"""
from inline_snapshot import snapshot
assert [5] == snapshot([*[4]])
"""
).run_inline(["--inline-snapshot=fix"])

0 comments on commit 58e0a00

Please sign in to comment.