diff --git a/src/inline_snapshot/_adapter/adapter.py b/src/inline_snapshot/_adapter/adapter.py index 66d63e8..bd4d26b 100644 --- a/src/inline_snapshot/_adapter/adapter.py +++ b/src/inline_snapshot/_adapter/adapter.py @@ -2,23 +2,23 @@ import ast import typing -from dataclasses import is_dataclass from inline_snapshot._source_file import SourceFile def get_adapter_type(value): - if is_dataclass(value): - from .dataclass_adapter import DataclassAdapter + from inline_snapshot._adapter.dataclass_adapter import get_adapter_for_type - return DataclassAdapter + adapter = get_adapter_for_type(type(value)) + if adapter is not None: + return adapter if isinstance(value, list): from .sequence_adapter import ListAdapter return ListAdapter - if isinstance(value, tuple): + if type(value) is tuple: from .sequence_adapter import TupleAdapter return TupleAdapter @@ -56,12 +56,16 @@ def get_adapter(self, old_value, new_value) -> Adapter: assert False def assign(self, old_value, old_node, new_value): - raise NotImplementedError + raise NotImplementedError(cls) @classmethod def map(cls, value, map_function): raise NotImplementedError(cls) + @classmethod + def repr(cls, value): + raise NotImplementedError(cls) + def adapter_map(value, map_function): return get_adapter_type(value).map(value, map_function) diff --git a/src/inline_snapshot/_adapter/dataclass_adapter.py b/src/inline_snapshot/_adapter/dataclass_adapter.py index 5394d36..767dcd5 100644 --- a/src/inline_snapshot/_adapter/dataclass_adapter.py +++ b/src/inline_snapshot/_adapter/dataclass_adapter.py @@ -2,8 +2,12 @@ import ast import warnings +from abc import ABC +from collections import defaultdict from dataclasses import fields +from dataclasses import is_dataclass from dataclasses import MISSING +from typing import Any from inline_snapshot._adapter.value_adapter import ValueAdapter @@ -15,8 +19,42 @@ from .adapter import Item +def get_adapter_for_type(typ): + subclasses = DataclassAdapter.__subclasses__() + options = [cls for cls in subclasses if cls.check_type(typ)] + # print(typ,options) + if not options: + return + + assert len(options) == 1 + return options[0] + + class DataclassAdapter(Adapter): + @classmethod + def check_type(cls, typ) -> bool: + raise NotImplementedError(cls) + + @classmethod + def arguments(cls, value) -> tuple[list[Any], dict[str, Any]]: + raise NotImplementedError(cls) + + @classmethod + def argument(cls, value, pos_or_name) -> Any: + raise NotImplementedError(cls) + + @classmethod + def repr(cls, value): + + args, kwargs = cls.arguments(value) + + arguments = [repr(value) for value in args] + [ + f"{key}={repr(value)}" for key, value in kwargs.items() + ] + + return f"{repr(type(value))}({', '.join(arguments)})" + @classmethod def map(cls, value, map_function): new_args, new_kwargs = cls.arguments(value) @@ -36,35 +74,8 @@ def items(self, value, node): if kw.arg ] - @classmethod - def arguments(cls, value): - - kwargs = {} - - for field in fields(value): # type: ignore - if field.repr: - field_value = getattr(value, field.name) - - if field.default != MISSING and field.default == field_value: - continue - - if ( - field.default_factory != MISSING - and field.default_factory() == field_value - ): - continue - - kwargs[field.name] = field_value - - return ([], kwargs) - - def argument(self, value, pos_or_name): - assert isinstance(pos_or_name, str) - return getattr(value, pos_or_name) - def assign(self, old_value, old_node, new_value): if old_node is None: - value = yield from ValueAdapter(self.context).assign( old_value, old_node, new_value ) @@ -72,6 +83,18 @@ def assign(self, old_value, old_node, new_value): assert isinstance(old_node, ast.Call) + # positional arguments + for pos_arg in old_node.args: + if isinstance(pos_arg, ast.Starred): + warnings.warn_explicit( + "star-expressions are not supported inside snapshots", + filename=self.context._source.filename, + lineno=pos_arg.lineno, + category=InlineSnapshotSyntaxWarning, + ) + return old_value + + # keyword arguments for kw in old_node.keywords: if kw.arg is None: warnings.warn_explicit( @@ -84,6 +107,42 @@ def assign(self, old_value, old_node, new_value): new_args, new_kwargs = self.arguments(new_value) + # positional arguments + + result_args = [] + + for i, (new_value_element, node) in enumerate(zip(new_args, old_node.args)): + old_value_element = self.argument(old_value, i) + result = yield from self.get_adapter( + old_value_element, new_value_element + ).assign(old_value_element, node, new_value_element) + result_args.append(result) + + print(old_node.args) + print(new_args) + if len(old_node.args) > len(new_args): + for arg_pos, node in list(enumerate(old_node.args))[len(new_args) :]: + print("del", arg_pos) + yield Delete( + "fix", + self.context._source, + node, + self.argument(old_value, arg_pos), + ) + + if len(old_node.args) < len(new_args): + for insert_pos, value in list(enumerate(new_args))[len(old_node.args) :]: + yield CallArg( + flag="fix", + file=self.context._source, + node=old_node, + arg_pos=insert_pos, + arg_name=None, + new_code=self.context._value_to_code(value), + new_value=value, + ) + + # keyword arguments result_kwargs = {} for kw in old_node.keywords: if not kw.arg in new_kwargs: @@ -143,4 +202,126 @@ def assign(self, old_value, old_node, new_value): new_value=value, ) - return type(old_value)(**result_kwargs) + return type(old_value)(*result_args, **result_kwargs) + + +class DataclassContainer(DataclassAdapter): + + @classmethod + def check_type(cls, value): + return is_dataclass(value) + + @classmethod + def arguments(cls, value): + + kwargs = {} + + for field in fields(value): # type: ignore + if field.repr: + field_value = getattr(value, field.name) + + if field.default != MISSING and field.default == field_value: + continue + + if ( + field.default_factory != MISSING + and field.default_factory() == field_value + ): + continue + + kwargs[field.name] = field_value + + return ([], kwargs) + + def argument(self, value, pos_or_name): + assert isinstance(pos_or_name, str) + return getattr(value, pos_or_name) + + +try: + from pydantic import BaseModel +except ImportError: # pragma: no cover + pass +else: + + class PydanticContainer(DataclassAdapter): + + @classmethod + def check_type(cls, value): + return issubclass(value, BaseModel) + + @classmethod + def arguments(cls, value): + + return ( + [], + { + name: getattr(value, name) + for name, info in value.model_fields.items() + if getattr(value, name) != info.default + }, + ) + + def argument(self, value, pos_or_name): + assert isinstance(pos_or_name, str) + return getattr(value, pos_or_name) + + +class IsNamedTuple(ABC): + _inline_snapshot_name = "namedtuple" + + _fields: tuple + _field_defaults: dict + + @classmethod + def __subclasshook__(cls, t): + b = t.__bases__ + if len(b) != 1 or b[0] != tuple: + return False + f = getattr(t, "_fields", None) + if not isinstance(f, tuple): + return False + return all(type(n) == str for n in f) + + +class NamedTupleContainer(DataclassAdapter): + + @classmethod + def check_type(cls, value): + return issubclass(value, IsNamedTuple) + + @classmethod + def arguments(cls, value: IsNamedTuple): + + return ( + [], + { + field: getattr(value, field) + for field in value._fields + if field not in value._field_defaults + or getattr(value, field) != value._field_defaults[field] + }, + ) + + def argument(self, value, pos_or_name): + assert isinstance(pos_or_name, str) + return getattr(value, pos_or_name) + + +class DefaultDictContainer(DataclassAdapter): + @classmethod + def check_type(cls, value): + return issubclass(value, defaultdict) + + @classmethod + def arguments(cls, value: defaultdict): + + return ([value.default_factory, dict(value)], {}) + + def argument(self, value, pos_or_name): + assert isinstance(pos_or_name, int) + if pos_or_name == 0: + return value.default_factory + elif pos_or_name == 1: + return dict(value) + assert False diff --git a/src/inline_snapshot/_adapter/dict_adapter.py b/src/inline_snapshot/_adapter/dict_adapter.py index ed64458..4e0cf94 100644 --- a/src/inline_snapshot/_adapter/dict_adapter.py +++ b/src/inline_snapshot/_adapter/dict_adapter.py @@ -13,6 +13,19 @@ class DictAdapter(Adapter): + @classmethod + def repr(cls, value): + result = ( + "{" + + ", ".join(f"{repr(k)}: {repr(value)}" for k, value in value.items()) + + "}" + ) + + if type(value) is not dict: + result = f"{repr(type(value))}({result})" + + return result + @classmethod def map(cls, value, map_function): return {k: adapter_map(v, map_function) for k, v in value.items()} diff --git a/src/inline_snapshot/_adapter/sequence_adapter.py b/src/inline_snapshot/_adapter/sequence_adapter.py index 544c380..11ab9f2 100644 --- a/src/inline_snapshot/_adapter/sequence_adapter.py +++ b/src/inline_snapshot/_adapter/sequence_adapter.py @@ -18,6 +18,16 @@ class SequenceAdapter(Adapter): node_type: type value_type: type + braces: str + trailing_comma: bool + + @classmethod + def repr(cls, value): + if len(value) == 1 and cls.trailing_comma: + seq = repr(value[0]) + "," + else: + seq = ", ".join(map(repr, value)) + return cls.braces[0] + seq + cls.braces[1] @classmethod def map(cls, value, map_function): @@ -91,8 +101,12 @@ def assign(self, old_value, old_node, new_value): class ListAdapter(SequenceAdapter): node_type = ast.List value_type = list + braces = "[]" + trailing_comma = False class TupleAdapter(SequenceAdapter): node_type = ast.Tuple value_type = tuple + braces = "()" + trailing_comma = True diff --git a/src/inline_snapshot/_adapter/value_adapter.py b/src/inline_snapshot/_adapter/value_adapter.py index 17d8115..f44d235 100644 --- a/src/inline_snapshot/_adapter/value_adapter.py +++ b/src/inline_snapshot/_adapter/value_adapter.py @@ -1,5 +1,6 @@ from __future__ import annotations +from inline_snapshot._code_repr import value_code_repr from inline_snapshot._unmanaged import Unmanaged from inline_snapshot._unmanaged import update_allowed from inline_snapshot._utils import value_to_token @@ -10,6 +11,10 @@ class ValueAdapter(Adapter): + @classmethod + def repr(cls, value): + return value_code_repr(value) + @classmethod def map(cls, value, map_function): return map_function(value) diff --git a/src/inline_snapshot/_code_repr.py b/src/inline_snapshot/_code_repr.py index 0de5cc8..0f878d7 100644 --- a/src/inline_snapshot/_code_repr.py +++ b/src/inline_snapshot/_code_repr.py @@ -1,9 +1,4 @@ import ast -from abc import ABC -from collections import defaultdict -from dataclasses import fields -from dataclasses import is_dataclass -from dataclasses import MISSING from enum import Enum from enum import Flag from functools import singledispatch @@ -79,14 +74,27 @@ def _(obj: MyCustomClass): def code_repr(obj): + + with mock.patch("builtins.repr", mocked_code_repr): + return mocked_code_repr(obj) + + +def mocked_code_repr(obj): + from inline_snapshot._adapter.adapter import get_adapter_type + + adapter = get_adapter_type(obj) + assert adapter is not None + return adapter.repr(obj) + + +def value_code_repr(obj): if not type(obj) == type(obj): # dispatch will not work in cases like this return ( f"HasRepr({repr(type(obj))}, '< type(obj) can not be compared with == >')" ) - with mock.patch("builtins.repr", code_repr): - result = code_repr_dispatch(obj) + result = code_repr_dispatch(obj) try: ast.parse(result) @@ -111,59 +119,6 @@ def _(value: Flag): return " | ".join(f"{name}.{flag.name}" for flag in type(value) if flag in value) -# -8<- [start:list] -@customize_repr -def _(value: list): - return "[" + ", ".join(map(repr, value)) + "]" - - -# -8<- [end:list] - - -class OnlyTuple(ABC): - _inline_snapshot_name = "builtins.tuple" - - @classmethod - def __subclasshook__(cls, t): - return t is tuple - - -@customize_repr -def _(value: OnlyTuple): - assert isinstance(value, tuple) - if len(value) == 1: - return f"({repr(value[0])},)" - return "(" + ", ".join(map(repr, value)) + ")" - - -class IsNamedTuple(ABC): - _inline_snapshot_name = "namedtuple" - - _fields: tuple - _field_defaults: dict - - @classmethod - def __subclasshook__(cls, t): - b = t.__bases__ - if len(b) != 1 or b[0] != tuple: - return False - f = getattr(t, "_fields", None) - if not isinstance(f, tuple): - return False - return all(type(n) == str for n in f) - - -@customize_repr -def _(value: IsNamedTuple): - params = ", ".join( - f"{field}={repr(getattr(value,field))}" - for field in value._fields - if field not in value._field_defaults - or getattr(value, field) != value._field_defaults[field] - ) - return f"{repr(type(value))}({params})" - - @customize_repr def _(value: set): if len(value) == 0: @@ -180,71 +135,6 @@ def _(value: frozenset): return "frozenset({" + ", ".join(map(repr, value)) + "})" -@customize_repr -def _(value: dict): - result = ( - "{" + ", ".join(f"{repr(k)}: {repr(value)}" for k, value in value.items()) + "}" - ) - - if type(value) is not dict: - result = f"{repr(type(value))}({result})" - - return result - - -@customize_repr -def _(value: defaultdict): - return f"defaultdict({repr(value.default_factory)}, {repr(dict(value))})" - - @customize_repr def _(value: type): return value.__qualname__ - - -class IsDataclass(ABC): - _inline_snapshot_name = "dataclass" - - @classmethod - def __subclasshook__(cls, subclass): - return is_dataclass(subclass) - - -@customize_repr -def _(value: IsDataclass): - attrs = [] - for field in fields(value): # type: ignore - if field.repr: - field_value = getattr(value, field.name) - - if field.default != MISSING and field.default == field_value: - continue - - if ( - field.default_factory != MISSING - and field.default_factory() == field_value - ): - continue - - attrs.append(f"{field.name}={repr(field_value)}") - - return f"{repr(type(value))}({', '.join(attrs)})" - - -try: - from pydantic import BaseModel -except ImportError: # pragma: no cover - pass -else: - - @customize_repr - def _(model: BaseModel): - return ( - type(model).__qualname__ - + "(" - + ", ".join( - e + "=" + repr(getattr(model, e)) - for e in sorted(model.__pydantic_fields_set__) - ) - + ")" - ) diff --git a/src/inline_snapshot/testing/_example.py b/src/inline_snapshot/testing/_example.py index 034ee8a..64f6595 100644 --- a/src/inline_snapshot/testing/_example.py +++ b/src/inline_snapshot/testing/_example.py @@ -5,6 +5,7 @@ import platform import re import subprocess as sp +import traceback from argparse import ArgumentParser from pathlib import Path from tempfile import TemporaryDirectory @@ -167,6 +168,7 @@ def run_inline( if k.startswith("test_") and callable(v): v() except Exception as e: + traceback.print_exc() raised_exception = e finally: diff --git a/tests/adapter/test_dataclass.py b/tests/adapter/test_dataclass.py index 4d81b7d..a0b1176 100644 --- a/tests/adapter/test_dataclass.py +++ b/tests/adapter/test_dataclass.py @@ -222,3 +222,170 @@ def test_something(): """ ), ) + + +def test_positional_star_args(): + + with warns( + snapshot( + [ + "InlineSnapshotSyntaxWarning: star-expressions are not supported inside snapshots" + ] + ) + ): + Example( + """\ +from inline_snapshot import snapshot +from dataclasses import dataclass + +@dataclass +class A: + a:int + +def test_something(): + assert A(a=3) == snapshot(A(*[],a=3)),"not equal" +""" + ).run_inline( + ["--inline-snapshot=report"], + ) + + +def test_remove_positional_argument(): + Example( + """\ +from inline_snapshot import snapshot + +from inline_snapshot._adapter.dataclass_adapter import DataclassAdapter + + +class L: + def __init__(self,*l): + self.l=l + + def __eq__(self,other): + if not isinstance(other,L): + return NotImplemented + return other.l==self.l + +class LAdapter(DataclassAdapter): + @classmethod + def check_type(cls, typ): + return issubclass(typ,L) + + @classmethod + def arguments(cls, value): + return (value.l,{}) + + @classmethod + def argument(cls, value, pos_or_name): + assert isinstance(pos_or_name,int) + return value.l[pos_or_name] + +def test_L1(): + assert L(1,2) == snapshot(L(1)), "not equal" + +def test_L2(): + assert L(1,2) == snapshot(L(1, 2, 3)), "not equal" +""" + ).run_pytest( + ["--inline-snapshot=fix"], + changed_files=snapshot( + { + "test_something.py": """\ +from inline_snapshot import snapshot + +from inline_snapshot._adapter.dataclass_adapter import DataclassAdapter + + +class L: + def __init__(self,*l): + self.l=l + + def __eq__(self,other): + if not isinstance(other,L): + return NotImplemented + return other.l==self.l + +class LAdapter(DataclassAdapter): + @classmethod + def check_type(cls, typ): + return issubclass(typ,L) + + @classmethod + def arguments(cls, value): + return (value.l,{}) + + @classmethod + def argument(cls, value, pos_or_name): + assert isinstance(pos_or_name,int) + return value.l[pos_or_name] + +def test_L1(): + assert L(1,2) == snapshot(L(1, 2)), "not equal" + +def test_L2(): + assert L(1,2) == snapshot(L(1, 2)), "not equal" +""" + } + ), + ) + + +def test_namedtuple(): + Example( + """\ +from inline_snapshot import snapshot +from collections import namedtuple + +T=namedtuple("T","a,b") + +def test_tuple(): + assert T(a=1,b=2) == snapshot(T(a=1, b=3)), "not equal" +""" + ).run_pytest( + ["--inline-snapshot=fix"], + changed_files=snapshot( + { + "test_something.py": """\ +from inline_snapshot import snapshot +from collections import namedtuple + +T=namedtuple("T","a,b") + +def test_tuple(): + assert T(a=1,b=2) == snapshot(T(a=1, b=2)), "not equal" +""" + } + ), + ) + + +def test_defaultdict(): + Example( + """\ +from inline_snapshot import snapshot +from collections import defaultdict + + +def test_tuple(): + d=defaultdict(list) + d[1].append(2) + assert d == snapshot(defaultdict(list, {1: [3]})), "not equal" +""" + ).run_pytest( + ["--inline-snapshot=fix"], + changed_files=snapshot( + { + "test_something.py": """\ +from inline_snapshot import snapshot +from collections import defaultdict + + +def test_tuple(): + d=defaultdict(list) + d[1].append(2) + assert d == snapshot(defaultdict(list, {1: [2]})), "not equal" +""" + } + ), + ) diff --git a/tests/test_pydantic.py b/tests/test_pydantic.py index 47d1e23..7d8cd29 100644 --- a/tests/test_pydantic.py +++ b/tests/test_pydantic.py @@ -31,7 +31,7 @@ class M(BaseModel): age:int=4 def test_pydantic(): - assert M(size=5,name="Tom")==snapshot(M(name="Tom", size=5)) + assert M(size=5,name="Tom")==snapshot(M(size=5, name="Tom")) """ } ),