Skip to content

Commit

Permalink
feat: add custom merge
Browse files Browse the repository at this point in the history
  • Loading branch information
pquadri committed Jul 30, 2024
1 parent 21a68a1 commit ca6f4ce
Showing 1 changed file with 53 additions and 31 deletions.
84 changes: 53 additions & 31 deletions snowflake_utils/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,7 +196,6 @@ def bulk_insert(
def _copy(
self,
query: str,
query_args: dict,
path: str,
file_format: InlineFileFormat | FileFormat,
storage_integration: str | None = None,
Expand All @@ -212,8 +211,7 @@ def _copy(
if sync_tags and self.table_structure:
self.sync_tags(cursor)
logging.info(f"Starting copy into `{self.fqn}` from path '{path}'")
query_args = query_args | {"file_format": file_format}
return execute(query.format(**query_args))
return execute(query.format(file_format=file_format))

def copy_into(
self,
Expand All @@ -235,7 +233,6 @@ def copy_into(
MATCH_BY_COLUMN_NAME={match_by_column_name.value}
{self._include_metadata()}
""",
{},
path,
file_format,
storage_integration,
Expand Down Expand Up @@ -274,38 +271,23 @@ def exists(self, cursor: SnowflakeCursor) -> bool:
).fetchall()
)

def merge(
def _merge(
self,
path: str,
file_format: InlineFileFormat | FileFormat,
copy_callable: callable,
primary_keys: list[str] = ["id"],
replication_keys: list[str] | None = None,
storage_integration: str | None = None,
match_by_column_name: MatchByColumnName = MatchByColumnName.CASE_INSENSITIVE,
qualify: bool = False,
) -> None:
):
with connect() as connection:
cursor = connection.cursor()
if not self.exists(cursor):
self.copy_into(
path=path,
storage_integration=storage_integration,
file_format=file_format,
match_by_column_name=match_by_column_name,
sync_tags=True,
)
copy_callable(self, sync_tags=True)
if qualify:
self.qualify(cursor, primary_keys, replication_keys)
return None

temp_table = self.model_copy(update={"name": f"{self.name}_temp"})
temp_table.copy_into(
path=path,
storage_integration=storage_integration,
file_format=file_format,
match_by_column_name=match_by_column_name,
full_refresh=True,
)
copy_callable(temp_table, sync_tags=False)
if qualify:
with connect() as connection:
cursor = connection.cursor()
Expand All @@ -330,6 +312,27 @@ def merge(
self.sync_tags(cursor)
temp_table.drop(cursor)

def merge(
self,
path: str,
file_format: InlineFileFormat | FileFormat,
primary_keys: list[str] = ["id"],
replication_keys: list[str] | None = None,
storage_integration: str | None = None,
match_by_column_name: MatchByColumnName = MatchByColumnName.CASE_INSENSITIVE,
qualify: bool = False,
) -> None:
def copy_callable(table: Table, sync_tags: bool) -> None:
return table.copy_into(
path=path,
storage_integration=storage_integration,
file_format=file_format,
match_by_column_name=match_by_column_name,
sync_tags=sync_tags,
)

return self._merge(copy_callable, primary_keys, replication_keys, qualify)

def qualify(
self,
cursor: SnowflakeCursor,
Expand Down Expand Up @@ -455,20 +458,39 @@ def copy_custom(
full_refresh: bool = False,
sync_tags: bool = False,
) -> None:
return self._copy(
query = (
f"""
COPY INTO {self.fqn} ({", ".join(column_definitions.keys())})
FROM @{self.temporary_stage}/
FILE_FORMAT = ( FORMAT_NAME ='{{file_format}}')
{self._include_metadata()}
""",
{},
path,
file_format,
storage_integration,
full_refresh,
sync_tags,
)
return self._copy(
query, path, file_format, storage_integration, full_refresh, sync_tags
)

def merge_custom(
self,
column_definitions: dict[str, str],
path: str,
file_format: InlineFileFormat | FileFormat,
primary_keys: list[str] = ["id"],
replication_keys: list[str] | None = None,
storage_integration: str | None = None,
qualify: bool = False,
) -> None:
def copy_callable(table: Table, sync_tags: bool) -> None:
return table.copy_custom(
column_definitions,
path=path,
storage_integration=storage_integration,
file_format=file_format,
full_refresh=True,
sync_tags=sync_tags,
)

return self._merge(copy_callable, primary_keys, replication_keys, qualify)

def setup_connection(
self, path: str, storage_integration: str, cursor: SnowflakeCursor
Expand Down

0 comments on commit ca6f4ce

Please sign in to comment.