Skip to content

Commit

Permalink
Custom column validator for pdf2parquet (#577)
Browse files Browse the repository at this point in the history
* use @classmethod to allow overloading

Signed-off-by: Michele Dolfi <dol@zurich.ibm.com>

* custom validator for document types in pdf2parquet

Signed-off-by: Michele Dolfi <dol@zurich.ibm.com>

---------

Signed-off-by: Michele Dolfi <dol@zurich.ibm.com>
  • Loading branch information
dolfim-ibm authored Sep 6, 2024
1 parent 51c8676 commit 9fd27e0
Show file tree
Hide file tree
Showing 3 changed files with 126 additions and 24 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -75,8 +75,8 @@ def pytest_generate_tests(metafunc):
def _install_test_fixtures(self, metafunc):
raise NotImplemented("Sub-class must implemented this to install the fixtures for its tests.")

@staticmethod
def validate_expected_tables(table_list: list[pa.Table], expected_table_list: list[pa.Table]):
@classmethod
def validate_expected_tables(cls, table_list: list[pa.Table], expected_table_list: list[pa.Table]):
"""
Verify with assertion messages that the two lists of Tables are equivalent.
:param table_list:
Expand All @@ -100,10 +100,10 @@ def validate_expected_tables(table_list: list[pa.Table], expected_table_list: li
r1 = t1.take([j])
r2 = t2.take([j])
# assert r1 == r2, f"Row {j} of table {i} are not equal\n\tTransformed: {r1}\n\tExpected : {r2}"
AbstractTest.validate_expected_row(i, j, r1, r2)
cls.validate_expected_row(i, j, r1, r2)

@staticmethod
def validate_expected_row(table_index: int, row_index: int, test_row: pa.Table, expected_row: pa.Table):
@classmethod
def validate_expected_row(cls, table_index: int, row_index: int, test_row: pa.Table, expected_row: pa.Table):
"""
Compare the two rows for equality, allowing float values to be within a percentage
of each other as defined by global _allowed_float_percent_diff.
Expand Down Expand Up @@ -139,8 +139,8 @@ def validate_expected_row(table_index: int, row_index: int, test_row: pa.Table,
diff = abs(test_value - expected_value)
assert diff <= allowed_diff, msg

@staticmethod
def validate_expected_files(files_list: list[tuple[bytes, str]], expected_files_list: list[tuple[bytes, str]]):
@classmethod
def validate_expected_files(cls, files_list: list[tuple[bytes, str]], expected_files_list: list[tuple[bytes, str]]):
"""
Verify with assertion messages that the two lists of Tables are equivalent.
:param files_list:
Expand Down Expand Up @@ -171,15 +171,15 @@ def validate_expected_files(files_list: list[tuple[bytes, str]], expected_files_
diff <= diff_allowed
), f"produced file length {lenf1} vs expected {lenf2}, exceeds allowance of {diff_allowed}"

@staticmethod
def validate_expected_metadata_lists(metadata: list[dict[str, float]], expected_metadata: list[dict[str, float]]):
@classmethod
def validate_expected_metadata_lists(cls, metadata: list[dict[str, float]], expected_metadata: list[dict[str, float]]):
elen = len(expected_metadata)
assert len(metadata) == elen, f"Number of metadata dictionaries not the expected of {elen}"
for index in range(elen):
AbstractTest.validate_expected_metadata(metadata[index], expected_metadata[index])
cls.validate_expected_metadata(metadata[index], expected_metadata[index])

@staticmethod
def validate_expected_metadata(metadata: dict[str, float], expected_metadata: dict[str, float]):
@classmethod
def validate_expected_metadata(cls, metadata: dict[str, float], expected_metadata: dict[str, float]):
"""
Verify with assertion messages that the two dictionaries are as expected.
:param metadata:
Expand All @@ -194,8 +194,8 @@ def validate_expected_metadata(metadata: dict[str, float], expected_metadata: di
f"Metadata not equal\n" "\tTransformed: {metadata} Expected : {expected_metadata}"
)

@staticmethod
def validate_directory_contents(directory: str, expected_dir: str, drop_columns: list[str] = []):
@classmethod
def validate_directory_contents(cls, directory: str, expected_dir: str, drop_columns: list[str] = []):
"""
Make sure the directory contents are the same.
:param directory:
Expand All @@ -217,28 +217,28 @@ def validate_directory_contents(directory: str, expected_dir: str, drop_columns:
expected_diffs = 0
failed = len(dir_cmp.diff_files) != expected_diffs
if failed:
AbstractTest.__confirm_diffs(directory, expected_dir, dir_cmp.diff_files, "/tmp", drop_columns)
cls.__confirm_diffs(directory, expected_dir, dir_cmp.diff_files, "/tmp", drop_columns)

# Traverse into the subdirs since dircmp doesn't seem to do that.
subdirs = [f.name for f in os.scandir(expected_dir) if f.is_dir()]
for subdir in subdirs:
d1 = os.path.join(directory, subdir)
d2 = os.path.join(expected_dir, subdir)
AbstractTest.validate_directory_contents(d1, d2, drop_columns)
cls.validate_directory_contents(d1, d2, drop_columns)

@staticmethod
def _validate_table_files(parquet1: str, parquet2: str, drop_columns: list[str] = []):
@classmethod
def _validate_table_files(cls, parquet1: str, parquet2: str, drop_columns: list[str] = []):
da = DataAccessLocal()
t1, _ = da.get_table(parquet1)
t2, _ = da.get_table(parquet2)
if len(drop_columns) > 0:
t1 = t1.drop_columns(drop_columns)
t2 = t2.drop_columns(drop_columns)
AbstractTest.validate_expected_tables([t1], [t2])
cls.validate_expected_tables([t1], [t2])

@staticmethod
@classmethod
def __confirm_diffs(
src_dir: str, expected_dir: str, diff_files: list, dest_dir: str, drop_columns: list[str] = []
cls, src_dir: str, expected_dir: str, diff_files: list, dest_dir: str, drop_columns: list[str] = []
):
"""
Copy all files from the source dir to the dest dir.
Expand All @@ -256,7 +256,7 @@ def __confirm_diffs(
# It seems file can be different on disk, but contain the same column/row values.
# so for these, do the inmemory comparison.
try:
AbstractTest._validate_table_files(expected, src, drop_columns)
cls._validate_table_files(expected, src, drop_columns)
except AssertionError as e:
logger.info(f"Copying file with difference: {src} to {dest}")
shutil.copyfile(src, dest)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ def _validate_directory_contents_match(self, dir: str, expected: str, ignore_col
Confirm that the two directories contains the same files.
Stubbed out like this to allow spark tests to override this since spark tends to rename the files.
"""
AbstractTest.validate_directory_contents(dir, expected, ignore_columns)
self.validate_directory_contents(dir, expected, ignore_columns)

def _install_test_fixtures(self, metafunc):
# Apply the fixtures for the method with these input names (i.e. test_transform()).
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,16 @@
import ast
import os

import pyarrow as pa
from data_processing.runtime.pure_python import PythonTransformLauncher
from data_processing.test_support.abstract_test import _allowed_float_percent_diff
from data_processing.test_support.launch.transform_test import (
AbstractTransformLauncherTest,
)
from docling_core.types import Document
from docling_core.types.doc.base import BaseText
from pdf2parquet_transform_python import Pdf2ParquetPythonTransformConfiguration
from pydantic import ValidationError


class TestPythonPdf2ParquetTransform(AbstractTransformLauncherTest):
Expand All @@ -35,7 +40,7 @@ def get_test_transform_fixtures(self) -> list[tuple]:
}

# this is added as a fixture to remove these columns from comparison
ignore_columns = ["date_acquired", "document_id", "pdf_convert_time"]
ignore_columns = ["date_acquired", "document_id", "pdf_convert_time", "hash"]

fixtures = []
launcher = PythonTransformLauncher(Pdf2ParquetPythonTransformConfiguration())
Expand Down Expand Up @@ -84,3 +89,100 @@ def get_test_transform_fixtures(self) -> list[tuple]:
)

return fixtures

@classmethod
def validate_expected_row(
cls,
table_index: int,
row_index: int,
test_row: pa.Table,
expected_row: pa.Table,
):
"""
Compare the two rows for equality, allowing float values to be within a percentage
of each other as defined by global _allowed_float_percent_diff.
We assume the schema has already been compared and is equivalent.
Args:
table_index: index of tables that is the source of the rows.
row_index:
test_row:
expected_row:
"""

assert test_row.num_rows == 1, "Invalid usage. Expected test table with 1 row"
assert (
expected_row.num_rows == 1
), "Invalid usage. Expected expected table with 1 row"
if test_row != expected_row:
# Else look for floating point values that might differ within the allowance
msg = f"Row {row_index} of table {table_index} are not equal\n\tTransformed: {test_row}\n\tExpected : {expected_row}"
assert test_row.num_columns == expected_row.num_columns, msg
num_columns = test_row.num_columns
for i in range(num_columns):
# Over each cell/column in the row
test_column = test_row.column(i)
expected_column = expected_row.column(i)
if test_column != expected_column:
# Check if the value is a float and if so, allow a fuzzy match
test_value = test_column.to_pylist()[0]
expected_value = expected_column.to_pylist()[0]

# Test for Document type
try:

expected_doc = Document.model_validate_json(expected_value)
test_doc = Document.model_validate_json(test_value)
cls.validate_documents(
row_index=row_index,
table_index=table_index,
test_doc=test_doc,
expected_doc=expected_doc,
)

continue

except ValidationError:
pass

# Test for floats
is_float = isinstance(test_value, float) and isinstance(
expected_value, float
)
if is_float:
# It IS a float, so allow a fuzzy match
allowed_diff = abs(_allowed_float_percent_diff * expected_value)
diff = abs(test_value - expected_value)
assert diff <= allowed_diff, msg

continue

# Its NOT a float or other managed types, so do a normal compare
assert test_column == expected_column, msg

@classmethod
def validate_documents(
cls,
row_index: int,
table_index: int,
test_doc: Document,
expected_doc: Document,
):
msg = f"Row {row_index} of table {table_index} are not equal\n\t"
assert len(test_doc.main_text) == len(expected_doc.main_text), (
msg + f"Main Text lengths do not match."
)

for i in range(len(expected_doc.main_text)):
expected_item = expected_doc.main_text[i]
test_item = test_doc.main_text[i]

# Validate type
assert expected_item.obj_type == test_item.obj_type, (
msg + f"Object type does not match."
)

# Validate text content
if isinstance(expected_item, BaseText):
assert expected_item.text == test_item.text, (
msg + f"Text does not match."
)

0 comments on commit 9fd27e0

Please sign in to comment.