Skip to content

Commit

Permalink
refactor: use adapter for code_repr
Browse files Browse the repository at this point in the history
  • Loading branch information
15r10nk committed Nov 3, 2024
1 parent 59cc569 commit b708c40
Show file tree
Hide file tree
Showing 9 changed files with 436 additions and 160 deletions.
16 changes: 10 additions & 6 deletions src/inline_snapshot/_adapter/adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,23 +2,23 @@

import ast
import typing
from dataclasses import is_dataclass

from inline_snapshot._source_file import SourceFile


def get_adapter_type(value):
if is_dataclass(value):
from .dataclass_adapter import DataclassAdapter
from inline_snapshot._adapter.dataclass_adapter import get_adapter_for_type

return DataclassAdapter
adapter = get_adapter_for_type(type(value))
if adapter is not None:
return adapter

if isinstance(value, list):
from .sequence_adapter import ListAdapter

return ListAdapter

if isinstance(value, tuple):
if type(value) is tuple:
from .sequence_adapter import TupleAdapter

return TupleAdapter
Expand Down Expand Up @@ -56,12 +56,16 @@ def get_adapter(self, old_value, new_value) -> Adapter:
assert False

def assign(self, old_value, old_node, new_value):
raise NotImplementedError
raise NotImplementedError(cls)

@classmethod
def map(cls, value, map_function):
raise NotImplementedError(cls)

@classmethod
def repr(cls, value):
raise NotImplementedError(cls)


def adapter_map(value, map_function):
return get_adapter_type(value).map(value, map_function)
237 changes: 209 additions & 28 deletions src/inline_snapshot/_adapter/dataclass_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,12 @@

import ast
import warnings
from abc import ABC
from collections import defaultdict
from dataclasses import fields
from dataclasses import is_dataclass
from dataclasses import MISSING
from typing import Any

from inline_snapshot._adapter.value_adapter import ValueAdapter

Expand All @@ -15,8 +19,42 @@
from .adapter import Item


def get_adapter_for_type(typ):
subclasses = DataclassAdapter.__subclasses__()
options = [cls for cls in subclasses if cls.check_type(typ)]
# print(typ,options)
if not options:
return

assert len(options) == 1
return options[0]


class DataclassAdapter(Adapter):

@classmethod
def check_type(cls, typ) -> bool:
raise NotImplementedError(cls)

@classmethod
def arguments(cls, value) -> tuple[list[Any], dict[str, Any]]:
raise NotImplementedError(cls)

@classmethod
def argument(cls, value, pos_or_name) -> Any:
raise NotImplementedError(cls)

@classmethod
def repr(cls, value):

args, kwargs = cls.arguments(value)

arguments = [repr(value) for value in args] + [
f"{key}={repr(value)}" for key, value in kwargs.items()
]

return f"{repr(type(value))}({', '.join(arguments)})"

@classmethod
def map(cls, value, map_function):
new_args, new_kwargs = cls.arguments(value)
Expand All @@ -36,42 +74,27 @@ def items(self, value, node):
if kw.arg
]

@classmethod
def arguments(cls, value):

kwargs = {}

for field in fields(value): # type: ignore
if field.repr:
field_value = getattr(value, field.name)

if field.default != MISSING and field.default == field_value:
continue

if (
field.default_factory != MISSING
and field.default_factory() == field_value
):
continue

kwargs[field.name] = field_value

return ([], kwargs)

def argument(self, value, pos_or_name):
assert isinstance(pos_or_name, str)
return getattr(value, pos_or_name)

def assign(self, old_value, old_node, new_value):
if old_node is None:

value = yield from ValueAdapter(self.context).assign(
old_value, old_node, new_value
)
return value

assert isinstance(old_node, ast.Call)

# positional arguments
for pos_arg in old_node.args:
if isinstance(pos_arg, ast.Starred):
warnings.warn_explicit(
"star-expressions are not supported inside snapshots",
filename=self.context._source.filename,
lineno=pos_arg.lineno,
category=InlineSnapshotSyntaxWarning,
)
return old_value

# keyword arguments
for kw in old_node.keywords:
if kw.arg is None:
warnings.warn_explicit(
Expand All @@ -84,6 +107,42 @@ def assign(self, old_value, old_node, new_value):

new_args, new_kwargs = self.arguments(new_value)

# positional arguments

result_args = []

for i, (new_value_element, node) in enumerate(zip(new_args, old_node.args)):
old_value_element = self.argument(old_value, i)
result = yield from self.get_adapter(
old_value_element, new_value_element
).assign(old_value_element, node, new_value_element)
result_args.append(result)

print(old_node.args)
print(new_args)
if len(old_node.args) > len(new_args):
for arg_pos, node in list(enumerate(old_node.args))[len(new_args) :]:
print("del", arg_pos)
yield Delete(
"fix",
self.context._source,
node,
self.argument(old_value, arg_pos),
)

if len(old_node.args) < len(new_args):
for insert_pos, value in list(enumerate(new_args))[len(old_node.args) :]:
yield CallArg(
flag="fix",
file=self.context._source,
node=old_node,
arg_pos=insert_pos,
arg_name=None,
new_code=self.context._value_to_code(value),
new_value=value,
)

# keyword arguments
result_kwargs = {}
for kw in old_node.keywords:
if not kw.arg in new_kwargs:
Expand Down Expand Up @@ -143,4 +202,126 @@ def assign(self, old_value, old_node, new_value):
new_value=value,
)

return type(old_value)(**result_kwargs)
return type(old_value)(*result_args, **result_kwargs)


class DataclassContainer(DataclassAdapter):

@classmethod
def check_type(cls, value):
return is_dataclass(value)

@classmethod
def arguments(cls, value):

kwargs = {}

for field in fields(value): # type: ignore
if field.repr:
field_value = getattr(value, field.name)

if field.default != MISSING and field.default == field_value:
continue

if (
field.default_factory != MISSING
and field.default_factory() == field_value
):
continue

kwargs[field.name] = field_value

return ([], kwargs)

def argument(self, value, pos_or_name):
assert isinstance(pos_or_name, str)
return getattr(value, pos_or_name)


try:
from pydantic import BaseModel
except ImportError: # pragma: no cover
pass
else:

class PydanticContainer(DataclassAdapter):

@classmethod
def check_type(cls, value):
return issubclass(value, BaseModel)

@classmethod
def arguments(cls, value):

return (
[],
{
name: getattr(value, name)
for name, info in value.model_fields.items()
if getattr(value, name) != info.default
},
)

def argument(self, value, pos_or_name):
assert isinstance(pos_or_name, str)
return getattr(value, pos_or_name)


class IsNamedTuple(ABC):
_inline_snapshot_name = "namedtuple"

_fields: tuple
_field_defaults: dict

@classmethod
def __subclasshook__(cls, t):
b = t.__bases__
if len(b) != 1 or b[0] != tuple:
return False
f = getattr(t, "_fields", None)
if not isinstance(f, tuple):
return False
return all(type(n) == str for n in f)


class NamedTupleContainer(DataclassAdapter):

@classmethod
def check_type(cls, value):
return issubclass(value, IsNamedTuple)

@classmethod
def arguments(cls, value: IsNamedTuple):

return (
[],
{
field: getattr(value, field)
for field in value._fields
if field not in value._field_defaults
or getattr(value, field) != value._field_defaults[field]
},
)

def argument(self, value, pos_or_name):
assert isinstance(pos_or_name, str)
return getattr(value, pos_or_name)


class DefaultDictContainer(DataclassAdapter):
@classmethod
def check_type(cls, value):
return issubclass(value, defaultdict)

@classmethod
def arguments(cls, value: defaultdict):

return ([value.default_factory, dict(value)], {})

def argument(self, value, pos_or_name):
assert isinstance(pos_or_name, int)
if pos_or_name == 0:
return value.default_factory
elif pos_or_name == 1:
return dict(value)
assert False
13 changes: 13 additions & 0 deletions src/inline_snapshot/_adapter/dict_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,19 @@

class DictAdapter(Adapter):

@classmethod
def repr(cls, value):
result = (
"{"
+ ", ".join(f"{repr(k)}: {repr(value)}" for k, value in value.items())
+ "}"
)

if type(value) is not dict:
result = f"{repr(type(value))}({result})"

return result

@classmethod
def map(cls, value, map_function):
return {k: adapter_map(v, map_function) for k, v in value.items()}
Expand Down
14 changes: 14 additions & 0 deletions src/inline_snapshot/_adapter/sequence_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,16 @@
class SequenceAdapter(Adapter):
node_type: type
value_type: type
braces: str
trailing_comma: bool

@classmethod
def repr(cls, value):
if len(value) == 1 and cls.trailing_comma:
seq = repr(value[0]) + ","
else:
seq = ", ".join(map(repr, value))
return cls.braces[0] + seq + cls.braces[1]

@classmethod
def map(cls, value, map_function):
Expand Down Expand Up @@ -91,8 +101,12 @@ def assign(self, old_value, old_node, new_value):
class ListAdapter(SequenceAdapter):
node_type = ast.List
value_type = list
braces = "[]"
trailing_comma = False


class TupleAdapter(SequenceAdapter):
node_type = ast.Tuple
value_type = tuple
braces = "()"
trailing_comma = True
5 changes: 5 additions & 0 deletions src/inline_snapshot/_adapter/value_adapter.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

from inline_snapshot._code_repr import value_code_repr
from inline_snapshot._unmanaged import Unmanaged
from inline_snapshot._unmanaged import update_allowed
from inline_snapshot._utils import value_to_token
Expand All @@ -10,6 +11,10 @@

class ValueAdapter(Adapter):

@classmethod
def repr(cls, value):
return value_code_repr(value)

@classmethod
def map(cls, value, map_function):
return map_function(value)
Expand Down
Loading

0 comments on commit b708c40

Please sign in to comment.