Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
Han Wang committed Mar 12, 2024
1 parent b7a811d commit 73785bb
Show file tree
Hide file tree
Showing 7 changed files with 167 additions and 6 deletions.
61 changes: 58 additions & 3 deletions datacompy/fsql.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,6 +204,25 @@ class SchemaCompareResult:
left_only: List[str]
right_only: List[str]

def are_equal(self, check_column_order: bool = True) -> bool:
if len(self.left_only) > 0 or len(self.right_only) > 0:
return False
if check_column_order:
return self.schema1.names == self.schema2.names
return True

def get_stats(self) -> Dict[str, Any]:
return {
"left_schema": str(self.schema1),
"right_schema": str(self.schema2),
"intersect_cols": ",".join(self.intersect_cols),
"join_cols": ",".join(self.join_cols),
"value_cols": ",".join(self.value_cols),
"common_col_count": len(self.value_cols),
"left_only_cols": ",".join(self.left_only),
"right_only_cols": ",".join(self.right_only),
}

def is_floating(self, name: str) -> bool:
tp = self.schema1[name].type
tp2 = self.schema2[name].type
Expand Down Expand Up @@ -234,11 +253,38 @@ class CompareResult:
df1_samples: Optional[pd.DataFrame]
df2_samples: Optional[pd.DataFrame]

def are_equal(self, check_column_order: bool = True) -> bool:
if not self.schema_compare.are_equal(check_column_order=check_column_order):
return False
diff = self.get_diff_summary()
return diff["diff"].sum() == 0 and diff["null_diff"].sum() == 0

def get_stats(self) -> Dict[str, Any]:
schema_stats = self.schema_compare.get_stats()
counts = self.get_row_counts()
row_diff_count = self.get_common_rows_diff_count()
row_equal_count = counts.get(3, 0) - row_diff_count
rows_stats = {
"left_only_row_count": counts.get(1, 0),
"right_only_row_count": counts.get(2, 0),
"common_row_count": counts.get(3, 0),
"common_row_diff_count": row_diff_count,
"common_row_equal_count": row_equal_count,
}
rows_stats.update(schema_stats)
return rows_stats

def get_row_counts(self) -> Dict[int, int]:
return (
self.raw_diff_summary.groupby(_SIDE_FLAG)[_TOTAL_COUNT_COL].sum().to_dict()
)

def get_common_rows_diff_count(self) -> int:
sub = self.raw_diff_summary[self.raw_diff_summary[_SIDE_FLAG] == 3]
if len(sub) == 0:
return 0
return int(sub[_ROW_DIFF_FLAG].sum())

def get_diff_summary(self) -> pd.DataFrame:
res: List[Dict[str, Any]] = []
df = self.raw_diff_summary[self.raw_diff_summary[_SIDE_FLAG] == 3]
Expand Down Expand Up @@ -289,7 +335,7 @@ def __init__(
abs_tol: float = 0,
rel_tol: float = 0,
) -> None:
assert rel_tol == 0, "Relative tolerance is not supported"
assert rel_tol >= 0, "Relative tolerance must be non-negative"
assert abs_tol >= 0, "Absolute tolerance must be non-negative"
self.abs_tol = abs_tol
self.rel_tol = rel_tol
Expand Down Expand Up @@ -429,8 +475,17 @@ def _gen_col_eq(self, name: str) -> Iterable[str]:
_no_null = f"({_fa} IS NOT NULL AND {_fb} IS NOT NULL)"
if pa.types.is_string(tp):
c = f"{_fa} = {_fb}"
elif self.abs_tol > 0 and is_floating:
c = f"({_fa} - {_fb} <= {self.abs_tol} AND {_fa} - {_fb} >= -1 * {self.abs_tol})"
elif (self.abs_tol > 0 or self.rel_tol > 0) and is_floating:
# https://numpy.org/doc/stable/reference/generated/numpy.isclose.html
# absolute(a - b) <= (atol + rtol * absolute(b))
# c = f"ABS({_fa}-{_fb}) <= {self.abs_tol}+{_fb}*ABS({self.rel_tol})"
a_b_sq = f"({_fa}-{_fb})*({_fa}-{_fb})"
diff_pos = f"({self.abs_tol}+{_fb}*{self.rel_tol})"
diff_neg = f"({self.abs_tol}-{_fb}*{self.rel_tol})"
c = (
f"CASE WHEN {_fb}>0 THEN {a_b_sq} < {diff_pos}*{diff_pos} "
f"ELSE {a_b_sq} < {diff_neg}*{diff_neg} END"
)
else:
c = f"{_fa} = {_fb}"
diff_col = _quote_name(_DIFF_PREFIX + name)
Expand Down
2 changes: 2 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,8 @@ docs = [
tests = [
"pytest",
"pytest-cov",
"pytest-benchmark",
"fugue[cpp_sql_parser]==0.9.0.dev2",
]

tests-spark = [
Expand Down
3 changes: 2 additions & 1 deletion pytest.ini
Original file line number Diff line number Diff line change
Expand Up @@ -6,4 +6,5 @@ spark_options =
spark.default.parallelism: 4
spark.executor.cores: 4
spark.sql.execution.arrow.pyspark.enabled: true
spark.sql.adaptive.enabled: false
spark.sql.adaptive.enabled: false
addopts = --ignore-glob=tests/test_perf*
16 changes: 14 additions & 2 deletions tests/test_fugue/test_fsql.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,8 @@ def test_compare_schemas() -> None:
s1 = Schema("a:int,b:str")
s2 = Schema("b:str,a:int")
comp = compare_schemas(s1, s2, "a")
assert not comp.are_equal()
assert comp.are_equal(check_column_order=False)
assert comp.intersect_cols == ["a", "b"]
assert comp.left_only == []
assert comp.right_only == []
Expand All @@ -84,6 +86,8 @@ def test_compare_schemas() -> None:
s1 = Schema("a:long,b:int,d:double")
s2 = Schema("c:str,a:int16,b:double")
comp = compare_schemas(s1, s2, "a")
assert not comp.are_equal()
assert not comp.are_equal(check_column_order=False)
assert comp.intersect_cols == ["a", "b"]
assert comp.left_only == ["d"]
assert comp.right_only == ["c"]
Expand Down Expand Up @@ -152,6 +156,7 @@ def test_same_data(self) -> None:
assert len(res.get_unique_samples(1)) == 0
assert len(res.get_unique_samples(2)) == 0
assert len(res.get_unequal_samples()[0]) == 0
assert res.are_equal()

def test_overlap(self) -> None:
df1 = self.to_df(
Expand Down Expand Up @@ -195,8 +200,8 @@ def test_overlap_with_close_numbers(self) -> None:
],
"id:int,a:double",
)
for abs_tol in [0, 0.01]:
res = compare(df1, df2, "id", abs_tol=abs_tol)
for rel_tol, abs_tol in [(0, 0), (0.005, 0.005)]:
res = compare(df1, df2, "id", abs_tol=abs_tol, rel_tol=rel_tol)
assert res.get_row_counts() == {3: 2}
assert len(res.get_unequal_samples()[0]) == 2
assert len(res.get_unequal_samples()[1]) == 2
Expand All @@ -208,6 +213,12 @@ def test_overlap_with_close_numbers(self) -> None:
assert len(res.get_unequal_samples()[1]) == 0
diff = res.get_diff_summary().groupby("column")["max_diff"].sum().to_dict()
assert abs(diff["a"] - 0.02) < 1e-5
res = compare(df1, df2, "id", rel_tol=0.03)
assert res.get_row_counts() == {3: 2}
assert len(res.get_unequal_samples()[0]) == 0
assert len(res.get_unequal_samples()[1]) == 0
diff = res.get_diff_summary().groupby("column")["max_diff"].sum().to_dict()
assert abs(diff["a"] - 0.02) < 1e-5


@ft.fugue_test_suite("pandas", mark_test=True)
Expand All @@ -226,5 +237,6 @@ class DuckDBCompareTests(CompareTests):
@ft.fugue_test_suite("ray", mark_test=True)
class RayCompareTests(CompareTests):
pass

except ImportError:
pass
Empty file added tests/test_perf/__init__.py
Empty file.
49 changes: 49 additions & 0 deletions tests/test_perf/_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
import pandas as pd
import numpy as np
import os
from typing import Tuple


def generate_dfs(size: int) -> Tuple[pd.DataFrame, pd.DataFrame]:
np.random.seed(0)

base = pd.DataFrame(
dict(
id=np.arange(0, size),
a=np.random.randint(0, 10, size),
b=np.random.rand(size),
c=np.random.choice(["aaa", "bbb", "ccc"], size),
d=np.random.randint(0, 10, size),
e=np.random.rand(size),
f=np.random.choice(["aaa", "bbb", "ccc"], size),
g=np.random.randint(0, 10, size),
h=np.random.rand(size),
i=np.random.choice(["aaa", "bbb", "ccc"], size),
)
)

compare = pd.DataFrame(
dict(
id=np.arange(0, size),
d=np.random.randint(0, 10, size),
e=np.random.rand(size),
f=np.random.choice(["aaa", "bbb", "ccc"], size),
g=np.random.randint(0, 10, size),
h=np.random.rand(size),
i=np.random.choice(["aaa", "bbb", "ccc"], size),
j=np.random.randint(0, 10, size),
k=np.random.rand(size),
l=np.random.choice(["aaa", "bbb", "ccc"], size),
)
)

return base, compare


def generate_files(size: int, folder: str) -> Tuple[str, str]:
base, compare = generate_dfs(size)
base_file = os.path.join(folder, "base.parquet")
compare_file = os.path.join(folder, "compare.parquet")
base.to_parquet(base_file, index=False)
compare.to_parquet(compare_file, index=False)
return base_file, compare_file
42 changes: 42 additions & 0 deletions tests/test_perf/test_pandas.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
from ._utils import generate_dfs
from datacompy import Compare
from datacompy.fsql import compare as fsql_compare
import pytest


@pytest.fixture(params=[1000, 100000])
def paired_dfs(benchmark, request):
benchmark.name = "test - %s" % request.param
return generate_dfs(request.param)


def v1_df_pandas_perf(base, compare):
compare = Compare(base, compare, ["id"])
return compare.report()


def v2_df_fsql_perf(base, compare):
res= fsql_compare(base, compare, "id")
return res


def _test_v1_pandas_perf(benchmark, paired_dfs):
base, compare = paired_dfs
benchmark.pedantic(
v1_df_pandas_perf,
args=(base, compare),
iterations=1,
rounds=10,
warmup_rounds=2,
)


def test_v2_df_fsql_perf(benchmark, paired_dfs):
base, compare = paired_dfs
benchmark.pedantic(
v2_df_fsql_perf,
args=(base, compare),
iterations=1,
rounds=10,
warmup_rounds=2,
)

0 comments on commit 73785bb

Please sign in to comment.