Skip to content

Commit

Permalink
🐛 Ensure compatibility with Python 3.8
Browse files Browse the repository at this point in the history
Co-authored-by: Yoann Isaac <yoann.isaac@vidal.fr>
  • Loading branch information
aguiddir and Jin66 committed Dec 7, 2023
1 parent 0b33f07 commit 166f7b5
Show file tree
Hide file tree
Showing 4 changed files with 62 additions and 58 deletions.
18 changes: 9 additions & 9 deletions datacompy/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
"""
import logging
import os
from typing import cast, Any
from typing import cast, Any, List, Dict, Union, Optional

import numpy as np
import pandas as pd
Expand Down Expand Up @@ -81,7 +81,7 @@ def __init__(
self,
df1: pd.DataFrame,
df2: pd.DataFrame,
join_columns: list[str] | str | None = None,
join_columns: Optional[Union[List[str], str]] = None,
on_index: bool = False,
abs_tol: float = 0,
rel_tol: float = 0,
Expand All @@ -107,7 +107,7 @@ def __init__(
else:
self.join_columns = [
str(col).lower() if self.cast_column_names_lower else str(col)
for col in cast(list[str], join_columns)
for col in cast(List[str], join_columns)
]
self.on_index = False

Expand All @@ -123,7 +123,7 @@ def __init__(
self.df1_unq_rows: pd.DataFrame
self.df2_unq_rows: pd.DataFrame
self.intersect_rows: pd.DataFrame
self.column_stats: list[dict[str, Any]] = []
self.column_stats: List[Dict[str, Any]] = []
self._compare(ignore_spaces=ignore_spaces, ignore_case=ignore_case)

@property
Expand Down Expand Up @@ -243,7 +243,7 @@ def _dataframe_merge(self, ignore_spaces: bool) -> None:
If ``on_index`` is True, this will join on index values, otherwise it
will join on the ``join_columns``.
"""
params: dict[str, Any]
params: Dict[str, Any]
index_column: str
LOG.debug("Outer joining")
if self._any_dupes:
Expand Down Expand Up @@ -533,7 +533,7 @@ def report(
self,
sample_count: int = 10,
column_count: int = 10,
html_file: str | None = None,
html_file: Optional[str] = None,
) -> str:
"""Returns a string representation of a report. The representation can
then be printed or saved to a file.
Expand Down Expand Up @@ -696,7 +696,7 @@ def df_to_str(pdf: pd.DataFrame) -> str:
return report


def render(filename: str, *fields: int | float | str) -> str:
def render(filename: str, *fields: Union[int, float, str]) -> str:
"""Renders out an individual template. This basically just reads in a
template file, and applies ``.format()`` on the fields.
Expand Down Expand Up @@ -844,7 +844,7 @@ def compare_string_and_date_columns(

def get_merged_columns(
original_df: pd.DataFrame, merged_df: pd.DataFrame, suffix: str
) -> list[str]:
) -> List[str]:
"""Gets the columns from an original dataframe, in the new merged dataframe
Parameters
Expand Down Expand Up @@ -915,7 +915,7 @@ def calculate_max_diff(col_1: "pd.Series[Any]", col_2: "pd.Series[Any]") -> floa


def generate_id_within_group(
dataframe: pd.DataFrame, join_columns: list[str]
dataframe: pd.DataFrame, join_columns: List[str]
) -> "pd.Series[int]":
"""Generate an ID column that can be used to deduplicate identical rows. The series generated
is the order within a unique group, and it handles nulls.
Expand Down
34 changes: 17 additions & 17 deletions datacompy/fugue.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
import logging
import pickle
from collections import defaultdict
from typing import Any, Callable, Dict, Iterable, List, cast, Union, Optional
from typing import Any, Callable, Dict, Iterable, List, cast, Union, Optional, Tuple

import fugue.api as fa
import pandas as pd
Expand Down Expand Up @@ -99,15 +99,15 @@ def all_columns_match(df1: AnyDataFrame, df2: AnyDataFrame) -> bool:
def is_match(
df1: AnyDataFrame,
df2: AnyDataFrame,
join_columns: str | list[str],
join_columns: Union[str, List[str]],
abs_tol: float = 0,
rel_tol: float = 0,
df1_name: str = "df1",
df2_name: str = "df2",
ignore_spaces: bool = False,
ignore_case: bool = False,
cast_column_names_lower: bool = True,
parallelism: int | None = None,
parallelism: Optional[int] = None,
strict_schema: bool = False,
) -> bool:
"""Check whether two dataframes match.
Expand Down Expand Up @@ -198,15 +198,15 @@ def is_match(
def all_rows_overlap(
df1: AnyDataFrame,
df2: AnyDataFrame,
join_columns: str | list[str],
join_columns: Union[str, List[str]],
abs_tol: float = 0,
rel_tol: float = 0,
df1_name: str = "df1",
df2_name: str = "df2",
ignore_spaces: bool = False,
ignore_case: bool = False,
cast_column_names_lower: bool = True,
parallelism: int | None = None,
parallelism: Optional[int] = None,
strict_schema: bool = False,
) -> bool:
"""Check if the rows are all present in both dataframes
Expand Down Expand Up @@ -294,7 +294,7 @@ def all_rows_overlap(
def report(
df1: AnyDataFrame,
df2: AnyDataFrame,
join_columns: str | list[str],
join_columns: Union[str, List[str]],
abs_tol: float = 0,
rel_tol: float = 0,
df1_name: str = "df1",
Expand All @@ -304,8 +304,8 @@ def report(
cast_column_names_lower: bool = True,
sample_count: int = 10,
column_count: int = 10,
html_file: str | None = None,
parallelism: int | None = None,
html_file: Optional[str] = None,
parallelism: Optional[int] = None,
) -> str:
"""Returns a string representation of a report. The representation can
then be printed or saved to a file.
Expand All @@ -320,7 +320,7 @@ def report(
First dataframe to check
df2 : ``AnyDataFrame``
Second dataframe to check
join_columns : list or str, optional
join_columns : list or str
Column(s) to join dataframes on. If a string is passed in, that one
column will be used.
abs_tol : float, optional
Expand Down Expand Up @@ -454,8 +454,8 @@ def _any(col: str) -> int:
"Yes" if _any("_any_dupes") else "No",
)

column_stats: list[dict[str, Any]]
match_sample: list[pd.DataFrame]
column_stats: List[Dict[str, Any]]
match_sample: List[pd.DataFrame]
column_stats, match_sample = _aggregate_stats(res, sample_count=sample_count)
any_mismatch = len(match_sample) > 0

Expand Down Expand Up @@ -671,7 +671,7 @@ def _serialize(dfs: Iterable[pd.DataFrame], left: bool) -> Iterable[Dict[str, An
)

def _deserialize(
df: list[dict[str, Any]], left: bool, schema: Schema
df: List[Dict[str, Any]], left: bool, schema: Schema
) -> pd.DataFrame:
arr = [pickle.loads(r["data"]) for r in df if r["left"] == left]
if len(arr) > 0:
Expand All @@ -682,7 +682,7 @@ def _deserialize(
# The following is how to construct an empty pandas dataframe with
# the correct schema, it avoids pandas schema inference which is wrong.
# This is not needed when upgrading to Fugue >= 0.8.7
sample_row: list[Any] = []
sample_row: List[Any] = []
for field in schema.fields:
if pa.types.is_string(field.type):
sample_row.append("x")
Expand Down Expand Up @@ -728,7 +728,7 @@ def _comp(df: List[Dict[str, Any]]) -> List[List[Any]]:

def _get_compare_result(
compare: Compare, sample_count: int, column_count: int
) -> dict[str, Any]:
) -> Dict[str, Any]:
mismatch_samples: Dict[str, pd.DataFrame] = {}
for column in compare.column_stats:
if not column["all_match"]:
Expand Down Expand Up @@ -777,8 +777,8 @@ def _get_compare_result(


def _aggregate_stats(
compares: list[Any], sample_count: int
) -> tuple[list[dict[str, Any]], list[pd.DataFrame]]:
compares: List[Any], sample_count: int
) -> Tuple[List[Dict[str, Any]], List[pd.DataFrame]]:
samples = defaultdict(list)
stats = []
for compare in compares:
Expand All @@ -804,7 +804,7 @@ def _aggregate_stats(
.reset_index(drop=False)
)
return cast(
tuple[list[dict[str, Any]], list[pd.DataFrame]],
Tuple[List[Dict[str, Any]], List[pd.DataFrame]],
(
df.to_dict(orient="records"),
[
Expand Down
64 changes: 32 additions & 32 deletions datacompy/spark.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
import sys
from enum import Enum
from itertools import chain
from typing import Any, TextIO, TypeAlias
from typing import Any, TextIO, NewType, List, Union, Tuple, Optional, Dict, Set

import pyspark

Expand All @@ -25,7 +25,7 @@
except ImportError:
pass # Let non-Spark people at least enjoy the loveliness of the pandas datacompy functionality

SparkDataFrame: TypeAlias = pyspark.sql.DataFrame
SparkDataFrame = pyspark.sql.DataFrame


class MatchType(Enum):
Expand Down Expand Up @@ -146,12 +146,12 @@ class SparkCompare:
def __init__(
self,
spark_session: pyspark.sql.SparkSession,
base_df: pyspark.sql.DataFrame,
compare_df: pyspark.sql.DataFrame,
join_columns: list[str | tuple[str, str]],
column_mapping: list[tuple[str, str]] | None = None,
base_df: SparkDataFrame,
compare_df: SparkDataFrame,
join_columns: List[Union[str, Tuple[str, str]]],
column_mapping: Optional[List[Tuple[str, str]]] = None,
cache_intermediates: bool = False,
known_differences: list[dict[str, Any]] | None = None,
known_differences: Optional[List[Dict[str, Any]]] = None,
rel_tol: float = 0,
abs_tol: float = 0,
show_all_columns: bool = False,
Expand Down Expand Up @@ -186,15 +186,15 @@ def __init__(

self.spark = spark_session
self.base_unq_rows = self.compare_unq_rows = None
self._base_row_count: int | None = None
self._compare_row_count: int | None = None
self._common_row_count: int | None = None
self._joined_dataframe: SparkDataFrame | None = None
self._rows_only_base: SparkDataFrame | None = None
self._rows_only_compare: SparkDataFrame | None = None
self._all_matched_rows: SparkDataFrame | None = None
self._all_rows_mismatched: SparkDataFrame | None = None
self.columns_match_dict: dict[str, Any] = {}
self._base_row_count: Optional[int] = None
self._compare_row_count: Optional[int] = None
self._common_row_count: Optional[int] = None
self._joined_dataframe: Optional[SparkDataFrame] = None
self._rows_only_base: Optional[SparkDataFrame] = None
self._rows_only_compare: Optional[SparkDataFrame] = None
self._all_matched_rows: Optional[SparkDataFrame] = None
self._all_rows_mismatched: Optional[SparkDataFrame] = None
self.columns_match_dict: Dict[str, Any] = {}

# drop the duplicates before actual comparison made.
self.base_df = base_df.dropDuplicates(self._join_column_names)
Expand All @@ -207,9 +207,9 @@ def __init__(
self._compare_row_count = self.compare_df.count()

def _tuplizer(
self, input_list: list[str | tuple[str, str]]
) -> list[tuple[str, str]]:
join_columns: list[tuple[str, str]] = []
self, input_list: List[Union[str, Tuple[str, str]]]
) -> List[Tuple[str, str]]:
join_columns: List[Tuple[str, str]] = []
for val in input_list:
if isinstance(val, str):
join_columns.append((val, val))
Expand All @@ -219,12 +219,12 @@ def _tuplizer(
return join_columns

@property
def columns_in_both(self) -> set[str]:
def columns_in_both(self) -> Set[str]:
"""set[str]: Get columns in both dataframes"""
return set(self.base_df.columns) & set(self.compare_df.columns)

@property
def columns_compared(self) -> list[str]:
def columns_compared(self) -> List[str]:
"""list[str]: Get columns to be compared in both dataframes (all
columns in both excluding the join key(s)"""
return [
Expand All @@ -234,12 +234,12 @@ def columns_compared(self) -> list[str]:
]

@property
def columns_only_base(self) -> set[str]:
def columns_only_base(self) -> Set[str]:
"""set[str]: Get columns that are unique to the base dataframe"""
return set(self.base_df.columns) - set(self.compare_df.columns)

@property
def columns_only_compare(self) -> set[str]:
def columns_only_compare(self) -> Set[str]:
"""set[str]: Get columns that are unique to the compare dataframe"""
return set(self.compare_df.columns) - set(self.base_df.columns)

Expand Down Expand Up @@ -268,13 +268,13 @@ def common_row_count(self) -> int:

return self._common_row_count

def _get_unq_base_rows(self) -> pyspark.sql.DataFrame:
def _get_unq_base_rows(self) -> SparkDataFrame:
"""Get the rows only from base data frame"""
return self.base_df.select(self._join_column_names).subtract(
self.compare_df.select(self._join_column_names)
)

def _get_compare_rows(self) -> pyspark.sql.DataFrame:
def _get_compare_rows(self) -> SparkDataFrame:
"""Get the rows only from compare data frame"""
return self.compare_df.select(self._join_column_names).subtract(
self.base_df.select(self._join_column_names)
Expand Down Expand Up @@ -329,7 +329,7 @@ def _print_only_columns(self, base_or_compare: str, myfile: TextIO) -> None:
col_type = df.select(column).dtypes[0][1]
print((format_pattern + " {:13s}").format(column, col_type), file=myfile)

def _columns_with_matching_schema(self) -> dict[str, str]:
def _columns_with_matching_schema(self) -> Dict[str, str]:
"""This function will identify the columns which has matching schema"""
col_schema_match = {}
base_columns_dict = dict(self.base_df.dtypes)
Expand All @@ -343,7 +343,7 @@ def _columns_with_matching_schema(self) -> dict[str, str]:

return col_schema_match

def _columns_with_schemadiff(self) -> dict[str, dict[str, str]]:
def _columns_with_schemadiff(self) -> Dict[str, Dict[str, str]]:
"""This function will identify the columns which has different schema"""
col_schema_diff = {}
base_columns_dict = dict(self.base_df.dtypes)
Expand All @@ -363,23 +363,23 @@ def _columns_with_schemadiff(self) -> dict[str, dict[str, str]]:
return col_schema_diff

@property
def rows_both_mismatch(self) -> SparkDataFrame | None:
def rows_both_mismatch(self) -> Optional[SparkDataFrame]:
"""pyspark.sql.DataFrame: Returns all rows in both dataframes that have mismatches"""
if self._all_rows_mismatched is None:
self._merge_dataframes()

return self._all_rows_mismatched

@property
def rows_both_all(self) -> SparkDataFrame | None:
def rows_both_all(self) -> Optional[SparkDataFrame]:
"""pyspark.sql.DataFrame: Returns all rows in both dataframes"""
if self._all_matched_rows is None:
self._merge_dataframes()

return self._all_matched_rows

@property
def rows_only_base(self) -> SparkDataFrame | None:
def rows_only_base(self) -> SparkDataFrame:
"""pyspark.sql.DataFrame: Returns rows only in the base dataframe"""
if not self._rows_only_base:
base_rows = self._get_unq_base_rows()
Expand All @@ -399,7 +399,7 @@ def rows_only_base(self) -> SparkDataFrame | None:
return self._rows_only_base

@property
def rows_only_compare(self) -> SparkDataFrame | None:
def rows_only_compare(self) -> Optional[SparkDataFrame]:
"""pyspark.sql.DataFrame: Returns rows only in the compare dataframe"""
if not self._rows_only_compare:
compare_rows = self._get_compare_rows()
Expand Down Expand Up @@ -475,7 +475,7 @@ def _merge_dataframes(self) -> None:
self._join_column_names # type: ignore[arg-type]
)

def _get_or_create_joined_dataframe(self) -> pyspark.sql.DataFrame:
def _get_or_create_joined_dataframe(self) -> SparkDataFrame:
if self._joined_dataframe is None:
join_condition = " AND ".join(
["A." + name + "<=>B." + name for name in self._join_column_names]
Expand Down
Loading

0 comments on commit 166f7b5

Please sign in to comment.