Skip to content

Commit

Permalink
Add drop expression support
Browse files Browse the repository at this point in the history
  • Loading branch information
djrobstep committed Sep 18, 2022
1 parent a08bbcc commit 6db03eb
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 1 deletion.
15 changes: 14 additions & 1 deletion schemainspect/inspected.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ def __init__(
is_identity_always=False,
is_generated=False,
is_inherited=False,
can_drop_generated=False,
):
self.name = name or ""
self.dbtype = dbtype
Expand All @@ -66,6 +67,7 @@ def __init__(
self.is_identity_always = is_identity_always
self.is_generated = is_generated
self.is_inherited = is_inherited
self.can_drop_generated = can_drop_generated

def __eq__(self, other):
return (
Expand Down Expand Up @@ -109,7 +111,7 @@ def alter_clauses(self, other):
)

if default_changed:
clauses.append(self.alter_default_clause)
clauses.append(self.alter_default_clause_or_generated(other))

if notnull_added:
clauses.append(self.alter_not_null_clause)
Expand Down Expand Up @@ -218,6 +220,17 @@ def alter_default_clause(self):
alter = "alter column {} drop default".format(self.quoted_name)
return alter

def alter_default_clause_or_generated(self, other):
if self.default:
alter = "alter column {} set default {}".format(
self.quoted_name, self.default
)
elif other.is_generated and not self.is_generated:
alter = "alter column {} drop expression".format(self.quoted_name)
else:
alter = "alter column {} drop default".format(self.quoted_name)
return alter

def alter_identity_clause(self, other):
if self.is_identity:
identity_type = "always" if self.is_identity_always else "by default"
Expand Down
1 change: 1 addition & 0 deletions schemainspect/pg/obj.py
Original file line number Diff line number Diff line change
Expand Up @@ -1418,6 +1418,7 @@ def get_enum(name, schema):
is_identity=c.is_identity,
is_identity_always=c.is_identity_always,
is_generated=c.is_generated,
can_drop_generated=self.pg_version >= 13,
)
for c in clist
if c.position_number
Expand Down

0 comments on commit 6db03eb

Please sign in to comment.