From 5352ea6a394ab1f8791bd4d2505594a3084e23d5 Mon Sep 17 00:00:00 2001
From: Frank Hoffmann <15r10nk-git@polarbit.de>
Date: Fri, 3 May 2024 23:30:12 +0200
Subject: [PATCH 1/4] feat: support for `enum.Enum`, `enum.Flag`, `type` and
omitting of default values (#73)
---
docs/code_generation.md | 105 ++++++++
docs/customize_repr.md | 114 +++++++++
docs/index.md | 103 --------
inline_snapshot/__init__.py | 4 +-
inline_snapshot/_code_repr.py | 243 +++++++++++++++++++
inline_snapshot/_inline_snapshot.py | 4 -
inline_snapshot/_utils.py | 4 +-
inline_snapshot/pytest_plugin.py | 11 +-
mkdocs.yml | 2 +
noxfile.py | 3 +-
pyproject.toml | 4 +
tests/conftest.py | 2 +-
tests/example.py | 107 +++++++--
tests/test_code_repr.py | 359 ++++++++++++++++++++++++++++
tests/test_example.py | 22 +-
tests/test_hasrepr.py | 12 +
tests/test_inline_snapshot.py | 2 +-
tests/test_pydantic.py | 42 ++++
18 files changed, 1012 insertions(+), 131 deletions(-)
create mode 100644 docs/code_generation.md
create mode 100644 docs/customize_repr.md
create mode 100644 inline_snapshot/_code_repr.py
create mode 100644 tests/test_code_repr.py
create mode 100644 tests/test_hasrepr.py
create mode 100644 tests/test_pydantic.py
diff --git a/docs/code_generation.md b/docs/code_generation.md
new file mode 100644
index 0000000..1a8353a
--- /dev/null
+++ b/docs/code_generation.md
@@ -0,0 +1,105 @@
+
+
+You can use almost any python datatype and also complex values like `datatime.date`, because `repr()` is used to convert the values to source code.
+The default `__repr__()` behaviour can be [customized](customize_repr.md).
+It might be necessary to import the right modules to match the `repr()` output.
+
+=== "original code"
+
+ ```python
+ from inline_snapshot import snapshot
+ import datetime
+
+
+ def something():
+ return {
+ "name": "hello",
+ "one number": 5,
+ "numbers": list(range(10)),
+ "sets": {1, 2, 15},
+ "datetime": datetime.date(1, 2, 22),
+ "complex stuff": 5j + 3,
+ "bytes": b"byte abc\n\x16",
+ }
+
+
+ def test_something():
+ assert something() == snapshot()
+ ```
+=== "--inline-snapshot=create"
+
+ ```python
+ from inline_snapshot import snapshot
+ import datetime
+
+
+ def something():
+ return {
+ "name": "hello",
+ "one number": 5,
+ "numbers": list(range(10)),
+ "sets": {1, 2, 15},
+ "datetime": datetime.date(1, 2, 22),
+ "complex stuff": 5j + 3,
+ "bytes": b"byte abc\n\x16",
+ }
+
+
+ def test_something():
+ assert something() == snapshot(
+ {
+ "name": "hello",
+ "one number": 5,
+ "numbers": [0, 1, 2, 3, 4, 5, 6, 7, 8, 9],
+ "sets": {1, 2, 15},
+ "datetime": datetime.date(1, 2, 22),
+ "complex stuff": (3 + 5j),
+ "bytes": b"byte abc\n\x16",
+ }
+ )
+ ```
+
+The code is generated in the following way:
+
+1. The value is copied with `value = copy.deepcopy(value)` and it is checked if the copied value is equal to the original value.
+2. The code is generated with `repr(value)` (which can be [customized](customize_repr.md))
+3. Strings which contain newlines are converted to triple quoted strings.
+
+ !!! note
+ Missing newlines at start or end are escaped (since 0.4.0).
+
+ === "original code"
+
+ ``` python
+ def test_something():
+ assert "first line\nsecond line" == snapshot(
+ """first line
+ second line"""
+ )
+ ```
+
+ === "--inline-snapshot=update"
+
+ ``` python
+ def test_something():
+ assert "first line\nsecond line" == snapshot(
+ """\
+ first line
+ second line\
+ """
+ )
+ ```
+
+
+4. The code is formatted with black.
+
+
+5. The whole file is formatted with black if it was formatted before.
+
+ !!! note
+ The black formatting of the whole file could not work for the following reasons:
+
+ 1. black is configured with cli arguments and not in a configuration file.
+ **Solution:** configure black in a [configuration file](https://black.readthedocs.io/en/stable/usage_and_configuration/the_basics.html#configuration-via-a-file)
+ 2. inline-snapshot uses a different black version.
+ **Solution:** specify which black version inline-snapshot should use by adding black with a specific version to your dependencies.
diff --git a/docs/customize_repr.md b/docs/customize_repr.md
new file mode 100644
index 0000000..5f7e5f0
--- /dev/null
+++ b/docs/customize_repr.md
@@ -0,0 +1,114 @@
+
+
+
+`repr()` can be used to convert a python object into a source code representation of the object, but this does not work for every type.
+Here are some examples:
+```pycon
+>>> repr(int)
+""
+
+>>> from enum import Enum
+>>> E = Enum("E", ["a", "b"])
+>>> repr(E.a)
+''
+```
+
+`customize_repr` can be used to overwrite the default `repr()` behaviour.
+
+The implementation for `Enum` looks like this:
+
+```python exec="1" result="python"
+print('--8<-- "inline_snapshot/_code_repr.py:Enum"')
+```
+
+This implementation is then used by inline-snapshot if `repr()` is called during the code generation, but not in normal code.
+
+
+```python
+from enum import Enum
+
+
+def test_enum():
+ E = Enum("E", ["a", "b"])
+
+ # normal repr
+ assert repr(E.a) == ""
+
+ # the special implementation to convert the Enum into a code
+ assert E.a == snapshot(E.a)
+```
+
+## builtin datatypes
+
+inline-snapshot comes with a special implementation for the following types:
+```python exec="1"
+from inline_snapshot._code_repr import code_repr_dispatch, code_repr
+
+for name, obj in sorted(
+ (
+ getattr(
+ obj, "_inline_snapshot_name", f"{obj.__module__}.{obj.__qualname__}"
+ ),
+ obj,
+ )
+ for obj in code_repr_dispatch.registry.keys()
+):
+ if obj is not object:
+ print(f"- `{name}`")
+```
+
+Container types like `dict` or `dataclass` need a special implementation because it is necessary that the implementation uses `repr()` for the child elements.
+
+```python exec="1" result="python"
+print('--8<-- "inline_snapshot/_code_repr.py:list"')
+```
+
+!!! note
+ using `#!python f"{obj!r}"` or `#!c PyObject_Repr()` will not work, because inline-snapshot replaces `#!python builtins.repr` during the code generation.
+
+## customize
+
+You can also use `repr()` inside `__repr__()`, if you want to make your own type compatible with inline-snapshot.
+
+
+```python
+from enum import Enum
+
+
+class Pair:
+ def __init__(self, a, b):
+ self.a = a
+ self.b = b
+
+ def __repr__(self):
+ # this would not work
+ # return f"Pair({self.a!r}, {self.b!r})"
+
+ # you have to use repr()
+ return f"Pair({repr(self.a)}, {repr(self.b)})"
+
+ def __eq__(self, other):
+ if not isinstance(other, Pair):
+ return NotImplemented
+ return self.a == other.a and self.b == other.b
+
+
+def test_enum():
+ E = Enum("E", ["a", "b"])
+
+ # the special repr implementation is used recursive here
+ # to convert every Enum to the correct representation
+ assert Pair(E.a, [E.b]) == snapshot(Pair(E.a, [E.b]))
+```
+
+you can also customize the representation of datatypes in other libraries:
+
+``` python
+from inline_snapshot import customize_repr
+from other_lib import SomeType
+
+
+@customize_repr
+def _(value: SomeType):
+ return f"SomeType(x={repr(value.x)})"
+```
diff --git a/docs/index.md b/docs/index.md
index f9d1a79..9d6d76a 100644
--- a/docs/index.md
+++ b/docs/index.md
@@ -201,108 +201,5 @@ or passed as an argument to a function:
-## Code generation
-
-You can use almost any python datatype and also complex values like `datatime.date`, because `repr()` is used to convert the values to a source code.
-It might be necessary to import the right modules to match the `repr()` output.
-
-=== "original code"
-
- ```python
- from inline_snapshot import snapshot
- import datetime
-
-
- def something():
- return {
- "name": "hello",
- "one number": 5,
- "numbers": list(range(10)),
- "sets": {1, 2, 15},
- "datetime": datetime.date(1, 2, 22),
- "complex stuff": 5j + 3,
- "bytes": b"fglecg\n\x16",
- }
-
-
- def test_something():
- assert something() == snapshot()
- ```
-=== "--inline-snapshot=create"
-
- ```python
- from inline_snapshot import snapshot
- import datetime
-
-
- def something():
- return {
- "name": "hello",
- "one number": 5,
- "numbers": list(range(10)),
- "sets": {1, 2, 15},
- "datetime": datetime.date(1, 2, 22),
- "complex stuff": 5j + 3,
- "bytes": b"fglecg\n\x16",
- }
-
-
- def test_something():
- assert something() == snapshot(
- {
- "name": "hello",
- "one number": 5,
- "numbers": [0, 1, 2, 3, 4, 5, 6, 7, 8, 9],
- "sets": {1, 2, 15},
- "datetime": datetime.date(1, 2, 22),
- "complex stuff": (3 + 5j),
- "bytes": b"fglecg\n\x16",
- }
- )
- ```
-
-The code is generated in the following way:
-
-1. The value is copied with `value = copy.deepcopy(value)`
-2. The code is generated with `repr(value)`
-3. Strings which contain newlines are converted to triple quoted strings.
-
- !!! note
- Missing newlines at start or end are escaped (since 0.4.0).
-
- === "original code"
-
- ``` python
- def test_something():
- assert "first line\nsecond line" == snapshot(
- """first line
- second line"""
- )
- ```
-
- === "--inline-snapshot=update"
-
- ``` python
- def test_something():
- assert "first line\nsecond line" == snapshot(
- """\
- first line
- second line\
- """
- )
- ```
-
-
-4. The code is formatted with black.
-
- !!! note
- The black formatting of the whole file could not work for the following reasons:
-
- 1. black is configured with cli arguments and not in a configuration file.
- **Solution:** configure black in a [configuration file](https://black.readthedocs.io/en/stable/usage_and_configuration/the_basics.html#configuration-via-a-file)
- 2. inline-snapshot uses a different black version.
- **Solution:** specify which black version inline-snapshot should use by adding black with a specific version to your dependencies.
-
-5. The whole file is formatted with black if it was formatted before.
--8<-- "README.md:Feedback"
diff --git a/inline_snapshot/__init__.py b/inline_snapshot/__init__.py
index af0b273..c1fdf7c 100644
--- a/inline_snapshot/__init__.py
+++ b/inline_snapshot/__init__.py
@@ -1,7 +1,9 @@
+from ._code_repr import customize_repr
+from ._code_repr import HasRepr
from ._external import external
from ._external import outsource
from ._inline_snapshot import snapshot
-__all__ = ["snapshot", "external", "outsource"]
+__all__ = ["snapshot", "external", "outsource", "customize_repr", "HasRepr"]
__version__ = "0.10.2"
diff --git a/inline_snapshot/_code_repr.py b/inline_snapshot/_code_repr.py
new file mode 100644
index 0000000..9a5dcd3
--- /dev/null
+++ b/inline_snapshot/_code_repr.py
@@ -0,0 +1,243 @@
+import ast
+from abc import ABC
+from collections import defaultdict
+from dataclasses import fields
+from dataclasses import is_dataclass
+from dataclasses import MISSING
+from enum import Enum
+from enum import Flag
+from functools import singledispatch
+from unittest import mock
+
+real_repr = repr
+
+
+class HasRepr:
+ """This class is used for objects where `__repr__()` returns an non-
+ parsable representation.
+
+ HasRepr uses the type and repr of the value for equal comparison.
+
+ You can change `__repr__()` to return valid python code or use
+ `@customize_repr` to customize repr which is used by inline-
+ snapshot.
+ """
+
+ def __init__(self, type, str_repr: str) -> None:
+ self._type = type
+ self._str_repr = str_repr
+
+ def __repr__(self):
+ return f"HasRepr({self._type.__qualname__}, {self._str_repr!r})"
+
+ def __eq__(self, other):
+ if isinstance(other, HasRepr):
+ if other._type is not self._type:
+ return False
+ else:
+ if type(other) is not self._type:
+ return False
+
+ other_repr = code_repr(other)
+ return other_repr == self._str_repr or other_repr == repr(self)
+
+
+def used_hasrepr(tree):
+ return [
+ n
+ for n in ast.walk(tree)
+ if isinstance(n, ast.Call)
+ and isinstance(n.func, ast.Name)
+ and n.func.id == "HasRepr"
+ and len(n.args) == 2
+ ]
+
+
+@singledispatch
+def code_repr_dispatch(value):
+ return real_repr(value)
+
+
+def customize_repr(f):
+ """Register a funtion which should be used to get the code representation
+ of a object.
+
+ ```python
+ @customize_repr
+ def _(obj: MyCustomClass):
+ return f"MyCustomClass(attr={repr(obj.attr)})"
+ ```
+
+ it is important to use `repr()` inside the implementation, because it is mocked to return the code represenation
+
+ you dont have to provide a custom implementation if:
+ * __repr__() of your class returns a valid code representation,
+ * and __repr__() uses `repr()` to get the representaion of the child objects
+ """
+ code_repr_dispatch.register(f)
+
+
+def code_repr(obj):
+ with mock.patch("builtins.repr", code_repr):
+ result = code_repr_dispatch(obj)
+
+ try:
+ ast.parse(result)
+ except SyntaxError:
+ return real_repr(HasRepr(type(obj), result))
+
+ return result
+
+
+# -8<- [start:Enum]
+@customize_repr
+def _(value: Enum):
+ return f"{type(value).__qualname__}.{value.name}"
+
+
+# -8<- [end:Enum]
+
+
+@customize_repr
+def _(value: Flag):
+ name = type(value).__qualname__
+ return " | ".join(f"{name}.{flag.name}" for flag in type(value) if flag in value)
+
+
+# -8<- [start:list]
+@customize_repr
+def _(value: list):
+ return "[" + ", ".join(map(repr, value)) + "]"
+
+
+# -8<- [end:list]
+
+
+class OnlyTuple(ABC):
+ _inline_snapshot_name = "builtins.tuple"
+
+ @classmethod
+ def __subclasshook__(cls, t):
+ return t is tuple
+
+
+@customize_repr
+def _(value: OnlyTuple):
+ assert isinstance(value, tuple)
+ if len(value) == 1:
+ return f"({repr(value[0])},)"
+ return "(" + ", ".join(map(repr, value)) + ")"
+
+
+class IsNamedTuple(ABC):
+ _inline_snapshot_name = "namedtuple"
+
+ _fields: tuple
+ _field_defaults: dict
+
+ @classmethod
+ def __subclasshook__(cls, t):
+ b = t.__bases__
+ if len(b) != 1 or b[0] != tuple:
+ return False
+ f = getattr(t, "_fields", None)
+ if not isinstance(f, tuple):
+ return False
+ return all(type(n) == str for n in f)
+
+
+@customize_repr
+def _(value: IsNamedTuple):
+ params = ", ".join(
+ f"{field}={repr(getattr(value,field))}"
+ for field in value._fields
+ if field not in value._field_defaults
+ or getattr(value, field) != value._field_defaults[field]
+ )
+ return f"{repr(type(value))}({params})"
+
+
+@customize_repr
+def _(value: set):
+ if len(value) == 0:
+ return "set()"
+
+ return "{" + ", ".join(map(repr, value)) + "}"
+
+
+@customize_repr
+def _(value: frozenset):
+ if len(value) == 0:
+ return "frozenset()"
+
+ return "frozenset({" + ", ".join(map(repr, value)) + "})"
+
+
+@customize_repr
+def _(value: dict):
+ result = (
+ "{" + ", ".join(f"{repr(k)}: {repr(value)}" for k, value in value.items()) + "}"
+ )
+
+ if type(value) is not dict:
+ result = f"{repr(type(value))}({result})"
+
+ return result
+
+
+@customize_repr
+def _(value: defaultdict):
+ return f"defaultdict({repr(value.default_factory)}, {repr(dict(value))})"
+
+
+@customize_repr
+def _(value: type):
+ return value.__qualname__
+
+
+class IsDataclass(ABC):
+ _inline_snapshot_name = "dataclass"
+
+ @classmethod
+ def __subclasshook__(cls, subclass):
+ return is_dataclass(subclass)
+
+
+@customize_repr
+def _(value: IsDataclass):
+ attrs = []
+ for field in fields(value): # type: ignore
+ if field.repr:
+ field_value = getattr(value, field.name)
+
+ if field.default != MISSING and field.default == field_value:
+ continue
+
+ if (
+ field.default_factory != MISSING
+ and field.default_factory() == field_value
+ ):
+ continue
+
+ attrs.append(f"{field.name}={repr(field_value)}")
+
+ return f"{repr(type(value))}({', '.join(attrs)})"
+
+
+try:
+ from pydantic import BaseModel
+except ImportError: # pragma: no cover
+ pass
+else:
+
+ @customize_repr
+ def _(model: BaseModel):
+ return (
+ type(model).__qualname__
+ + "("
+ + ", ".join(
+ e + "=" + repr(getattr(model, e))
+ for e in sorted(model.__pydantic_fields_set__)
+ )
+ + ")"
+ )
diff --git a/inline_snapshot/_inline_snapshot.py b/inline_snapshot/_inline_snapshot.py
index 0bdf8eb..3b5c237 100644
--- a/inline_snapshot/_inline_snapshot.py
+++ b/inline_snapshot/_inline_snapshot.py
@@ -716,10 +716,6 @@ def __init__(self, value, expr):
def _changes(self):
if self._value._old_value is undefined:
new_code = self._value._new_code()
- try:
- ast.parse(new_code)
- except:
- return
yield CallArg(
"create",
diff --git a/inline_snapshot/_utils.py b/inline_snapshot/_utils.py
index 2860e16..a5ecf11 100644
--- a/inline_snapshot/_utils.py
+++ b/inline_snapshot/_utils.py
@@ -4,6 +4,8 @@
import tokenize
from collections import namedtuple
+from ._code_repr import code_repr
+
def normalize_strings(token_sequence):
"""Normalize string concattenanion.
@@ -118,7 +120,7 @@ def __eq__(self, other):
def value_to_token(value):
- input = io.StringIO(repr(value))
+ input = io.StringIO(code_repr(value))
def map_string(tok):
"""Convert strings with newlines in triple quoted strings."""
diff --git a/inline_snapshot/pytest_plugin.py b/inline_snapshot/pytest_plugin.py
index c42a76c..d97a440 100644
--- a/inline_snapshot/pytest_plugin.py
+++ b/inline_snapshot/pytest_plugin.py
@@ -15,6 +15,7 @@
from . import _find_external
from . import _inline_snapshot
from ._change import apply_all
+from ._code_repr import used_hasrepr
from ._find_external import ensure_import
from ._inline_snapshot import used_externals
from ._rewrite_code import ChangeRecorder
@@ -312,9 +313,17 @@ def report(flag, message, message_n):
tree = ast.parse(test_file.new_code())
used = used_externals(tree)
+ required_imports = []
+
if used:
+ required_imports.append("external")
+
+ if used_hasrepr(tree):
+ required_imports.append("HasRepr")
+
+ if required_imports:
ensure_import(
- test_file.filename, {"inline_snapshot": ["external"]}
+ test_file.filename, {"inline_snapshot": required_imports}
)
for external_name in used:
diff --git a/mkdocs.yml b/mkdocs.yml
index 782bcc1..7f04a05 100644
--- a/mkdocs.yml
+++ b/mkdocs.yml
@@ -36,9 +36,11 @@ nav:
- x in snapshot(): in_snapshot.md
- snapshot()[key]: getitem_snapshot.md
- outsource(data): outsource.md
+ - '@customize_repr': customize_repr.md
- pytest integration: pytest.md
- Categories: categories.md
- Configuration: configuration.md
+- Code generation: code_generation.md
- Contributing: contributing.md
- Changelog: changelog.md
diff --git a/noxfile.py b/noxfile.py
index 01a8a02..ec4faf8 100644
--- a/noxfile.py
+++ b/noxfile.py
@@ -24,7 +24,7 @@ def coverage(session):
@session(python=python_versions)
def mypy(session):
- session.install("mypy", "pytest", "hypothesis", "dirty-equals", ".")
+ session.install("mypy", "pytest", "hypothesis", "dirty-equals", "pydantic", ".")
args = ["inline_snapshot", "tests"]
if session.posargs:
args = session.posargs
@@ -46,6 +46,7 @@ def test(session):
"time-machine",
"mypy",
"pyright",
+ "pydantic",
)
cmd = []
diff --git a/pyproject.toml b/pyproject.toml
index 950154d..4c86344 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -90,3 +90,7 @@ Funding = "https://github.com/sponsors/15r10nk"
Homepage = "https://15r10nk.github.io/inline-snapshot"
Issues = "https://github.com/15r10nk/inline-snapshots/issues"
Repository = "https://github.com/15r10nk/inline-snapshot/"
+
+[tool.pyright]
+venv = "test-3-12"
+venvPath = ".nox"
diff --git a/tests/conftest.py b/tests/conftest.py
index 6d0c49d..24b762a 100644
--- a/tests/conftest.py
+++ b/tests/conftest.py
@@ -98,7 +98,7 @@ def run(self, *flags):
error = False
try:
- exec(compile(filename.read_text("utf-8"), filename, "exec"))
+ exec(compile(filename.read_text("utf-8"), filename, "exec"), {})
except AssertionError:
traceback.print_exc()
error = True
diff --git a/tests/example.py b/tests/example.py
index f6a8356..bc9f4b2 100644
--- a/tests/example.py
+++ b/tests/example.py
@@ -1,9 +1,20 @@
+from __future__ import annotations
+
import os
import platform
import re
import subprocess as sp
from pathlib import Path
from tempfile import TemporaryDirectory
+from typing import Any
+
+import inline_snapshot._external
+from .utils import snapshot_env
+from inline_snapshot import _inline_snapshot
+from inline_snapshot._inline_snapshot import Flags
+from inline_snapshot._rewrite_code import ChangeRecorder
+
+pytest_plugins = "pytester"
ansi_escape = re.compile(
@@ -36,10 +47,75 @@ def write_files(self, dir: Path):
def read_files(self, dir: Path):
return {p.name: p.read_text() for p in dir.iterdir() if p.is_file()}
- def run_pytest(self, *args, changed_files=None, report=None, env={}):
+ def run_inline(
+ self, *flags, changes=None, reported_flags=None, changed_files=None
+ ) -> Example:
+
+ with TemporaryDirectory() as dir:
+ tmp_path = Path(dir)
+
+ self.write_files(tmp_path)
+
+ with snapshot_env():
+ with ChangeRecorder().activate() as recorder:
+ _inline_snapshot._update_flags = Flags({*flags})
+ inline_snapshot._external.storage = (
+ inline_snapshot._external.DiscStorage(tmp_path / ".storage")
+ )
+
+ error = False
+
+ try:
+ for filename in tmp_path.glob("*.py"):
+ globals: dict[str, Any] = {}
+ exec(
+ compile(filename.read_text("utf-8"), filename, "exec"),
+ globals,
+ )
+
+ # run all test_* functions
+ for k, v in globals.items():
+ if k.startswith("test_") and callable(v):
+ v()
+
+ finally:
+ _inline_snapshot._active = False
+
+ # number_snapshots = len(_inline_snapshot.snapshots)
+
+ snapshot_flags = set()
+
+ all_changes = []
+
+ for snapshot in _inline_snapshot.snapshots.values():
+ snapshot_flags |= snapshot._flags
+ snapshot._change()
+ all_changes += snapshot._changes()
+
+ if reported_flags is not None:
+ assert sorted(snapshot_flags) == reported_flags
+
+ # if changes is not None:
+ # assert all_changes == changes
+
+ recorder.fix_all()
+
+ if changed_files is not None:
+ current_files = {}
+
+ for name, content in sorted(self.read_files(tmp_path).items()):
+ if name not in self.files or self.files[name] != content:
+ current_files[name] = content
+ assert changed_files == current_files
+
+ return Example(self.read_files(tmp_path))
+
+ def run_pytest(
+ self, *args, changed_files=None, report=None, env={}, returncode=None
+ ) -> Example:
with TemporaryDirectory() as dir:
- dir = Path(dir)
- self.write_files(dir)
+ tmp_path = Path(dir)
+ self.write_files(tmp_path)
cmd = ["pytest", *args]
@@ -54,7 +130,7 @@ def run_pytest(self, *args, changed_files=None, report=None, env={}):
command_env.update(env)
- result = sp.run(cmd, cwd=dir, capture_output=True, env=command_env)
+ result = sp.run(cmd, cwd=tmp_path, capture_output=True, env=command_env)
print("run>", *cmd)
print("stdout:")
@@ -62,9 +138,12 @@ def run_pytest(self, *args, changed_files=None, report=None, env={}):
print("stderr:")
print(result.stderr.decode())
+ if returncode is not None:
+ assert result.returncode == returncode
+
if report is not None:
- new_report = []
+ report_list = []
record = False
for line in result.stdout.decode().splitlines():
line = line.strip()
@@ -72,28 +151,28 @@ def run_pytest(self, *args, changed_files=None, report=None, env={}):
record = False
if record and line:
- new_report.append(line)
+ report_list.append(line)
if line.startswith("====") and "inline snapshot" in line:
record = True
- new_report = "\n".join(new_report)
+ report_str = "\n".join(report_list)
- new_report = ansi_escape.sub("", new_report)
+ report_str = ansi_escape.sub("", report_str)
# fix windows problems
- new_report = new_report.replace("\u2500", "-")
- new_report = new_report.replace("\r", "")
- new_report = new_report.replace(" \n", " ⏎\n")
+ report_str = report_str.replace("\u2500", "-")
+ report_str = report_str.replace("\r", "")
+ report_str = report_str.replace(" \n", " ⏎\n")
- assert new_report == report
+ assert report_str == report
if changed_files is not None:
current_files = {}
- for name, content in sorted(self.read_files(dir).items()):
+ for name, content in sorted(self.read_files(tmp_path).items()):
if name not in self.files or self.files[name] != content:
current_files[name] = content
assert changed_files == current_files
- return Example(self.read_files(dir))
+ return Example(self.read_files(tmp_path))
diff --git a/tests/test_code_repr.py b/tests/test_code_repr.py
new file mode 100644
index 0000000..6b8e250
--- /dev/null
+++ b/tests/test_code_repr.py
@@ -0,0 +1,359 @@
+import dataclasses
+from dataclasses import dataclass
+
+import pytest
+
+from .example import Example
+from inline_snapshot import HasRepr
+from inline_snapshot import snapshot
+from inline_snapshot._code_repr import code_repr
+
+
+def test_enum(check_update):
+
+ assert (
+ check_update(
+ """
+from enum import Enum
+
+class color(Enum):
+ val="val"
+
+
+assert [color.val] == snapshot()
+
+ """,
+ flags="create",
+ )
+ == snapshot(
+ """\
+
+from enum import Enum
+
+class color(Enum):
+ val="val"
+
+
+assert [color.val] == snapshot([color.val])
+
+"""
+ )
+ )
+
+
+def test_snapshot_generates_hasrepr():
+
+ Example(
+ """\
+from inline_snapshot import snapshot
+
+class Thing:
+ def __repr__(self):
+ return ""
+
+def test_thing():
+ assert Thing() == snapshot()
+
+ """
+ ).run_pytest(
+ "--inline-snapshot=create",
+ returncode=snapshot(0),
+ changed_files=snapshot(
+ {
+ "test_something.py": """\
+from inline_snapshot import snapshot
+
+from inline_snapshot import HasRepr
+
+class Thing:
+ def __repr__(self):
+ return ""
+
+def test_thing():
+ assert Thing() == snapshot(HasRepr(Thing, ""))
+
+ \
+"""
+ }
+ ),
+ ).run_pytest(
+ "--inline-snapshot=disable", returncode=0
+ ).run_pytest(
+ returncode=0
+ )
+
+
+def test_hasrepr_type():
+ assert 5 == HasRepr(int, "5")
+ assert not "a" == HasRepr(int, "5")
+ assert not HasRepr(float, "nan") == HasRepr(str, "nan")
+ assert not HasRepr(str, "a") == HasRepr(str, "b")
+
+
+def test_enum_in_dataclass(check_update):
+
+ assert (
+ check_update(
+ """
+from enum import Enum
+from dataclasses import dataclass
+
+class color(Enum):
+ red="red"
+ blue="blue"
+
+@dataclass
+class container:
+ bg: color=color.red
+ fg: color=color.blue
+
+assert container(bg=color.red,fg=color.red) == snapshot()
+
+ """,
+ flags="create",
+ )
+ == snapshot(
+ """\
+
+from enum import Enum
+from dataclasses import dataclass
+
+class color(Enum):
+ red="red"
+ blue="blue"
+
+@dataclass
+class container:
+ bg: color=color.red
+ fg: color=color.blue
+
+assert container(bg=color.red,fg=color.red) == snapshot(container(fg=color.red))
+
+"""
+ )
+ )
+
+
+def test_dataclass_field_repr(check_update):
+
+ Example(
+ """\
+from inline_snapshot import snapshot
+from dataclasses import dataclass,field
+
+@dataclass
+class container:
+ a: int
+ b: int = field(default=5,repr=False)
+
+assert container(a=1,b=5) == snapshot()
+"""
+ ).run_inline(
+ "create",
+ changed_files=snapshot(
+ {
+ "test_something.py": """\
+from inline_snapshot import snapshot
+from dataclasses import dataclass,field
+
+@dataclass
+class container:
+ a: int
+ b: int = field(default=5,repr=False)
+
+assert container(a=1,b=5) == snapshot(container(a=1))
+"""
+ }
+ ),
+ ).run_inline()
+
+
+def test_flag(check_update):
+
+ assert (
+ check_update(
+ """
+from enum import Flag, auto
+
+class Color(Flag):
+ red = auto()
+ green = auto()
+ blue = auto()
+
+assert Color.red | Color.blue == snapshot()
+
+ """,
+ flags="create",
+ )
+ == snapshot(
+ """\
+
+from enum import Flag, auto
+
+class Color(Flag):
+ red = auto()
+ green = auto()
+ blue = auto()
+
+assert Color.red | Color.blue == snapshot(Color.red | Color.blue)
+
+"""
+ )
+ )
+
+
+def test_type(check_update):
+
+ assert (
+ check_update(
+ """\
+class Color:
+ pass
+
+assert [Color,int] == snapshot()
+
+ """,
+ flags="create",
+ )
+ == snapshot(
+ """\
+class Color:
+ pass
+
+assert [Color,int] == snapshot([Color, int])
+
+"""
+ )
+ )
+
+
+def test_qualname():
+
+ Example(
+ """\
+from enum import Enum
+from inline_snapshot import snapshot
+
+
+class Namespace:
+ class Color(Enum):
+ red="red"
+
+assert Namespace.Color.red == snapshot()
+
+ """
+ ).run_inline(
+ "create",
+ changed_files=snapshot(
+ {
+ "test_something.py": """\
+from enum import Enum
+from inline_snapshot import snapshot
+
+
+class Namespace:
+ class Color(Enum):
+ red="red"
+
+assert Namespace.Color.red == snapshot(Namespace.Color.red)
+
+ \
+"""
+ }
+ ),
+ ).run_inline()
+
+
+from collections import (
+ Counter,
+ OrderedDict,
+ UserDict,
+ UserList,
+ defaultdict,
+ namedtuple,
+)
+from typing import NamedTuple
+
+A = namedtuple("A", "a,b", defaults=[0])
+B = namedtuple("B", "a,b", defaults=[0, 0])
+
+
+class C(NamedTuple):
+ a: int
+ b: int = 0
+ c: int = 0
+
+
+@dataclass
+class Dataclass:
+ a: int
+ b: int = dataclasses.field(default=0)
+ c: list = dataclasses.field(default_factory=lambda: [])
+
+
+default_dict = defaultdict(list)
+default_dict[5].append(2)
+default_dict[3].append(1)
+
+
+@pytest.mark.parametrize(
+ "d",
+ [
+ frozenset(["a"]),
+ frozenset(),
+ {"a"},
+ set(),
+ list(),
+ ["a"],
+ {},
+ {1: "1"},
+ (),
+ (1,),
+ (1, 2, 3),
+ A(1, 2),
+ A(1),
+ A(0, 0),
+ B(),
+ B(b=5),
+ C(1),
+ C(1, 2),
+ C(a=1, c=2),
+ Dataclass(a=0, b=0, c=[]),
+ Dataclass(a=1, b=2, c=[3]),
+ default_dict,
+ OrderedDict({1: 2, 3: 4}),
+ UserDict({1: 2}),
+ UserList([1, 2]),
+ ],
+)
+def test_datatypes(d):
+ code = code_repr(d)
+ print("repr: ", repr(d))
+ print("code_repr:", code)
+ assert d == eval(code)
+
+
+def test_datatypes_explicit():
+ assert code_repr(C(a=1, c=2)) == snapshot("C(a=1, c=2)")
+ assert code_repr(B(b=5)) == snapshot("B(b=5)")
+ assert code_repr(B(b=0)) == snapshot("B()")
+
+ assert code_repr(Dataclass(a=0, b=0, c=[])) == snapshot("Dataclass(a=0)")
+ assert code_repr(Dataclass(a=1, b=2, c=[3])) == snapshot(
+ "Dataclass(a=1, b=2, c=[3])"
+ )
+ assert code_repr(Counter([1, 1, 1, 2])) == snapshot("Counter({1: 3, 2: 1})")
+
+ assert code_repr(default_dict) == snapshot("defaultdict(list, {5: [2], 3: [1]})")
+
+
+def test_tuple():
+
+ class FakeTuple(tuple):
+ def __init__(self):
+ self._fields = 5
+
+ def __repr__(self):
+ return "FakeTuple()"
+
+ assert code_repr(FakeTuple()) == snapshot("FakeTuple()")
diff --git a/tests/test_example.py b/tests/test_example.py
index 9f24116..dbcd93e 100644
--- a/tests/test_example.py
+++ b/tests/test_example.py
@@ -1,15 +1,29 @@
from .example import Example
+from inline_snapshot import snapshot
-def test_diff_multiple_files():
+def test_example():
- Example(
- """
+ e = Example(
+ {
+ "test_a.py": """
from inline_snapshot import snapshot
def test_a():
assert 1==snapshot(2)
""",
- ).run_pytest(
+ "test_b.py": "1+1",
+ },
+ )
+
+ e.run_pytest(
"--inline-snapshot=create,fix",
)
+
+ e.run_inline(
+ "fix",
+ reported_flags=snapshot(["fix"]),
+ ).run_inline(
+ "fix",
+ changed_files=snapshot({}),
+ )
diff --git a/tests/test_hasrepr.py b/tests/test_hasrepr.py
new file mode 100644
index 0000000..1ec0899
--- /dev/null
+++ b/tests/test_hasrepr.py
@@ -0,0 +1,12 @@
+from inline_snapshot._code_repr import HasRepr
+
+
+def test_hasrepr_eq():
+
+ assert float("nan") == HasRepr(float, "nan")
+
+ class Thing:
+ def __repr__(self):
+ return ""
+
+ assert Thing() == HasRepr(Thing, "")
diff --git a/tests/test_inline_snapshot.py b/tests/test_inline_snapshot.py
index b2486b9..e9b738f 100644
--- a/tests/test_inline_snapshot.py
+++ b/tests/test_inline_snapshot.py
@@ -787,7 +787,7 @@ class Thing:
def __repr__(self):
return "+++"
-assert Thing() == snapshot()
+assert Thing() == snapshot(HasRepr(Thing, "+++"))
"""
)
)
diff --git a/tests/test_pydantic.py b/tests/test_pydantic.py
new file mode 100644
index 0000000..c6c55a7
--- /dev/null
+++ b/tests/test_pydantic.py
@@ -0,0 +1,42 @@
+from .example import Example
+from inline_snapshot import snapshot
+
+
+def test_pydantic_repr():
+
+ Example(
+ """
+from pydantic import BaseModel
+from inline_snapshot import snapshot
+
+class M(BaseModel):
+ size:int
+ name:str
+ age:int=4
+
+def test_pydantic():
+ assert M(size=5,name="Tom")==snapshot()
+
+ """
+ ).run_inline(
+ "create",
+ changed_files=snapshot(
+ {
+ "test_something.py": """\
+
+from pydantic import BaseModel
+from inline_snapshot import snapshot
+
+class M(BaseModel):
+ size:int
+ name:str
+ age:int=4
+
+def test_pydantic():
+ assert M(size=5,name="Tom")==snapshot(M(name="Tom", size=5))
+
+ \
+"""
+ }
+ ),
+ ).run_inline()
From e8cdfdd3460355b25d4bd935a5acb3ff46afb890 Mon Sep 17 00:00:00 2001
From: Frank Hoffmann <15r10nk-git@polarbit.de>
Date: Sat, 6 Jul 2024 14:29:56 +0200
Subject: [PATCH 2/4] docs: Readme
---
README.md | 56 +++++++++++++++++++++++++++++++++++++++++++++++++++++--
1 file changed, 54 insertions(+), 2 deletions(-)
diff --git a/README.md b/README.md
index 15cfbe7..d1f6c6d 100644
--- a/README.md
+++ b/README.md
@@ -32,6 +32,7 @@ pip install inline-snapshot
- **Preserved Black Formatting:** Retains formatting consistency with Black formatting.
- **External File Storage:** Store snapshots externally using `outsource(data)`.
- **Seamless Pytest Integration:** Integrated seamlessly with pytest for effortless testing.
+- **Customizable:** code generation can be customized with [@customize_repr](https://15r10nk.github.io/inline-snapshot/customize_repr)
- **Comprehensive Documentation:** Access detailed [documentation](https://15r10nk.github.io/inline-snapshot/) for complete guidance.
@@ -51,7 +52,7 @@ def test_something():
You can now run the tests and record the correct values.
```
-$ pytest --inline-snapshot=create
+$ pytest --inline-snapshot=review
```
@@ -63,7 +64,8 @@ def test_something():
assert 1548 * 18489 == snapshot(28620972)
```
-inline-snapshot provides more advanced features like:
+The following examples show how you can use inline-snapshot in your tests. Take a look at the
+[documentation](https://15r10nk.github.io/inline-snapshot/) if you want to know more.
```python
@@ -100,6 +102,56 @@ def test_something():
assert outsource("large string\n" * 1000) == snapshot(
external("8bf10bdf2c30*.txt")
)
+
+ assert "generates\nmultiline\nstrings" == snapshot(
+ """\
+generates
+multiline
+strings\
+"""
+ )
+```
+
+
+`snapshot()` can also be used as parameter for functions:
+
+
+```python
+from inline_snapshot import snapshot
+import subprocess as sp
+import sys
+
+
+def run_python(cmd, stdout=None, stderr=None):
+ result = sp.run([sys.executable, "-c", cmd], capture_output=True)
+ if stdout is not None:
+ assert result.stdout.decode() == stdout
+ if stderr is not None:
+ assert result.stderr.decode() == stderr
+
+
+def test_cmd():
+ run_python(
+ "print('hello world')",
+ stdout=snapshot(
+ """\
+hello world
+"""
+ ),
+ stderr=snapshot(""),
+ )
+
+ run_python(
+ "1/0",
+ stdout=snapshot(""),
+ stderr=snapshot(
+ """\
+Traceback (most recent call last):
+ File "", line 1, in
+ZeroDivisionError: division by zero
+"""
+ ),
+ )
```
From a278fa2224eb25e24386497c8f114c3d8d5dc6ca Mon Sep 17 00:00:00 2001
From: Frank Hoffmann <15r10nk-git@polarbit.de>
Date: Sat, 6 Jul 2024 11:15:32 +0200
Subject: [PATCH 3/4] feat: check if the result of copy.deepcopy() is equal to
the copied value
---
inline_snapshot/_exceptions.py | 2 ++
inline_snapshot/_inline_snapshot.py | 41 +++++++++++++++++++++--------
inline_snapshot/_sentinels.py | 3 ++-
tests/example.py | 4 ++-
tests/test_code_repr.py | 12 +++++++++
tests/test_inline_snapshot.py | 10 +++++++
tests/test_pytest_plugin.py | 33 +++++++++++++++++++++++
7 files changed, 92 insertions(+), 13 deletions(-)
create mode 100644 inline_snapshot/_exceptions.py
diff --git a/inline_snapshot/_exceptions.py b/inline_snapshot/_exceptions.py
new file mode 100644
index 0000000..24bae56
--- /dev/null
+++ b/inline_snapshot/_exceptions.py
@@ -0,0 +1,2 @@
+class UsageError(Exception):
+ pass
diff --git a/inline_snapshot/_inline_snapshot.py b/inline_snapshot/_inline_snapshot.py
index 3b5c237..84bb253 100644
--- a/inline_snapshot/_inline_snapshot.py
+++ b/inline_snapshot/_inline_snapshot.py
@@ -22,6 +22,8 @@
from ._change import DictInsert
from ._change import ListInsert
from ._change import Replace
+from ._code_repr import code_repr
+from ._exceptions import UsageError
from ._format import format_code
from ._sentinels import undefined
from ._utils import ignore_tokens
@@ -163,7 +165,7 @@ def _change(self, cls):
self.__class__ = cls
def _new_code(self):
- return ""
+ assert False
def _get_changes(self) -> Iterator[Change]:
# generic fallback
@@ -224,6 +226,23 @@ def update_allowed(value):
return not isinstance(value, dirty_equals.DirtyEquals)
+def clone(obj):
+ new = copy.deepcopy(obj)
+ if not obj == new:
+ raise UsageError(
+ f"""\
+inline-snapshot uses `copy.deepcopy` to copy objects,
+but the copied object is not equal to the original one:
+
+original: {code_repr(obj)}
+copied: {code_repr(new)}
+
+Please fix the way your object is copied or your __eq__ implementation.
+"""
+ )
+ return new
+
+
class EqValue(GenericValue):
_current_op = "x == snapshot"
@@ -232,10 +251,8 @@ def __eq__(self, other):
if self._old_value is undefined:
_missing_values += 1
- other = copy.deepcopy(other)
-
if self._new_value is undefined:
- self._new_value = other
+ self._new_value = clone(other)
return self._visible_value() == other
@@ -391,13 +408,12 @@ def _generic_cmp(self, other):
global _missing_values
if self._old_value is undefined:
_missing_values += 1
- other = copy.deepcopy(other)
if self._new_value is undefined:
- self._new_value = other
+ self._new_value = clone(other)
else:
self._new_value = (
- self._new_value if self.cmp(self._new_value, other) else other
+ self._new_value if self.cmp(self._new_value, other) else clone(other)
)
return self.cmp(self._visible_value(), other)
@@ -481,13 +497,11 @@ def __contains__(self, item):
if self._old_value is undefined:
_missing_values += 1
- item = copy.deepcopy(item)
-
if self._new_value is undefined:
- self._new_value = [item]
+ self._new_value = [clone(item)]
else:
if item not in self._new_value:
- self._new_value.append(item)
+ self._new_value.append(clone(item))
if ignore_old_value() or self._old_value is undefined:
return True
@@ -714,7 +728,12 @@ def __init__(self, value, expr):
self._uses_externals = []
def _changes(self):
+
if self._value._old_value is undefined:
+
+ if self._value._new_value is undefined:
+ return
+
new_code = self._value._new_code()
yield CallArg(
diff --git a/inline_snapshot/_sentinels.py b/inline_snapshot/_sentinels.py
index 8f34f7f..8c0d8e1 100644
--- a/inline_snapshot/_sentinels.py
+++ b/inline_snapshot/_sentinels.py
@@ -1,6 +1,7 @@
# sentinels
class Undefined:
- pass
+ def __repr__(self):
+ return "undefined"
undefined = Undefined()
diff --git a/tests/example.py b/tests/example.py
index bc9f4b2..50d168e 100644
--- a/tests/example.py
+++ b/tests/example.py
@@ -48,7 +48,7 @@ def read_files(self, dir: Path):
return {p.name: p.read_text() for p in dir.iterdir() if p.is_file()}
def run_inline(
- self, *flags, changes=None, reported_flags=None, changed_files=None
+ self, *flags, changes=None, reported_flags=None, changed_files=None, raises=None
) -> Example:
with TemporaryDirectory() as dir:
@@ -77,6 +77,8 @@ def run_inline(
for k, v in globals.items():
if k.startswith("test_") and callable(v):
v()
+ except Exception as e:
+ assert raises == f"{type(e).__name__}:\n" + str(e)
finally:
_inline_snapshot._active = False
diff --git a/tests/test_code_repr.py b/tests/test_code_repr.py
index 6b8e250..3a2871c 100644
--- a/tests/test_code_repr.py
+++ b/tests/test_code_repr.py
@@ -7,6 +7,7 @@
from inline_snapshot import HasRepr
from inline_snapshot import snapshot
from inline_snapshot._code_repr import code_repr
+from inline_snapshot._sentinels import undefined
def test_enum(check_update):
@@ -51,6 +52,11 @@ class Thing:
def __repr__(self):
return ""
+ def __eq__(self,other):
+ if not isinstance(other,Thing):
+ return NotImplemented
+ return True
+
def test_thing():
assert Thing() == snapshot()
@@ -69,6 +75,11 @@ class Thing:
def __repr__(self):
return ""
+ def __eq__(self,other):
+ if not isinstance(other,Thing):
+ return NotImplemented
+ return True
+
def test_thing():
assert Thing() == snapshot(HasRepr(Thing, ""))
@@ -324,6 +335,7 @@ class Dataclass:
OrderedDict({1: 2, 3: 4}),
UserDict({1: 2}),
UserList([1, 2]),
+ undefined,
],
)
def test_datatypes(d):
diff --git a/tests/test_inline_snapshot.py b/tests/test_inline_snapshot.py
index e9b738f..95cfb80 100644
--- a/tests/test_inline_snapshot.py
+++ b/tests/test_inline_snapshot.py
@@ -777,6 +777,11 @@ class Thing:
def __repr__(self):
return "+++"
+ def __eq__(self,other):
+ if not isinstance(other,Thing):
+ return NotImplemented
+ return True
+
assert Thing() == snapshot()
""",
flags="create",
@@ -787,6 +792,11 @@ class Thing:
def __repr__(self):
return "+++"
+ def __eq__(self,other):
+ if not isinstance(other,Thing):
+ return NotImplemented
+ return True
+
assert Thing() == snapshot(HasRepr(Thing, "+++"))
"""
)
diff --git a/tests/test_pytest_plugin.py b/tests/test_pytest_plugin.py
index 984ff26..ee321ac 100644
--- a/tests/test_pytest_plugin.py
+++ b/tests/test_pytest_plugin.py
@@ -608,3 +608,36 @@ def test_a():
"""
),
)
+
+
+def test_equal_check():
+
+ Example(
+ {
+ "test_a.py": """
+from inline_snapshot import snapshot
+
+class Thing:
+ def __repr__(self):
+ return "Thing()"
+
+def test_a():
+ assert Thing()==snapshot()
+ """,
+ }
+ ).run_inline(
+ "--inline-snapshot=create",
+ changed_files=snapshot({}),
+ raises=snapshot(
+ """\
+UsageError:
+inline-snapshot uses `copy.deepcopy` to copy objects,
+but the copied object is not equal to the original one:
+
+original: Thing()
+copied: Thing()
+
+Please fix the way your object is copied or your __eq__ implementation.
+"""
+ ),
+ )
From fde435da46e0adf9c79de482306f2c502467aa33 Mon Sep 17 00:00:00 2001
From: Frank Hoffmann <15r10nk-git@polarbit.de>
Date: Sat, 6 Jul 2024 15:22:07 +0200
Subject: [PATCH 4/4] test: skip doc tests on windows
---
tests/test_docs.py | 5 +++++
1 file changed, 5 insertions(+)
diff --git a/tests/test_docs.py b/tests/test_docs.py
index 83e405d..7e69a83 100644
--- a/tests/test_docs.py
+++ b/tests/test_docs.py
@@ -1,3 +1,4 @@
+import platform
import re
import textwrap
from pathlib import Path
@@ -7,6 +8,10 @@
import inline_snapshot._inline_snapshot
+@pytest.mark.skipif(
+ platform.system() == "Windows",
+ reason="\r in stdout can cause problems in snapshot strings",
+)
@pytest.mark.parametrize(
"file",
[