Skip to content

Commit

Permalink
feat; add enum, refactor models file
Browse files Browse the repository at this point in the history
  • Loading branch information
pquadri committed Sep 5, 2024
1 parent 2180416 commit f306692
Show file tree
Hide file tree
Showing 11 changed files with 186 additions and 146 deletions.
11 changes: 6 additions & 5 deletions snowflake_utils/__main__.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
import logging
import os

import typer
from typing_extensions import Annotated

from .models import FileFormat, InlineFileFormat, Table, Schema, Column
from .queries import connect
import logging
import os
from ..snowflake_utils.settings import SnowflakeSettings
from .models import Column, FileFormat, InlineFileFormat, Schema, Table

app = typer.Typer()

Expand Down Expand Up @@ -42,7 +43,7 @@ def mass_single_column_update(
new_column = Column(name=new_column, data_type=data_type)
log_level = os.getenv("LOG_LEVEL", "INFO")
logging.getLogger("snowflake-utils").setLevel(log_level)
with connect() as conn, conn.cursor() as cursor:
with SnowflakeSettings.connect() as conn, conn.cursor() as cursor:
tables = db_schema.get_tables(cursor=cursor)
for table in tables:
columns = table.get_columns(cursor=cursor)
Expand Down
17 changes: 17 additions & 0 deletions snowflake_utils/models/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
from .column import Column
from .enums import MatchByColumnName, TagLevel
from .file_format import FileFormat, InlineFileFormat
from .schema import Schema
from .table import Table
from .table_structure import TableStructure

__all__ = [
"Column",
"MatchByColumnName",
"TagLevel",
"Schema",
"Table",
"TableStructure",
"FileFormat",
"InlineFileFormat",
]
43 changes: 43 additions & 0 deletions snowflake_utils/models/column.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
from datetime import date, datetime

from pydantic import BaseModel, Field


class Column(BaseModel):
name: str
data_type: str
tags: dict[str, str] = Field(default_factory=dict)


def _possibly_cast(s: str, old_column_type: str, new_column_type: str) -> str:
if old_column_type == "VARIANT" and new_column_type != "VARIANT":
return f"PARSE_JSON({s})"
return s


def _matched(columns: list[Column], old_columns: dict[str, str]):
def tmp(x: str) -> str:
return f'tmp."{x}"'

return ",".join(
f'dest."{c.name}" = {_possibly_cast(tmp(c.name), old_columns.get(c.name), c.data_type)}'
for c in columns
)


def _inserts(columns: list[Column], old_columns: dict[str, str]) -> str:
return ",".join(
_possibly_cast(f'tmp."{c.name}"', old_columns.get(c.name), c.data_type)
for c in columns
)


def _type_cast(s: any) -> any:
if isinstance(s, (int, float)):
return str(s)
elif isinstance(s, str):
return f"'{s}'"
elif isinstance(s, (datetime, date)):
return f"'{s.isoformat()}'"
else:
return f"'{s}'"
12 changes: 12 additions & 0 deletions snowflake_utils/models/enums.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
from enum import Enum


class MatchByColumnName(Enum):
CASE_SENSITIVE = "CASE_SENSITIVE"
CASE_INSENSITIVE = "CASE_INSENSITIVE"
NONE = "NONE"


class TagLevel(Enum):
COLUMN = "column"
TABLE = "table"
31 changes: 31 additions & 0 deletions snowflake_utils/models/file_format.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
from typing import Self

from pydantic import BaseModel


class InlineFileFormat(BaseModel):
definition: str


class FileFormat(BaseModel):
database: str | None = None
schema_: str | None = None
name: str

def __str__(self) -> str:
return ".".join(
s for s in [self.database, self.schema_, self.name] if s is not None
)

@classmethod
def from_string(cls, s: str) -> Self:
s = s.split(".")
match s:
case [database, schema, name]:
return cls(database=database, schema_=schema, name=name)
case [schema, name]:
return cls(schema_=schema, name=name)
case [name]:
return cls(name=name)
case _:
raise ValueError("Cannot parse file format")
26 changes: 26 additions & 0 deletions snowflake_utils/models/schema.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
from pydantic import BaseModel
from snowflake.connector.cursor import SnowflakeCursor

from .table import Table


class Schema(BaseModel):
name: str
database: str | None = None

@property
def fully_qualified_name(self):
if self.database:
return f"{self.database}.{self.name}"
else:
return self.name

def get_tables(self, cursor: SnowflakeCursor):
cursor.execute(f"show tables in schema {self.fully_qualified_name};")
data = cursor.execute(
'select "name", "database_name", "schema_name" FROM TABLE(RESULT_SCAN(LAST_QUERY_ID()));'
).fetchall()
return [
Table(name=name, schema_=schema, database=database)
for (name, database, schema, *_) in data
]
143 changes: 11 additions & 132 deletions snowflake_utils/models.py → snowflake_utils/models/table.py
Original file line number Diff line number Diff line change
@@ -1,103 +1,16 @@
import logging
from collections import defaultdict
from datetime import date, datetime
from enum import Enum
from functools import partial

from pydantic import BaseModel, Field, field_validator
from pydantic import BaseModel, Field
from snowflake.connector.cursor import SnowflakeCursor
from typing_extensions import Self

from .queries import connect, execute_statement
from .settings import governance_settings


class MatchByColumnName(Enum):
CASE_SENSITIVE = "CASE_SENSITIVE"
CASE_INSENSITIVE = "CASE_INSENSITIVE"
NONE = "NONE"


class InlineFileFormat(BaseModel):
definition: str


class FileFormat(BaseModel):
database: str | None = None
schema_: str | None = None
name: str

def __str__(self) -> str:
return ".".join(
s for s in [self.database, self.schema_, self.name] if s is not None
)

@classmethod
def from_string(cls, s: str) -> Self:
s = s.split(".")
match s:
case [database, schema, name]:
return cls(database=database, schema_=schema, name=name)
case [schema, name]:
return cls(schema_=schema, name=name)
case [name]:
return cls(name=name)
case _:
raise ValueError("Cannot parse file format")


class Column(BaseModel):
name: str
data_type: str
tags: dict[str, str] = Field(default_factory=dict)


class TableStructure(BaseModel):
columns: dict = [str, Column]
tags: dict[str, str] = Field(default_factory=dict)

@property
def parsed_columns(self, replace_chars: bool = False) -> str:
if replace_chars:
return ", ".join(
f'"{str.upper(k).strip().replace("-","_")}" {v.data_type}'
for k, v in self.columns.items()
)
else:
return ", ".join(
f'"{str.upper(k).strip()}" {v.data_type}'
for k, v in self.columns.items()
)

def parse_from_json(self):
raise NotImplementedError("Not implemented yet")

@field_validator("columns")
@classmethod
def force_columns_to_casefold(cls, value) -> dict:
return {k.casefold(): v for k, v in value.items()}


class Schema(BaseModel):
name: str
database: str | None = None

@property
def fully_qualified_name(self):
if self.database:
return f"{self.database}.{self.name}"
else:
return self.name

def get_tables(self, cursor: SnowflakeCursor):
cursor.execute(f"show tables in schema {self.fully_qualified_name};")
data = cursor.execute(
'select "name", "database_name", "schema_name" FROM TABLE(RESULT_SCAN(LAST_QUERY_ID()));'
).fetchall()
return [
Table(name=name, schema_=schema, database=database)
for (name, database, schema, *_) in data
]
from ..queries import execute_statement
from ..settings import connect, governance_settings
from .column import Column, _inserts, _matched, _type_cast
from .enums import MatchByColumnName, TagLevel
from .file_format import FileFormat, InlineFileFormat
from .table_structure import TableStructure


class Table(BaseModel):
Expand Down Expand Up @@ -431,28 +344,28 @@ def single_column_update(
f"UPDATE {self.fqn} SET {target_column.name} = {new_column.name};"
)

def _current_tags(self, level: str) -> list[tuple[str, str, str]]:
def _current_tags(self, level: TagLevel) -> list[tuple[str, str, str]]:
with connect() as connection:
cursor = connection.cursor()
cursor.execute(
f"""select lower(column_name) as column_name, lower(tag_name) as tag_name, tag_value
from table(information_schema.tag_references_all_columns('{self.fqn}', 'table'))
where lower(level) = '{level}'
where lower(level) = '{level.value}'
"""
)
return cursor.fetchall()

def current_column_tags(self) -> dict[str, dict[str, str]]:
tags = defaultdict(dict)

for column_name, tag_name, tag_value in self._current_tags("column"):
for column_name, tag_name, tag_value in self._current_tags(TagLevel.COLUMN):
tags[column_name][tag_name] = tag_value
return tags

def current_table_tags(self) -> dict[str, str]:
return {
tag_name.casefold(): tag_value
for _, tag_name, tag_value in self._current_tags("table")
for _, tag_name, tag_value in self._current_tags(TagLevel.TABLE)
}

def sync_tags_table(self, cursor: SnowflakeCursor) -> None:
Expand Down Expand Up @@ -586,37 +499,3 @@ def setup_connection(
)

return _execute_statement


def _possibly_cast(s: str, old_column_type: str, new_column_type: str) -> str:
if old_column_type == "VARIANT" and new_column_type != "VARIANT":
return f"PARSE_JSON({s})"
return s


def _matched(columns: list[Column], old_columns: dict[str, str]):
def tmp(x: str) -> str:
return f'tmp."{x}"'

return ",".join(
f'dest."{c.name}" = {_possibly_cast(tmp(c.name), old_columns.get(c.name), c.data_type)}'
for c in columns
)


def _inserts(columns: list[Column], old_columns: dict[str, str]) -> str:
return ",".join(
_possibly_cast(f'tmp."{c.name}"', old_columns.get(c.name), c.data_type)
for c in columns
)


def _type_cast(s: any) -> any:
if isinstance(s, (int, float)):
return str(s)
elif isinstance(s, str):
return f"'{s}'"
elif isinstance(s, (datetime, date)):
return f"'{s.isoformat()}'"
else:
return f"'{s}'"
29 changes: 29 additions & 0 deletions snowflake_utils/models/table_structure.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
from pydantic import BaseModel, Field, field_validator

from .column import Column


class TableStructure(BaseModel):
columns: dict = [str, Column]
tags: dict[str, str] = Field(default_factory=dict)

@property
def parsed_columns(self, replace_chars: bool = False) -> str:
if replace_chars:
return ", ".join(
f'"{str.upper(k).strip().replace("-","_")}" {v.data_type}'
for k, v in self.columns.items()
)
else:
return ", ".join(
f'"{str.upper(k).strip()}" {v.data_type}'
for k, v in self.columns.items()
)

def parse_from_json(self):
raise NotImplementedError("Not implemented yet")

@field_validator("columns")
@classmethod
def force_columns_to_casefold(cls, value) -> dict:
return {k.casefold(): v for k, v in value.items()}
9 changes: 1 addition & 8 deletions snowflake_utils/queries.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,8 @@
import logging
from typing import no_type_check

from snowflake import connector

from .settings import SnowflakeSettings
import logging


def connect() -> connector.SnowflakeConnection:
settings = SnowflakeSettings()
return connector.connect(**settings.creds())


@no_type_check
def execute_statement(
Expand Down
Loading

0 comments on commit f306692

Please sign in to comment.