diff --git a/datacompy/spark.py b/datacompy/spark.py index 014897e7..035aa876 100644 --- a/datacompy/spark.py +++ b/datacompy/spark.py @@ -21,10 +21,6 @@ try: import pyspark from pyspark.sql import functions as F - - SparkSessionType = pyspark.sql.SparkSession - SparkDataFrame = pyspark.sql.DataFrame - except ImportError: pass # Let non-Spark people at least enjoy the loveliness of the pandas datacompy functionality @@ -146,9 +142,9 @@ class SparkCompare: def __init__( self, - spark_session: "SparkSessionType", - base_df: "SparkDataFrame", - compare_df: "SparkDataFrame", + spark_session: "pyspark.sql.SparkSession", + base_df: "pyspark.sql.DataFrame", + compare_df: "pyspark.sql.DataFrame", join_columns: List[Union[str, Tuple[str, str]]], column_mapping: Optional[List[Tuple[str, str]]] = None, cache_intermediates: bool = False, @@ -190,11 +186,11 @@ def __init__( self._base_row_count: Optional[int] = None self._compare_row_count: Optional[int] = None self._common_row_count: Optional[int] = None - self._joined_dataframe: Optional["SparkDataFrame"] = None - self._rows_only_base: Optional["SparkDataFrame"] = None - self._rows_only_compare: Optional["SparkDataFrame"] = None - self._all_matched_rows: Optional["SparkDataFrame"] = None - self._all_rows_mismatched: Optional["SparkDataFrame"] = None + self._joined_dataframe: Optional["pyspark.sql.DataFrame"] = None + self._rows_only_base: Optional["pyspark.sql.DataFrame"] = None + self._rows_only_compare: Optional["pyspark.sql.DataFrame"] = None + self._all_matched_rows: Optional["pyspark.sql.DataFrame"] = None + self._all_rows_mismatched: Optional["pyspark.sql.DataFrame"] = None self.columns_match_dict: Dict[str, Any] = {} # drop the duplicates before actual comparison made. @@ -269,13 +265,13 @@ def common_row_count(self) -> int: return self._common_row_count - def _get_unq_base_rows(self) -> "SparkDataFrame": + def _get_unq_base_rows(self) -> "pyspark.sql.DataFrame": """Get the rows only from base data frame""" return self.base_df.select(self._join_column_names).subtract( self.compare_df.select(self._join_column_names) ) - def _get_compare_rows(self) -> "SparkDataFrame": + def _get_compare_rows(self) -> "pyspark.sql.DataFrame": """Get the rows only from compare data frame""" return self.compare_df.select(self._join_column_names).subtract( self.base_df.select(self._join_column_names) @@ -364,7 +360,7 @@ def _columns_with_schemadiff(self) -> Dict[str, Dict[str, str]]: return col_schema_diff @property - def rows_both_mismatch(self) -> Optional["SparkDataFrame"]: + def rows_both_mismatch(self) -> Optional["pyspark.sql.DataFrame"]: """pyspark.sql.DataFrame: Returns all rows in both dataframes that have mismatches""" if self._all_rows_mismatched is None: self._merge_dataframes() @@ -372,7 +368,7 @@ def rows_both_mismatch(self) -> Optional["SparkDataFrame"]: return self._all_rows_mismatched @property - def rows_both_all(self) -> Optional["SparkDataFrame"]: + def rows_both_all(self) -> Optional["pyspark.sql.DataFrame"]: """pyspark.sql.DataFrame: Returns all rows in both dataframes""" if self._all_matched_rows is None: self._merge_dataframes() @@ -380,7 +376,7 @@ def rows_both_all(self) -> Optional["SparkDataFrame"]: return self._all_matched_rows @property - def rows_only_base(self) -> "SparkDataFrame": + def rows_only_base(self) -> "pyspark.sql.DataFrame": """pyspark.sql.DataFrame: Returns rows only in the base dataframe""" if not self._rows_only_base: base_rows = self._get_unq_base_rows() @@ -400,7 +396,7 @@ def rows_only_base(self) -> "SparkDataFrame": return self._rows_only_base @property - def rows_only_compare(self) -> Optional["SparkDataFrame"]: + def rows_only_compare(self) -> Optional["pyspark.sql.DataFrame"]: """pyspark.sql.DataFrame: Returns rows only in the compare dataframe""" if not self._rows_only_compare: compare_rows = self._get_compare_rows() @@ -476,7 +472,7 @@ def _merge_dataframes(self) -> None: self._join_column_names # type: ignore[arg-type] ) - def _get_or_create_joined_dataframe(self) -> "SparkDataFrame": + def _get_or_create_joined_dataframe(self) -> "pyspark.sql.DataFrame": if self._joined_dataframe is None: join_condition = " AND ".join( ["A." + name + "<=>B." + name for name in self._join_column_names]