diff --git a/snowflake_utils/models.py b/snowflake_utils/models.py index eabacb3..a2b8010 100644 --- a/snowflake_utils/models.py +++ b/snowflake_utils/models.py @@ -196,7 +196,6 @@ def bulk_insert( def _copy( self, query: str, - query_args: dict, path: str, file_format: InlineFileFormat | FileFormat, storage_integration: str | None = None, @@ -212,8 +211,7 @@ def _copy( if sync_tags and self.table_structure: self.sync_tags(cursor) logging.info(f"Starting copy into `{self.fqn}` from path '{path}'") - query_args = query_args | {"file_format": file_format} - return execute(query.format(**query_args)) + return execute(query.format(file_format=file_format)) def copy_into( self, @@ -235,7 +233,6 @@ def copy_into( MATCH_BY_COLUMN_NAME={match_by_column_name.value} {self._include_metadata()} """, - {}, path, file_format, storage_integration, @@ -274,38 +271,23 @@ def exists(self, cursor: SnowflakeCursor) -> bool: ).fetchall() ) - def merge( + def _merge( self, - path: str, - file_format: InlineFileFormat | FileFormat, + copy_callable: callable, primary_keys: list[str] = ["id"], replication_keys: list[str] | None = None, - storage_integration: str | None = None, - match_by_column_name: MatchByColumnName = MatchByColumnName.CASE_INSENSITIVE, qualify: bool = False, - ) -> None: + ): with connect() as connection: cursor = connection.cursor() if not self.exists(cursor): - self.copy_into( - path=path, - storage_integration=storage_integration, - file_format=file_format, - match_by_column_name=match_by_column_name, - sync_tags=True, - ) + copy_callable(self, sync_tags=True) if qualify: self.qualify(cursor, primary_keys, replication_keys) return None temp_table = self.model_copy(update={"name": f"{self.name}_temp"}) - temp_table.copy_into( - path=path, - storage_integration=storage_integration, - file_format=file_format, - match_by_column_name=match_by_column_name, - full_refresh=True, - ) + copy_callable(temp_table, sync_tags=False) if qualify: with connect() as connection: cursor = connection.cursor() @@ -330,6 +312,27 @@ def merge( self.sync_tags(cursor) temp_table.drop(cursor) + def merge( + self, + path: str, + file_format: InlineFileFormat | FileFormat, + primary_keys: list[str] = ["id"], + replication_keys: list[str] | None = None, + storage_integration: str | None = None, + match_by_column_name: MatchByColumnName = MatchByColumnName.CASE_INSENSITIVE, + qualify: bool = False, + ) -> None: + def copy_callable(table: Table, sync_tags: bool) -> None: + return table.copy_into( + path=path, + storage_integration=storage_integration, + file_format=file_format, + match_by_column_name=match_by_column_name, + sync_tags=sync_tags, + ) + + return self._merge(copy_callable, primary_keys, replication_keys, qualify) + def qualify( self, cursor: SnowflakeCursor, @@ -455,20 +458,39 @@ def copy_custom( full_refresh: bool = False, sync_tags: bool = False, ) -> None: - return self._copy( + query = ( f""" COPY INTO {self.fqn} ({", ".join(column_definitions.keys())}) FROM @{self.temporary_stage}/ FILE_FORMAT = ( FORMAT_NAME ='{{file_format}}') {self._include_metadata()} """, - {}, - path, - file_format, - storage_integration, - full_refresh, - sync_tags, ) + return self._copy( + query, path, file_format, storage_integration, full_refresh, sync_tags + ) + + def merge_custom( + self, + column_definitions: dict[str, str], + path: str, + file_format: InlineFileFormat | FileFormat, + primary_keys: list[str] = ["id"], + replication_keys: list[str] | None = None, + storage_integration: str | None = None, + qualify: bool = False, + ) -> None: + def copy_callable(table: Table, sync_tags: bool) -> None: + return table.copy_custom( + column_definitions, + path=path, + storage_integration=storage_integration, + file_format=file_format, + full_refresh=True, + sync_tags=sync_tags, + ) + + return self._merge(copy_callable, primary_keys, replication_keys, qualify) def setup_connection( self, path: str, storage_integration: str, cursor: SnowflakeCursor