Skip to content

Commit

Permalink
✨ Add mypy for static type checking and type annotations (#249)
Browse files Browse the repository at this point in the history
  • Loading branch information
aguiddir authored Dec 8, 2023
1 parent 3b078e6 commit 2ed5d51
Show file tree
Hide file tree
Showing 5 changed files with 211 additions and 142 deletions.
156 changes: 94 additions & 62 deletions datacompy/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,9 @@
PROC COMPARE in SAS - i.e. human-readable reporting on the difference between
two dataframes.
"""

import logging
import os
from typing import cast, Any, List, Dict, Union, Optional

import numpy as np
import pandas as pd
Expand Down Expand Up @@ -79,18 +79,18 @@ class Compare:

def __init__(
self,
df1,
df2,
join_columns=None,
on_index=False,
abs_tol=0,
rel_tol=0,
df1_name="df1",
df2_name="df2",
ignore_spaces=False,
ignore_case=False,
cast_column_names_lower=True,
):
df1: pd.DataFrame,
df2: pd.DataFrame,
join_columns: Optional[Union[List[str], str]] = None,
on_index: bool = False,
abs_tol: float = 0,
rel_tol: float = 0,
df1_name: str = "df1",
df2_name: str = "df2",
ignore_spaces: bool = False,
ignore_case: bool = False,
cast_column_names_lower: bool = True,
) -> None:
self.cast_column_names_lower = cast_column_names_lower
if on_index and join_columns is not None:
raise Exception("Only provide on_index or join_columns")
Expand All @@ -107,11 +107,11 @@ def __init__(
else:
self.join_columns = [
str(col).lower() if self.cast_column_names_lower else str(col)
for col in join_columns
for col in cast(List[str], join_columns)
]
self.on_index = False

self._any_dupes = False
self._any_dupes: bool = False
self.df1 = df1
self.df2 = df2
self.df1_name = df1_name
Expand All @@ -120,35 +120,39 @@ def __init__(
self.rel_tol = rel_tol
self.ignore_spaces = ignore_spaces
self.ignore_case = ignore_case
self.df1_unq_rows = self.df2_unq_rows = self.intersect_rows = None
self.column_stats = []
self._compare(ignore_spaces, ignore_case)
self.df1_unq_rows: pd.DataFrame
self.df2_unq_rows: pd.DataFrame
self.intersect_rows: pd.DataFrame
self.column_stats: List[Dict[str, Any]] = []
self._compare(ignore_spaces=ignore_spaces, ignore_case=ignore_case)

@property
def df1(self):
def df1(self) -> pd.DataFrame:
return self._df1

@df1.setter
def df1(self, df1):
def df1(self, df1: pd.DataFrame) -> None:
"""Check that it is a dataframe and has the join columns"""
self._df1 = df1
self._validate_dataframe(
"df1", cast_column_names_lower=self.cast_column_names_lower
)

@property
def df2(self):
def df2(self) -> pd.DataFrame:
return self._df2

@df2.setter
def df2(self, df2):
def df2(self, df2: pd.DataFrame) -> None:
"""Check that it is a dataframe and has the join columns"""
self._df2 = df2
self._validate_dataframe(
"df2", cast_column_names_lower=self.cast_column_names_lower
)

def _validate_dataframe(self, index, cast_column_names_lower=True):
def _validate_dataframe(
self, index: str, cast_column_names_lower: bool = True
) -> None:
"""Check that it is a dataframe and has the join columns
Parameters
Expand All @@ -163,9 +167,11 @@ def _validate_dataframe(self, index, cast_column_names_lower=True):
raise TypeError(f"{index} must be a pandas DataFrame")

if cast_column_names_lower:
dataframe.columns = [str(col).lower() for col in dataframe.columns]
dataframe.columns = pd.Index(
[str(col).lower() for col in dataframe.columns]
)
else:
dataframe.columns = [str(col) for col in dataframe.columns]
dataframe.columns = pd.Index([str(col) for col in dataframe.columns])
# Check if join_columns are present in the dataframe
if not set(self.join_columns).issubset(set(dataframe.columns)):
raise ValueError(f"{index} must have all columns from join_columns")
Expand All @@ -182,7 +188,7 @@ def _validate_dataframe(self, index, cast_column_names_lower=True):
):
self._any_dupes = True

def _compare(self, ignore_spaces, ignore_case):
def _compare(self, ignore_spaces: bool, ignore_case: bool) -> None:
"""Actually run the comparison. This tries to run df1.equals(df2)
first so that if they're truly equal we can tell.
Expand Down Expand Up @@ -214,26 +220,31 @@ def _compare(self, ignore_spaces, ignore_case):
else:
LOG.info("df1 does not match df2")

def df1_unq_columns(self):
def df1_unq_columns(self) -> OrderedSet[str]:
"""Get columns that are unique to df1"""
return OrderedSet(self.df1.columns) - OrderedSet(self.df2.columns)
return cast(
OrderedSet[str], OrderedSet(self.df1.columns) - OrderedSet(self.df2.columns)
)

def df2_unq_columns(self):
def df2_unq_columns(self) -> OrderedSet[str]:
"""Get columns that are unique to df2"""
return OrderedSet(self.df2.columns) - OrderedSet(self.df1.columns)
return cast(
OrderedSet[str], OrderedSet(self.df2.columns) - OrderedSet(self.df1.columns)
)

def intersect_columns(self):
def intersect_columns(self) -> OrderedSet[str]:
"""Get columns that are shared between the two dataframes"""
return OrderedSet(self.df1.columns) & OrderedSet(self.df2.columns)

def _dataframe_merge(self, ignore_spaces):
def _dataframe_merge(self, ignore_spaces: bool) -> None:
"""Merge df1 to df2 on the join columns, to get df1 - df2, df2 - df1
and df1 & df2
If ``on_index`` is True, this will join on index values, otherwise it
will join on the ``join_columns``.
"""

params: Dict[str, Any]
index_column: str
LOG.debug("Outer joining")
if self._any_dupes:
LOG.debug("Duplicate rows found, deduping by order of remaining fields")
Expand Down Expand Up @@ -275,11 +286,10 @@ def _dataframe_merge(self, ignore_spaces):
# Clean up temp columns for duplicate row matching
if self._any_dupes:
if self.on_index:
outer_join.index = outer_join[index_column]
outer_join.drop(index_column, axis=1, inplace=True)
outer_join.set_index(keys=index_column, drop=True, inplace=True)
self.df1.drop(index_column, axis=1, inplace=True)
self.df2.drop(index_column, axis=1, inplace=True)
outer_join.drop(order_column, axis=1, inplace=True)
outer_join.drop(labels=order_column, axis=1, inplace=True)
self.df1.drop(order_column, axis=1, inplace=True)
self.df2.drop(order_column, axis=1, inplace=True)

Expand All @@ -306,7 +316,7 @@ def _dataframe_merge(self, ignore_spaces):
f"Number of rows in df1 and df2 (not necessarily equal): {len(self.intersect_rows)}"
)

def _intersect_compare(self, ignore_spaces, ignore_case):
def _intersect_compare(self, ignore_spaces: bool, ignore_case: bool) -> None:
"""Run the comparison on the intersect dataframe
This loops through all columns that are shared between df1 and df2, and
Expand All @@ -319,7 +329,7 @@ def _intersect_compare(self, ignore_spaces, ignore_case):
if column in self.join_columns:
match_cnt = row_cnt
col_match = ""
max_diff = 0
max_diff = 0.0
null_diff = 0
else:
col_1 = column + "_df1"
Expand Down Expand Up @@ -367,11 +377,11 @@ def _intersect_compare(self, ignore_spaces, ignore_case):
}
)

def all_columns_match(self):
def all_columns_match(self) -> bool:
"""Whether the columns all match in the dataframes"""
return self.df1_unq_columns() == self.df2_unq_columns() == set()

def all_rows_overlap(self):
def all_rows_overlap(self) -> bool:
"""Whether the rows are all present in both dataframes
Returns
Expand All @@ -382,7 +392,7 @@ def all_rows_overlap(self):
"""
return len(self.df1_unq_rows) == len(self.df2_unq_rows) == 0

def count_matching_rows(self):
def count_matching_rows(self) -> int:
"""Count the number of rows match (on overlapping fields)
Returns
Expand All @@ -396,12 +406,12 @@ def count_matching_rows(self):
match_columns.append(column + "_match")
return self.intersect_rows[match_columns].all(axis=1).sum()

def intersect_rows_match(self):
def intersect_rows_match(self) -> bool:
"""Check whether the intersect rows all match"""
actual_length = self.intersect_rows.shape[0]
return self.count_matching_rows() == actual_length

def matches(self, ignore_extra_columns=False):
def matches(self, ignore_extra_columns: bool = False) -> bool:
"""Return True or False if the dataframes match.
Parameters
Expand All @@ -418,7 +428,7 @@ def matches(self, ignore_extra_columns=False):
else:
return True

def subset(self):
def subset(self) -> bool:
"""Return True if dataframe 2 is a subset of dataframe 1.
Dataframe 2 is considered a subset if all of its columns are in
Expand All @@ -434,7 +444,9 @@ def subset(self):
else:
return True

def sample_mismatch(self, column, sample_count=10, for_display=False):
def sample_mismatch(
self, column: str, sample_count: int = 10, for_display: bool = False
) -> pd.DataFrame:
"""Returns a sample sub-dataframe which contains the identifying
columns, and df1 and df2 versions of the column.
Expand Down Expand Up @@ -463,13 +475,16 @@ def sample_mismatch(self, column, sample_count=10, for_display=False):
return_cols = self.join_columns + [column + "_df1", column + "_df2"]
to_return = sample[return_cols]
if for_display:
to_return.columns = self.join_columns + [
column + " (" + self.df1_name + ")",
column + " (" + self.df2_name + ")",
]
to_return.columns = pd.Index(
self.join_columns
+ [
column + " (" + self.df1_name + ")",
column + " (" + self.df2_name + ")",
]
)
return to_return

def all_mismatch(self, ignore_matching_cols=False):
def all_mismatch(self, ignore_matching_cols: bool = False) -> pd.DataFrame:
"""All rows with any columns that have a mismatch. Returns all df1 and df2 versions of the columns and join
columns.
Expand Down Expand Up @@ -512,7 +527,12 @@ def all_mismatch(self, ignore_matching_cols=False):
mm_bool = self.intersect_rows[match_list].all(axis="columns")
return self.intersect_rows[~mm_bool][self.join_columns + return_list]

def report(self, sample_count=10, column_count=10, html_file=None):
def report(
self,
sample_count: int = 10,
column_count: int = 10,
html_file: Optional[str] = None,
) -> str:
"""Returns a string representation of a report. The representation can
then be printed or saved to a file.
Expand All @@ -533,7 +553,7 @@ def report(self, sample_count=10, column_count=10, html_file=None):
The report, formatted kinda nicely.
"""

def df_to_str(pdf):
def df_to_str(pdf: pd.DataFrame) -> str:
if not self.on_index:
pdf = pdf.reset_index(drop=True)
return pdf.to_string()
Expand Down Expand Up @@ -674,7 +694,7 @@ def df_to_str(pdf):
return report


def render(filename, *fields):
def render(filename: str, *fields: Union[int, float, str]) -> str:
"""Renders out an individual template. This basically just reads in a
template file, and applies ``.format()`` on the fields.
Expand All @@ -697,8 +717,13 @@ def render(filename, *fields):


def columns_equal(
col_1, col_2, rel_tol=0, abs_tol=0, ignore_spaces=False, ignore_case=False
):
col_1: "pd.Series[Any]",
col_2: "pd.Series[Any]",
rel_tol: float = 0,
abs_tol: float = 0,
ignore_spaces: bool = False,
ignore_case: bool = False,
) -> "pd.Series[bool]":
"""Compares two columns from a dataframe, returning a True/False series,
with the same index as column 1.
Expand Down Expand Up @@ -731,6 +756,7 @@ def columns_equal(
A series of Boolean values. True == the values match, False == the
values don't match.
"""
compare: pd.Series[bool]
try:
compare = pd.Series(
np.isclose(col_1, col_2, rtol=rel_tol, atol=abs_tol, equal_nan=True)
Expand Down Expand Up @@ -773,7 +799,9 @@ def columns_equal(
return compare


def compare_string_and_date_columns(col_1, col_2):
def compare_string_and_date_columns(
col_1: "pd.Series[Any]", col_2: "pd.Series[Any]"
) -> "pd.Series[bool]":
"""Compare a string column and date column, value-wise. This tries to
convert a string column to a date column and compare that way.
Expand Down Expand Up @@ -812,7 +840,9 @@ def compare_string_and_date_columns(col_1, col_2):
return pd.Series(False, index=col_1.index)


def get_merged_columns(original_df, merged_df, suffix):
def get_merged_columns(
original_df: pd.DataFrame, merged_df: pd.DataFrame, suffix: str
) -> List[str]:
"""Gets the columns from an original dataframe, in the new merged dataframe
Parameters
Expand All @@ -836,7 +866,7 @@ def get_merged_columns(original_df, merged_df, suffix):
return columns


def temp_column_name(*dataframes):
def temp_column_name(*dataframes: pd.DataFrame) -> str:
"""Gets a temp column name that isn't included in columns of any dataframes
Parameters
Expand All @@ -861,7 +891,7 @@ def temp_column_name(*dataframes):
return temp_column


def calculate_max_diff(col_1, col_2):
def calculate_max_diff(col_1: "pd.Series[Any]", col_2: "pd.Series[Any]") -> float:
"""Get a maximum difference between two columns
Parameters
Expand All @@ -877,12 +907,14 @@ def calculate_max_diff(col_1, col_2):
Numeric field, or zero.
"""
try:
return (col_1.astype(float) - col_2.astype(float)).abs().max()
return cast(float, (col_1.astype(float) - col_2.astype(float)).abs().max())
except:
return 0
return 0.0


def generate_id_within_group(dataframe, join_columns):
def generate_id_within_group(
dataframe: pd.DataFrame, join_columns: List[str]
) -> "pd.Series[int]":
"""Generate an ID column that can be used to deduplicate identical rows. The series generated
is the order within a unique group, and it handles nulls.
Expand Down
Loading

0 comments on commit 2ed5d51

Please sign in to comment.