Skip to content

Commit

Permalink
feat: support for more datatypes
Browse files Browse the repository at this point in the history
  • Loading branch information
15r10nk committed Jul 5, 2024
1 parent c656dce commit ebd28a3
Show file tree
Hide file tree
Showing 3 changed files with 185 additions and 5 deletions.
83 changes: 79 additions & 4 deletions inline_snapshot/_code_repr.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import ast
import dataclasses
from abc import ABC
from collections import defaultdict
from dataclasses import fields
from dataclasses import is_dataclass
from enum import Enum
Expand Down Expand Up @@ -103,6 +105,49 @@ def _(v: list):
# -8<- [end:list]


class OnlyTuple(ABC):

@classmethod
def __subclasshook__(cls, t):
return t is tuple


@register_repr
def _(v: OnlyTuple):
assert isinstance(v, tuple)
if len(v) == 1:
return f"({repr(v[0])},)"
return "(" + ", ".join(map(repr, v)) + ")"


class IsNamedTuple(ABC):
_inline_snapshot_name = "namedtuple"

_fields: tuple
_field_defaults: dict

@classmethod
def __subclasshook__(cls, t):
b = t.__bases__
if len(b) != 1 or b[0] != tuple:
return False
f = getattr(t, "_fields", None)
if not isinstance(f, tuple):
return False
return all(type(n) == str for n in f)


@register_repr
def _(v: IsNamedTuple):
params = ", ".join(
f"{field}={repr(getattr(v,field))}"
for field in v._fields
if field not in v._field_defaults
or getattr(v, field) != v._field_defaults[field]
)
return f"{repr(type(v))}({params})"


@register_repr
def _(v: set):
if len(v) == 0:
Expand All @@ -111,9 +156,29 @@ def _(v: set):
return "{" + ", ".join(map(repr, v)) + "}"


@register_repr
def _(v: frozenset):
if len(v) == 0:
return "frozenset()"

return "frozenset({" + ", ".join(map(repr, v)) + "})"


@register_repr
def _(v: dict):
return "{" + ", ".join(f"{repr(k)}:{repr(value)}" for k, value in v.items()) + "}"
result = (
"{" + ", ".join(f"{repr(k)}: {repr(value)}" for k, value in v.items()) + "}"
)

if type(v) is not dict:
result = f"{repr(type(v))}({result})"

return result


@register_repr
def _(v: defaultdict):
return f"defaultdict({repr(v.default_factory)}, {repr(dict(v))})"


@register_repr
Expand All @@ -124,18 +189,28 @@ def _(v: type):
class IsDataclass(ABC):
_inline_snapshot_name = "dataclasses"

@staticmethod
def __subclasshook__(subclass):
@classmethod
def __subclasshook__(cls, subclass):
return is_dataclass(subclass)


@register_repr
def _(v: IsDataclass):
attrs = []
for field in fields(v): # type: ignore

if field.repr:
value = getattr(v, field.name)
attrs.append(f"{field.name} = {repr(value)}")

if dataclasses.MISSING is not field.default == value:
continue
if (
dataclasses.MISSING is not field.default_factory
and field.default_factory() == value
):
continue

attrs.append(f"{field.name}={repr(value)}")

return f"{repr(type(v))}({', '.join(attrs)})"

Expand Down
4 changes: 4 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -90,3 +90,7 @@ Funding = "https://github.com/sponsors/15r10nk"
Homepage = "https://15r10nk.github.io/inline-snapshot"
Issues = "https://github.com/15r10nk/inline-snapshots/issues"
Repository = "https://github.com/15r10nk/inline-snapshot/"

[tool.pyright]
venv = "test-3-12"
venvPath = ".nox"
103 changes: 102 additions & 1 deletion tests/test_code_repr.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,12 @@
import dataclasses
from dataclasses import dataclass

import pytest

from .example import Example
from inline_snapshot import HasRepr
from inline_snapshot import snapshot
from inline_snapshot._code_repr import code_repr


def test_enum(check_update):
Expand Down Expand Up @@ -121,7 +127,7 @@ class container:
bg: color=color.red
fg: color=color.blue
assert container(bg=color.red,fg=color.red) == snapshot(container(bg=color.red, fg=color.red))
assert container(bg=color.red,fg=color.red) == snapshot(container(fg=color.red))
"""
)
Expand Down Expand Up @@ -256,3 +262,98 @@ class Color(Enum):
}
),
).run_inline()


from collections import (
Counter,
OrderedDict,
UserDict,
UserList,
defaultdict,
namedtuple,
)
from typing import NamedTuple

A = namedtuple("A", "a,b", defaults=[0])
B = namedtuple("B", "a,b", defaults=[0, 0])


class C(NamedTuple):
a: int
b: int = 0
c: int = 0


@dataclass
class Dataclass:
a: int
b: int = dataclasses.field(default=0)
c: list = dataclasses.field(default_factory=lambda: [])


default_dict = defaultdict(list)
default_dict[5].append(2)
default_dict[3].append(1)


@pytest.mark.parametrize(
"d",
[
frozenset(["a"]),
frozenset(),
{"a"},
set(),
list(),
["a"],
{},
{1: "1"},
(),
(1,),
(1, 2, 3),
A(1, 2),
A(1),
A(0, 0),
B(),
B(b=5),
C(1),
C(1, 2),
C(a=1, c=2),
Dataclass(a=0, b=0, c=[]),
Dataclass(a=1, b=2, c=[3]),
default_dict,
OrderedDict({1: 2, 3: 4}),
UserDict({1: 2}),
UserList([1, 2]),
],
)
def test_datatypes(d):
code = code_repr(d)
print("repr: ", repr(d))
print("code_repr:", code)
assert d == eval(code)


def test_datatypes_explicit():
assert code_repr(C(a=1, c=2)) == snapshot("C(a=1, c=2)")
assert code_repr(B(b=5)) == snapshot("B(b=5)")
assert code_repr(B(b=0)) == snapshot("B()")

assert code_repr(Dataclass(a=0, b=0, c=[])) == snapshot("Dataclass(a=0)")
assert code_repr(Dataclass(a=1, b=2, c=[3])) == snapshot(
"Dataclass(a=1, b=2, c=[3])"
)
assert code_repr(Counter([1, 1, 1, 2])) == snapshot("Counter({1: 3, 2: 1})")

assert code_repr(default_dict) == snapshot("defaultdict(list, {5: [2], 3: [1]})")


def test_tuple():

class FakeTuple(tuple):
def __init__(self):
self._fields = 5

def __repr__(self):
return "FakeTuple()"

assert code_repr(FakeTuple()) == snapshot("FakeTuple()")

0 comments on commit ebd28a3

Please sign in to comment.