Skip to content

Commit

Permalink
[AP-1254] Refactored partialsync to use merge (#1010)
Browse files Browse the repository at this point in the history
* refactored partialsync to use merge

* fixed unify

* fix unity

* fixed pep8
  • Loading branch information
amofakhar authored Sep 7, 2022
1 parent fd7f9bb commit 355ed0c
Show file tree
Hide file tree
Showing 22 changed files with 1,198 additions and 206 deletions.
2 changes: 1 addition & 1 deletion docs/user_guide/metadata_columns.rst
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ at the end of the table:
target database as well. Please also note that Only :ref:`log_based` replication method
detects delete row events.

To turn off **Hard Delete** mode add ``hard_delete: False`` to the target :ref:`targets_list`
To turn off **Hard Delete** mode add ``hard_delete: false`` to the target :ref:`targets_list`
YAML config file. In this case when a deleted row captured in source then
``_SDC_DELETED_AT`` column will only get flagged and not get deleted in the target.
Please also note that Only :ref:`log_based` replication method detects delete row events.
Expand Down
31 changes: 14 additions & 17 deletions pipelinewise/fastsync/commons/tap_mysql.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import pymysql.cursors

from argparse import Namespace
from typing import Tuple, Dict, Callable, Union
from typing import Tuple, Dict, Callable
from pymysql import InterfaceError, OperationalError, Connection

from ...utils import safe_column_name
Expand Down Expand Up @@ -280,11 +280,11 @@ def fetch_current_incremental_key_pos(self, table, replication_key):
Get the actual incremental key position in the table
"""
result = self.query(
'SELECT MAX({}) AS key_value FROM {}'.format(replication_key, table)
f'SELECT MAX({replication_key}) AS key_value FROM {table}'
)
if not result:
raise Exception(
'Cannot get replication key value for table: {}'.format(table)
f'Cannot get replication key value for table: {table}'
)

mysql_key_value = result[0].get('key_value')
Expand All @@ -311,9 +311,8 @@ def get_primary_keys(self, table_name):
Get the primary key of a table
"""
table_dict = utils.tablename_to_dict(table_name)
sql = "SHOW KEYS FROM `{}`.`{}` WHERE Key_name = 'PRIMARY'".format(
table_dict['schema_name'], table_dict['table_name']
)
sql = f"SHOW KEYS FROM `{table_dict['schema_name']}`.`{table_dict['table_name']}` WHERE Key_name = 'PRIMARY'"

pk_specs = self.query(sql)
if len(pk_specs) > 0:
return [
Expand Down Expand Up @@ -417,7 +416,7 @@ def copy_table(
split_file_chunk_size_mb=1000,
split_file_max_chunks=20,
compress=True,
where_clause_setting=None
where_clause_sql='',
):
"""
Export data from table to a zipped csv
Expand All @@ -437,16 +436,14 @@ def copy_table(
raise Exception('{} table not found.'.format(table_name))

table_dict = utils.tablename_to_dict(table_name)
where_clause_sql = ''
if where_clause_setting:
where_clause_sql = f' WHERE {where_clause_setting["column"]} >= \'{where_clause_setting["start_value"]}\''
if where_clause_setting['end_value']:
where_clause_sql += f' AND {where_clause_setting["column"]} <= \'{where_clause_setting["end_value"]}\''

column_safe_sql_values = column_safe_sql_values + [
"CONVERT_TZ( NOW(),@@session.time_zone,'+00:00') AS `_SDC_EXTRACTED_AT`",
"CONVERT_TZ( NOW(),@@session.time_zone,'+00:00') AS `_SDC_BATCHED_AT`",
'null AS `_SDC_DELETED_AT`'
]

sql = """SELECT {}
,CONVERT_TZ( NOW(),@@session.time_zone,'+00:00') AS _SDC_EXTRACTED_AT
,CONVERT_TZ( NOW(),@@session.time_zone,'+00:00') AS _SDC_BATCHED_AT
,null AS _SDC_DELETED_AT
FROM `{}`.`{}` {}
""".format(
','.join(column_safe_sql_values),
Expand Down Expand Up @@ -500,7 +497,7 @@ def copy_table(
)

def export_source_table_data(
self, args: Namespace, tap_id: str, where_clause_setting: Union[Dict, None] = None) -> list:
self, args: Namespace, tap_id: str, where_clause_sql: str = '') -> list:
"""Export source table data"""
filename = utils.gen_export_filename(tap_id=tap_id, table=args.table, sync_type='partialsync')
filepath = os.path.join(args.temp_dir, filename)
Expand All @@ -513,7 +510,7 @@ def export_source_table_data(
split_large_files=args.target.get('split_large_files'),
split_file_chunk_size_mb=args.target.get('split_file_chunk_size_mb'),
split_file_max_chunks=args.target.get('split_file_max_chunks'),
where_clause_setting=where_clause_setting
where_clause_sql=where_clause_sql,
)
file_parts = glob.glob(f'{filepath}*')
return file_parts
Expand Down
39 changes: 15 additions & 24 deletions pipelinewise/fastsync/commons/tap_postgres.py
Original file line number Diff line number Diff line change
Expand Up @@ -337,7 +337,7 @@ def fetch_current_incremental_key_pos(self, table, replication_key):
)
if not result:
raise Exception(
'Cannot get replication key value for table: {}'.format(table)
f'Cannot get replication key value for table: {table}'
)

postgres_key_value = result[0].get('key_value')
Expand Down Expand Up @@ -477,7 +477,7 @@ def copy_table(
split_file_chunk_size_mb=1000,
split_file_max_chunks=20,
compress=True,
where_clause_setting=None
where_clause_sql='',
):
"""
Export data from table to a zipped csv
Expand All @@ -494,24 +494,20 @@ def copy_table(

# If self.get_table_columns returns zero row then table not exist
if len(column_safe_sql_values) == 0:
raise Exception('{} table not found.'.format(table_name))
raise Exception(f'{table_name} table not found.')

schema_name, table_name = table_name.split('.')

where_clause_sql = ''
if where_clause_setting:
where_clause_sql = f' WHERE {where_clause_setting["column"]} >= \'{where_clause_setting["start_value"]}\''
if where_clause_setting['end_value']:
where_clause_sql += f' AND {where_clause_setting["column"]} <= \'{where_clause_setting["end_value"]}\''

sql = """COPY (SELECT {}
,now() AT TIME ZONE 'UTC'
,now() AT TIME ZONE 'UTC'
,null
FROM {}."{}"{}) TO STDOUT with CSV DELIMITER ','
""".format(
','.join(column_safe_sql_values), schema_name, table_name, where_clause_sql
)
column_safe_sql_values = column_safe_sql_values + [
"now() AT TIME ZONE 'UTC' AS _SDC_EXTRACTED_AT",
"now() AT TIME ZONE 'UTC' AS _SDC_BATCHED_AT",
'null _SDC_DELETED_AT'
]

sql = f"""COPY (SELECT {','.join(column_safe_sql_values)}
FROM {schema_name}."{table_name}"{where_clause_sql}) TO STDOUT with CSV DELIMITER ','
"""

LOGGER.info('Exporting data: %s', sql)

gzip_splitter = split_gzip.open(
Expand All @@ -526,23 +522,18 @@ def copy_table(
self.curr.copy_expert(sql, split_gzip_files, size=131072)

def export_source_table_data(
self, args: Namespace, tap_id: str) -> list:
self, args: Namespace, tap_id: str, where_clause_sql: str = '') -> list:
"""Exporting data from the source table"""
filename = utils.gen_export_filename(tap_id=tap_id, table=args.table, sync_type='partialsync')
filepath = os.path.join(args.temp_dir, filename)

where_clause_setting = {
'column': args.column,
'start_value': args.start_value,
'end_value': args.end_value
}
self.copy_table(
args.table,
filepath,
split_large_files=args.target.get('split_large_files'),
split_file_chunk_size_mb=args.target.get('split_file_chunk_size_mb'),
split_file_max_chunks=args.target.get('split_file_max_chunks'),
where_clause_setting=where_clause_setting
where_clause_sql=where_clause_sql
)
file_parts = glob.glob(f'{filepath}*')
return file_parts
29 changes: 29 additions & 0 deletions pipelinewise/fastsync/commons/target_snowflake.py
Original file line number Diff line number Diff line change
Expand Up @@ -417,6 +417,28 @@ def obfuscate_columns(self, target_schema: str, table_name: str):

LOGGER.info('Obfuscation rules applied.')

def merge_tables(self, schema, source_table, target_table, columns, primary_keys):
on_clause = ' AND '.join(
[f'"{source_table.upper()}".{p.upper()} = "{target_table.upper()}".{p.upper()}' for p in primary_keys]
)
update_clause = ', '.join(
[f'"{target_table.upper()}".{c.upper()} = "{source_table.upper()}".{c.upper()}' for c in columns]
)
columns_for_insert = ', '.join([f'{c.upper()}' for c in columns])
values = ', '.join([f'"{source_table.upper()}".{c.upper()}' for c in columns])

query = f'MERGE INTO {schema}."{target_table.upper()}" USING {schema}."{source_table.upper()}"' \
f' ON {on_clause}' \
f' WHEN MATCHED THEN UPDATE SET {update_clause}' \
f' WHEN NOT MATCHED THEN INSERT ({columns_for_insert})' \
f' VALUES ({values})'
self.query(query)

def partial_hard_delete(self, schema, table, where_clause_sql):
self.query(
f'DELETE FROM {schema}."{table.upper()}"{where_clause_sql} AND _SDC_DELETEd_AT IS NOT NULL'
)

def swap_tables(self, schema, table_name) -> None:
"""
Swaps given target table with its temp version and drops the latter
Expand All @@ -440,6 +462,13 @@ def swap_tables(self, schema, table_name) -> None:
query_tag_props={'schema': schema, 'table': temp_table},
)

def add_columns(self, schema: str, table_name: str, adding_columns: dict) -> None:
if adding_columns:
add_columns_list = [f'{column_name} {column_type}' for column_name, column_type in adding_columns.items()]
add_clause = ', '.join(add_columns_list)
query = f'ALTER TABLE {schema}."{table_name.upper()}" ADD {add_clause}'
self.query(query)

def __apply_transformations(
self, transformations: List[Dict], target_schema: str, table_name: str
) -> None:
Expand Down
52 changes: 42 additions & 10 deletions pipelinewise/fastsync/partialsync/mysql_to_snowflake.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,13 @@
from pipelinewise.fastsync.partialsync import utils

from pipelinewise.fastsync.mysql_to_snowflake import REQUIRED_CONFIG_KEYS, tap_type_to_target_type
from pipelinewise.fastsync.partialsync.utils import load_into_snowflake, upload_to_s3, update_state_file
from pipelinewise.fastsync.partialsync.utils import (
upload_to_s3, update_state_file, diff_source_target_columns, load_into_snowflake)

LOGGER = Logger().get_logger(__name__)


# pylint: disable=too-many-locals
def partial_sync_table(args: Namespace) -> Union[bool, str]:
"""Partial sync table for MySQL to Snowflake"""
snowflake = FastSyncTargetSnowflake(args.target, args.transform)
Expand All @@ -26,22 +28,52 @@ def partial_sync_table(args: Namespace) -> Union[bool, str]:
try:
mysql = FastSyncTapMySql(args.tap, tap_type_to_target_type)

# Get bookmark - Binlog position or Incremental Key value
mysql.open_connections()
bookmark = common_utils.get_bookmark_for_table(args.table, args.properties, mysql)

where_clause_setting = {
'column': args.column,
'start_value': args.start_value,
'end_value': args.end_value
# Get column differences
target_schema = common_utils.get_target_schema(args.target, args.table)
table_dict = common_utils.tablename_to_dict(args.table)
target_table = table_dict.get('table_name')

target_sf = {
'sf_object': snowflake,
'schema': target_schema,
'table': target_table,
'temp': table_dict.get('temp_table_name')
}

file_parts = mysql.export_source_table_data(args, tap_id, where_clause_setting)
snowflake_types = mysql.map_column_types_to_target(args.table)
source_columns = snowflake_types.get('columns', [])
columns_diff = diff_source_target_columns(target_sf, source_columns=source_columns)

# Get bookmark - Binlog position or Incremental Key value
bookmark = common_utils.get_bookmark_for_table(args.table, args.properties, mysql)

where_clause_sql = f' WHERE {args.column} >= \'{args.start_value}\''
if args.end_value:
where_clause_sql += f' AND {args.column} <= \'{args.end_value}\''

# export data from source
file_parts = mysql.export_source_table_data(args, tap_id, where_clause_sql)

# mark partial data as deleted in the target
snowflake.query(f'UPDATE {target_schema}."{target_table.upper()}"'
f' SET _SDC_DELETEd_AT = CURRENT_TIMESTAMP(){where_clause_sql} AND _SDC_DELETED_AT IS NULL')

# Creating temp table in Snowflake
primary_keys = snowflake_types.get('primary_key')
snowflake.create_schema(target_schema)
snowflake.create_table(
target_schema, args.table, source_columns, primary_keys, is_temporary=True
)

mysql.close_connections()

size_bytes = sum([os.path.getsize(file_part) for file_part in file_parts])
s3_keys, s3_key_pattern = upload_to_s3(snowflake, file_parts, args.temp_dir)
load_into_snowflake(snowflake, args, s3_keys, s3_key_pattern, size_bytes)
_, s3_key_pattern = upload_to_s3(snowflake, file_parts, args.temp_dir)

load_into_snowflake(target_sf, args, columns_diff, primary_keys, s3_key_pattern, size_bytes, where_clause_sql)

update_state_file(args, bookmark)

return True
Expand Down
46 changes: 42 additions & 4 deletions pipelinewise/fastsync/partialsync/postgres_to_snowflake.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,13 @@
from pipelinewise.fastsync.postgres_to_snowflake import REQUIRED_CONFIG_KEYS, tap_type_to_target_type
from pipelinewise.fastsync.commons import utils as common_utils
from pipelinewise.fastsync.partialsync.utils import (
load_into_snowflake, upload_to_s3, update_state_file, parse_args_for_partial_sync)
upload_to_s3, update_state_file, parse_args_for_partial_sync, diff_source_target_columns, load_into_snowflake)
from pipelinewise.logger import Logger

LOGGER = Logger().get_logger(__name__)


# pylint: disable=too-many-locals
def partial_sync_table(args: Namespace) -> Union[bool, str]:
"""Partial sync table for Postgres to Snowflake"""
snowflake = FastSyncTargetSnowflake(args.target, args.transform)
Expand All @@ -26,13 +27,50 @@ def partial_sync_table(args: Namespace) -> Union[bool, str]:

# Get bookmark - Binlog position or Incremental Key value
postgres.open_connection()

# Get column differences
target_schema = common_utils.get_target_schema(args.target, args.table)
table_dict = common_utils.tablename_to_dict(args.table)
target_table = table_dict.get('table_name')

target_sf = {
'sf_object': snowflake,
'schema': target_schema,
'table': target_table,
'temp': table_dict.get('temp_table_name')
}

snowflake_types = postgres.map_column_types_to_target(args.table)
source_columns = snowflake_types.get('columns', [])
columns_diff = diff_source_target_columns(target_sf, source_columns=source_columns)

bookmark = common_utils.get_bookmark_for_table(args.table, args.properties, postgres, dbname=dbname)

file_parts = postgres.export_source_table_data(args, tap_id)
where_clause_sql = f' WHERE {args.column} >= \'{args.start_value}\''
if args.end_value:
where_clause_sql += f' AND {args.column} <= \'{args.end_value}\''

file_parts = postgres.export_source_table_data(args, tap_id, where_clause_sql)

# mark partial data as deleted in the target
snowflake.query(
f'UPDATE {target_schema}."{target_table.upper()}"'
f' SET _SDC_DELETEd_AT = CURRENT_TIMESTAMP(){where_clause_sql} AND _SDC_DELETED_AT IS NULL')

# Creating temp table in Snowflake
primary_keys = snowflake_types.get('primary_key')
snowflake.create_schema(target_schema)
snowflake.create_table(
target_schema, args.table, source_columns, primary_keys, is_temporary=True
)

postgres.close_connection()

size_bytes = sum([os.path.getsize(file_part) for file_part in file_parts])
s3_keys, s3_key_pattern = upload_to_s3(snowflake, file_parts, args.temp_dir)
load_into_snowflake(snowflake, args, s3_keys, s3_key_pattern, size_bytes)
_, s3_key_pattern = upload_to_s3(snowflake, file_parts, args.temp_dir)

load_into_snowflake(target_sf, args, columns_diff, primary_keys, s3_key_pattern, size_bytes, where_clause_sql)

update_state_file(args, bookmark)

return True
Expand Down
Loading

0 comments on commit 355ed0c

Please sign in to comment.