From 5352ea6a394ab1f8791bd4d2505594a3084e23d5 Mon Sep 17 00:00:00 2001 From: Frank Hoffmann <15r10nk-git@polarbit.de> Date: Fri, 3 May 2024 23:30:12 +0200 Subject: [PATCH 1/4] feat: support for `enum.Enum`, `enum.Flag`, `type` and omitting of default values (#73) --- docs/code_generation.md | 105 ++++++++ docs/customize_repr.md | 114 +++++++++ docs/index.md | 103 -------- inline_snapshot/__init__.py | 4 +- inline_snapshot/_code_repr.py | 243 +++++++++++++++++++ inline_snapshot/_inline_snapshot.py | 4 - inline_snapshot/_utils.py | 4 +- inline_snapshot/pytest_plugin.py | 11 +- mkdocs.yml | 2 + noxfile.py | 3 +- pyproject.toml | 4 + tests/conftest.py | 2 +- tests/example.py | 107 +++++++-- tests/test_code_repr.py | 359 ++++++++++++++++++++++++++++ tests/test_example.py | 22 +- tests/test_hasrepr.py | 12 + tests/test_inline_snapshot.py | 2 +- tests/test_pydantic.py | 42 ++++ 18 files changed, 1012 insertions(+), 131 deletions(-) create mode 100644 docs/code_generation.md create mode 100644 docs/customize_repr.md create mode 100644 inline_snapshot/_code_repr.py create mode 100644 tests/test_code_repr.py create mode 100644 tests/test_hasrepr.py create mode 100644 tests/test_pydantic.py diff --git a/docs/code_generation.md b/docs/code_generation.md new file mode 100644 index 0000000..1a8353a --- /dev/null +++ b/docs/code_generation.md @@ -0,0 +1,105 @@ + + +You can use almost any python datatype and also complex values like `datatime.date`, because `repr()` is used to convert the values to source code. +The default `__repr__()` behaviour can be [customized](customize_repr.md). +It might be necessary to import the right modules to match the `repr()` output. + +=== "original code" + + ```python + from inline_snapshot import snapshot + import datetime + + + def something(): + return { + "name": "hello", + "one number": 5, + "numbers": list(range(10)), + "sets": {1, 2, 15}, + "datetime": datetime.date(1, 2, 22), + "complex stuff": 5j + 3, + "bytes": b"byte abc\n\x16", + } + + + def test_something(): + assert something() == snapshot() + ``` +=== "--inline-snapshot=create" + + ```python + from inline_snapshot import snapshot + import datetime + + + def something(): + return { + "name": "hello", + "one number": 5, + "numbers": list(range(10)), + "sets": {1, 2, 15}, + "datetime": datetime.date(1, 2, 22), + "complex stuff": 5j + 3, + "bytes": b"byte abc\n\x16", + } + + + def test_something(): + assert something() == snapshot( + { + "name": "hello", + "one number": 5, + "numbers": [0, 1, 2, 3, 4, 5, 6, 7, 8, 9], + "sets": {1, 2, 15}, + "datetime": datetime.date(1, 2, 22), + "complex stuff": (3 + 5j), + "bytes": b"byte abc\n\x16", + } + ) + ``` + +The code is generated in the following way: + +1. The value is copied with `value = copy.deepcopy(value)` and it is checked if the copied value is equal to the original value. +2. The code is generated with `repr(value)` (which can be [customized](customize_repr.md)) +3. Strings which contain newlines are converted to triple quoted strings. + + !!! note + Missing newlines at start or end are escaped (since 0.4.0). + + === "original code" + + ``` python + def test_something(): + assert "first line\nsecond line" == snapshot( + """first line + second line""" + ) + ``` + + === "--inline-snapshot=update" + + ``` python + def test_something(): + assert "first line\nsecond line" == snapshot( + """\ + first line + second line\ + """ + ) + ``` + + +4. The code is formatted with black. + + +5. The whole file is formatted with black if it was formatted before. + + !!! note + The black formatting of the whole file could not work for the following reasons: + + 1. black is configured with cli arguments and not in a configuration file.
+ **Solution:** configure black in a [configuration file](https://black.readthedocs.io/en/stable/usage_and_configuration/the_basics.html#configuration-via-a-file) + 2. inline-snapshot uses a different black version.
+ **Solution:** specify which black version inline-snapshot should use by adding black with a specific version to your dependencies. diff --git a/docs/customize_repr.md b/docs/customize_repr.md new file mode 100644 index 0000000..5f7e5f0 --- /dev/null +++ b/docs/customize_repr.md @@ -0,0 +1,114 @@ + + + +`repr()` can be used to convert a python object into a source code representation of the object, but this does not work for every type. +Here are some examples: +```pycon +>>> repr(int) +"" + +>>> from enum import Enum +>>> E = Enum("E", ["a", "b"]) +>>> repr(E.a) +'' +``` + +`customize_repr` can be used to overwrite the default `repr()` behaviour. + +The implementation for `Enum` looks like this: + +```python exec="1" result="python" +print('--8<-- "inline_snapshot/_code_repr.py:Enum"') +``` + +This implementation is then used by inline-snapshot if `repr()` is called during the code generation, but not in normal code. + + +```python +from enum import Enum + + +def test_enum(): + E = Enum("E", ["a", "b"]) + + # normal repr + assert repr(E.a) == "" + + # the special implementation to convert the Enum into a code + assert E.a == snapshot(E.a) +``` + +## builtin datatypes + +inline-snapshot comes with a special implementation for the following types: +```python exec="1" +from inline_snapshot._code_repr import code_repr_dispatch, code_repr + +for name, obj in sorted( + ( + getattr( + obj, "_inline_snapshot_name", f"{obj.__module__}.{obj.__qualname__}" + ), + obj, + ) + for obj in code_repr_dispatch.registry.keys() +): + if obj is not object: + print(f"- `{name}`") +``` + +Container types like `dict` or `dataclass` need a special implementation because it is necessary that the implementation uses `repr()` for the child elements. + +```python exec="1" result="python" +print('--8<-- "inline_snapshot/_code_repr.py:list"') +``` + +!!! note + using `#!python f"{obj!r}"` or `#!c PyObject_Repr()` will not work, because inline-snapshot replaces `#!python builtins.repr` during the code generation. + +## customize + +You can also use `repr()` inside `__repr__()`, if you want to make your own type compatible with inline-snapshot. + + +```python +from enum import Enum + + +class Pair: + def __init__(self, a, b): + self.a = a + self.b = b + + def __repr__(self): + # this would not work + # return f"Pair({self.a!r}, {self.b!r})" + + # you have to use repr() + return f"Pair({repr(self.a)}, {repr(self.b)})" + + def __eq__(self, other): + if not isinstance(other, Pair): + return NotImplemented + return self.a == other.a and self.b == other.b + + +def test_enum(): + E = Enum("E", ["a", "b"]) + + # the special repr implementation is used recursive here + # to convert every Enum to the correct representation + assert Pair(E.a, [E.b]) == snapshot(Pair(E.a, [E.b])) +``` + +you can also customize the representation of datatypes in other libraries: + +``` python +from inline_snapshot import customize_repr +from other_lib import SomeType + + +@customize_repr +def _(value: SomeType): + return f"SomeType(x={repr(value.x)})" +``` diff --git a/docs/index.md b/docs/index.md index f9d1a79..9d6d76a 100644 --- a/docs/index.md +++ b/docs/index.md @@ -201,108 +201,5 @@ or passed as an argument to a function: -## Code generation - -You can use almost any python datatype and also complex values like `datatime.date`, because `repr()` is used to convert the values to a source code. -It might be necessary to import the right modules to match the `repr()` output. - -=== "original code" - - ```python - from inline_snapshot import snapshot - import datetime - - - def something(): - return { - "name": "hello", - "one number": 5, - "numbers": list(range(10)), - "sets": {1, 2, 15}, - "datetime": datetime.date(1, 2, 22), - "complex stuff": 5j + 3, - "bytes": b"fglecg\n\x16", - } - - - def test_something(): - assert something() == snapshot() - ``` -=== "--inline-snapshot=create" - - ```python - from inline_snapshot import snapshot - import datetime - - - def something(): - return { - "name": "hello", - "one number": 5, - "numbers": list(range(10)), - "sets": {1, 2, 15}, - "datetime": datetime.date(1, 2, 22), - "complex stuff": 5j + 3, - "bytes": b"fglecg\n\x16", - } - - - def test_something(): - assert something() == snapshot( - { - "name": "hello", - "one number": 5, - "numbers": [0, 1, 2, 3, 4, 5, 6, 7, 8, 9], - "sets": {1, 2, 15}, - "datetime": datetime.date(1, 2, 22), - "complex stuff": (3 + 5j), - "bytes": b"fglecg\n\x16", - } - ) - ``` - -The code is generated in the following way: - -1. The value is copied with `value = copy.deepcopy(value)` -2. The code is generated with `repr(value)` -3. Strings which contain newlines are converted to triple quoted strings. - - !!! note - Missing newlines at start or end are escaped (since 0.4.0). - - === "original code" - - ``` python - def test_something(): - assert "first line\nsecond line" == snapshot( - """first line - second line""" - ) - ``` - - === "--inline-snapshot=update" - - ``` python - def test_something(): - assert "first line\nsecond line" == snapshot( - """\ - first line - second line\ - """ - ) - ``` - - -4. The code is formatted with black. - - !!! note - The black formatting of the whole file could not work for the following reasons: - - 1. black is configured with cli arguments and not in a configuration file.
- **Solution:** configure black in a [configuration file](https://black.readthedocs.io/en/stable/usage_and_configuration/the_basics.html#configuration-via-a-file) - 2. inline-snapshot uses a different black version.
- **Solution:** specify which black version inline-snapshot should use by adding black with a specific version to your dependencies. - -5. The whole file is formatted with black if it was formatted before. --8<-- "README.md:Feedback" diff --git a/inline_snapshot/__init__.py b/inline_snapshot/__init__.py index af0b273..c1fdf7c 100644 --- a/inline_snapshot/__init__.py +++ b/inline_snapshot/__init__.py @@ -1,7 +1,9 @@ +from ._code_repr import customize_repr +from ._code_repr import HasRepr from ._external import external from ._external import outsource from ._inline_snapshot import snapshot -__all__ = ["snapshot", "external", "outsource"] +__all__ = ["snapshot", "external", "outsource", "customize_repr", "HasRepr"] __version__ = "0.10.2" diff --git a/inline_snapshot/_code_repr.py b/inline_snapshot/_code_repr.py new file mode 100644 index 0000000..9a5dcd3 --- /dev/null +++ b/inline_snapshot/_code_repr.py @@ -0,0 +1,243 @@ +import ast +from abc import ABC +from collections import defaultdict +from dataclasses import fields +from dataclasses import is_dataclass +from dataclasses import MISSING +from enum import Enum +from enum import Flag +from functools import singledispatch +from unittest import mock + +real_repr = repr + + +class HasRepr: + """This class is used for objects where `__repr__()` returns an non- + parsable representation. + + HasRepr uses the type and repr of the value for equal comparison. + + You can change `__repr__()` to return valid python code or use + `@customize_repr` to customize repr which is used by inline- + snapshot. + """ + + def __init__(self, type, str_repr: str) -> None: + self._type = type + self._str_repr = str_repr + + def __repr__(self): + return f"HasRepr({self._type.__qualname__}, {self._str_repr!r})" + + def __eq__(self, other): + if isinstance(other, HasRepr): + if other._type is not self._type: + return False + else: + if type(other) is not self._type: + return False + + other_repr = code_repr(other) + return other_repr == self._str_repr or other_repr == repr(self) + + +def used_hasrepr(tree): + return [ + n + for n in ast.walk(tree) + if isinstance(n, ast.Call) + and isinstance(n.func, ast.Name) + and n.func.id == "HasRepr" + and len(n.args) == 2 + ] + + +@singledispatch +def code_repr_dispatch(value): + return real_repr(value) + + +def customize_repr(f): + """Register a funtion which should be used to get the code representation + of a object. + + ```python + @customize_repr + def _(obj: MyCustomClass): + return f"MyCustomClass(attr={repr(obj.attr)})" + ``` + + it is important to use `repr()` inside the implementation, because it is mocked to return the code represenation + + you dont have to provide a custom implementation if: + * __repr__() of your class returns a valid code representation, + * and __repr__() uses `repr()` to get the representaion of the child objects + """ + code_repr_dispatch.register(f) + + +def code_repr(obj): + with mock.patch("builtins.repr", code_repr): + result = code_repr_dispatch(obj) + + try: + ast.parse(result) + except SyntaxError: + return real_repr(HasRepr(type(obj), result)) + + return result + + +# -8<- [start:Enum] +@customize_repr +def _(value: Enum): + return f"{type(value).__qualname__}.{value.name}" + + +# -8<- [end:Enum] + + +@customize_repr +def _(value: Flag): + name = type(value).__qualname__ + return " | ".join(f"{name}.{flag.name}" for flag in type(value) if flag in value) + + +# -8<- [start:list] +@customize_repr +def _(value: list): + return "[" + ", ".join(map(repr, value)) + "]" + + +# -8<- [end:list] + + +class OnlyTuple(ABC): + _inline_snapshot_name = "builtins.tuple" + + @classmethod + def __subclasshook__(cls, t): + return t is tuple + + +@customize_repr +def _(value: OnlyTuple): + assert isinstance(value, tuple) + if len(value) == 1: + return f"({repr(value[0])},)" + return "(" + ", ".join(map(repr, value)) + ")" + + +class IsNamedTuple(ABC): + _inline_snapshot_name = "namedtuple" + + _fields: tuple + _field_defaults: dict + + @classmethod + def __subclasshook__(cls, t): + b = t.__bases__ + if len(b) != 1 or b[0] != tuple: + return False + f = getattr(t, "_fields", None) + if not isinstance(f, tuple): + return False + return all(type(n) == str for n in f) + + +@customize_repr +def _(value: IsNamedTuple): + params = ", ".join( + f"{field}={repr(getattr(value,field))}" + for field in value._fields + if field not in value._field_defaults + or getattr(value, field) != value._field_defaults[field] + ) + return f"{repr(type(value))}({params})" + + +@customize_repr +def _(value: set): + if len(value) == 0: + return "set()" + + return "{" + ", ".join(map(repr, value)) + "}" + + +@customize_repr +def _(value: frozenset): + if len(value) == 0: + return "frozenset()" + + return "frozenset({" + ", ".join(map(repr, value)) + "})" + + +@customize_repr +def _(value: dict): + result = ( + "{" + ", ".join(f"{repr(k)}: {repr(value)}" for k, value in value.items()) + "}" + ) + + if type(value) is not dict: + result = f"{repr(type(value))}({result})" + + return result + + +@customize_repr +def _(value: defaultdict): + return f"defaultdict({repr(value.default_factory)}, {repr(dict(value))})" + + +@customize_repr +def _(value: type): + return value.__qualname__ + + +class IsDataclass(ABC): + _inline_snapshot_name = "dataclass" + + @classmethod + def __subclasshook__(cls, subclass): + return is_dataclass(subclass) + + +@customize_repr +def _(value: IsDataclass): + attrs = [] + for field in fields(value): # type: ignore + if field.repr: + field_value = getattr(value, field.name) + + if field.default != MISSING and field.default == field_value: + continue + + if ( + field.default_factory != MISSING + and field.default_factory() == field_value + ): + continue + + attrs.append(f"{field.name}={repr(field_value)}") + + return f"{repr(type(value))}({', '.join(attrs)})" + + +try: + from pydantic import BaseModel +except ImportError: # pragma: no cover + pass +else: + + @customize_repr + def _(model: BaseModel): + return ( + type(model).__qualname__ + + "(" + + ", ".join( + e + "=" + repr(getattr(model, e)) + for e in sorted(model.__pydantic_fields_set__) + ) + + ")" + ) diff --git a/inline_snapshot/_inline_snapshot.py b/inline_snapshot/_inline_snapshot.py index 0bdf8eb..3b5c237 100644 --- a/inline_snapshot/_inline_snapshot.py +++ b/inline_snapshot/_inline_snapshot.py @@ -716,10 +716,6 @@ def __init__(self, value, expr): def _changes(self): if self._value._old_value is undefined: new_code = self._value._new_code() - try: - ast.parse(new_code) - except: - return yield CallArg( "create", diff --git a/inline_snapshot/_utils.py b/inline_snapshot/_utils.py index 2860e16..a5ecf11 100644 --- a/inline_snapshot/_utils.py +++ b/inline_snapshot/_utils.py @@ -4,6 +4,8 @@ import tokenize from collections import namedtuple +from ._code_repr import code_repr + def normalize_strings(token_sequence): """Normalize string concattenanion. @@ -118,7 +120,7 @@ def __eq__(self, other): def value_to_token(value): - input = io.StringIO(repr(value)) + input = io.StringIO(code_repr(value)) def map_string(tok): """Convert strings with newlines in triple quoted strings.""" diff --git a/inline_snapshot/pytest_plugin.py b/inline_snapshot/pytest_plugin.py index c42a76c..d97a440 100644 --- a/inline_snapshot/pytest_plugin.py +++ b/inline_snapshot/pytest_plugin.py @@ -15,6 +15,7 @@ from . import _find_external from . import _inline_snapshot from ._change import apply_all +from ._code_repr import used_hasrepr from ._find_external import ensure_import from ._inline_snapshot import used_externals from ._rewrite_code import ChangeRecorder @@ -312,9 +313,17 @@ def report(flag, message, message_n): tree = ast.parse(test_file.new_code()) used = used_externals(tree) + required_imports = [] + if used: + required_imports.append("external") + + if used_hasrepr(tree): + required_imports.append("HasRepr") + + if required_imports: ensure_import( - test_file.filename, {"inline_snapshot": ["external"]} + test_file.filename, {"inline_snapshot": required_imports} ) for external_name in used: diff --git a/mkdocs.yml b/mkdocs.yml index 782bcc1..7f04a05 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -36,9 +36,11 @@ nav: - x in snapshot(): in_snapshot.md - snapshot()[key]: getitem_snapshot.md - outsource(data): outsource.md + - '@customize_repr': customize_repr.md - pytest integration: pytest.md - Categories: categories.md - Configuration: configuration.md +- Code generation: code_generation.md - Contributing: contributing.md - Changelog: changelog.md diff --git a/noxfile.py b/noxfile.py index 01a8a02..ec4faf8 100644 --- a/noxfile.py +++ b/noxfile.py @@ -24,7 +24,7 @@ def coverage(session): @session(python=python_versions) def mypy(session): - session.install("mypy", "pytest", "hypothesis", "dirty-equals", ".") + session.install("mypy", "pytest", "hypothesis", "dirty-equals", "pydantic", ".") args = ["inline_snapshot", "tests"] if session.posargs: args = session.posargs @@ -46,6 +46,7 @@ def test(session): "time-machine", "mypy", "pyright", + "pydantic", ) cmd = [] diff --git a/pyproject.toml b/pyproject.toml index 950154d..4c86344 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -90,3 +90,7 @@ Funding = "https://github.com/sponsors/15r10nk" Homepage = "https://15r10nk.github.io/inline-snapshot" Issues = "https://github.com/15r10nk/inline-snapshots/issues" Repository = "https://github.com/15r10nk/inline-snapshot/" + +[tool.pyright] +venv = "test-3-12" +venvPath = ".nox" diff --git a/tests/conftest.py b/tests/conftest.py index 6d0c49d..24b762a 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -98,7 +98,7 @@ def run(self, *flags): error = False try: - exec(compile(filename.read_text("utf-8"), filename, "exec")) + exec(compile(filename.read_text("utf-8"), filename, "exec"), {}) except AssertionError: traceback.print_exc() error = True diff --git a/tests/example.py b/tests/example.py index f6a8356..bc9f4b2 100644 --- a/tests/example.py +++ b/tests/example.py @@ -1,9 +1,20 @@ +from __future__ import annotations + import os import platform import re import subprocess as sp from pathlib import Path from tempfile import TemporaryDirectory +from typing import Any + +import inline_snapshot._external +from .utils import snapshot_env +from inline_snapshot import _inline_snapshot +from inline_snapshot._inline_snapshot import Flags +from inline_snapshot._rewrite_code import ChangeRecorder + +pytest_plugins = "pytester" ansi_escape = re.compile( @@ -36,10 +47,75 @@ def write_files(self, dir: Path): def read_files(self, dir: Path): return {p.name: p.read_text() for p in dir.iterdir() if p.is_file()} - def run_pytest(self, *args, changed_files=None, report=None, env={}): + def run_inline( + self, *flags, changes=None, reported_flags=None, changed_files=None + ) -> Example: + + with TemporaryDirectory() as dir: + tmp_path = Path(dir) + + self.write_files(tmp_path) + + with snapshot_env(): + with ChangeRecorder().activate() as recorder: + _inline_snapshot._update_flags = Flags({*flags}) + inline_snapshot._external.storage = ( + inline_snapshot._external.DiscStorage(tmp_path / ".storage") + ) + + error = False + + try: + for filename in tmp_path.glob("*.py"): + globals: dict[str, Any] = {} + exec( + compile(filename.read_text("utf-8"), filename, "exec"), + globals, + ) + + # run all test_* functions + for k, v in globals.items(): + if k.startswith("test_") and callable(v): + v() + + finally: + _inline_snapshot._active = False + + # number_snapshots = len(_inline_snapshot.snapshots) + + snapshot_flags = set() + + all_changes = [] + + for snapshot in _inline_snapshot.snapshots.values(): + snapshot_flags |= snapshot._flags + snapshot._change() + all_changes += snapshot._changes() + + if reported_flags is not None: + assert sorted(snapshot_flags) == reported_flags + + # if changes is not None: + # assert all_changes == changes + + recorder.fix_all() + + if changed_files is not None: + current_files = {} + + for name, content in sorted(self.read_files(tmp_path).items()): + if name not in self.files or self.files[name] != content: + current_files[name] = content + assert changed_files == current_files + + return Example(self.read_files(tmp_path)) + + def run_pytest( + self, *args, changed_files=None, report=None, env={}, returncode=None + ) -> Example: with TemporaryDirectory() as dir: - dir = Path(dir) - self.write_files(dir) + tmp_path = Path(dir) + self.write_files(tmp_path) cmd = ["pytest", *args] @@ -54,7 +130,7 @@ def run_pytest(self, *args, changed_files=None, report=None, env={}): command_env.update(env) - result = sp.run(cmd, cwd=dir, capture_output=True, env=command_env) + result = sp.run(cmd, cwd=tmp_path, capture_output=True, env=command_env) print("run>", *cmd) print("stdout:") @@ -62,9 +138,12 @@ def run_pytest(self, *args, changed_files=None, report=None, env={}): print("stderr:") print(result.stderr.decode()) + if returncode is not None: + assert result.returncode == returncode + if report is not None: - new_report = [] + report_list = [] record = False for line in result.stdout.decode().splitlines(): line = line.strip() @@ -72,28 +151,28 @@ def run_pytest(self, *args, changed_files=None, report=None, env={}): record = False if record and line: - new_report.append(line) + report_list.append(line) if line.startswith("====") and "inline snapshot" in line: record = True - new_report = "\n".join(new_report) + report_str = "\n".join(report_list) - new_report = ansi_escape.sub("", new_report) + report_str = ansi_escape.sub("", report_str) # fix windows problems - new_report = new_report.replace("\u2500", "-") - new_report = new_report.replace("\r", "") - new_report = new_report.replace(" \n", " ⏎\n") + report_str = report_str.replace("\u2500", "-") + report_str = report_str.replace("\r", "") + report_str = report_str.replace(" \n", " ⏎\n") - assert new_report == report + assert report_str == report if changed_files is not None: current_files = {} - for name, content in sorted(self.read_files(dir).items()): + for name, content in sorted(self.read_files(tmp_path).items()): if name not in self.files or self.files[name] != content: current_files[name] = content assert changed_files == current_files - return Example(self.read_files(dir)) + return Example(self.read_files(tmp_path)) diff --git a/tests/test_code_repr.py b/tests/test_code_repr.py new file mode 100644 index 0000000..6b8e250 --- /dev/null +++ b/tests/test_code_repr.py @@ -0,0 +1,359 @@ +import dataclasses +from dataclasses import dataclass + +import pytest + +from .example import Example +from inline_snapshot import HasRepr +from inline_snapshot import snapshot +from inline_snapshot._code_repr import code_repr + + +def test_enum(check_update): + + assert ( + check_update( + """ +from enum import Enum + +class color(Enum): + val="val" + + +assert [color.val] == snapshot() + + """, + flags="create", + ) + == snapshot( + """\ + +from enum import Enum + +class color(Enum): + val="val" + + +assert [color.val] == snapshot([color.val]) + +""" + ) + ) + + +def test_snapshot_generates_hasrepr(): + + Example( + """\ +from inline_snapshot import snapshot + +class Thing: + def __repr__(self): + return "" + +def test_thing(): + assert Thing() == snapshot() + + """ + ).run_pytest( + "--inline-snapshot=create", + returncode=snapshot(0), + changed_files=snapshot( + { + "test_something.py": """\ +from inline_snapshot import snapshot + +from inline_snapshot import HasRepr + +class Thing: + def __repr__(self): + return "" + +def test_thing(): + assert Thing() == snapshot(HasRepr(Thing, "")) + + \ +""" + } + ), + ).run_pytest( + "--inline-snapshot=disable", returncode=0 + ).run_pytest( + returncode=0 + ) + + +def test_hasrepr_type(): + assert 5 == HasRepr(int, "5") + assert not "a" == HasRepr(int, "5") + assert not HasRepr(float, "nan") == HasRepr(str, "nan") + assert not HasRepr(str, "a") == HasRepr(str, "b") + + +def test_enum_in_dataclass(check_update): + + assert ( + check_update( + """ +from enum import Enum +from dataclasses import dataclass + +class color(Enum): + red="red" + blue="blue" + +@dataclass +class container: + bg: color=color.red + fg: color=color.blue + +assert container(bg=color.red,fg=color.red) == snapshot() + + """, + flags="create", + ) + == snapshot( + """\ + +from enum import Enum +from dataclasses import dataclass + +class color(Enum): + red="red" + blue="blue" + +@dataclass +class container: + bg: color=color.red + fg: color=color.blue + +assert container(bg=color.red,fg=color.red) == snapshot(container(fg=color.red)) + +""" + ) + ) + + +def test_dataclass_field_repr(check_update): + + Example( + """\ +from inline_snapshot import snapshot +from dataclasses import dataclass,field + +@dataclass +class container: + a: int + b: int = field(default=5,repr=False) + +assert container(a=1,b=5) == snapshot() +""" + ).run_inline( + "create", + changed_files=snapshot( + { + "test_something.py": """\ +from inline_snapshot import snapshot +from dataclasses import dataclass,field + +@dataclass +class container: + a: int + b: int = field(default=5,repr=False) + +assert container(a=1,b=5) == snapshot(container(a=1)) +""" + } + ), + ).run_inline() + + +def test_flag(check_update): + + assert ( + check_update( + """ +from enum import Flag, auto + +class Color(Flag): + red = auto() + green = auto() + blue = auto() + +assert Color.red | Color.blue == snapshot() + + """, + flags="create", + ) + == snapshot( + """\ + +from enum import Flag, auto + +class Color(Flag): + red = auto() + green = auto() + blue = auto() + +assert Color.red | Color.blue == snapshot(Color.red | Color.blue) + +""" + ) + ) + + +def test_type(check_update): + + assert ( + check_update( + """\ +class Color: + pass + +assert [Color,int] == snapshot() + + """, + flags="create", + ) + == snapshot( + """\ +class Color: + pass + +assert [Color,int] == snapshot([Color, int]) + +""" + ) + ) + + +def test_qualname(): + + Example( + """\ +from enum import Enum +from inline_snapshot import snapshot + + +class Namespace: + class Color(Enum): + red="red" + +assert Namespace.Color.red == snapshot() + + """ + ).run_inline( + "create", + changed_files=snapshot( + { + "test_something.py": """\ +from enum import Enum +from inline_snapshot import snapshot + + +class Namespace: + class Color(Enum): + red="red" + +assert Namespace.Color.red == snapshot(Namespace.Color.red) + + \ +""" + } + ), + ).run_inline() + + +from collections import ( + Counter, + OrderedDict, + UserDict, + UserList, + defaultdict, + namedtuple, +) +from typing import NamedTuple + +A = namedtuple("A", "a,b", defaults=[0]) +B = namedtuple("B", "a,b", defaults=[0, 0]) + + +class C(NamedTuple): + a: int + b: int = 0 + c: int = 0 + + +@dataclass +class Dataclass: + a: int + b: int = dataclasses.field(default=0) + c: list = dataclasses.field(default_factory=lambda: []) + + +default_dict = defaultdict(list) +default_dict[5].append(2) +default_dict[3].append(1) + + +@pytest.mark.parametrize( + "d", + [ + frozenset(["a"]), + frozenset(), + {"a"}, + set(), + list(), + ["a"], + {}, + {1: "1"}, + (), + (1,), + (1, 2, 3), + A(1, 2), + A(1), + A(0, 0), + B(), + B(b=5), + C(1), + C(1, 2), + C(a=1, c=2), + Dataclass(a=0, b=0, c=[]), + Dataclass(a=1, b=2, c=[3]), + default_dict, + OrderedDict({1: 2, 3: 4}), + UserDict({1: 2}), + UserList([1, 2]), + ], +) +def test_datatypes(d): + code = code_repr(d) + print("repr: ", repr(d)) + print("code_repr:", code) + assert d == eval(code) + + +def test_datatypes_explicit(): + assert code_repr(C(a=1, c=2)) == snapshot("C(a=1, c=2)") + assert code_repr(B(b=5)) == snapshot("B(b=5)") + assert code_repr(B(b=0)) == snapshot("B()") + + assert code_repr(Dataclass(a=0, b=0, c=[])) == snapshot("Dataclass(a=0)") + assert code_repr(Dataclass(a=1, b=2, c=[3])) == snapshot( + "Dataclass(a=1, b=2, c=[3])" + ) + assert code_repr(Counter([1, 1, 1, 2])) == snapshot("Counter({1: 3, 2: 1})") + + assert code_repr(default_dict) == snapshot("defaultdict(list, {5: [2], 3: [1]})") + + +def test_tuple(): + + class FakeTuple(tuple): + def __init__(self): + self._fields = 5 + + def __repr__(self): + return "FakeTuple()" + + assert code_repr(FakeTuple()) == snapshot("FakeTuple()") diff --git a/tests/test_example.py b/tests/test_example.py index 9f24116..dbcd93e 100644 --- a/tests/test_example.py +++ b/tests/test_example.py @@ -1,15 +1,29 @@ from .example import Example +from inline_snapshot import snapshot -def test_diff_multiple_files(): +def test_example(): - Example( - """ + e = Example( + { + "test_a.py": """ from inline_snapshot import snapshot def test_a(): assert 1==snapshot(2) """, - ).run_pytest( + "test_b.py": "1+1", + }, + ) + + e.run_pytest( "--inline-snapshot=create,fix", ) + + e.run_inline( + "fix", + reported_flags=snapshot(["fix"]), + ).run_inline( + "fix", + changed_files=snapshot({}), + ) diff --git a/tests/test_hasrepr.py b/tests/test_hasrepr.py new file mode 100644 index 0000000..1ec0899 --- /dev/null +++ b/tests/test_hasrepr.py @@ -0,0 +1,12 @@ +from inline_snapshot._code_repr import HasRepr + + +def test_hasrepr_eq(): + + assert float("nan") == HasRepr(float, "nan") + + class Thing: + def __repr__(self): + return "" + + assert Thing() == HasRepr(Thing, "") diff --git a/tests/test_inline_snapshot.py b/tests/test_inline_snapshot.py index b2486b9..e9b738f 100644 --- a/tests/test_inline_snapshot.py +++ b/tests/test_inline_snapshot.py @@ -787,7 +787,7 @@ class Thing: def __repr__(self): return "+++" -assert Thing() == snapshot() +assert Thing() == snapshot(HasRepr(Thing, "+++")) """ ) ) diff --git a/tests/test_pydantic.py b/tests/test_pydantic.py new file mode 100644 index 0000000..c6c55a7 --- /dev/null +++ b/tests/test_pydantic.py @@ -0,0 +1,42 @@ +from .example import Example +from inline_snapshot import snapshot + + +def test_pydantic_repr(): + + Example( + """ +from pydantic import BaseModel +from inline_snapshot import snapshot + +class M(BaseModel): + size:int + name:str + age:int=4 + +def test_pydantic(): + assert M(size=5,name="Tom")==snapshot() + + """ + ).run_inline( + "create", + changed_files=snapshot( + { + "test_something.py": """\ + +from pydantic import BaseModel +from inline_snapshot import snapshot + +class M(BaseModel): + size:int + name:str + age:int=4 + +def test_pydantic(): + assert M(size=5,name="Tom")==snapshot(M(name="Tom", size=5)) + + \ +""" + } + ), + ).run_inline() From e8cdfdd3460355b25d4bd935a5acb3ff46afb890 Mon Sep 17 00:00:00 2001 From: Frank Hoffmann <15r10nk-git@polarbit.de> Date: Sat, 6 Jul 2024 14:29:56 +0200 Subject: [PATCH 2/4] docs: Readme --- README.md | 56 +++++++++++++++++++++++++++++++++++++++++++++++++++++-- 1 file changed, 54 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index 15cfbe7..d1f6c6d 100644 --- a/README.md +++ b/README.md @@ -32,6 +32,7 @@ pip install inline-snapshot - **Preserved Black Formatting:** Retains formatting consistency with Black formatting. - **External File Storage:** Store snapshots externally using `outsource(data)`. - **Seamless Pytest Integration:** Integrated seamlessly with pytest for effortless testing. +- **Customizable:** code generation can be customized with [@customize_repr](https://15r10nk.github.io/inline-snapshot/customize_repr) - **Comprehensive Documentation:** Access detailed [documentation](https://15r10nk.github.io/inline-snapshot/) for complete guidance. @@ -51,7 +52,7 @@ def test_something(): You can now run the tests and record the correct values. ``` -$ pytest --inline-snapshot=create +$ pytest --inline-snapshot=review ``` @@ -63,7 +64,8 @@ def test_something(): assert 1548 * 18489 == snapshot(28620972) ``` -inline-snapshot provides more advanced features like: +The following examples show how you can use inline-snapshot in your tests. Take a look at the +[documentation](https://15r10nk.github.io/inline-snapshot/) if you want to know more. ```python @@ -100,6 +102,56 @@ def test_something(): assert outsource("large string\n" * 1000) == snapshot( external("8bf10bdf2c30*.txt") ) + + assert "generates\nmultiline\nstrings" == snapshot( + """\ +generates +multiline +strings\ +""" + ) +``` + + +`snapshot()` can also be used as parameter for functions: + + +```python +from inline_snapshot import snapshot +import subprocess as sp +import sys + + +def run_python(cmd, stdout=None, stderr=None): + result = sp.run([sys.executable, "-c", cmd], capture_output=True) + if stdout is not None: + assert result.stdout.decode() == stdout + if stderr is not None: + assert result.stderr.decode() == stderr + + +def test_cmd(): + run_python( + "print('hello world')", + stdout=snapshot( + """\ +hello world +""" + ), + stderr=snapshot(""), + ) + + run_python( + "1/0", + stdout=snapshot(""), + stderr=snapshot( + """\ +Traceback (most recent call last): + File "", line 1, in +ZeroDivisionError: division by zero +""" + ), + ) ``` From a278fa2224eb25e24386497c8f114c3d8d5dc6ca Mon Sep 17 00:00:00 2001 From: Frank Hoffmann <15r10nk-git@polarbit.de> Date: Sat, 6 Jul 2024 11:15:32 +0200 Subject: [PATCH 3/4] feat: check if the result of copy.deepcopy() is equal to the copied value --- inline_snapshot/_exceptions.py | 2 ++ inline_snapshot/_inline_snapshot.py | 41 +++++++++++++++++++++-------- inline_snapshot/_sentinels.py | 3 ++- tests/example.py | 4 ++- tests/test_code_repr.py | 12 +++++++++ tests/test_inline_snapshot.py | 10 +++++++ tests/test_pytest_plugin.py | 33 +++++++++++++++++++++++ 7 files changed, 92 insertions(+), 13 deletions(-) create mode 100644 inline_snapshot/_exceptions.py diff --git a/inline_snapshot/_exceptions.py b/inline_snapshot/_exceptions.py new file mode 100644 index 0000000..24bae56 --- /dev/null +++ b/inline_snapshot/_exceptions.py @@ -0,0 +1,2 @@ +class UsageError(Exception): + pass diff --git a/inline_snapshot/_inline_snapshot.py b/inline_snapshot/_inline_snapshot.py index 3b5c237..84bb253 100644 --- a/inline_snapshot/_inline_snapshot.py +++ b/inline_snapshot/_inline_snapshot.py @@ -22,6 +22,8 @@ from ._change import DictInsert from ._change import ListInsert from ._change import Replace +from ._code_repr import code_repr +from ._exceptions import UsageError from ._format import format_code from ._sentinels import undefined from ._utils import ignore_tokens @@ -163,7 +165,7 @@ def _change(self, cls): self.__class__ = cls def _new_code(self): - return "" + assert False def _get_changes(self) -> Iterator[Change]: # generic fallback @@ -224,6 +226,23 @@ def update_allowed(value): return not isinstance(value, dirty_equals.DirtyEquals) +def clone(obj): + new = copy.deepcopy(obj) + if not obj == new: + raise UsageError( + f"""\ +inline-snapshot uses `copy.deepcopy` to copy objects, +but the copied object is not equal to the original one: + +original: {code_repr(obj)} +copied: {code_repr(new)} + +Please fix the way your object is copied or your __eq__ implementation. +""" + ) + return new + + class EqValue(GenericValue): _current_op = "x == snapshot" @@ -232,10 +251,8 @@ def __eq__(self, other): if self._old_value is undefined: _missing_values += 1 - other = copy.deepcopy(other) - if self._new_value is undefined: - self._new_value = other + self._new_value = clone(other) return self._visible_value() == other @@ -391,13 +408,12 @@ def _generic_cmp(self, other): global _missing_values if self._old_value is undefined: _missing_values += 1 - other = copy.deepcopy(other) if self._new_value is undefined: - self._new_value = other + self._new_value = clone(other) else: self._new_value = ( - self._new_value if self.cmp(self._new_value, other) else other + self._new_value if self.cmp(self._new_value, other) else clone(other) ) return self.cmp(self._visible_value(), other) @@ -481,13 +497,11 @@ def __contains__(self, item): if self._old_value is undefined: _missing_values += 1 - item = copy.deepcopy(item) - if self._new_value is undefined: - self._new_value = [item] + self._new_value = [clone(item)] else: if item not in self._new_value: - self._new_value.append(item) + self._new_value.append(clone(item)) if ignore_old_value() or self._old_value is undefined: return True @@ -714,7 +728,12 @@ def __init__(self, value, expr): self._uses_externals = [] def _changes(self): + if self._value._old_value is undefined: + + if self._value._new_value is undefined: + return + new_code = self._value._new_code() yield CallArg( diff --git a/inline_snapshot/_sentinels.py b/inline_snapshot/_sentinels.py index 8f34f7f..8c0d8e1 100644 --- a/inline_snapshot/_sentinels.py +++ b/inline_snapshot/_sentinels.py @@ -1,6 +1,7 @@ # sentinels class Undefined: - pass + def __repr__(self): + return "undefined" undefined = Undefined() diff --git a/tests/example.py b/tests/example.py index bc9f4b2..50d168e 100644 --- a/tests/example.py +++ b/tests/example.py @@ -48,7 +48,7 @@ def read_files(self, dir: Path): return {p.name: p.read_text() for p in dir.iterdir() if p.is_file()} def run_inline( - self, *flags, changes=None, reported_flags=None, changed_files=None + self, *flags, changes=None, reported_flags=None, changed_files=None, raises=None ) -> Example: with TemporaryDirectory() as dir: @@ -77,6 +77,8 @@ def run_inline( for k, v in globals.items(): if k.startswith("test_") and callable(v): v() + except Exception as e: + assert raises == f"{type(e).__name__}:\n" + str(e) finally: _inline_snapshot._active = False diff --git a/tests/test_code_repr.py b/tests/test_code_repr.py index 6b8e250..3a2871c 100644 --- a/tests/test_code_repr.py +++ b/tests/test_code_repr.py @@ -7,6 +7,7 @@ from inline_snapshot import HasRepr from inline_snapshot import snapshot from inline_snapshot._code_repr import code_repr +from inline_snapshot._sentinels import undefined def test_enum(check_update): @@ -51,6 +52,11 @@ class Thing: def __repr__(self): return "" + def __eq__(self,other): + if not isinstance(other,Thing): + return NotImplemented + return True + def test_thing(): assert Thing() == snapshot() @@ -69,6 +75,11 @@ class Thing: def __repr__(self): return "" + def __eq__(self,other): + if not isinstance(other,Thing): + return NotImplemented + return True + def test_thing(): assert Thing() == snapshot(HasRepr(Thing, "")) @@ -324,6 +335,7 @@ class Dataclass: OrderedDict({1: 2, 3: 4}), UserDict({1: 2}), UserList([1, 2]), + undefined, ], ) def test_datatypes(d): diff --git a/tests/test_inline_snapshot.py b/tests/test_inline_snapshot.py index e9b738f..95cfb80 100644 --- a/tests/test_inline_snapshot.py +++ b/tests/test_inline_snapshot.py @@ -777,6 +777,11 @@ class Thing: def __repr__(self): return "+++" + def __eq__(self,other): + if not isinstance(other,Thing): + return NotImplemented + return True + assert Thing() == snapshot() """, flags="create", @@ -787,6 +792,11 @@ class Thing: def __repr__(self): return "+++" + def __eq__(self,other): + if not isinstance(other,Thing): + return NotImplemented + return True + assert Thing() == snapshot(HasRepr(Thing, "+++")) """ ) diff --git a/tests/test_pytest_plugin.py b/tests/test_pytest_plugin.py index 984ff26..ee321ac 100644 --- a/tests/test_pytest_plugin.py +++ b/tests/test_pytest_plugin.py @@ -608,3 +608,36 @@ def test_a(): """ ), ) + + +def test_equal_check(): + + Example( + { + "test_a.py": """ +from inline_snapshot import snapshot + +class Thing: + def __repr__(self): + return "Thing()" + +def test_a(): + assert Thing()==snapshot() + """, + } + ).run_inline( + "--inline-snapshot=create", + changed_files=snapshot({}), + raises=snapshot( + """\ +UsageError: +inline-snapshot uses `copy.deepcopy` to copy objects, +but the copied object is not equal to the original one: + +original: Thing() +copied: Thing() + +Please fix the way your object is copied or your __eq__ implementation. +""" + ), + ) From fde435da46e0adf9c79de482306f2c502467aa33 Mon Sep 17 00:00:00 2001 From: Frank Hoffmann <15r10nk-git@polarbit.de> Date: Sat, 6 Jul 2024 15:22:07 +0200 Subject: [PATCH 4/4] test: skip doc tests on windows --- tests/test_docs.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/tests/test_docs.py b/tests/test_docs.py index 83e405d..7e69a83 100644 --- a/tests/test_docs.py +++ b/tests/test_docs.py @@ -1,3 +1,4 @@ +import platform import re import textwrap from pathlib import Path @@ -7,6 +8,10 @@ import inline_snapshot._inline_snapshot +@pytest.mark.skipif( + platform.system() == "Windows", + reason="\r in stdout can cause problems in snapshot strings", +) @pytest.mark.parametrize( "file", [