From 9fd27e09122aba4f9fce86c0e63f87ff26843546 Mon Sep 17 00:00:00 2001 From: Michele Dolfi <97102151+dolfim-ibm@users.noreply.github.com> Date: Fri, 6 Sep 2024 22:54:12 +0200 Subject: [PATCH] Custom column validator for pdf2parquet (#577) * use @classmethod to allow overloading Signed-off-by: Michele Dolfi * custom validator for document types in pdf2parquet Signed-off-by: Michele Dolfi --------- Signed-off-by: Michele Dolfi --- .../test_support/abstract_test.py | 44 ++++---- .../test_support/launch/transform_test.py | 2 +- .../python/test/test_pdf2parquet_python.py | 104 +++++++++++++++++- 3 files changed, 126 insertions(+), 24 deletions(-) diff --git a/data-processing-lib/python/src/data_processing/test_support/abstract_test.py b/data-processing-lib/python/src/data_processing/test_support/abstract_test.py index 4ec398862..76fd29000 100644 --- a/data-processing-lib/python/src/data_processing/test_support/abstract_test.py +++ b/data-processing-lib/python/src/data_processing/test_support/abstract_test.py @@ -75,8 +75,8 @@ def pytest_generate_tests(metafunc): def _install_test_fixtures(self, metafunc): raise NotImplemented("Sub-class must implemented this to install the fixtures for its tests.") - @staticmethod - def validate_expected_tables(table_list: list[pa.Table], expected_table_list: list[pa.Table]): + @classmethod + def validate_expected_tables(cls, table_list: list[pa.Table], expected_table_list: list[pa.Table]): """ Verify with assertion messages that the two lists of Tables are equivalent. :param table_list: @@ -100,10 +100,10 @@ def validate_expected_tables(table_list: list[pa.Table], expected_table_list: li r1 = t1.take([j]) r2 = t2.take([j]) # assert r1 == r2, f"Row {j} of table {i} are not equal\n\tTransformed: {r1}\n\tExpected : {r2}" - AbstractTest.validate_expected_row(i, j, r1, r2) + cls.validate_expected_row(i, j, r1, r2) - @staticmethod - def validate_expected_row(table_index: int, row_index: int, test_row: pa.Table, expected_row: pa.Table): + @classmethod + def validate_expected_row(cls, table_index: int, row_index: int, test_row: pa.Table, expected_row: pa.Table): """ Compare the two rows for equality, allowing float values to be within a percentage of each other as defined by global _allowed_float_percent_diff. @@ -139,8 +139,8 @@ def validate_expected_row(table_index: int, row_index: int, test_row: pa.Table, diff = abs(test_value - expected_value) assert diff <= allowed_diff, msg - @staticmethod - def validate_expected_files(files_list: list[tuple[bytes, str]], expected_files_list: list[tuple[bytes, str]]): + @classmethod + def validate_expected_files(cls, files_list: list[tuple[bytes, str]], expected_files_list: list[tuple[bytes, str]]): """ Verify with assertion messages that the two lists of Tables are equivalent. :param files_list: @@ -171,15 +171,15 @@ def validate_expected_files(files_list: list[tuple[bytes, str]], expected_files_ diff <= diff_allowed ), f"produced file length {lenf1} vs expected {lenf2}, exceeds allowance of {diff_allowed}" - @staticmethod - def validate_expected_metadata_lists(metadata: list[dict[str, float]], expected_metadata: list[dict[str, float]]): + @classmethod + def validate_expected_metadata_lists(cls, metadata: list[dict[str, float]], expected_metadata: list[dict[str, float]]): elen = len(expected_metadata) assert len(metadata) == elen, f"Number of metadata dictionaries not the expected of {elen}" for index in range(elen): - AbstractTest.validate_expected_metadata(metadata[index], expected_metadata[index]) + cls.validate_expected_metadata(metadata[index], expected_metadata[index]) - @staticmethod - def validate_expected_metadata(metadata: dict[str, float], expected_metadata: dict[str, float]): + @classmethod + def validate_expected_metadata(cls, metadata: dict[str, float], expected_metadata: dict[str, float]): """ Verify with assertion messages that the two dictionaries are as expected. :param metadata: @@ -194,8 +194,8 @@ def validate_expected_metadata(metadata: dict[str, float], expected_metadata: di f"Metadata not equal\n" "\tTransformed: {metadata} Expected : {expected_metadata}" ) - @staticmethod - def validate_directory_contents(directory: str, expected_dir: str, drop_columns: list[str] = []): + @classmethod + def validate_directory_contents(cls, directory: str, expected_dir: str, drop_columns: list[str] = []): """ Make sure the directory contents are the same. :param directory: @@ -217,28 +217,28 @@ def validate_directory_contents(directory: str, expected_dir: str, drop_columns: expected_diffs = 0 failed = len(dir_cmp.diff_files) != expected_diffs if failed: - AbstractTest.__confirm_diffs(directory, expected_dir, dir_cmp.diff_files, "/tmp", drop_columns) + cls.__confirm_diffs(directory, expected_dir, dir_cmp.diff_files, "/tmp", drop_columns) # Traverse into the subdirs since dircmp doesn't seem to do that. subdirs = [f.name for f in os.scandir(expected_dir) if f.is_dir()] for subdir in subdirs: d1 = os.path.join(directory, subdir) d2 = os.path.join(expected_dir, subdir) - AbstractTest.validate_directory_contents(d1, d2, drop_columns) + cls.validate_directory_contents(d1, d2, drop_columns) - @staticmethod - def _validate_table_files(parquet1: str, parquet2: str, drop_columns: list[str] = []): + @classmethod + def _validate_table_files(cls, parquet1: str, parquet2: str, drop_columns: list[str] = []): da = DataAccessLocal() t1, _ = da.get_table(parquet1) t2, _ = da.get_table(parquet2) if len(drop_columns) > 0: t1 = t1.drop_columns(drop_columns) t2 = t2.drop_columns(drop_columns) - AbstractTest.validate_expected_tables([t1], [t2]) + cls.validate_expected_tables([t1], [t2]) - @staticmethod + @classmethod def __confirm_diffs( - src_dir: str, expected_dir: str, diff_files: list, dest_dir: str, drop_columns: list[str] = [] + cls, src_dir: str, expected_dir: str, diff_files: list, dest_dir: str, drop_columns: list[str] = [] ): """ Copy all files from the source dir to the dest dir. @@ -256,7 +256,7 @@ def __confirm_diffs( # It seems file can be different on disk, but contain the same column/row values. # so for these, do the inmemory comparison. try: - AbstractTest._validate_table_files(expected, src, drop_columns) + cls._validate_table_files(expected, src, drop_columns) except AssertionError as e: logger.info(f"Copying file with difference: {src} to {dest}") shutil.copyfile(src, dest) diff --git a/data-processing-lib/python/src/data_processing/test_support/launch/transform_test.py b/data-processing-lib/python/src/data_processing/test_support/launch/transform_test.py index 63381dca2..77d21fc0d 100644 --- a/data-processing-lib/python/src/data_processing/test_support/launch/transform_test.py +++ b/data-processing-lib/python/src/data_processing/test_support/launch/transform_test.py @@ -65,7 +65,7 @@ def _validate_directory_contents_match(self, dir: str, expected: str, ignore_col Confirm that the two directories contains the same files. Stubbed out like this to allow spark tests to override this since spark tends to rename the files. """ - AbstractTest.validate_directory_contents(dir, expected, ignore_columns) + self.validate_directory_contents(dir, expected, ignore_columns) def _install_test_fixtures(self, metafunc): # Apply the fixtures for the method with these input names (i.e. test_transform()). diff --git a/transforms/language/pdf2parquet/python/test/test_pdf2parquet_python.py b/transforms/language/pdf2parquet/python/test/test_pdf2parquet_python.py index ba8102822..6f51a8317 100644 --- a/transforms/language/pdf2parquet/python/test/test_pdf2parquet_python.py +++ b/transforms/language/pdf2parquet/python/test/test_pdf2parquet_python.py @@ -13,11 +13,16 @@ import ast import os +import pyarrow as pa from data_processing.runtime.pure_python import PythonTransformLauncher +from data_processing.test_support.abstract_test import _allowed_float_percent_diff from data_processing.test_support.launch.transform_test import ( AbstractTransformLauncherTest, ) +from docling_core.types import Document +from docling_core.types.doc.base import BaseText from pdf2parquet_transform_python import Pdf2ParquetPythonTransformConfiguration +from pydantic import ValidationError class TestPythonPdf2ParquetTransform(AbstractTransformLauncherTest): @@ -35,7 +40,7 @@ def get_test_transform_fixtures(self) -> list[tuple]: } # this is added as a fixture to remove these columns from comparison - ignore_columns = ["date_acquired", "document_id", "pdf_convert_time"] + ignore_columns = ["date_acquired", "document_id", "pdf_convert_time", "hash"] fixtures = [] launcher = PythonTransformLauncher(Pdf2ParquetPythonTransformConfiguration()) @@ -84,3 +89,100 @@ def get_test_transform_fixtures(self) -> list[tuple]: ) return fixtures + + @classmethod + def validate_expected_row( + cls, + table_index: int, + row_index: int, + test_row: pa.Table, + expected_row: pa.Table, + ): + """ + Compare the two rows for equality, allowing float values to be within a percentage + of each other as defined by global _allowed_float_percent_diff. + We assume the schema has already been compared and is equivalent. + Args: + table_index: index of tables that is the source of the rows. + row_index: + test_row: + expected_row: + """ + + assert test_row.num_rows == 1, "Invalid usage. Expected test table with 1 row" + assert ( + expected_row.num_rows == 1 + ), "Invalid usage. Expected expected table with 1 row" + if test_row != expected_row: + # Else look for floating point values that might differ within the allowance + msg = f"Row {row_index} of table {table_index} are not equal\n\tTransformed: {test_row}\n\tExpected : {expected_row}" + assert test_row.num_columns == expected_row.num_columns, msg + num_columns = test_row.num_columns + for i in range(num_columns): + # Over each cell/column in the row + test_column = test_row.column(i) + expected_column = expected_row.column(i) + if test_column != expected_column: + # Check if the value is a float and if so, allow a fuzzy match + test_value = test_column.to_pylist()[0] + expected_value = expected_column.to_pylist()[0] + + # Test for Document type + try: + + expected_doc = Document.model_validate_json(expected_value) + test_doc = Document.model_validate_json(test_value) + cls.validate_documents( + row_index=row_index, + table_index=table_index, + test_doc=test_doc, + expected_doc=expected_doc, + ) + + continue + + except ValidationError: + pass + + # Test for floats + is_float = isinstance(test_value, float) and isinstance( + expected_value, float + ) + if is_float: + # It IS a float, so allow a fuzzy match + allowed_diff = abs(_allowed_float_percent_diff * expected_value) + diff = abs(test_value - expected_value) + assert diff <= allowed_diff, msg + + continue + + # Its NOT a float or other managed types, so do a normal compare + assert test_column == expected_column, msg + + @classmethod + def validate_documents( + cls, + row_index: int, + table_index: int, + test_doc: Document, + expected_doc: Document, + ): + msg = f"Row {row_index} of table {table_index} are not equal\n\t" + assert len(test_doc.main_text) == len(expected_doc.main_text), ( + msg + f"Main Text lengths do not match." + ) + + for i in range(len(expected_doc.main_text)): + expected_item = expected_doc.main_text[i] + test_item = test_doc.main_text[i] + + # Validate type + assert expected_item.obj_type == test_item.obj_type, ( + msg + f"Object type does not match." + ) + + # Validate text content + if isinstance(expected_item, BaseText): + assert expected_item.text == test_item.text, ( + msg + f"Text does not match." + )