From 166f7b59fba2ce7fc0e88f19380e37b3b79c441d Mon Sep 17 00:00:00 2001 From: aguiddir Date: Thu, 7 Dec 2023 14:36:53 +0100 Subject: [PATCH] =?UTF-8?q?=F0=9F=90=9B=20Ensure=20compatibility=20with=20?= =?UTF-8?q?Python=203.8?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Yoann Isaac --- datacompy/core.py | 18 ++++++------- datacompy/fugue.py | 34 ++++++++++++------------ datacompy/spark.py | 64 +++++++++++++++++++++++----------------------- pyproject.toml | 4 +++ 4 files changed, 62 insertions(+), 58 deletions(-) diff --git a/datacompy/core.py b/datacompy/core.py index ca613a1e..3f1f47c0 100644 --- a/datacompy/core.py +++ b/datacompy/core.py @@ -22,7 +22,7 @@ """ import logging import os -from typing import cast, Any +from typing import cast, Any, List, Dict, Union, Optional import numpy as np import pandas as pd @@ -81,7 +81,7 @@ def __init__( self, df1: pd.DataFrame, df2: pd.DataFrame, - join_columns: list[str] | str | None = None, + join_columns: Optional[Union[List[str], str]] = None, on_index: bool = False, abs_tol: float = 0, rel_tol: float = 0, @@ -107,7 +107,7 @@ def __init__( else: self.join_columns = [ str(col).lower() if self.cast_column_names_lower else str(col) - for col in cast(list[str], join_columns) + for col in cast(List[str], join_columns) ] self.on_index = False @@ -123,7 +123,7 @@ def __init__( self.df1_unq_rows: pd.DataFrame self.df2_unq_rows: pd.DataFrame self.intersect_rows: pd.DataFrame - self.column_stats: list[dict[str, Any]] = [] + self.column_stats: List[Dict[str, Any]] = [] self._compare(ignore_spaces=ignore_spaces, ignore_case=ignore_case) @property @@ -243,7 +243,7 @@ def _dataframe_merge(self, ignore_spaces: bool) -> None: If ``on_index`` is True, this will join on index values, otherwise it will join on the ``join_columns``. """ - params: dict[str, Any] + params: Dict[str, Any] index_column: str LOG.debug("Outer joining") if self._any_dupes: @@ -533,7 +533,7 @@ def report( self, sample_count: int = 10, column_count: int = 10, - html_file: str | None = None, + html_file: Optional[str] = None, ) -> str: """Returns a string representation of a report. The representation can then be printed or saved to a file. @@ -696,7 +696,7 @@ def df_to_str(pdf: pd.DataFrame) -> str: return report -def render(filename: str, *fields: int | float | str) -> str: +def render(filename: str, *fields: Union[int, float, str]) -> str: """Renders out an individual template. This basically just reads in a template file, and applies ``.format()`` on the fields. @@ -844,7 +844,7 @@ def compare_string_and_date_columns( def get_merged_columns( original_df: pd.DataFrame, merged_df: pd.DataFrame, suffix: str -) -> list[str]: +) -> List[str]: """Gets the columns from an original dataframe, in the new merged dataframe Parameters @@ -915,7 +915,7 @@ def calculate_max_diff(col_1: "pd.Series[Any]", col_2: "pd.Series[Any]") -> floa def generate_id_within_group( - dataframe: pd.DataFrame, join_columns: list[str] + dataframe: pd.DataFrame, join_columns: List[str] ) -> "pd.Series[int]": """Generate an ID column that can be used to deduplicate identical rows. The series generated is the order within a unique group, and it handles nulls. diff --git a/datacompy/fugue.py b/datacompy/fugue.py index 86f2da50..9a0109f4 100644 --- a/datacompy/fugue.py +++ b/datacompy/fugue.py @@ -20,7 +20,7 @@ import logging import pickle from collections import defaultdict -from typing import Any, Callable, Dict, Iterable, List, cast, Union, Optional +from typing import Any, Callable, Dict, Iterable, List, cast, Union, Optional, Tuple import fugue.api as fa import pandas as pd @@ -99,7 +99,7 @@ def all_columns_match(df1: AnyDataFrame, df2: AnyDataFrame) -> bool: def is_match( df1: AnyDataFrame, df2: AnyDataFrame, - join_columns: str | list[str], + join_columns: Union[str, List[str]], abs_tol: float = 0, rel_tol: float = 0, df1_name: str = "df1", @@ -107,7 +107,7 @@ def is_match( ignore_spaces: bool = False, ignore_case: bool = False, cast_column_names_lower: bool = True, - parallelism: int | None = None, + parallelism: Optional[int] = None, strict_schema: bool = False, ) -> bool: """Check whether two dataframes match. @@ -198,7 +198,7 @@ def is_match( def all_rows_overlap( df1: AnyDataFrame, df2: AnyDataFrame, - join_columns: str | list[str], + join_columns: Union[str, List[str]], abs_tol: float = 0, rel_tol: float = 0, df1_name: str = "df1", @@ -206,7 +206,7 @@ def all_rows_overlap( ignore_spaces: bool = False, ignore_case: bool = False, cast_column_names_lower: bool = True, - parallelism: int | None = None, + parallelism: Optional[int] = None, strict_schema: bool = False, ) -> bool: """Check if the rows are all present in both dataframes @@ -294,7 +294,7 @@ def all_rows_overlap( def report( df1: AnyDataFrame, df2: AnyDataFrame, - join_columns: str | list[str], + join_columns: Union[str, List[str]], abs_tol: float = 0, rel_tol: float = 0, df1_name: str = "df1", @@ -304,8 +304,8 @@ def report( cast_column_names_lower: bool = True, sample_count: int = 10, column_count: int = 10, - html_file: str | None = None, - parallelism: int | None = None, + html_file: Optional[str] = None, + parallelism: Optional[int] = None, ) -> str: """Returns a string representation of a report. The representation can then be printed or saved to a file. @@ -320,7 +320,7 @@ def report( First dataframe to check df2 : ``AnyDataFrame`` Second dataframe to check - join_columns : list or str, optional + join_columns : list or str Column(s) to join dataframes on. If a string is passed in, that one column will be used. abs_tol : float, optional @@ -454,8 +454,8 @@ def _any(col: str) -> int: "Yes" if _any("_any_dupes") else "No", ) - column_stats: list[dict[str, Any]] - match_sample: list[pd.DataFrame] + column_stats: List[Dict[str, Any]] + match_sample: List[pd.DataFrame] column_stats, match_sample = _aggregate_stats(res, sample_count=sample_count) any_mismatch = len(match_sample) > 0 @@ -671,7 +671,7 @@ def _serialize(dfs: Iterable[pd.DataFrame], left: bool) -> Iterable[Dict[str, An ) def _deserialize( - df: list[dict[str, Any]], left: bool, schema: Schema + df: List[Dict[str, Any]], left: bool, schema: Schema ) -> pd.DataFrame: arr = [pickle.loads(r["data"]) for r in df if r["left"] == left] if len(arr) > 0: @@ -682,7 +682,7 @@ def _deserialize( # The following is how to construct an empty pandas dataframe with # the correct schema, it avoids pandas schema inference which is wrong. # This is not needed when upgrading to Fugue >= 0.8.7 - sample_row: list[Any] = [] + sample_row: List[Any] = [] for field in schema.fields: if pa.types.is_string(field.type): sample_row.append("x") @@ -728,7 +728,7 @@ def _comp(df: List[Dict[str, Any]]) -> List[List[Any]]: def _get_compare_result( compare: Compare, sample_count: int, column_count: int -) -> dict[str, Any]: +) -> Dict[str, Any]: mismatch_samples: Dict[str, pd.DataFrame] = {} for column in compare.column_stats: if not column["all_match"]: @@ -777,8 +777,8 @@ def _get_compare_result( def _aggregate_stats( - compares: list[Any], sample_count: int -) -> tuple[list[dict[str, Any]], list[pd.DataFrame]]: + compares: List[Any], sample_count: int +) -> Tuple[List[Dict[str, Any]], List[pd.DataFrame]]: samples = defaultdict(list) stats = [] for compare in compares: @@ -804,7 +804,7 @@ def _aggregate_stats( .reset_index(drop=False) ) return cast( - tuple[list[dict[str, Any]], list[pd.DataFrame]], + Tuple[List[Dict[str, Any]], List[pd.DataFrame]], ( df.to_dict(orient="records"), [ diff --git a/datacompy/spark.py b/datacompy/spark.py index 9a33097d..1086c8d0 100644 --- a/datacompy/spark.py +++ b/datacompy/spark.py @@ -16,7 +16,7 @@ import sys from enum import Enum from itertools import chain -from typing import Any, TextIO, TypeAlias +from typing import Any, TextIO, NewType, List, Union, Tuple, Optional, Dict, Set import pyspark @@ -25,7 +25,7 @@ except ImportError: pass # Let non-Spark people at least enjoy the loveliness of the pandas datacompy functionality -SparkDataFrame: TypeAlias = pyspark.sql.DataFrame +SparkDataFrame = pyspark.sql.DataFrame class MatchType(Enum): @@ -146,12 +146,12 @@ class SparkCompare: def __init__( self, spark_session: pyspark.sql.SparkSession, - base_df: pyspark.sql.DataFrame, - compare_df: pyspark.sql.DataFrame, - join_columns: list[str | tuple[str, str]], - column_mapping: list[tuple[str, str]] | None = None, + base_df: SparkDataFrame, + compare_df: SparkDataFrame, + join_columns: List[Union[str, Tuple[str, str]]], + column_mapping: Optional[List[Tuple[str, str]]] = None, cache_intermediates: bool = False, - known_differences: list[dict[str, Any]] | None = None, + known_differences: Optional[List[Dict[str, Any]]] = None, rel_tol: float = 0, abs_tol: float = 0, show_all_columns: bool = False, @@ -186,15 +186,15 @@ def __init__( self.spark = spark_session self.base_unq_rows = self.compare_unq_rows = None - self._base_row_count: int | None = None - self._compare_row_count: int | None = None - self._common_row_count: int | None = None - self._joined_dataframe: SparkDataFrame | None = None - self._rows_only_base: SparkDataFrame | None = None - self._rows_only_compare: SparkDataFrame | None = None - self._all_matched_rows: SparkDataFrame | None = None - self._all_rows_mismatched: SparkDataFrame | None = None - self.columns_match_dict: dict[str, Any] = {} + self._base_row_count: Optional[int] = None + self._compare_row_count: Optional[int] = None + self._common_row_count: Optional[int] = None + self._joined_dataframe: Optional[SparkDataFrame] = None + self._rows_only_base: Optional[SparkDataFrame] = None + self._rows_only_compare: Optional[SparkDataFrame] = None + self._all_matched_rows: Optional[SparkDataFrame] = None + self._all_rows_mismatched: Optional[SparkDataFrame] = None + self.columns_match_dict: Dict[str, Any] = {} # drop the duplicates before actual comparison made. self.base_df = base_df.dropDuplicates(self._join_column_names) @@ -207,9 +207,9 @@ def __init__( self._compare_row_count = self.compare_df.count() def _tuplizer( - self, input_list: list[str | tuple[str, str]] - ) -> list[tuple[str, str]]: - join_columns: list[tuple[str, str]] = [] + self, input_list: List[Union[str, Tuple[str, str]]] + ) -> List[Tuple[str, str]]: + join_columns: List[Tuple[str, str]] = [] for val in input_list: if isinstance(val, str): join_columns.append((val, val)) @@ -219,12 +219,12 @@ def _tuplizer( return join_columns @property - def columns_in_both(self) -> set[str]: + def columns_in_both(self) -> Set[str]: """set[str]: Get columns in both dataframes""" return set(self.base_df.columns) & set(self.compare_df.columns) @property - def columns_compared(self) -> list[str]: + def columns_compared(self) -> List[str]: """list[str]: Get columns to be compared in both dataframes (all columns in both excluding the join key(s)""" return [ @@ -234,12 +234,12 @@ def columns_compared(self) -> list[str]: ] @property - def columns_only_base(self) -> set[str]: + def columns_only_base(self) -> Set[str]: """set[str]: Get columns that are unique to the base dataframe""" return set(self.base_df.columns) - set(self.compare_df.columns) @property - def columns_only_compare(self) -> set[str]: + def columns_only_compare(self) -> Set[str]: """set[str]: Get columns that are unique to the compare dataframe""" return set(self.compare_df.columns) - set(self.base_df.columns) @@ -268,13 +268,13 @@ def common_row_count(self) -> int: return self._common_row_count - def _get_unq_base_rows(self) -> pyspark.sql.DataFrame: + def _get_unq_base_rows(self) -> SparkDataFrame: """Get the rows only from base data frame""" return self.base_df.select(self._join_column_names).subtract( self.compare_df.select(self._join_column_names) ) - def _get_compare_rows(self) -> pyspark.sql.DataFrame: + def _get_compare_rows(self) -> SparkDataFrame: """Get the rows only from compare data frame""" return self.compare_df.select(self._join_column_names).subtract( self.base_df.select(self._join_column_names) @@ -329,7 +329,7 @@ def _print_only_columns(self, base_or_compare: str, myfile: TextIO) -> None: col_type = df.select(column).dtypes[0][1] print((format_pattern + " {:13s}").format(column, col_type), file=myfile) - def _columns_with_matching_schema(self) -> dict[str, str]: + def _columns_with_matching_schema(self) -> Dict[str, str]: """This function will identify the columns which has matching schema""" col_schema_match = {} base_columns_dict = dict(self.base_df.dtypes) @@ -343,7 +343,7 @@ def _columns_with_matching_schema(self) -> dict[str, str]: return col_schema_match - def _columns_with_schemadiff(self) -> dict[str, dict[str, str]]: + def _columns_with_schemadiff(self) -> Dict[str, Dict[str, str]]: """This function will identify the columns which has different schema""" col_schema_diff = {} base_columns_dict = dict(self.base_df.dtypes) @@ -363,7 +363,7 @@ def _columns_with_schemadiff(self) -> dict[str, dict[str, str]]: return col_schema_diff @property - def rows_both_mismatch(self) -> SparkDataFrame | None: + def rows_both_mismatch(self) -> Optional[SparkDataFrame]: """pyspark.sql.DataFrame: Returns all rows in both dataframes that have mismatches""" if self._all_rows_mismatched is None: self._merge_dataframes() @@ -371,7 +371,7 @@ def rows_both_mismatch(self) -> SparkDataFrame | None: return self._all_rows_mismatched @property - def rows_both_all(self) -> SparkDataFrame | None: + def rows_both_all(self) -> Optional[SparkDataFrame]: """pyspark.sql.DataFrame: Returns all rows in both dataframes""" if self._all_matched_rows is None: self._merge_dataframes() @@ -379,7 +379,7 @@ def rows_both_all(self) -> SparkDataFrame | None: return self._all_matched_rows @property - def rows_only_base(self) -> SparkDataFrame | None: + def rows_only_base(self) -> SparkDataFrame: """pyspark.sql.DataFrame: Returns rows only in the base dataframe""" if not self._rows_only_base: base_rows = self._get_unq_base_rows() @@ -399,7 +399,7 @@ def rows_only_base(self) -> SparkDataFrame | None: return self._rows_only_base @property - def rows_only_compare(self) -> SparkDataFrame | None: + def rows_only_compare(self) -> Optional[SparkDataFrame]: """pyspark.sql.DataFrame: Returns rows only in the compare dataframe""" if not self._rows_only_compare: compare_rows = self._get_compare_rows() @@ -475,7 +475,7 @@ def _merge_dataframes(self) -> None: self._join_column_names # type: ignore[arg-type] ) - def _get_or_create_joined_dataframe(self) -> pyspark.sql.DataFrame: + def _get_or_create_joined_dataframe(self) -> SparkDataFrame: if self._joined_dataframe is None: join_condition = " AND ".join( ["A." + name + "<=>B." + name for name in self._join_column_names] diff --git a/pyproject.toml b/pyproject.toml index 2cbf8c3d..9c9572e6 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -95,6 +95,10 @@ dev = [ "datacompy[build]", ] +mypy = [ + "pandas-stubs", +] + [tool.isort] multi_line_output = 3 include_trailing_comma = true