Skip to content

Commit

Permalink
feat: add tags at table level (#12)
Browse files Browse the repository at this point in the history
* feat: add tags at column level

* chore: remove set

* feat; add enum, refactor models file

* fix: use retrocompatible typing_extensions

* feat: add settings.yml
  • Loading branch information
pquadri authored Sep 5, 2024
1 parent a841431 commit 1f1be19
Show file tree
Hide file tree
Showing 20 changed files with 9,228 additions and 4,544 deletions.
12 changes: 12 additions & 0 deletions .github/settings.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
branches:
- name: main
protection:
required_status_checks:
strict: true
contexts: []
required_pull_request_reviews:
dismiss_stale_reviews: true
require_code_owner_reviews: true
required_approving_review_count: 1
restrictions: null
enforce_admins: true
2 changes: 1 addition & 1 deletion .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ on:
tags-ignore:
- "*.*.*"

name: tests
name: ci

concurrency:
group: tests
Expand Down
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -158,3 +158,4 @@ cython_debug/
# and can be added to the global gitignore or merged into this file. For a more nuclear
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
#.idea/
.envrc
245 changes: 215 additions & 30 deletions poetry.lock

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ readme = "README.md"

[tool.poetry.dependencies]
python = ">=3.10,<4.0"
snowflake-connector-python = "^3.7.1"
snowflake-connector-python = {extras = ["secure-local-storage"], version = "^3.12.1"}
pydantic-settings = "^2.2.1"
typer = "^0.12.0"

Expand Down
11 changes: 6 additions & 5 deletions snowflake_utils/__main__.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
import logging
import os

import typer
from typing_extensions import Annotated

from .models import FileFormat, InlineFileFormat, Table, Schema, Column
from .queries import connect
import logging
import os
from ..snowflake_utils.settings import SnowflakeSettings
from .models import Column, FileFormat, InlineFileFormat, Schema, Table

app = typer.Typer()

Expand Down Expand Up @@ -42,7 +43,7 @@ def mass_single_column_update(
new_column = Column(name=new_column, data_type=data_type)
log_level = os.getenv("LOG_LEVEL", "INFO")
logging.getLogger("snowflake-utils").setLevel(log_level)
with connect() as conn, conn.cursor() as cursor:
with SnowflakeSettings.connect() as conn, conn.cursor() as cursor:
tables = db_schema.get_tables(cursor=cursor)
for table in tables:
columns = table.get_columns(cursor=cursor)
Expand Down
17 changes: 17 additions & 0 deletions snowflake_utils/models/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
from .column import Column
from .enums import MatchByColumnName, TagLevel
from .file_format import FileFormat, InlineFileFormat
from .schema import Schema
from .table import Table
from .table_structure import TableStructure

__all__ = [
"Column",
"MatchByColumnName",
"TagLevel",
"Schema",
"Table",
"TableStructure",
"FileFormat",
"InlineFileFormat",
]
43 changes: 43 additions & 0 deletions snowflake_utils/models/column.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
from datetime import date, datetime

from pydantic import BaseModel, Field


class Column(BaseModel):
name: str
data_type: str
tags: dict[str, str] = Field(default_factory=dict)


def _possibly_cast(s: str, old_column_type: str, new_column_type: str) -> str:
if old_column_type == "VARIANT" and new_column_type != "VARIANT":
return f"PARSE_JSON({s})"
return s


def _matched(columns: list[Column], old_columns: dict[str, str]):
def tmp(x: str) -> str:
return f'tmp."{x}"'

return ",".join(
f'dest."{c.name}" = {_possibly_cast(tmp(c.name), old_columns.get(c.name), c.data_type)}'
for c in columns
)


def _inserts(columns: list[Column], old_columns: dict[str, str]) -> str:
return ",".join(
_possibly_cast(f'tmp."{c.name}"', old_columns.get(c.name), c.data_type)
for c in columns
)


def _type_cast(s: any) -> any:
if isinstance(s, (int, float)):
return str(s)
elif isinstance(s, str):
return f"'{s}'"
elif isinstance(s, (datetime, date)):
return f"'{s.isoformat()}'"
else:
return f"'{s}'"
12 changes: 12 additions & 0 deletions snowflake_utils/models/enums.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
from enum import Enum


class MatchByColumnName(Enum):
CASE_SENSITIVE = "CASE_SENSITIVE"
CASE_INSENSITIVE = "CASE_INSENSITIVE"
NONE = "NONE"


class TagLevel(Enum):
COLUMN = "column"
TABLE = "table"
30 changes: 30 additions & 0 deletions snowflake_utils/models/file_format.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
from pydantic import BaseModel
from typing_extensions import Self


class InlineFileFormat(BaseModel):
definition: str


class FileFormat(BaseModel):
database: str | None = None
schema_: str | None = None
name: str

def __str__(self) -> str:
return ".".join(
s for s in [self.database, self.schema_, self.name] if s is not None
)

@classmethod
def from_string(cls, s: str) -> Self:
s = s.split(".")
match s:
case [database, schema, name]:
return cls(database=database, schema_=schema, name=name)
case [schema, name]:
return cls(schema_=schema, name=name)
case [name]:
return cls(name=name)
case _:
raise ValueError("Cannot parse file format")
26 changes: 26 additions & 0 deletions snowflake_utils/models/schema.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
from pydantic import BaseModel
from snowflake.connector.cursor import SnowflakeCursor

from .table import Table


class Schema(BaseModel):
name: str
database: str | None = None

@property
def fully_qualified_name(self):
if self.database:
return f"{self.database}.{self.name}"
else:
return self.name

def get_tables(self, cursor: SnowflakeCursor):
cursor.execute(f"show tables in schema {self.fully_qualified_name};")
data = cursor.execute(
'select "name", "database_name", "schema_name" FROM TABLE(RESULT_SCAN(LAST_QUERY_ID()));'
).fetchall()
return [
Table(name=name, schema_=schema, database=database)
for (name, database, schema, *_) in data
]
Loading

0 comments on commit 1f1be19

Please sign in to comment.