diff --git a/README.md b/README.md index c822801..2aab34f 100644 --- a/README.md +++ b/README.md @@ -35,7 +35,7 @@ class Address(CompositeType): country = models.CharField(max_length=50) class Meta: - db_type = 'x_address' # Required + db_table = 'x_address' # Required class Person(models.Model): @@ -76,7 +76,7 @@ class Card(CompositeType): rank = models.CharField(max_length=2) class Meta: - db_type = 'card' + db_table = 'card' class Hand(models.Model): @@ -94,13 +94,13 @@ class Point(CompositeType): y = models.IntegerField() class Meta: - db_type = 'x_point' # Postgres already has a point type + db_table = 'x_point' # Postgres already has a point type class Box(CompositeType): """An axis-aligned box on the cartesian plane.""" class Meta: - db_type = 'x_box' # Postgres already has a box type + db_table = 'x_box' # Postgres already has a box type top_left = Point.Field() bottom_right = Point.Field() diff --git a/postgres_composite_types/caster.py b/postgres_composite_types/caster.py index b3e714b..3080cbe 100644 --- a/postgres_composite_types/caster.py +++ b/postgres_composite_types/caster.py @@ -1,5 +1,10 @@ +from typing import TYPE_CHECKING, Type + from psycopg2.extras import CompositeCaster +if TYPE_CHECKING: + from .composite_type import CompositeType + __all__ = ["BaseCaster"] @@ -9,7 +14,7 @@ class BaseCaster(CompositeCaster): instance. """ - Meta = None + _composite_type_model: Type["CompositeType"] def make(self, values): - return self.Meta.model(*values) + return self._composite_type_model(*values) diff --git a/postgres_composite_types/composite_type.py b/postgres_composite_types/composite_type.py index 22e418e..340e7e4 100644 --- a/postgres_composite_types/composite_type.py +++ b/postgres_composite_types/composite_type.py @@ -1,19 +1,17 @@ -import inspect import logging -import sys +from typing import Type -from django.db import models -from django.db.backends.postgresql.base import ( - DatabaseWrapper as PostgresDatabaseWrapper, -) +from django.db import connections, models from django.db.backends.signals import connection_created +from django.db.models.base import ModelBase +from django.db.models.manager import EmptyManager +from django.db.models.signals import post_migrate from psycopg2 import ProgrammingError from psycopg2.extensions import ISQLQuote, register_adapter -from psycopg2.extras import CompositeCaster, register_composite +from psycopg2.extras import register_composite from .caster import BaseCaster -from .fields import BaseField -from .operations import BaseOperation +from .fields import BaseField, DummyField from .quoting import QuotedCompositeType LOGGER = logging.getLogger(__name__) @@ -21,13 +19,7 @@ __all__ = ["CompositeType"] -def _add_class_to_module(cls, module_name): - cls.__module__ = module_name - module = sys.modules[module_name] - setattr(module, cls.__name__, cls) - - -class CompositeTypeMeta(type): +class CompositeTypeMeta(ModelBase): """Metaclass for Type.""" @classmethod @@ -40,6 +32,7 @@ def __prepare__(cls, name, bases): """ return {} + # pylint:disable=arguments-differ def __new__(cls, name, bases, attrs): # Only apply the metaclass to our subclasses if name == "CompositeType": @@ -52,94 +45,62 @@ def __new__(cls, name, bases, attrs): raise TypeError("Composite types cannot contain " "related fields") if isinstance(value, models.Field): - field = attrs.pop(field_name) + field = attrs[field_name] field.set_attributes_from_name(field_name) fields.append((field_name, field)) # retrieve the Meta from our declaration try: - meta_obj = attrs.pop("Meta") + meta_obj = attrs["Meta"] except KeyError as exc: raise TypeError(f'{name} has no "Meta" class') from exc try: - meta_obj.db_type + meta_obj.db_table except AttributeError as exc: - raise TypeError(f"{name}.Meta.db_type is required.") from exc - - meta_obj.fields = fields + raise TypeError(f"{name}.Meta.db_table is required.") from exc # create the field for this Type - attrs["Field"] = type(f"{name}Field", (BaseField,), {"Meta": meta_obj}) - - # add field class to the module in which the composite type class lives - # this is required for migrations to work - _add_class_to_module(attrs["Field"], attrs["__module__"]) - - # create the database operation for this type - attrs["Operation"] = type( - f"Create{name}Type", (BaseOperation,), {"Meta": meta_obj} - ) - - # create the caster for this type - attrs["Caster"] = type(f"{name}Caster", (BaseCaster,), {"Meta": meta_obj}) + attrs["Field"] = type(f"{name}.Field", (BaseField,), {}) - new_cls = super().__new__(cls, name, bases, attrs) - new_cls._meta = meta_obj + attrs[DummyField.name] = DummyField(primary_key=True, serialize=False) - meta_obj.model = new_cls + # Use an EmptyManager for everything as types cannot be queried. + meta_obj.default_manager_name = "objects" + meta_obj.base_manager_name = "objects" + attrs["objects"] = EmptyManager(model=None) # type: ignore - return new_cls + ret = super().__new__(cls, name, bases, attrs) + ret.Field._composite_type_model = ret # type: ignore + return ret def __init__(cls, name, bases, attrs): super().__init__(name, bases, attrs) if name == "CompositeType": return - cls._capture_descriptors() # pylint:disable=no-value-for-parameter + # pylint:disable=no-value-for-parameter + cls._connect_signals() - # Register the type on the first database connection - connection_created.connect( - receiver=cls.database_connected, dispatch_uid=cls._meta.db_type - ) - - def _capture_descriptors(cls): - """Work around for not being able to call contribute_to_class. - - Too much code to fake in our meta objects etc to be able to call - contribute_to_class directly, but we still want fields to be able - to set custom type descriptors. So we fake a model instead, with the - same fields as the composite type, and extract any custom descriptors - on that. + def _on_signal_register_type(cls, signal, sender, connection=None, **kwargs): """ + Attempt registering the type after a migration succeeds. + """ + from django.db.backends.postgresql.base import DatabaseWrapper - attrs = dict(cls._meta.fields) - - # we need to build a unique app label and model name combination for - # every composite type so django doesn't complain about model reloads - class Meta: - app_label = cls.__module__ - - attrs["__module__"] = cls.__module__ - attrs["Meta"] = Meta - model_name = f"_Fake{cls.__name__}Model" + if connection is None: + connection = connections["default"] - fake_model = type(model_name, (models.Model,), attrs) - for field_name, _ in cls._meta.fields: - attr = getattr(fake_model, field_name) - if inspect.isdatadescriptor(attr): - setattr(cls, field_name, attr) + if isinstance(connection, DatabaseWrapper): + # On-connect, register the QuotedCompositeType with psycopg2. + # This is what to do when the type is going in to the database + register_adapter(cls, QuotedCompositeType) - def database_connected(cls, signal, sender, connection, **kwargs): - """ - Register this type with the database the first time a connection is - made. - """ - if isinstance(connection, PostgresDatabaseWrapper): - # Try to register the type. If the type has not been created in a - # migration, the registration will fail. The type will be + # Now try to register the type. If the type has not been created + # in a migration, the registration will fail. The type will be # registered as part of the migration, so hopefully the migration # will run soon. + try: cls.register_composite(connection) except ProgrammingError as exc: @@ -150,11 +111,35 @@ def database_connected(cls, signal, sender, connection, **kwargs): cls.__name__, exc, ) + else: + # Registration succeeded.Disconnect the signals now. + cls._disconnect_signals() # pylint:disable=no-value-for-parameter + + def _connect_signals(cls): + type_id = cls._meta.db_table - # Disconnect the signal now - only need to register types on the - # initial connection + # Register the type on the first database connection + connection_created.connect( + receiver=cls._on_signal_register_type, dispatch_uid=f"connect:{type_id}" + ) + + # Also register on post-migrate. + # This ensures that, if the on-connect signal failed due to a migration + # not having run yet, running the migration will still register it, + # even if in the same session (this can happen in tests for example). + # dispatch_uid needs to be distinct from the one on connection_created. + post_migrate.connect( + receiver=cls._on_signal_register_type, + dispatch_uid=f"post_migrate:{type_id}", + ) + + def _disconnect_signals(cls): + type_id = cls._meta.db_table connection_created.disconnect( - cls.database_connected, dispatch_uid=cls._meta.db_type + cls._on_signal_register_type, dispatch_uid=f"connect:{type_id}" + ) + post_migrate.disconnect( + cls._on_signal_register_type, dispatch_uid=f"post_migrate:{type_id}" ) @@ -163,22 +148,21 @@ class CompositeType(metaclass=CompositeTypeMeta): A new composite type stored in Postgres. """ - _meta = None - # The database connection this type is registered with registered_connection = None + _meta: Type def __init__(self, *args, **kwargs): if args and kwargs: raise RuntimeError("Specify either args or kwargs but not both.") - # Initialise blank values for anyone expecting them - for name, _ in self._meta.fields: - setattr(self, name, None) + fields = self.get_fields() + for field in fields: + setattr(self, field.name, None) # Unpack any args as if they came from the type - for (name, _), arg in zip(self._meta.fields, args): - setattr(self, name, arg) + for field, arg in zip(fields, args): + setattr(self, field.name, arg) for name, value in kwargs.items(): setattr(self, name, value) @@ -189,14 +173,14 @@ def __repr__(self): def __to_tuple__(self): return tuple( - field.get_prep_value(getattr(self, name)) - for name, field in self._meta.fields + field.get_prep_value(getattr(self, field.name)) + for field in self.get_fields() ) def __to_dict__(self): return { - name: field.get_prep_value(getattr(self, name)) - for name, field in self._meta.fields + field.name: field.get_prep_value(getattr(self, field.name)) + for field in self.get_fields() } def __eq__(self, other): @@ -204,8 +188,8 @@ def __eq__(self, other): return False if self._meta.model != other._meta.model: return False - for name, _ in self._meta.fields: - if getattr(self, name) != getattr(other, name): + for field in self.get_fields(): + if getattr(self, field.name) != getattr(other, field.name): return False return True @@ -227,11 +211,10 @@ def register_composite(cls, connection): with connection.temporary_connection() as cur: # This is what to do when the type is coming out of the database - register_composite( - cls._meta.db_type, cur, globally=True, factory=cls.Caster - ) - # This is what to do when the type is going in to the database - register_adapter(cls, QuotedCompositeType) + # We create a custom class subclassing BaseCaster (see caster.py), + # and set _composite_type_model attribute accordingly. + caster = type("Caster", (BaseCaster,), {"_composite_type_model": cls}) + register_composite(cls._meta.db_table, cur, globally=True, factory=caster) def __conform__(self, protocol): """ @@ -251,12 +234,13 @@ class Field(BaseField): Placeholder for the field that will be produced for this type. """ - class Operation(BaseOperation): - """ - Placeholder for the DB operation that will be produced for this type. - """ + # pylint:disable=invalid-name + def _get_next_or_previous_by_FIELD(self): + pass - class Caster(CompositeCaster): - """ - Placeholder for the caster that will be produced for this type - """ + @classmethod + def check(cls, **kwargs): + return [] + + def get_fields(self): + return self._meta.fields diff --git a/postgres_composite_types/fields.py b/postgres_composite_types/fields.py index 1529459..272d5d6 100644 --- a/postgres_composite_types/fields.py +++ b/postgres_composite_types/fields.py @@ -1,4 +1,5 @@ import json +from typing import TYPE_CHECKING, Type from django.core.exceptions import ValidationError from django.db.backends.postgresql.base import ( @@ -6,23 +7,45 @@ ) from django.db.models import Field +if TYPE_CHECKING: + from .composite_type import CompositeType + + __all__ = ["BaseField"] +class DummyField(Field): + """ + A dummy field added on every CompositeType, that behaves as the + type's primary key. This is a hack due to Django's requirement for + all models to have a primary key. + """ + + name = "_id_not_used" + + class BaseField(Field): """Base class for the field that relates to this type.""" - Meta = None + _composite_type_model: Type["CompositeType"] default_error_messages = { "bad_json": "to_python() received a string that was not valid JSON", } + def deconstruct(self): + name, path, args, kwargs = super().deconstruct() + + name = path.replace("postgres_composite_types.composite_type.", "") + path = self._composite_type_model.__module__ + "." + name + + return name, path, args, kwargs + def db_type(self, connection): if not isinstance(connection, PostgresDatabaseWrapper): raise RuntimeError("Composite types are only available for postgres") - return self.Meta.db_type + return self._composite_type_model._meta.db_table def formfield(self, **kwargs): # pylint:disable=arguments-differ """Form field for address.""" @@ -30,7 +53,7 @@ def formfield(self, **kwargs): # pylint:disable=arguments-differ defaults = { "form_class": CompositeTypeField, - "model": self.Meta.model, + "model": self._composite_type_model, } defaults.update(kwargs) @@ -55,10 +78,11 @@ def to_python(self, value): code="bad_json", ) from exc - return self.Meta.model( + return self._composite_type_model( **{ - name: field.to_python(value.get(name)) - for name, field in self.Meta.fields + field.name: field.to_python(value.get(field.name)) + for field in self._composite_type_model._meta.fields + if field.name != DummyField.name } ) @@ -71,5 +95,9 @@ def value_to_string(self, obj): """ value = self.value_from_object(obj) return json.dumps( - {name: field.value_to_string(value) for name, field in self.Meta.fields} + { + field.name: field.value_to_string(value) + for field in self._composite_type_model._meta.fields + if field.name != DummyField.name + } ) diff --git a/postgres_composite_types/forms.py b/postgres_composite_types/forms.py index 7966fa1..d4ae85c 100644 --- a/postgres_composite_types/forms.py +++ b/postgres_composite_types/forms.py @@ -12,6 +12,7 @@ from django.utils.translation import gettext as _ from . import CompositeType +from .fields import DummyField LOGGER = logging.getLogger(__name__) @@ -61,13 +62,16 @@ class CompositeTypeField(forms.Field): } def __init__(self, *args, fields=None, model=None, **kwargs): - if fields is None: - fields = {name: field.formfield() for name, field in model._meta.fields} - else: - fields = dict(fields) + fields = { + field.name: field.formfield() + for field in fields or model._meta.fields + if field.name != DummyField.name + } widget = CompositeTypeWidget( - widgets=[(name, field.widget) for name, field in fields.items()] + widgets=[ + (name, getattr(field, "widget", None)) for name, field in fields.items() + ] ) super().__init__(*args, widget=widget, **kwargs) @@ -75,7 +79,8 @@ def __init__(self, *args, fields=None, model=None, **kwargs): self.model = model for field, widget in zip(fields.values(), self.widget.widgets.values()): - widget.attrs["placeholder"] = field.label + if widget: + widget.attrs["placeholder"] = getattr(field, "label", "") def prepare_value(self, value): """ diff --git a/postgres_composite_types/operations.py b/postgres_composite_types/operations.py index 5de9590..a73309a 100644 --- a/postgres_composite_types/operations.py +++ b/postgres_composite_types/operations.py @@ -1,20 +1,22 @@ -from django.db.migrations.operations.base import Operation +from django.db.migrations import CreateModel +from django.db.migrations.state import ModelState -from .signals import composite_type_created +from .fields import DummyField -__all__ = ["BaseOperation"] +__all__ = ["CreateType"] def sql_field_definition(field_name, field, schema_editor): quoted_name = schema_editor.quote_name(field_name) - db_type = field.db_type(schema_editor.connection) - return f"{quoted_name} {db_type}" + type_name = field.db_type(schema_editor.connection) + return f"{quoted_name} {type_name}" def sql_create_type(type_name, fields, schema_editor): fields_list = ", ".join( sql_field_definition(field_name, field, schema_editor) for field_name, field in fields + if field_name != DummyField.name ) quoted_name = schema_editor.quote_name(type_name) return f"CREATE TYPE {quoted_name} AS ({fields_list})" @@ -25,28 +27,32 @@ def sql_drop_type(type_name, schema_editor): return f"DROP TYPE {quoted_name}" -class BaseOperation(Operation): +class CreateType(CreateModel): """Base class for the DB operation that relates to this type.""" reversible = True - Meta = None - def state_forwards(self, app_label, state): - pass + def __init__(self, *, name: str, fields, options) -> None: + fields = [ + (DummyField.name, DummyField(primary_key=True, serialize=False)), + *fields, + ] + super().__init__(name, fields, options) def describe(self): - return f"Creates type {self.Meta.db_type}" + return f"Creates type {self.name}" + + def state_forwards(self, app_label, state) -> None: + state.add_model( + ModelState(app_label, self.name, list(self.fields), dict(self.options)) + ) def database_forwards(self, app_label, schema_editor, from_state, to_state): schema_editor.execute( - sql_create_type(self.Meta.db_type, self.Meta.fields, schema_editor) - ) - self.Meta.model.register_composite(schema_editor.connection) - composite_type_created.send( - self.Meta.model, connection=schema_editor.connection + sql_create_type(self.options["db_table"], self.fields, schema_editor) ) def database_backwards(self, app_label, schema_editor, from_state, to_state): schema_editor.execute( - sql_drop_type(self.Meta.db_type, schema_editor=schema_editor) + sql_drop_type(self.options["db_table"], schema_editor=schema_editor) ) diff --git a/postgres_composite_types/quoting.py b/postgres_composite_types/quoting.py index 7f5f291..8571564 100644 --- a/postgres_composite_types/quoting.py +++ b/postgres_composite_types/quoting.py @@ -1,5 +1,7 @@ from psycopg2.extensions import ISQLQuote, adapt +from .fields import DummyField + __all__ = ["QuotedCompositeType"] @@ -19,9 +21,11 @@ def __init__(self, obj): self.value = adapt( tuple( field.get_db_prep_value( - field.value_from_object(self.obj), self.model.registered_connection + field.value_from_object(self.obj), + self.model.registered_connection, ) - for _, field in self.model._meta.fields + for field in self.model._meta.fields + if field.name != DummyField.name ) ) @@ -58,5 +62,5 @@ def getquoted(self): f"{name}.prepare() must be called before {name}.getquoted()" ) - db_type = self.model._meta.db_type.encode("ascii") + db_type = self.model._meta.db_table.encode("ascii") return self.value.getquoted() + b"::" + db_type diff --git a/tests/migrations/0001_initial.py b/tests/migrations/0001_initial.py index a3ed8eb..49baf84 100644 --- a/tests/migrations/0001_initial.py +++ b/tests/migrations/0001_initial.py @@ -1,31 +1,199 @@ -""" -Migration to create custom types -""" +# Generated by Django 3.2.16 on 2022-12-11 05:38 -from django.db import migrations +import django.contrib.postgres.fields +from django.db import migrations, models -from ..models import ( - Box, - Card, - DateRange, - DescriptorType, - OptionalBits, - Point, - SimpleType, -) +import postgres_composite_types.fields +import postgres_composite_types.operations +import tests.fields +import tests.models class Migration(migrations.Migration): - """Migration.""" + + initial = True dependencies = [] operations = [ - SimpleType.Operation(), - OptionalBits.Operation(), - Card.Operation(), - Point.Operation(), - Box.Operation(), - DateRange.Operation(), - DescriptorType.Operation(), + postgres_composite_types.operations.CreateType( + name="Point", + fields=[ + ("x", models.IntegerField()), + ("y", models.IntegerField()), + ], + options={ + "db_table": "test_point", + }, + ), + postgres_composite_types.operations.CreateType( + name="Box", + fields=[ + ("top_left", tests.models.Point.Field()), + ("bottom_right", tests.models.Point.Field()), + ], + options={ + "db_table": "test_box", + }, + ), + postgres_composite_types.operations.CreateType( + name="Card", + fields=[ + ("suit", models.CharField(max_length=1)), + ("rank", models.CharField(max_length=2)), + ], + options={ + "db_table": "card", + }, + ), + postgres_composite_types.operations.CreateType( + name="DateRange", + fields=[ + ("start", models.DateTimeField()), + ("end", models.DateTimeField()), + ], + options={ + "db_table": "test_date_range", + }, + ), + postgres_composite_types.operations.CreateType( + name="DescriptorType", + fields=[ + ("value", tests.fields.TriplingIntegerField()), + ], + options={ + "db_table": "test_custom_descriptor", + }, + ), + migrations.CreateModel( + name="DescriptorModel", + fields=[ + ( + "id", + models.BigAutoField( + auto_created=True, + primary_key=True, + serialize=False, + verbose_name="ID", + ), + ), + ( + "field", + tests.models.DescriptorType.Field(), + ), + ], + ), + migrations.CreateModel( + name="Hand", + fields=[ + ( + "id", + models.BigAutoField( + auto_created=True, + primary_key=True, + serialize=False, + verbose_name="ID", + ), + ), + ( + "cards", + django.contrib.postgres.fields.ArrayField( + base_field=tests.models.Card.Field(), + size=None, + ), + ), + ], + ), + migrations.CreateModel( + name="Item", + fields=[ + ( + "id", + models.BigAutoField( + auto_created=True, + primary_key=True, + serialize=False, + verbose_name="ID", + ), + ), + ("name", models.CharField(max_length=20)), + ("bounding_box", tests.models.Box.Field()), + ], + ), + migrations.CreateModel( + name="NamedDateRange", + fields=[ + ( + "id", + models.BigAutoField( + auto_created=True, + primary_key=True, + serialize=False, + verbose_name="ID", + ), + ), + ("name", models.TextField()), + ( + "date_range", + tests.models.DateRange.Field(), + ), + ], + ), + postgres_composite_types.operations.CreateType( + name="OptionalBits", + fields=[ + ("required", models.CharField(max_length=32)), + ("optional", models.CharField(blank=True, max_length=32, null=True)), + ], + options={ + "db_table": "optional_type", + }, + ), + migrations.CreateModel( + name="OptionalModel", + fields=[ + ( + "id", + models.BigAutoField( + auto_created=True, + primary_key=True, + serialize=False, + verbose_name="ID", + ), + ), + ( + "optional_field", + tests.models.OptionalBits.Field(blank=True, null=True), + ), + ], + ), + postgres_composite_types.operations.CreateType( + name="SimpleType", + fields=[ + ("a", models.IntegerField(verbose_name="A number")), + ("b", models.CharField(max_length=32, verbose_name="A name")), + ("c", models.DateTimeField(verbose_name="A date")), + ], + options={ + "db_table": "test_type", + }, + ), + migrations.CreateModel( + name="SimpleModel", + fields=[ + ( + "id", + models.BigAutoField( + auto_created=True, + primary_key=True, + serialize=False, + verbose_name="ID", + ), + ), + ( + "test_field", + tests.models.SimpleType.Field(), + ), + ], + ), ] diff --git a/tests/migrations/0002_models.py b/tests/migrations/0002_models.py deleted file mode 100644 index ac89a9d..0000000 --- a/tests/migrations/0002_models.py +++ /dev/null @@ -1,118 +0,0 @@ -# Generated by Django 2.0.2 on 2018-03-04 23:12 - -import django.contrib.postgres.fields -from django.db import migrations, models - -import tests.models - - -class Migration(migrations.Migration): - - initial = True - - dependencies = [ - ("tests", "0001_initial"), - ] - - operations = [ - migrations.CreateModel( - name="DescriptorModel", - fields=[ - ( - "id", - models.BigAutoField( - auto_created=True, - primary_key=True, - serialize=False, - verbose_name="ID", - ), - ), - ("field", tests.models.DescriptorTypeField()), - ], - ), - migrations.CreateModel( - name="Hand", - fields=[ - ( - "id", - models.BigAutoField( - auto_created=True, - primary_key=True, - serialize=False, - verbose_name="ID", - ), - ), - ( - "cards", - django.contrib.postgres.fields.ArrayField( - base_field=tests.models.CardField(), size=None - ), - ), - ], - ), - migrations.CreateModel( - name="Item", - fields=[ - ( - "id", - models.BigAutoField( - auto_created=True, - primary_key=True, - serialize=False, - verbose_name="ID", - ), - ), - ("name", models.CharField(max_length=20)), - ("bounding_box", tests.models.BoxField()), - ], - ), - migrations.CreateModel( - name="NamedDateRange", - fields=[ - ( - "id", - models.BigAutoField( - auto_created=True, - primary_key=True, - serialize=False, - verbose_name="ID", - ), - ), - ("name", models.TextField()), - ("date_range", tests.models.DateRangeField()), - ], - ), - migrations.CreateModel( - name="OptionalModel", - fields=[ - ( - "id", - models.BigAutoField( - auto_created=True, - primary_key=True, - serialize=False, - verbose_name="ID", - ), - ), - ( - "optional_field", - tests.models.OptionalBitsField(blank=True, null=True), - ), - ], - ), - migrations.CreateModel( - name="SimpleModel", - fields=[ - ( - "id", - models.BigAutoField( - auto_created=True, - primary_key=True, - serialize=False, - verbose_name="ID", - ), - ), - ("test_field", tests.models.SimpleTypeField()), - ], - ), - ] diff --git a/tests/models.py b/tests/models.py index 8b26992..ae088d9 100644 --- a/tests/models.py +++ b/tests/models.py @@ -11,7 +11,7 @@ class SimpleType(CompositeType): """A test type.""" class Meta: - db_type = "test_type" + db_table = "test_type" a = models.IntegerField(verbose_name="A number") b = models.CharField(verbose_name="A name", max_length=32) @@ -31,7 +31,7 @@ class OptionalBits(CompositeType): optional = models.CharField(max_length=32, null=True, blank=True) class Meta: - db_type = "optional_type" + db_table = "optional_type" class OptionalModel(models.Model): @@ -44,7 +44,7 @@ class Card(CompositeType): """A playing card.""" class Meta: - db_type = "card" + db_table = "card" suit = models.CharField(max_length=1) rank = models.CharField(max_length=2) @@ -60,7 +60,7 @@ class Point(CompositeType): """A point on the cartesian plane.""" class Meta: - db_type = "test_point" # Postgres already has a point type + db_table = "test_point" # Postgres already has a point type x = models.IntegerField() y = models.IntegerField() @@ -70,7 +70,7 @@ class Box(CompositeType): """An axis-aligned box on the cartesian plane.""" class Meta: - db_type = "test_box" # Postgres already has a box type + db_table = "test_box" # Postgres already has a box type top_left = Point.Field() bottom_right = Point.Field() @@ -97,7 +97,7 @@ class DateRange(CompositeType): """A date range with start and end.""" class Meta: - db_type = "test_date_range" + db_table = "test_date_range" start = models.DateTimeField() end = models.DateTimeField() # uses reserved keyword @@ -114,7 +114,7 @@ class DescriptorType(CompositeType): """Has a field implementing a custom descriptor""" class Meta: - db_type = "test_custom_descriptor" + db_table = "test_custom_descriptor" value = TriplingIntegerField() diff --git a/tests/test_field.py b/tests/test_field.py index 4ee6033..aa067b1 100644 --- a/tests/test_field.py +++ b/tests/test_field.py @@ -2,7 +2,6 @@ import datetime import json -from unittest import mock from django.core import serializers from django.core.exceptions import ValidationError @@ -11,8 +10,6 @@ from django.test import TestCase, TransactionTestCase from psycopg2.extensions import adapt -from postgres_composite_types.signals import composite_type_created - from .models import ( Box, DateRange, @@ -70,38 +67,20 @@ def test_migration(self): # The migrations have already been run, and the type already exists in # the database - self.assertTrue(does_type_exist(SimpleType._meta.db_type)) + self.assertTrue(does_type_exist(SimpleType._meta.db_table)) # Run the migration backwards to check the type is deleted migrate(self.migrate_from) # The type should now not exist - self.assertFalse(does_type_exist(SimpleType._meta.db_type)) - - # A signal is fired when the migration creates the type - signal_func = mock.Mock() - composite_type_created.connect(receiver=signal_func, sender=SimpleType) + self.assertFalse(does_type_exist(SimpleType._meta.db_table)) # Run the migration forwards to create the type again migrate(self.migrate_to) - self.assertTrue(does_type_exist(SimpleType._meta.db_type)) - - # The signal should have been sent - self.assertEqual(signal_func.call_count, 1) - self.assertEqual( - signal_func.call_args, - ( - (), - { - "sender": SimpleType, - "signal": composite_type_created, - "connection": connection, - }, - ), - ) + self.assertTrue(does_type_exist(SimpleType._meta.db_table)) # The type should now exist again - self.assertTrue(does_type_exist(SimpleType._meta.db_type)) + self.assertTrue(does_type_exist(SimpleType._meta.db_table)) def test_migration_quoting(self): """Test that migration SQL is generated with correct quoting""" @@ -109,7 +88,7 @@ def test_migration_quoting(self): # The migrations have already been run, and the type already exists in # the database migrate(self.migrate_to) - self.assertTrue(does_type_exist(DateRange._meta.db_type)) + self.assertTrue(does_type_exist(DateRange._meta.db_table)) class FieldTests(TestCase):