From 2ed5d515997c8c2760f23a819c2c5779df385569 Mon Sep 17 00:00:00 2001 From: amazigh <76942612+aguiddir@users.noreply.github.com> Date: Fri, 8 Dec 2023 02:05:27 +0100 Subject: [PATCH] =?UTF-8?q?=E2=9C=A8=20Add=20mypy=20for=20static=20type=20?= =?UTF-8?q?checking=20and=20type=20annotations=20(#249)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- datacompy/core.py | 156 +++++++++++++++++++++++++++------------------ datacompy/fugue.py | 42 +++++++----- datacompy/py.typed | 0 datacompy/spark.py | 142 ++++++++++++++++++++++------------------- pyproject.toml | 13 ++++ 5 files changed, 211 insertions(+), 142 deletions(-) create mode 100644 datacompy/py.typed diff --git a/datacompy/core.py b/datacompy/core.py index 9213c0e7..fda507fb 100644 --- a/datacompy/core.py +++ b/datacompy/core.py @@ -20,9 +20,9 @@ PROC COMPARE in SAS - i.e. human-readable reporting on the difference between two dataframes. """ - import logging import os +from typing import cast, Any, List, Dict, Union, Optional import numpy as np import pandas as pd @@ -79,18 +79,18 @@ class Compare: def __init__( self, - df1, - df2, - join_columns=None, - on_index=False, - abs_tol=0, - rel_tol=0, - df1_name="df1", - df2_name="df2", - ignore_spaces=False, - ignore_case=False, - cast_column_names_lower=True, - ): + df1: pd.DataFrame, + df2: pd.DataFrame, + join_columns: Optional[Union[List[str], str]] = None, + on_index: bool = False, + abs_tol: float = 0, + rel_tol: float = 0, + df1_name: str = "df1", + df2_name: str = "df2", + ignore_spaces: bool = False, + ignore_case: bool = False, + cast_column_names_lower: bool = True, + ) -> None: self.cast_column_names_lower = cast_column_names_lower if on_index and join_columns is not None: raise Exception("Only provide on_index or join_columns") @@ -107,11 +107,11 @@ def __init__( else: self.join_columns = [ str(col).lower() if self.cast_column_names_lower else str(col) - for col in join_columns + for col in cast(List[str], join_columns) ] self.on_index = False - self._any_dupes = False + self._any_dupes: bool = False self.df1 = df1 self.df2 = df2 self.df1_name = df1_name @@ -120,16 +120,18 @@ def __init__( self.rel_tol = rel_tol self.ignore_spaces = ignore_spaces self.ignore_case = ignore_case - self.df1_unq_rows = self.df2_unq_rows = self.intersect_rows = None - self.column_stats = [] - self._compare(ignore_spaces, ignore_case) + self.df1_unq_rows: pd.DataFrame + self.df2_unq_rows: pd.DataFrame + self.intersect_rows: pd.DataFrame + self.column_stats: List[Dict[str, Any]] = [] + self._compare(ignore_spaces=ignore_spaces, ignore_case=ignore_case) @property - def df1(self): + def df1(self) -> pd.DataFrame: return self._df1 @df1.setter - def df1(self, df1): + def df1(self, df1: pd.DataFrame) -> None: """Check that it is a dataframe and has the join columns""" self._df1 = df1 self._validate_dataframe( @@ -137,18 +139,20 @@ def df1(self, df1): ) @property - def df2(self): + def df2(self) -> pd.DataFrame: return self._df2 @df2.setter - def df2(self, df2): + def df2(self, df2: pd.DataFrame) -> None: """Check that it is a dataframe and has the join columns""" self._df2 = df2 self._validate_dataframe( "df2", cast_column_names_lower=self.cast_column_names_lower ) - def _validate_dataframe(self, index, cast_column_names_lower=True): + def _validate_dataframe( + self, index: str, cast_column_names_lower: bool = True + ) -> None: """Check that it is a dataframe and has the join columns Parameters @@ -163,9 +167,11 @@ def _validate_dataframe(self, index, cast_column_names_lower=True): raise TypeError(f"{index} must be a pandas DataFrame") if cast_column_names_lower: - dataframe.columns = [str(col).lower() for col in dataframe.columns] + dataframe.columns = pd.Index( + [str(col).lower() for col in dataframe.columns] + ) else: - dataframe.columns = [str(col) for col in dataframe.columns] + dataframe.columns = pd.Index([str(col) for col in dataframe.columns]) # Check if join_columns are present in the dataframe if not set(self.join_columns).issubset(set(dataframe.columns)): raise ValueError(f"{index} must have all columns from join_columns") @@ -182,7 +188,7 @@ def _validate_dataframe(self, index, cast_column_names_lower=True): ): self._any_dupes = True - def _compare(self, ignore_spaces, ignore_case): + def _compare(self, ignore_spaces: bool, ignore_case: bool) -> None: """Actually run the comparison. This tries to run df1.equals(df2) first so that if they're truly equal we can tell. @@ -214,26 +220,31 @@ def _compare(self, ignore_spaces, ignore_case): else: LOG.info("df1 does not match df2") - def df1_unq_columns(self): + def df1_unq_columns(self) -> OrderedSet[str]: """Get columns that are unique to df1""" - return OrderedSet(self.df1.columns) - OrderedSet(self.df2.columns) + return cast( + OrderedSet[str], OrderedSet(self.df1.columns) - OrderedSet(self.df2.columns) + ) - def df2_unq_columns(self): + def df2_unq_columns(self) -> OrderedSet[str]: """Get columns that are unique to df2""" - return OrderedSet(self.df2.columns) - OrderedSet(self.df1.columns) + return cast( + OrderedSet[str], OrderedSet(self.df2.columns) - OrderedSet(self.df1.columns) + ) - def intersect_columns(self): + def intersect_columns(self) -> OrderedSet[str]: """Get columns that are shared between the two dataframes""" return OrderedSet(self.df1.columns) & OrderedSet(self.df2.columns) - def _dataframe_merge(self, ignore_spaces): + def _dataframe_merge(self, ignore_spaces: bool) -> None: """Merge df1 to df2 on the join columns, to get df1 - df2, df2 - df1 and df1 & df2 If ``on_index`` is True, this will join on index values, otherwise it will join on the ``join_columns``. """ - + params: Dict[str, Any] + index_column: str LOG.debug("Outer joining") if self._any_dupes: LOG.debug("Duplicate rows found, deduping by order of remaining fields") @@ -275,11 +286,10 @@ def _dataframe_merge(self, ignore_spaces): # Clean up temp columns for duplicate row matching if self._any_dupes: if self.on_index: - outer_join.index = outer_join[index_column] - outer_join.drop(index_column, axis=1, inplace=True) + outer_join.set_index(keys=index_column, drop=True, inplace=True) self.df1.drop(index_column, axis=1, inplace=True) self.df2.drop(index_column, axis=1, inplace=True) - outer_join.drop(order_column, axis=1, inplace=True) + outer_join.drop(labels=order_column, axis=1, inplace=True) self.df1.drop(order_column, axis=1, inplace=True) self.df2.drop(order_column, axis=1, inplace=True) @@ -306,7 +316,7 @@ def _dataframe_merge(self, ignore_spaces): f"Number of rows in df1 and df2 (not necessarily equal): {len(self.intersect_rows)}" ) - def _intersect_compare(self, ignore_spaces, ignore_case): + def _intersect_compare(self, ignore_spaces: bool, ignore_case: bool) -> None: """Run the comparison on the intersect dataframe This loops through all columns that are shared between df1 and df2, and @@ -319,7 +329,7 @@ def _intersect_compare(self, ignore_spaces, ignore_case): if column in self.join_columns: match_cnt = row_cnt col_match = "" - max_diff = 0 + max_diff = 0.0 null_diff = 0 else: col_1 = column + "_df1" @@ -367,11 +377,11 @@ def _intersect_compare(self, ignore_spaces, ignore_case): } ) - def all_columns_match(self): + def all_columns_match(self) -> bool: """Whether the columns all match in the dataframes""" return self.df1_unq_columns() == self.df2_unq_columns() == set() - def all_rows_overlap(self): + def all_rows_overlap(self) -> bool: """Whether the rows are all present in both dataframes Returns @@ -382,7 +392,7 @@ def all_rows_overlap(self): """ return len(self.df1_unq_rows) == len(self.df2_unq_rows) == 0 - def count_matching_rows(self): + def count_matching_rows(self) -> int: """Count the number of rows match (on overlapping fields) Returns @@ -396,12 +406,12 @@ def count_matching_rows(self): match_columns.append(column + "_match") return self.intersect_rows[match_columns].all(axis=1).sum() - def intersect_rows_match(self): + def intersect_rows_match(self) -> bool: """Check whether the intersect rows all match""" actual_length = self.intersect_rows.shape[0] return self.count_matching_rows() == actual_length - def matches(self, ignore_extra_columns=False): + def matches(self, ignore_extra_columns: bool = False) -> bool: """Return True or False if the dataframes match. Parameters @@ -418,7 +428,7 @@ def matches(self, ignore_extra_columns=False): else: return True - def subset(self): + def subset(self) -> bool: """Return True if dataframe 2 is a subset of dataframe 1. Dataframe 2 is considered a subset if all of its columns are in @@ -434,7 +444,9 @@ def subset(self): else: return True - def sample_mismatch(self, column, sample_count=10, for_display=False): + def sample_mismatch( + self, column: str, sample_count: int = 10, for_display: bool = False + ) -> pd.DataFrame: """Returns a sample sub-dataframe which contains the identifying columns, and df1 and df2 versions of the column. @@ -463,13 +475,16 @@ def sample_mismatch(self, column, sample_count=10, for_display=False): return_cols = self.join_columns + [column + "_df1", column + "_df2"] to_return = sample[return_cols] if for_display: - to_return.columns = self.join_columns + [ - column + " (" + self.df1_name + ")", - column + " (" + self.df2_name + ")", - ] + to_return.columns = pd.Index( + self.join_columns + + [ + column + " (" + self.df1_name + ")", + column + " (" + self.df2_name + ")", + ] + ) return to_return - def all_mismatch(self, ignore_matching_cols=False): + def all_mismatch(self, ignore_matching_cols: bool = False) -> pd.DataFrame: """All rows with any columns that have a mismatch. Returns all df1 and df2 versions of the columns and join columns. @@ -512,7 +527,12 @@ def all_mismatch(self, ignore_matching_cols=False): mm_bool = self.intersect_rows[match_list].all(axis="columns") return self.intersect_rows[~mm_bool][self.join_columns + return_list] - def report(self, sample_count=10, column_count=10, html_file=None): + def report( + self, + sample_count: int = 10, + column_count: int = 10, + html_file: Optional[str] = None, + ) -> str: """Returns a string representation of a report. The representation can then be printed or saved to a file. @@ -533,7 +553,7 @@ def report(self, sample_count=10, column_count=10, html_file=None): The report, formatted kinda nicely. """ - def df_to_str(pdf): + def df_to_str(pdf: pd.DataFrame) -> str: if not self.on_index: pdf = pdf.reset_index(drop=True) return pdf.to_string() @@ -674,7 +694,7 @@ def df_to_str(pdf): return report -def render(filename, *fields): +def render(filename: str, *fields: Union[int, float, str]) -> str: """Renders out an individual template. This basically just reads in a template file, and applies ``.format()`` on the fields. @@ -697,8 +717,13 @@ def render(filename, *fields): def columns_equal( - col_1, col_2, rel_tol=0, abs_tol=0, ignore_spaces=False, ignore_case=False -): + col_1: "pd.Series[Any]", + col_2: "pd.Series[Any]", + rel_tol: float = 0, + abs_tol: float = 0, + ignore_spaces: bool = False, + ignore_case: bool = False, +) -> "pd.Series[bool]": """Compares two columns from a dataframe, returning a True/False series, with the same index as column 1. @@ -731,6 +756,7 @@ def columns_equal( A series of Boolean values. True == the values match, False == the values don't match. """ + compare: pd.Series[bool] try: compare = pd.Series( np.isclose(col_1, col_2, rtol=rel_tol, atol=abs_tol, equal_nan=True) @@ -773,7 +799,9 @@ def columns_equal( return compare -def compare_string_and_date_columns(col_1, col_2): +def compare_string_and_date_columns( + col_1: "pd.Series[Any]", col_2: "pd.Series[Any]" +) -> "pd.Series[bool]": """Compare a string column and date column, value-wise. This tries to convert a string column to a date column and compare that way. @@ -812,7 +840,9 @@ def compare_string_and_date_columns(col_1, col_2): return pd.Series(False, index=col_1.index) -def get_merged_columns(original_df, merged_df, suffix): +def get_merged_columns( + original_df: pd.DataFrame, merged_df: pd.DataFrame, suffix: str +) -> List[str]: """Gets the columns from an original dataframe, in the new merged dataframe Parameters @@ -836,7 +866,7 @@ def get_merged_columns(original_df, merged_df, suffix): return columns -def temp_column_name(*dataframes): +def temp_column_name(*dataframes: pd.DataFrame) -> str: """Gets a temp column name that isn't included in columns of any dataframes Parameters @@ -861,7 +891,7 @@ def temp_column_name(*dataframes): return temp_column -def calculate_max_diff(col_1, col_2): +def calculate_max_diff(col_1: "pd.Series[Any]", col_2: "pd.Series[Any]") -> float: """Get a maximum difference between two columns Parameters @@ -877,12 +907,14 @@ def calculate_max_diff(col_1, col_2): Numeric field, or zero. """ try: - return (col_1.astype(float) - col_2.astype(float)).abs().max() + return cast(float, (col_1.astype(float) - col_2.astype(float)).abs().max()) except: - return 0 + return 0.0 -def generate_id_within_group(dataframe, join_columns): +def generate_id_within_group( + dataframe: pd.DataFrame, join_columns: List[str] +) -> "pd.Series[int]": """Generate an ID column that can be used to deduplicate identical rows. The series generated is the order within a unique group, and it handles nulls. diff --git a/datacompy/fugue.py b/datacompy/fugue.py index 80038aa2..9a0109f4 100644 --- a/datacompy/fugue.py +++ b/datacompy/fugue.py @@ -20,7 +20,7 @@ import logging import pickle from collections import defaultdict -from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union +from typing import Any, Callable, Dict, Iterable, List, cast, Union, Optional, Tuple import fugue.api as fa import pandas as pd @@ -35,7 +35,7 @@ HASH_COL = "__datacompy__hash__" -def unq_columns(df1: AnyDataFrame, df2: AnyDataFrame): +def unq_columns(df1: AnyDataFrame, df2: AnyDataFrame) -> OrderedSet[str]: """Get columns that are unique to df1 Parameters @@ -53,10 +53,10 @@ def unq_columns(df1: AnyDataFrame, df2: AnyDataFrame): """ col1 = fa.get_column_names(df1) col2 = fa.get_column_names(df2) - return OrderedSet(col1) - OrderedSet(col2) + return cast(OrderedSet[str], OrderedSet(col1) - OrderedSet(col2)) -def intersect_columns(df1: AnyDataFrame, df2: AnyDataFrame): +def intersect_columns(df1: AnyDataFrame, df2: AnyDataFrame) -> OrderedSet[str]: """Get columns that are shared between the two dataframes Parameters @@ -77,7 +77,7 @@ def intersect_columns(df1: AnyDataFrame, df2: AnyDataFrame): return OrderedSet(col1) & OrderedSet(col2) -def all_columns_match(df1: AnyDataFrame, df2: AnyDataFrame): +def all_columns_match(df1: AnyDataFrame, df2: AnyDataFrame) -> bool: """Whether the columns all match in the dataframes Parameters @@ -302,9 +302,9 @@ def report( ignore_spaces: bool = False, ignore_case: bool = False, cast_column_names_lower: bool = True, - sample_count=10, - column_count=10, - html_file=None, + sample_count: int = 10, + column_count: int = 10, + html_file: Optional[str] = None, parallelism: Optional[int] = None, ) -> str: """Returns a string representation of a report. The representation can @@ -320,7 +320,7 @@ def report( First dataframe to check df2 : ``AnyDataFrame`` Second dataframe to check - join_columns : list or str, optional + join_columns : list or str Column(s) to join dataframes on. If a string is passed in, that one column will be used. abs_tol : float, optional @@ -406,7 +406,7 @@ def report( def shape0(col: str) -> int: return sum(x[col][0] for x in res) - def shape1(col: str) -> int: + def shape1(col: str) -> Any: return first[col][1] def _sum(col: str) -> int: @@ -454,6 +454,8 @@ def _any(col: str) -> int: "Yes" if _any("_any_dupes") else "No", ) + column_stats: List[Dict[str, Any]] + match_sample: List[pd.DataFrame] column_stats, match_sample = _aggregate_stats(res, sample_count=sample_count) any_mismatch = len(match_sample) > 0 @@ -673,7 +675,10 @@ def _deserialize( ) -> pd.DataFrame: arr = [pickle.loads(r["data"]) for r in df if r["left"] == left] if len(arr) > 0: - return pd.concat(arr).sort_values(schema.names).reset_index(drop=True) + return cast( + pd.DataFrame, + pd.concat(arr).sort_values(schema.names).reset_index(drop=True), + ) # The following is how to construct an empty pandas dataframe with # the correct schema, it avoids pandas schema inference which is wrong. # This is not needed when upgrading to Fugue >= 0.8.7 @@ -772,7 +777,7 @@ def _get_compare_result( def _aggregate_stats( - compares, sample_count + compares: List[Any], sample_count: int ) -> Tuple[List[Dict[str, Any]], List[pd.DataFrame]]: samples = defaultdict(list) stats = [] @@ -798,9 +803,16 @@ def _aggregate_stats( ) .reset_index(drop=False) ) - return df.to_dict(orient="records"), [ - _sample(pd.concat(v), sample_count=sample_count) for v in samples.values() - ] + return cast( + Tuple[List[Dict[str, Any]], List[pd.DataFrame]], + ( + df.to_dict(orient="records"), + [ + _sample(pd.concat(v), sample_count=sample_count) + for v in samples.values() + ], + ), + ) def _sample(df: pd.DataFrame, sample_count: int) -> pd.DataFrame: diff --git a/datacompy/py.typed b/datacompy/py.typed new file mode 100644 index 00000000..e69de29b diff --git a/datacompy/spark.py b/datacompy/spark.py index a285036b..53599bd0 100644 --- a/datacompy/spark.py +++ b/datacompy/spark.py @@ -13,12 +13,13 @@ # See the License for the specific language governing permissions and # limitations under the License. - import sys from enum import Enum from itertools import chain +from typing import Any, TextIO, List, Union, Tuple, Optional, Dict, Set try: + import pyspark from pyspark.sql import functions as F except ImportError: pass # Let non-Spark people at least enjoy the loveliness of the pandas datacompy functionality @@ -29,9 +30,9 @@ class MatchType(Enum): # Used for checking equality with decimal(X, Y) types. Otherwise treated as the string "decimal". -def decimal_comparator(): +def decimal_comparator() -> str: class DecimalComparator(str): - def __eq__(self, other): + def __eq__(self, other: str) -> bool: # type: ignore[override] return len(other) >= 7 and other[0:7] == "decimal" return DecimalComparator("decimal") @@ -48,7 +49,7 @@ def __eq__(self, other): ] -def _is_comparable(type1, type2): +def _is_comparable(type1: str, type2: str) -> bool: """Checks if two Spark data types can be safely compared. Two data types are considered comparable if any of the following apply: 1. Both data types are the same @@ -141,17 +142,17 @@ class SparkCompare: def __init__( self, - spark_session, - base_df, - compare_df, - join_columns, - column_mapping=None, - cache_intermediates=False, - known_differences=None, - rel_tol=0, - abs_tol=0, - show_all_columns=False, - match_rates=False, + spark_session: "pyspark.sql.SparkSession", + base_df: "pyspark.sql.DataFrame", + compare_df: "pyspark.sql.DataFrame", + join_columns: List[Union[str, Tuple[str, str]]], + column_mapping: Optional[List[Tuple[str, str]]] = None, + cache_intermediates: bool = False, + known_differences: Optional[List[Dict[str, Any]]] = None, + rel_tol: float = 0, + abs_tol: float = 0, + show_all_columns: bool = False, + match_rates: bool = False, ): self.rel_tol = rel_tol self.abs_tol = abs_tol @@ -164,7 +165,7 @@ def __init__( self._original_compare_df = compare_df self.cache_intermediates = cache_intermediates - self.join_columns = self._tuplizer(join_columns) + self.join_columns = self._tuplizer(input_list=join_columns) self._join_column_names = [name[0] for name in self.join_columns] self._known_differences = known_differences @@ -182,13 +183,15 @@ def __init__( self.spark = spark_session self.base_unq_rows = self.compare_unq_rows = None - self._base_row_count = self._compare_row_count = self._common_row_count = None - self._joined_dataframe = None - self._rows_only_base = None - self._rows_only_compare = None - self._all_matched_rows = None - self._all_rows_mismatched = None - self.columns_match_dict = {} + self._base_row_count: Optional[int] = None + self._compare_row_count: Optional[int] = None + self._common_row_count: Optional[int] = None + self._joined_dataframe: Optional["pyspark.sql.DataFrame"] = None + self._rows_only_base: Optional["pyspark.sql.DataFrame"] = None + self._rows_only_compare: Optional["pyspark.sql.DataFrame"] = None + self._all_matched_rows: Optional["pyspark.sql.DataFrame"] = None + self._all_rows_mismatched: Optional["pyspark.sql.DataFrame"] = None + self.columns_match_dict: Dict[str, Any] = {} # drop the duplicates before actual comparison made. self.base_df = base_df.dropDuplicates(self._join_column_names) @@ -200,8 +203,10 @@ def __init__( self.compare_df.cache() self._compare_row_count = self.compare_df.count() - def _tuplizer(self, input_list): - join_columns = [] + def _tuplizer( + self, input_list: List[Union[str, Tuple[str, str]]] + ) -> List[Tuple[str, str]]: + join_columns: List[Tuple[str, str]] = [] for val in input_list: if isinstance(val, str): join_columns.append((val, val)) @@ -211,12 +216,12 @@ def _tuplizer(self, input_list): return join_columns @property - def columns_in_both(self): + def columns_in_both(self) -> Set[str]: """set[str]: Get columns in both dataframes""" return set(self.base_df.columns) & set(self.compare_df.columns) @property - def columns_compared(self): + def columns_compared(self) -> List[str]: """list[str]: Get columns to be compared in both dataframes (all columns in both excluding the join key(s)""" return [ @@ -226,17 +231,17 @@ def columns_compared(self): ] @property - def columns_only_base(self): + def columns_only_base(self) -> Set[str]: """set[str]: Get columns that are unique to the base dataframe""" return set(self.base_df.columns) - set(self.compare_df.columns) @property - def columns_only_compare(self): + def columns_only_compare(self) -> Set[str]: """set[str]: Get columns that are unique to the compare dataframe""" return set(self.compare_df.columns) - set(self.base_df.columns) @property - def base_row_count(self): + def base_row_count(self) -> int: """int: Get the count of rows in the de-duped base dataframe""" if self._base_row_count is None: self._base_row_count = self.base_df.count() @@ -244,7 +249,7 @@ def base_row_count(self): return self._base_row_count @property - def compare_row_count(self): + def compare_row_count(self) -> int: """int: Get the count of rows in the de-duped compare dataframe""" if self._compare_row_count is None: self._compare_row_count = self.compare_df.count() @@ -252,7 +257,7 @@ def compare_row_count(self): return self._compare_row_count @property - def common_row_count(self): + def common_row_count(self) -> int: """int: Get the count of rows in common between base and compare dataframes""" if self._common_row_count is None: common_rows = self._get_or_create_joined_dataframe() @@ -260,19 +265,19 @@ def common_row_count(self): return self._common_row_count - def _get_unq_base_rows(self): + def _get_unq_base_rows(self) -> "pyspark.sql.DataFrame": """Get the rows only from base data frame""" return self.base_df.select(self._join_column_names).subtract( self.compare_df.select(self._join_column_names) ) - def _get_compare_rows(self): + def _get_compare_rows(self) -> "pyspark.sql.DataFrame": """Get the rows only from compare data frame""" return self.compare_df.select(self._join_column_names).subtract( self.base_df.select(self._join_column_names) ) - def _print_columns_summary(self, myfile): + def _print_columns_summary(self, myfile: TextIO) -> None: """Prints the column summary details""" print("\n****** Column Summary ******", file=myfile) print( @@ -292,7 +297,7 @@ def _print_columns_summary(self, myfile): file=myfile, ) - def _print_only_columns(self, base_or_compare, myfile): + def _print_only_columns(self, base_or_compare: str, myfile: TextIO) -> None: """Prints the columns and data types only in either the base or compare datasets""" if base_or_compare.upper() == "BASE": @@ -321,7 +326,7 @@ def _print_only_columns(self, base_or_compare, myfile): col_type = df.select(column).dtypes[0][1] print((format_pattern + " {:13s}").format(column, col_type), file=myfile) - def _columns_with_matching_schema(self): + def _columns_with_matching_schema(self) -> Dict[str, str]: """This function will identify the columns which has matching schema""" col_schema_match = {} base_columns_dict = dict(self.base_df.dtypes) @@ -329,12 +334,13 @@ def _columns_with_matching_schema(self): for base_row, base_type in base_columns_dict.items(): if base_row in compare_columns_dict: - if base_type in compare_columns_dict.get(base_row): - col_schema_match[base_row] = compare_columns_dict.get(base_row) + compare_column_type = compare_columns_dict.get(base_row) + if compare_column_type is not None and base_type in compare_column_type: + col_schema_match[base_row] = compare_column_type return col_schema_match - def _columns_with_schemadiff(self): + def _columns_with_schemadiff(self) -> Dict[str, Dict[str, str]]: """This function will identify the columns which has different schema""" col_schema_diff = {} base_columns_dict = dict(self.base_df.dtypes) @@ -342,15 +348,19 @@ def _columns_with_schemadiff(self): for base_row, base_type in base_columns_dict.items(): if base_row in compare_columns_dict: - if base_type not in compare_columns_dict.get(base_row): + compare_column_type = compare_columns_dict.get(base_row) + if ( + compare_column_type is not None + and base_type not in compare_column_type + ): col_schema_diff[base_row] = dict( base_type=base_type, - compare_type=compare_columns_dict.get(base_row), + compare_type=compare_column_type, ) return col_schema_diff @property - def rows_both_mismatch(self): + def rows_both_mismatch(self) -> Optional["pyspark.sql.DataFrame"]: """pyspark.sql.DataFrame: Returns all rows in both dataframes that have mismatches""" if self._all_rows_mismatched is None: self._merge_dataframes() @@ -358,7 +368,7 @@ def rows_both_mismatch(self): return self._all_rows_mismatched @property - def rows_both_all(self): + def rows_both_all(self) -> Optional["pyspark.sql.DataFrame"]: """pyspark.sql.DataFrame: Returns all rows in both dataframes""" if self._all_matched_rows is None: self._merge_dataframes() @@ -366,7 +376,7 @@ def rows_both_all(self): return self._all_matched_rows @property - def rows_only_base(self): + def rows_only_base(self) -> "pyspark.sql.DataFrame": """pyspark.sql.DataFrame: Returns rows only in the base dataframe""" if not self._rows_only_base: base_rows = self._get_unq_base_rows() @@ -386,7 +396,7 @@ def rows_only_base(self): return self._rows_only_base @property - def rows_only_compare(self): + def rows_only_compare(self) -> Optional["pyspark.sql.DataFrame"]: """pyspark.sql.DataFrame: Returns rows only in the compare dataframe""" if not self._rows_only_compare: compare_rows = self._get_compare_rows() @@ -407,7 +417,7 @@ def rows_only_compare(self): return self._rows_only_compare - def _generate_select_statement(self, match_data=True): + def _generate_select_statement(self, match_data: bool = True) -> str: """This function is to generate the select statement to be used later in the query.""" base_only = list(set(self.base_df.columns) - set(self.compare_df.columns)) compare_only = list(set(self.compare_df.columns) - set(self.base_df.columns)) @@ -440,7 +450,7 @@ def _generate_select_statement(self, match_data=True): return select_statement - def _merge_dataframes(self): + def _merge_dataframes(self) -> None: """Merges the two dataframes and creates self._all_matched_rows and self._all_rows_mismatched.""" full_joined_dataframe = self._get_or_create_joined_dataframe() full_joined_dataframe.createOrReplaceTempView("full_matched_table") @@ -449,9 +459,8 @@ def _merge_dataframes(self): select_query = """SELECT {} FROM full_matched_table A""".format( select_statement ) - self._all_matched_rows = self.spark.sql(select_query).orderBy( - self._join_column_names + self._join_column_names # type: ignore[arg-type] ) self._all_matched_rows.createOrReplaceTempView("matched_table") @@ -460,10 +469,10 @@ def _merge_dataframes(self): ) mismatch_query = """SELECT * FROM matched_table A WHERE {}""".format(where_cond) self._all_rows_mismatched = self.spark.sql(mismatch_query).orderBy( - self._join_column_names + self._join_column_names # type: ignore[arg-type] ) - def _get_or_create_joined_dataframe(self): + def _get_or_create_joined_dataframe(self) -> "pyspark.sql.DataFrame": if self._joined_dataframe is None: join_condition = " AND ".join( ["A." + name + "<=>B." + name for name in self._join_column_names] @@ -488,7 +497,7 @@ def _get_or_create_joined_dataframe(self): return self._joined_dataframe - def _print_num_of_rows_with_column_equality(self, myfile): + def _print_num_of_rows_with_column_equality(self, myfile: TextIO) -> None: # match_dataframe contains columns from both dataframes with flag to indicate if columns matched match_dataframe = self._get_or_create_joined_dataframe().select( *self.columns_compared @@ -507,7 +516,10 @@ def _print_num_of_rows_with_column_equality(self, myfile): ) ) all_rows_matched = self.spark.sql(match_query) - matched_rows = all_rows_matched.head()[0] + all_rows_matched_head = all_rows_matched.head() + matched_rows = ( + all_rows_matched_head[0] if all_rows_matched_head is not None else 0 + ) print("\n****** Row Comparison ******", file=myfile) print( @@ -516,7 +528,7 @@ def _print_num_of_rows_with_column_equality(self, myfile): ) print(f"Number of rows with all columns equal: {matched_rows}", file=myfile) - def _populate_columns_match_dict(self): + def _populate_columns_match_dict(self) -> None: """ side effects: columns_match_dict assigned to { column -> match_type_counts } @@ -531,7 +543,7 @@ def _populate_columns_match_dict(self): *self.columns_compared ) - def helper(c): + def helper(c: str) -> "pyspark.sql.Column": # Create a predicate for each match type, comparing column values to the match type value predicates = [F.col(c) == k.value for k in MatchType] # Create a tuple(number of match types found for each match type in this column) @@ -541,15 +553,15 @@ def helper(c): # For each column, create a single tuple. This tuple's values correspond to the number of times # each match type appears in that column - match_data = match_dataframe.agg( + match_data_agg = match_dataframe.agg( *[helper(col) for col in self.columns_compared] ).collect() - match_data = match_data[0] + match_data = match_data_agg[0] for c in self.columns_compared: self.columns_match_dict[c] = match_data[c] - def _create_select_statement(self, name): + def _create_select_statement(self, name: str) -> str: if self._known_differences: match_type_comparison = "" for k in MatchType: @@ -568,7 +580,7 @@ def _create_select_statement(self, name): name=name, match_failure=MatchType.MISMATCH.value ) - def _create_case_statement(self, name): + def _create_case_statement(self, name: str) -> str: equal_comparisons = ["(A.{name} IS NULL AND B.{name} IS NULL)"] known_diff_comparisons = ["(FALSE)"] @@ -622,7 +634,7 @@ def _create_case_statement(self, name): match_failure=MatchType.MISMATCH.value, ) - def _print_row_summary(self, myfile): + def _print_row_summary(self, myfile: TextIO) -> None: base_df_cnt = self.base_df.count() compare_df_cnt = self.compare_df.count() base_df_with_dup_cnt = self._original_base_df.count() @@ -647,7 +659,7 @@ def _print_row_summary(self, myfile): file=myfile, ) - def _print_schema_diff_details(self, myfile): + def _print_schema_diff_details(self, myfile: TextIO) -> None: schema_diff_dict = self._columns_with_schemadiff() if not schema_diff_dict: # If there are no differences, don't print the section @@ -691,7 +703,7 @@ def _print_schema_diff_details(self, myfile): file=myfile, ) - def _base_to_compare_name(self, base_name): + def _base_to_compare_name(self, base_name: str) -> str: """Translates a column name in the base dataframe to its counterpart in the compare dataframe, if they are different.""" @@ -703,7 +715,7 @@ def _base_to_compare_name(self, base_name): return name[1] return base_name - def _print_row_matches_by_column(self, myfile): + def _print_row_matches_by_column(self, myfile: TextIO) -> None: self._populate_columns_match_dict() columns_with_mismatches = { key: self.columns_match_dict[key] @@ -852,7 +864,7 @@ def _print_row_matches_by_column(self, myfile): print(format_pattern.format(*output_row), file=myfile) # noinspection PyUnresolvedReferences - def report(self, file=sys.stdout): + def report(self, file: TextIO = sys.stdout) -> None: """Creates a comparison report and prints it to the file specified (stdout by default). diff --git a/pyproject.toml b/pyproject.toml index ae11c20c..415849a3 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -74,6 +74,8 @@ qa = [ "pre-commit", "black", "isort", + "mypy", + "pandas-stubs", ] build = [ "build", @@ -102,6 +104,17 @@ use_parentheses = true line_length = 88 profile = "black" +[tool.mypy] +strict = true + +[[tool.mypy.overrides]] +module = ["fugue.*","triad.*"] +implicit_reexport = true + +[[tool.mypy.overrides]] +module = "pyarrow" +ignore_missing_imports = true + [edgetest.envs.core] python_version = "3.9" conda_install = ["openjdk=8"]