Skip to content

Commit

Permalink
fixup! ✨ typing aliases
Browse files Browse the repository at this point in the history
  • Loading branch information
aguiddir committed Dec 7, 2023
1 parent 7bf5b0a commit 5b07c52
Showing 1 changed file with 15 additions and 19 deletions.
34 changes: 15 additions & 19 deletions datacompy/spark.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,6 @@
try:
import pyspark
from pyspark.sql import functions as F

SparkSessionType = pyspark.sql.SparkSession
SparkDataFrame = pyspark.sql.DataFrame

except ImportError:
pass # Let non-Spark people at least enjoy the loveliness of the pandas datacompy functionality

Expand Down Expand Up @@ -146,9 +142,9 @@ class SparkCompare:

def __init__(
self,
spark_session: "SparkSessionType",
base_df: "SparkDataFrame",
compare_df: "SparkDataFrame",
spark_session: "pyspark.sql.SparkSession",
base_df: "pyspark.sql.DataFrame",
compare_df: "pyspark.sql.DataFrame",
join_columns: List[Union[str, Tuple[str, str]]],
column_mapping: Optional[List[Tuple[str, str]]] = None,
cache_intermediates: bool = False,
Expand Down Expand Up @@ -190,11 +186,11 @@ def __init__(
self._base_row_count: Optional[int] = None
self._compare_row_count: Optional[int] = None
self._common_row_count: Optional[int] = None
self._joined_dataframe: Optional["SparkDataFrame"] = None
self._rows_only_base: Optional["SparkDataFrame"] = None
self._rows_only_compare: Optional["SparkDataFrame"] = None
self._all_matched_rows: Optional["SparkDataFrame"] = None
self._all_rows_mismatched: Optional["SparkDataFrame"] = None
self._joined_dataframe: Optional["pyspark.sql.DataFrame"] = None
self._rows_only_base: Optional["pyspark.sql.DataFrame"] = None
self._rows_only_compare: Optional["pyspark.sql.DataFrame"] = None
self._all_matched_rows: Optional["pyspark.sql.DataFrame"] = None
self._all_rows_mismatched: Optional["pyspark.sql.DataFrame"] = None
self.columns_match_dict: Dict[str, Any] = {}

# drop the duplicates before actual comparison made.
Expand Down Expand Up @@ -269,13 +265,13 @@ def common_row_count(self) -> int:

return self._common_row_count

def _get_unq_base_rows(self) -> "SparkDataFrame":
def _get_unq_base_rows(self) -> "pyspark.sql.DataFrame":
"""Get the rows only from base data frame"""
return self.base_df.select(self._join_column_names).subtract(
self.compare_df.select(self._join_column_names)
)

def _get_compare_rows(self) -> "SparkDataFrame":
def _get_compare_rows(self) -> "pyspark.sql.DataFrame":
"""Get the rows only from compare data frame"""
return self.compare_df.select(self._join_column_names).subtract(
self.base_df.select(self._join_column_names)
Expand Down Expand Up @@ -364,23 +360,23 @@ def _columns_with_schemadiff(self) -> Dict[str, Dict[str, str]]:
return col_schema_diff

@property
def rows_both_mismatch(self) -> Optional["SparkDataFrame"]:
def rows_both_mismatch(self) -> Optional["pyspark.sql.DataFrame"]:
"""pyspark.sql.DataFrame: Returns all rows in both dataframes that have mismatches"""
if self._all_rows_mismatched is None:
self._merge_dataframes()

return self._all_rows_mismatched

@property
def rows_both_all(self) -> Optional["SparkDataFrame"]:
def rows_both_all(self) -> Optional["pyspark.sql.DataFrame"]:
"""pyspark.sql.DataFrame: Returns all rows in both dataframes"""
if self._all_matched_rows is None:
self._merge_dataframes()

return self._all_matched_rows

@property
def rows_only_base(self) -> "SparkDataFrame":
def rows_only_base(self) -> "pyspark.sql.DataFrame":
"""pyspark.sql.DataFrame: Returns rows only in the base dataframe"""
if not self._rows_only_base:
base_rows = self._get_unq_base_rows()
Expand All @@ -400,7 +396,7 @@ def rows_only_base(self) -> "SparkDataFrame":
return self._rows_only_base

@property
def rows_only_compare(self) -> Optional["SparkDataFrame"]:
def rows_only_compare(self) -> Optional["pyspark.sql.DataFrame"]:
"""pyspark.sql.DataFrame: Returns rows only in the compare dataframe"""
if not self._rows_only_compare:
compare_rows = self._get_compare_rows()
Expand Down Expand Up @@ -476,7 +472,7 @@ def _merge_dataframes(self) -> None:
self._join_column_names # type: ignore[arg-type]
)

def _get_or_create_joined_dataframe(self) -> "SparkDataFrame":
def _get_or_create_joined_dataframe(self) -> "pyspark.sql.DataFrame":
if self._joined_dataframe is None:
join_condition = " AND ".join(
["A." + name + "<=>B." + name for name in self._join_column_names]
Expand Down

0 comments on commit 5b07c52

Please sign in to comment.