Skip to content

Commit

Permalink
fixing legacy uncode column names
Browse files Browse the repository at this point in the history
  • Loading branch information
Faisal Dosani committed Mar 18, 2024
1 parent 8a3d852 commit 3b8ddd0
Show file tree
Hide file tree
Showing 2 changed files with 49 additions and 20 deletions.
61 changes: 41 additions & 20 deletions datacompy/legacy.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,11 @@
pass # Let non-Spark people at least enjoy the loveliness of the pandas datacompy functionality


warn(f"The module {__name__} is deprecated.", DeprecationWarning, stacklevel=2)
warn(
f"The module {__name__} is deprecated. In future versions (0.12.0 and above) SparkCompare will be refactored and the legacy logic will move to LegacySparkCompare ",
DeprecationWarning,
stacklevel=2,
)


class MatchType(Enum):
Expand Down Expand Up @@ -387,7 +391,10 @@ def rows_only_base(self) -> "pyspark.sql.DataFrame":
base_rows.createOrReplaceTempView("baseRows")
self.base_df.createOrReplaceTempView("baseTable")
join_condition = " AND ".join(
["A." + name + "<=>B." + name for name in self._join_column_names]
[
"A.`" + name + "`<=>B.`" + name + "`"
for name in self._join_column_names
]
)
sql_query = "select A.* from baseTable as A, baseRows as B where {}".format(
join_condition
Expand All @@ -407,7 +414,10 @@ def rows_only_compare(self) -> Optional["pyspark.sql.DataFrame"]:
compare_rows.createOrReplaceTempView("compareRows")
self.compare_df.createOrReplaceTempView("compareTable")
where_condition = " AND ".join(
["A." + name + "<=>B." + name for name in self._join_column_names]
[
"A.`" + name + "`<=>B.`" + name + "`"
for name in self._join_column_names
]
)
sql_query = (
"select A.* from compareTable as A, compareRows as B where {}".format(
Expand Down Expand Up @@ -439,15 +449,23 @@ def _generate_select_statement(self, match_data: bool = True) -> str:
[self._create_select_statement(name=column_name)]
)
elif column_name in base_only:
select_statement = select_statement + ",".join(["A." + column_name])
select_statement = select_statement + ",".join(
["A.`" + column_name + "`"]
)

elif column_name in compare_only:
if match_data:
select_statement = select_statement + ",".join(["B." + column_name])
select_statement = select_statement + ",".join(
["B.`" + column_name + "`"]
)
else:
select_statement = select_statement + ",".join(["A." + column_name])
select_statement = select_statement + ",".join(
["A.`" + column_name + "`"]
)
elif column_name in self._join_column_names:
select_statement = select_statement + ",".join(["A." + column_name])
select_statement = select_statement + ",".join(
["A.`" + column_name + "`"]
)

if column_name != sorted_list[-1]:
select_statement = select_statement + " , "
Expand All @@ -469,7 +487,7 @@ def _merge_dataframes(self) -> None:
self._all_matched_rows.createOrReplaceTempView("matched_table")

where_cond = " OR ".join(
["A." + name + "_match= False" for name in self.columns_compared]
["A.`" + name + "_match`= False" for name in self.columns_compared]
)
mismatch_query = """SELECT * FROM matched_table A WHERE {}""".format(where_cond)
self._all_rows_mismatched = self.spark.sql(mismatch_query).orderBy(
Expand All @@ -479,7 +497,10 @@ def _merge_dataframes(self) -> None:
def _get_or_create_joined_dataframe(self) -> "pyspark.sql.DataFrame":
if self._joined_dataframe is None:
join_condition = " AND ".join(
["A." + name + "<=>B." + name for name in self._join_column_names]
[
"A.`" + name + "`<=>B.`" + name + "`"
for name in self._join_column_names
]
)
select_statement = self._generate_select_statement(match_data=True)

Expand Down Expand Up @@ -570,22 +591,22 @@ def _create_select_statement(self, name: str) -> str:
match_type_comparison = ""
for k in MatchType:
match_type_comparison += (
" WHEN (A.{name}={match_value}) THEN '{match_name}'".format(
" WHEN (A.`{name}`={match_value}) THEN '{match_name}'".format(
name=name, match_value=str(k.value), match_name=k.name
)
)
return "A.{name}_base, A.{name}_compare, (CASE WHEN (A.{name}={match_failure}) THEN False ELSE True END) AS {name}_match, (CASE {match_type_comparison} ELSE 'UNDEFINED' END) AS {name}_match_type ".format(
return "A.`{name}_base`, A.`{name}_compare`, (CASE WHEN (A.`{name}`={match_failure}) THEN False ELSE True END) AS `{name}_match`, (CASE {match_type_comparison} ELSE 'UNDEFINED' END) AS `{name}_match_type` ".format(
name=name,
match_failure=MatchType.MISMATCH.value,
match_type_comparison=match_type_comparison,
)
else:
return "A.{name}_base, A.{name}_compare, CASE WHEN (A.{name}={match_failure}) THEN False ELSE True END AS {name}_match ".format(
return "A.`{name}_base`, A.`{name}_compare`, CASE WHEN (A.`{name}`={match_failure}) THEN False ELSE True END AS `{name}_match` ".format(
name=name, match_failure=MatchType.MISMATCH.value
)

def _create_case_statement(self, name: str) -> str:
equal_comparisons = ["(A.{name} IS NULL AND B.{name} IS NULL)"]
equal_comparisons = ["(A.`{name}` IS NULL AND B.`{name}` IS NULL)"]
known_diff_comparisons = ["(FALSE)"]

base_dtype = [d[1] for d in self.base_df.dtypes if d[0] == name][0]
Expand All @@ -596,30 +617,30 @@ def _create_case_statement(self, name: str) -> str:
compare_dtype in NUMERIC_SPARK_TYPES
): # numeric tolerance comparison
equal_comparisons.append(
"((A.{name}=B.{name}) OR ((abs(A.{name}-B.{name}))<=("
"((A.`{name}`=B.`{name}`) OR ((abs(A.`{name}`-B.`{name}`))<=("
+ str(self.abs_tol)
+ "+("
+ str(self.rel_tol)
+ "*abs(A.{name})))))"
+ "*abs(A.`{name}`)))))"
)
else: # non-numeric comparison
equal_comparisons.append("((A.{name}=B.{name}))")
equal_comparisons.append("((A.`{name}`=B.`{name}`))")

if self._known_differences:
new_input = "B.{name}"
new_input = "B.`{name}`"
for kd in self._known_differences:
if compare_dtype in kd["types"]:
if "flags" in kd and "nullcheck" in kd["flags"]:
known_diff_comparisons.append(
"(("
+ kd["transformation"].format(new_input, input=new_input)
+ ") is null AND A.{name} is null)"
+ ") is null AND A.`{name}` is null)"
)
else:
known_diff_comparisons.append(
"(("
+ kd["transformation"].format(new_input, input=new_input)
+ ") = A.{name})"
+ ") = A.`{name}`)"
)

case_string = (
Expand All @@ -628,7 +649,7 @@ def _create_case_statement(self, name: str) -> str:
+ ") THEN {match_success} WHEN ("
+ " OR ".join(known_diff_comparisons)
+ ") THEN {match_known_difference} ELSE {match_failure} END) "
+ "AS {name}, A.{name} AS {name}_base, B.{name} AS {name}_compare"
+ "AS `{name}`, A.`{name}` AS `{name}_base`, B.`{name}` AS `{name}_compare`"
)

return case_string.format(
Expand Down
8 changes: 8 additions & 0 deletions tests/test_legacy_spark.py
Original file line number Diff line number Diff line change
Expand Up @@ -2087,3 +2087,11 @@ def text_alignment_validator(

if not at_column_section and section_start in line:
at_column_section = True


def test_unicode_columns(spark_session):
df1 = spark_session.createDataFrame([{"a": 1, "例": 2}, {"a": 1, "例": 3}])
df2 = spark_session.createDataFrame([{"a": 1, "例": 2}, {"a": 1, "例": 3}])
compare = LegacySparkCompare(spark_session, df1, df2, join_columns=["例"])
# Just render the report to make sure it renders.
compare.report()

0 comments on commit 3b8ddd0

Please sign in to comment.