diff --git a/pyproject.toml b/pyproject.toml index ad42ff18..8af68c4c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -79,14 +79,22 @@ target-version = "py38" [tool.ruff.lint] select = [ - "F", # Pyflakes - "W", # pycodestyle warnings - "E", # pycodestyle errors - "I", # isort - "N", # pep8-naming - "D", # pydocsyle - "ICN", # flake8-import-conventions - "RUF", # ruff + "F", # Pyflakes + "W", # pycodestyle warnings + "E", # pycodestyle errors + "I", # isort + "N", # pep8-naming + "D", # pydocsyle + "UP", # pyupgrade + "ICN", # flake8-import-conventions + "RET", # flake8-return + "SIM", # flake8-simplify + "TCH", # flake8-type-checking + "ERA", # eradicate + "PGH", # pygrep-hooks + "PL", # Pylint + "PERF", # Perflint + "RUF", # ruff ] [tool.ruff.lint.flake8-import-conventions] diff --git a/target_postgres/connector.py b/target_postgres/connector.py index b8d3c907..819eb507 100644 --- a/target_postgres/connector.py +++ b/target_postgres/connector.py @@ -1,10 +1,13 @@ """Handles Postgres interactions.""" + from __future__ import annotations import atexit import io +import itertools import signal +import sys import typing as t from contextlib import contextmanager from os import chmod, path @@ -79,7 +82,7 @@ def __init__(self, config: dict) -> None: sqlalchemy_url=url.render_as_string(hide_password=False), ) - def prepare_table( # type: ignore[override] + def prepare_table( # type: ignore[override] # noqa: PLR0913 self, full_table_name: str, schema: dict, @@ -105,7 +108,7 @@ def prepare_table( # type: ignore[override] meta = sa.MetaData(schema=schema_name) table: sa.Table if not self.table_exists(full_table_name=full_table_name): - table = self.create_empty_table( + return self.create_empty_table( table_name=table_name, meta=meta, schema=schema, @@ -114,7 +117,6 @@ def prepare_table( # type: ignore[override] as_temp_table=as_temp_table, connection=connection, ) - return table meta.reflect(connection, only=[table_name]) table = meta.tables[ full_table_name @@ -161,19 +163,19 @@ def copy_table_structure( _, schema_name, table_name = self.parse_full_table_name(full_table_name) meta = sa.MetaData(schema=schema_name) new_table: sa.Table - columns = [] if self.table_exists(full_table_name=full_table_name): raise RuntimeError("Table already exists") - for column in from_table.columns: - columns.append(column._copy()) + + columns = [column._copy() for column in from_table.columns] + if as_temp_table: new_table = sa.Table(table_name, meta, *columns, prefixes=["TEMPORARY"]) new_table.create(bind=connection) return new_table - else: - new_table = sa.Table(table_name, meta, *columns) - new_table.create(bind=connection) - return new_table + + new_table = sa.Table(table_name, meta, *columns) + new_table.create(bind=connection) + return new_table @contextmanager def _connect(self) -> t.Iterator[sa.engine.Connection]: @@ -184,18 +186,17 @@ def drop_table(self, table: sa.Table, connection: sa.engine.Connection): """Drop table data.""" table.drop(bind=connection) - def clone_table( + def clone_table( # noqa: PLR0913 self, new_table_name, table, metadata, connection, temp_table ) -> sa.Table: """Clone a table.""" - new_columns = [] - for column in table.columns: - new_columns.append( - sa.Column( - column.name, - column.type, - ) + new_columns = [ + sa.Column( + column.name, + column.type, ) + for column in table.columns + ] if temp_table is True: new_table = sa.Table( new_table_name, metadata, *new_columns, prefixes=["TEMPORARY"] @@ -275,9 +276,8 @@ def pick_individual_type(jsonschema_type: dict): if jsonschema_type.get("format") == "date-time": return TIMESTAMP() individual_type = th.to_sql_type(jsonschema_type) - if isinstance(individual_type, VARCHAR): - return TEXT() - return individual_type + + return TEXT() if isinstance(individual_type, VARCHAR) else individual_type @staticmethod def pick_best_sql_type(sql_type_array: list): @@ -304,13 +304,12 @@ def pick_best_sql_type(sql_type_array: list): NOTYPE, ] - for sql_type in precedence_order: - for obj in sql_type_array: - if isinstance(obj, sql_type): - return obj + for sql_type, obj in itertools.product(precedence_order, sql_type_array): + if isinstance(obj, sql_type): + return obj return TEXT() - def create_empty_table( # type: ignore[override] + def create_empty_table( # type: ignore[override] # noqa: PLR0913 self, table_name: str, meta: sa.MetaData, @@ -324,7 +323,7 @@ def create_empty_table( # type: ignore[override] Args: table_name: the target table name. - meta: the SQLAchemy metadata object. + meta: the SQLAlchemy metadata object. schema: the JSON schema for the new table. connection: the database connection. primary_keys: list of key properties. @@ -367,7 +366,7 @@ def create_empty_table( # type: ignore[override] new_table.create(bind=connection) return new_table - def prepare_column( + def prepare_column( # noqa: PLR0913 self, full_table_name: str, column_name: str, @@ -415,7 +414,7 @@ def prepare_column( column_object=column_object, ) - def _create_empty_column( # type: ignore[override] + def _create_empty_column( # type: ignore[override] # noqa: PLR0913 self, schema_name: str, table_name: str, @@ -480,7 +479,7 @@ def get_column_add_ddl( # type: ignore[override] }, ) - def _adapt_column_type( # type: ignore[override] + def _adapt_column_type( # type: ignore[override] # noqa: PLR0913 self, schema_name: str, table_name: str, @@ -523,7 +522,7 @@ def _adapt_column_type( # type: ignore[override] return # Not the same type, generic type or compatible types - # calling merge_sql_types for assistnace + # calling merge_sql_types for assistance compatible_sql_type = self.merge_sql_types([current_type, sql_type]) if str(compatible_sql_type) == str(current_type): @@ -593,17 +592,16 @@ def get_sqlalchemy_url(self, config: dict) -> str: if config.get("sqlalchemy_url"): return cast(str, config["sqlalchemy_url"]) - else: - sqlalchemy_url = URL.create( - drivername=config["dialect+driver"], - username=config["user"], - password=config["password"], - host=config["host"], - port=config["port"], - database=config["database"], - query=self.get_sqlalchemy_query(config), - ) - return cast(str, sqlalchemy_url) + sqlalchemy_url = URL.create( + drivername=config["dialect+driver"], + username=config["user"], + password=config["password"], + host=config["host"], + port=config["port"], + database=config["database"], + query=self.get_sqlalchemy_query(config), + ) + return cast(str, sqlalchemy_url) def get_sqlalchemy_query(self, config: dict) -> dict: """Get query values to be used for sqlalchemy URL creation. @@ -619,7 +617,7 @@ def get_sqlalchemy_query(self, config: dict) -> dict: # ssl_enable is for verifying the server's identity to the client. if config["ssl_enable"]: ssl_mode = config["ssl_mode"] - query.update({"sslmode": ssl_mode}) + query["sslmode"] = ssl_mode query["sslrootcert"] = self.filepath_or_certificate( value=config["ssl_certificate_authority"], alternative_name=config["ssl_storage_directory"] + "/root.crt", @@ -665,12 +663,11 @@ def filepath_or_certificate( """ if path.isfile(value): return value - else: - with open(alternative_name, "wb") as alternative_file: - alternative_file.write(value.encode("utf-8")) - if restrict_permissions: - chmod(alternative_name, 0o600) - return alternative_name + with open(alternative_name, "wb") as alternative_file: + alternative_file.write(value.encode("utf-8")) + if restrict_permissions: + chmod(alternative_name, 0o600) + return alternative_name def guess_key_type(self, key_data: str) -> paramiko.PKey: """Guess the type of the private key. @@ -695,7 +692,7 @@ def guess_key_type(self, key_data: str) -> paramiko.PKey: ): try: key = key_class.from_private_key(io.StringIO(key_data)) # type: ignore[attr-defined] - except paramiko.SSHException: + except paramiko.SSHException: # noqa: PERF203 continue else: return key @@ -715,7 +712,7 @@ def catch_signal(self, signum, frame) -> None: signum: The signal number frame: The current stack frame """ - exit(1) # Calling this to be sure atexit is called, so clean_up gets called + sys.exit(1) # Calling this to be sure atexit is called, so clean_up gets called def _get_column_type( # type: ignore[override] self, diff --git a/target_postgres/sinks.py b/target_postgres/sinks.py index 4173c5a2..6f3714c6 100644 --- a/target_postgres/sinks.py +++ b/target_postgres/sinks.py @@ -47,10 +47,7 @@ def setup(self) -> None: This method is called on Sink creation, and creates the required Schema and Table entities in the target database. """ - if self.key_properties is None or self.key_properties == []: - self.append_only = True - else: - self.append_only = False + self.append_only = self.key_properties is None or self.key_properties == [] if self.schema_name: self.connector.prepare_schema(self.schema_name) with self.connector._connect() as connection, connection.begin(): @@ -109,14 +106,14 @@ def process_batch(self, context: dict) -> None: def generate_temp_table_name(self): """Uuid temp table name.""" - # sa.exc.IdentifierError: Identifier + # sa.exc.IdentifierError: Identifier # noqa: ERA001 # 'temp_test_optional_attributes_388470e9_fbd0_47b7_a52f_d32a2ee3f5f6' # exceeds maximum length of 63 characters # Is hit if we have a long table name, there is no limit on Temporary tables # in postgres, used a guid just in case we are using the same session return f"{str(uuid.uuid4()).replace('-', '_')}" - def bulk_insert_records( # type: ignore[override] + def bulk_insert_records( # type: ignore[override] # noqa: PLR0913 self, table: sa.Table, schema: dict, @@ -156,24 +153,24 @@ def bulk_insert_records( # type: ignore[override] if self.append_only is False: insert_records: Dict[str, Dict] = {} # pk : record for record in records: - insert_record = {} - for column in columns: - insert_record[column.name] = record.get(column.name) + insert_record = { + column.name: record.get(column.name) for column in columns + } # No need to check for a KeyError here because the SDK already - # guaruntees that all key properties exist in the record. + # guarantees that all key properties exist in the record. primary_key_value = "".join([str(record[key]) for key in primary_keys]) insert_records[primary_key_value] = insert_record data_to_insert = list(insert_records.values()) else: for record in records: - insert_record = {} - for column in columns: - insert_record[column.name] = record.get(column.name) + insert_record = { + column.name: record.get(column.name) for column in columns + } data_to_insert.append(insert_record) connection.execute(insert, data_to_insert) return True - def upsert( + def upsert( # noqa: PLR0913 self, from_table: sa.Table, to_table: sa.Table, @@ -232,7 +229,7 @@ def upsert( # Update where_condition = join_condition update_columns = {} - for column_name in self.schema["properties"].keys(): + for column_name in self.schema["properties"]: from_table_column: sa.Column = from_table.columns[column_name] to_table_column: sa.Column = to_table.columns[column_name] update_columns[to_table_column] = from_table_column @@ -249,14 +246,13 @@ def column_representation( schema: dict, ) -> List[sa.Column]: """Return a sqlalchemy table representation for the current schema.""" - columns: list[sa.Column] = [] - for property_name, property_jsonschema in schema["properties"].items(): - columns.append( - sa.Column( - property_name, - self.connector.to_sql_type(property_jsonschema), - ) + columns: list[sa.Column] = [ + sa.Column( + property_name, + self.connector.to_sql_type(property_jsonschema), ) + for property_name, property_jsonschema in schema["properties"].items() + ] return columns def generate_insert_statement( @@ -286,12 +282,12 @@ def schema_name(self) -> Optional[str]: """Return the schema name or `None` if using names with no schema part. Note that after the next SDK release (after 0.14.0) we can remove this - as it's already upstreamed. + as it's already up-streamed. Returns: The target schema name. """ - # Look for a default_target_scheme in the configuraion fle + # Look for a default_target_scheme in the configuration file default_target_schema: str = self.config.get("default_target_schema", None) parts = self.stream_name.split("-") @@ -302,14 +298,7 @@ def schema_name(self) -> Optional[str]: if default_target_schema: return default_target_schema - if len(parts) in {2, 3}: - # Stream name is a two-part or three-part identifier. - # Use the second-to-last part as the schema name. - stream_schema = self.conform_name(parts[-2], "schema") - return stream_schema - - # Schema name not detected. - return None + return self.conform_name(parts[-2], "schema") if len(parts) in {2, 3} else None def activate_version(self, new_version: int) -> None: """Bump the active version of the target table. diff --git a/target_postgres/target.py b/target_postgres/target.py index 6503ca3d..ed83d6fa 100644 --- a/target_postgres/target.py +++ b/target_postgres/target.py @@ -1,13 +1,16 @@ """Postgres target class.""" from __future__ import annotations -from pathlib import PurePath +import typing as t from singer_sdk import typing as th from singer_sdk.target_base import SQLTarget from target_postgres.sinks import PostgresSink +if t.TYPE_CHECKING: + from pathlib import PurePath + class TargetPostgres(SQLTarget): """Target for Postgres.""" @@ -172,7 +175,7 @@ def __init__( th.BooleanType, default=False, description=( - "When activate version is sent from a tap this specefies " + "When activate version is sent from a tap this specifies " + "if we should delete the records that don't match, or mark " + "them with a date in the `_sdc_deleted_at` column. This config " + "option is ignored if `activate_version` is set to false." diff --git a/target_postgres/tests/test_target_postgres.py b/target_postgres/tests/test_target_postgres.py index 93f56c30..23ba3133 100644 --- a/target_postgres/tests/test_target_postgres.py +++ b/target_postgres/tests/test_target_postgres.py @@ -77,11 +77,7 @@ def singer_file_to_target(file_name, target) -> None: def remove_metadata_columns(row: dict) -> dict: - new_row = {} - for column in row.keys(): - if not column.startswith("_sdc"): - new_row[column] = row[column] - return new_row + return {column: row[column] for column in row if not column.startswith("_sdc")} def verify_data( @@ -105,35 +101,35 @@ def verify_data( engine = create_engine(target) full_table_name = f"{target.config['default_target_schema']}.{table_name}" with engine.connect() as connection: - if primary_key is not None and check_data is not None: - if isinstance(check_data, dict): - result = connection.execute( - sqlalchemy.text( - f"SELECT * FROM {full_table_name} ORDER BY {primary_key}" - ) - ) - assert result.rowcount == number_of_rows - result_dict = remove_metadata_columns(result.first()._asdict()) - assert result_dict == check_data - elif isinstance(check_data, list): - result = connection.execute( - sqlalchemy.text( - f"SELECT * FROM {full_table_name} ORDER BY {primary_key}" - ) - ) - assert result.rowcount == number_of_rows - result_dict = [ - remove_metadata_columns(row._asdict()) for row in result.all() - ] - assert result_dict == check_data - else: - raise ValueError("Invalid check_data - not dict or list of dicts") - else: + if primary_key is None or check_data is None: result = connection.execute( sqlalchemy.text(f"SELECT COUNT(*) FROM {full_table_name}") ) assert result.first()[0] == number_of_rows + elif isinstance(check_data, dict): + result = connection.execute( + sqlalchemy.text( + f"SELECT * FROM {full_table_name} ORDER BY {primary_key}" + ) + ) + assert result.rowcount == number_of_rows + result_dict = remove_metadata_columns(result.first()._asdict()) + assert result_dict == check_data + elif isinstance(check_data, list): + result = connection.execute( + sqlalchemy.text( + f"SELECT * FROM {full_table_name} ORDER BY {primary_key}" + ) + ) + assert result.rowcount == number_of_rows + result_dict = [ + remove_metadata_columns(row._asdict()) for row in result.all() + ] + assert result_dict == check_data + else: + raise ValueError("Invalid check_data - not dict or list of dicts") + def test_sqlalchemy_url_config(postgres_config_no_ssl): """Be sure that passing a sqlalchemy_url works @@ -420,7 +416,7 @@ def test_encoded_string_data(postgres_target): https://www.postgresql.org/docs/current/functions-string.html#:~:text=chr(0)%20is%20disallowed%20because%20text%20data%20types%20cannot%20store%20that%20character. chr(0) is disallowed because text data types cannot store that character. - Note you will recieve a ValueError: A string literal cannot contain NUL (0x00) characters. Which seems like a reasonable error. + Note you will receive a ValueError: A string literal cannot contain NUL (0x00) characters. Which seems like a reasonable error. See issue https://github.com/MeltanoLabs/target-postgres/issues/60 for more details. """ @@ -480,18 +476,11 @@ def test_anyof(postgres_target): # Any of nullable array of strings or single string. # {"anyOf":[{"type":"array","items":{"type":["null","string"]}},{"type":"string"},{"type":"null"}]} - if column.name == "parent_ids": - assert isinstance(column.type, ARRAY) - - # Any of nullable string. - # {"anyOf":[{"type":"string"},{"type":"null"}]} - if column.name == "commit_message": + if column.name in ["commit_message", "legacy_id"]: assert isinstance(column.type, TEXT) - # Any of nullable string or integer. - # {"anyOf":[{"type":"string"},{"type":"integer"},{"type":"null"}]} - if column.name == "legacy_id": - assert isinstance(column.type, TEXT) + elif column.name == "parent_ids": + assert isinstance(column.type, ARRAY) def test_new_array_column(postgres_target): @@ -656,7 +645,7 @@ def test_activate_version_deletes_data_properly(postgres_target): def test_reserved_keywords(postgres_target): """Target should work regardless of column names - Postgres has a number of resereved keywords listed here https://www.postgresql.org/docs/current/sql-keywords-appendix.html. + Postgres has a number of reserved keywords listed here https://www.postgresql.org/docs/current/sql-keywords-appendix.html. """ file_name = "reserved_keywords.singer" singer_file_to_target(file_name, postgres_target)