Skip to content

Commit

Permalink
feat: add support for setting tags
Browse files Browse the repository at this point in the history
  • Loading branch information
pquadri committed Jul 15, 2024
1 parent 717b9e0 commit 71e6dc6
Show file tree
Hide file tree
Showing 6 changed files with 3,461 additions and 1,161 deletions.
130 changes: 98 additions & 32 deletions snowflake_utils/models.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,15 @@
import logging
from collections import defaultdict
from datetime import date, datetime
from enum import Enum
from functools import partial
from pydantic import BaseModel

from pydantic import BaseModel, Field
from snowflake.connector.cursor import SnowflakeCursor
from typing_extensions import Self
from datetime import datetime, date
from .queries import connect, execute_statement
import logging

logging.getLogger("snowflake.connector").setLevel(logging.WARNING)
from .queries import connect, execute_statement
from .settings import governance_settings


class MatchByColumnName(Enum):
Expand Down Expand Up @@ -44,25 +46,26 @@ def from_string(cls, s: str) -> Self:
raise ValueError("Cannot parse file format")


class Column(BaseModel):
name: str
data_type: str
tags: dict[str, str] = Field(default_factory=dict)


class TableStructure(BaseModel):
columns: dict = {}
columns: dict = [str, Column]

@property
def parsed_columns(self):
return ", ".join(
f'"{str.upper(k).strip().replace("-","_")}" {v}'
f'"{str.upper(k).strip().replace("-","_")}" {v.data_type}'
for k, v in self.columns.items()
)

def parse_from_json(self):
raise NotImplementedError("Not implemented yet")


class Column(BaseModel):
name: str
data_type: str


class Schema(BaseModel):
name: str
database: str | None = None
Expand Down Expand Up @@ -92,6 +95,12 @@ class Table(BaseModel):
role: str | None = None
database: str | None = None

@property
def fqn(self):
if database := self.database:
return f"{database}.{self.schema_}.{self.name}"
return f"{self.schema_}.{self.name}"

@property
def temporary_stage(self):
return f"tmp_external_stage_{self.schema_}_{self.name}".upper()
Expand Down Expand Up @@ -130,12 +139,12 @@ def get_create_table_statement(
self,
full_refresh: bool = False,
):
logging.debug(f"Creating table: {self.schema_}.{self.name}")
logging.debug(f"Creating table: {self.fqn}")
if self.table_structure:
return f"{'CREATE OR REPLACE TABLE' if full_refresh else 'CREATE TABLE IF NOT EXISTS'} {self.schema_}.{self.name} ({self.table_structure.parsed_columns})"
return f"{'CREATE OR REPLACE TABLE' if full_refresh else 'CREATE TABLE IF NOT EXISTS'} {self.fqn} ({self.table_structure.parsed_columns})"
else:
return f"""
{'CREATE OR REPLACE TABLE' if full_refresh else 'CREATE TABLE IF NOT EXISTS'} {self.schema_}.{self.name}
{'CREATE OR REPLACE TABLE' if full_refresh else 'CREATE TABLE IF NOT EXISTS'} {self.fqn}
USING TEMPLATE (
SELECT ARRAY_AGG(OBJECT_CONSTRUCT(*))
FROM TABLE(
Expand Down Expand Up @@ -163,7 +172,7 @@ def bulk_insert(
vals = ", ".join([_type_cast(v) for v in records[k].values()])
_execute_statement(
f"""
INSERT INTO {self.schema_}.{self.name}({cols})
INSERT INTO {self.fqn}({cols})
VALUES ({vals})
"""
)
Expand All @@ -176,7 +185,8 @@ def copy_into(
storage_integration: str | None = None,
match_by_column_name: MatchByColumnName = MatchByColumnName.CASE_INSENSITIVE,
full_refresh: bool = False,
target_columns: list[str] = [],
target_columns: list[str] | None = None,
sync_tags: bool = False,
) -> None:
"""Copy files into Snowflake"""
with connect() as connection:
Expand Down Expand Up @@ -207,12 +217,15 @@ def copy_into(
file_format = self.temporary_file_format

_execute_statement(self.get_create_table_statement(full_refresh))
logging.info(
f"Starting copy into `{self.schema_}.{self.name}` from path '{path}'"
)

if sync_tags and self.table_structure:
self.sync_tags(cursor)

logging.info(f"Starting copy into `{self.fqn}` from path '{path}'")
col_str = f"({', '.join(target_columns)})" if target_columns else ""
return _execute_statement(
f"""
COPY INTO {self.schema_}.{self.name} {f"({', '.join(target_columns)})" if len(target_columns) > 0 else ''}
COPY INTO {self.fqn} {col_str}
FROM {path}
{f"STORAGE_INTEGRATION = {storage_integration}" if storage_integration else ''}
FILE_FORMAT = ( FORMAT_NAME ='{file_format}')
Expand All @@ -221,14 +234,14 @@ def copy_into(
)

def get_columns(self, cursor: SnowflakeCursor) -> list[Column]:
data = cursor.execute(f"desc table {self.schema_}.{self.name}").fetchall()
data = cursor.execute(f"desc table {self.fqn}").fetchall()
return [
Column(name=name, data_type=data_type) for (name, data_type, *_) in data
]

def add_column(self, cursor: SnowflakeCursor, column: Column) -> None:
cursor.execute(
f"alter table {self.schema_}.{self.name} add column {column.name} {column.data_type}"
f"alter table {self.fqn} add column {column.name} {column.data_type}"
)

def exists(self, cursor: SnowflakeCursor) -> bool:
Expand Down Expand Up @@ -256,6 +269,7 @@ def merge(
storage_integration=storage_integration,
file_format=file_format,
match_by_column_name=match_by_column_name,
sync_tags=True,
)
if qualify:
self.qualify(cursor, primary_keys, replication_keys)
Expand Down Expand Up @@ -289,6 +303,8 @@ def merge(
temp_table, new_columns, old_columns, primary_keys
)
)
if self.table_structure:
self.sync_tags(cursor)
temp_table.drop(cursor)

def qualify(
Expand All @@ -302,12 +318,12 @@ def qualify(
f"{c} desc" for c in (replication_keys or primary_keys)
)
logging.debug(
f"Adding QUALIFY to table {self.schema_}.{self.name} on PARTITION {qualify_partition} ORDERED BY {qualify_order}"
f"Adding QUALIFY to table {self.fqn} on PARTITION {qualify_partition} ORDERED BY {qualify_order}"
)
return cursor.execute(
f"""
create or replace table {self.schema_}.{self.name} as (
select * from {self.schema_}.{self.name}
create or replace table {self.fqn} as (
select * from {self.fqn}
qualify row_number() over (partition by {qualify_partition} order by {qualify_order}) = 1
)
"""
Expand All @@ -328,20 +344,22 @@ def _merge_statement(
inserts = _inserts(columns, old_columns)

logging.info(
f"Running merge statement on table: {self.schema_}.{self.name} using {temp_table.schema_}.{temp_table.name}"
f"Running merge statement on table: {self.fqn} using {temp_table.schema_}.{temp_table.name}"
)
logging.debug(f"Primary keys: {pkes}")
return f"""
merge into {self.schema_}.{self.name} as dest
merge into {self.fqn} as dest
using {temp_table.schema_}.{temp_table.name} tmp
ON {pkes}
when matched then update set {matched}
when not matched then insert ({column_names}) VALUES ({inserts})
"""

def drop(self, cursor: SnowflakeCursor) -> None:
logging.debug(f"Dropping table:{self.schema_}.{self.name}")
cursor.execute(f"drop table {self.schema_}.{self.name}")
def drop(self, cursor: SnowflakeCursor | None = None) -> None:
if cursor is None:
cursor = connect().cursor()
logging.debug(f"Dropping table:{self.fqn}")
cursor.execute(f"drop table {self.fqn}")

def single_column_update(
self, cursor: SnowflakeCursor, target_column: Column, new_column: Column
Expand All @@ -351,7 +369,55 @@ def single_column_update(
f"Updating the value of {target_column.name} with {new_column.name} in the table {self.name}"
)
cursor.execute(
f"UPDATE {self.schema_}.{self.name} SET {target_column.name} = {new_column.name};"
f"UPDATE {self.fqn} SET {target_column.name} = {new_column.name};"
)

def current_tags(self) -> dict[str, dict[str, str]]:
tags = defaultdict(dict)
with connect() as connection:
cursor = connection.cursor()
cursor.execute(f"""select lower(column_name) as column_name, lower(tag_name) as tag_name, tag_value
from table(information_schema.tag_references_all_columns('{self.fqn}', 'table'))""")
for column_name, tag_name, tag_value in cursor.fetchall():
tags[column_name][tag_name] = tag_value
return tags

def sync_tags(self, cursor: SnowflakeCursor) -> None:
tags = self.current_tags()
existing_tags = {
f"{column}.{tag_name}.{tags[column][tag_name]}".casefold(): (
column,
tag_name,
tags[column][tag_name],
)
for column in tags
for tag_name in tags[column]
}

desired_tags = {
f"{column}.{tag_name}.{tag_value}".casefold(): (column, tag_name, tag_value)
for column in self.table_structure.columns
for tag_name, tag_value in self.table_structure.columns[column].tags.items()
}

for tag in existing_tags:
if tag not in desired_tags:
self._unset_tag(cursor, *existing_tags[tag])

for tag in desired_tags:
if tag not in existing_tags:
self._set_tag(cursor, *desired_tags[tag])

def _set_tag(
self, cursor: SnowflakeCursor, column: str, tag_name: str, tag_value: str
) -> None:
cursor.execute(
f"ALTER TABLE {self.fqn} MODIFY COLUMN {column} SET TAG {governance_settings.fqn(tag_name)} = '{tag_value}'"
)

def _unset_tag(self, cursor: SnowflakeCursor, column: str, tag: str):
cursor.execute(
f"ALTER TABLE {self.fqn} MODIFY COLUMN {column} UNSET TAG {governance_settings.fqn(tag)}"
)


Expand Down
11 changes: 11 additions & 0 deletions snowflake_utils/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,3 +22,14 @@ def creds(self) -> dict[str, str | None]:
"role": self.role,
"warehouse": self.warehouse,
}


class GovernanceSettings(BaseSettings):
governance_database: str = "governance"
governance_schema: str = "public"

def fqn(self, object_name: str) -> str:
return f"{self.governance_database}.{self.governance_schema}.{object_name}"


governance_settings = GovernanceSettings()
Loading

0 comments on commit 71e6dc6

Please sign in to comment.