diff --git a/snowflake_utils/models.py b/snowflake_utils/models.py index d991d1e..14c975c 100644 --- a/snowflake_utils/models.py +++ b/snowflake_utils/models.py @@ -230,23 +230,44 @@ def copy_into( full_refresh: bool = False, target_columns: list[str] | None = None, sync_tags: bool = False, + primary_keys: list[str] = ["id"], + replication_keys: list[str] | None = None, + qualify: bool = False, ) -> None: col_str = f"({', '.join(target_columns)})" if target_columns else "" - return self._copy( - f""" + copy_query = f""" COPY INTO {self.fqn} {col_str} FROM {path} {f"STORAGE_INTEGRATION = {storage_integration}" if storage_integration else ''} FILE_FORMAT = ( FORMAT_NAME ='{{file_format}}') MATCH_BY_COLUMN_NAME={match_by_column_name.value} {self._include_metadata()} - """, - path, - file_format, - storage_integration, - full_refresh, - sync_tags, - ) + """ + if qualify: + self._copy( + copy_query, + path, + file_format, + storage_integration, + full_refresh, + sync_tags, + ) + with connect() as connection: + cursor = connection.cursor() + self.qualify( + cursor=cursor, + primary_keys=primary_keys, + replication_keys=replication_keys, + ) + else: + return self._copy( + copy_query, + path, + file_format, + storage_integration, + full_refresh, + sync_tags, + ) def create_table(self, full_refresh: bool, execute_statement: callable) -> None: execute_statement(self.get_create_table_statement(full_refresh))