diff --git a/Makefile b/Makefile index 3dab636..6304e37 100644 --- a/Makefile +++ b/Makefile @@ -25,7 +25,8 @@ lint-bandit: ## Run bandit @echo "\n${BLUE}Running bandit...${NC}\n" @${POETRY_RUN} bandit -r ${PROJ} -lint-base: lint-flake8 lint-bandit ## Just run the linters without autolinting +#lint-base: lint-flake8 lint-bandit ## Just run the linters without autolinting +lint-base: lint-flake8 # TODO: Can we drop bandit? lint: autolint lint-base lint-mypy ## Autolint and code linting diff --git a/dat/generated_tables.py b/dat/generated_tables.py index cd15c68..79e7a35 100644 --- a/dat/generated_tables.py +++ b/dat/generated_tables.py @@ -1,14 +1,14 @@ -from decimal import Decimal import os +import random from datetime import date, datetime, timedelta +from decimal import Decimal from pathlib import Path -import random from typing import Callable, List, Tuple -from delta.tables import DeltaTable import pyspark.sql -from pyspark.sql import SparkSession import pyspark.sql.types as types +from delta.tables import DeltaTable +from pyspark.sql import SparkSession from dat.models import TableVersionMetadata, TestCaseInfo from dat.spark_builder import get_spark_session @@ -158,16 +158,17 @@ def create_multi_partitioned(case: TestCaseInfo, spark: SparkSession): @reference_table( - name="multi_partitioned_2", - description="Multiple levels of partitioning, with boolean, timestamp, and decimal partition columns" + name='multi_partitioned_2', + description=('Multiple levels of partitioning, with boolean, timestamp, and ' + 'decimal partition columns') ) def create_multi_partitioned_2(case: TestCaseInfo, spark: SparkSession): columns = ['bool', 'time', 'amount', 'int'] partition_columns = ['bool', 'time', 'amount'] data = [ - (True, datetime(1970, 1, 1), Decimal("200.00"), 1), - (True, datetime(1970, 1, 1, 12, 30), Decimal("200.00"), 2), - (False, datetime(1970, 1, 2, 8, 45), Decimal("12.00"), 3) + (True, datetime(1970, 1, 1), Decimal('200.00'), 1), + (True, datetime(1970, 1, 1, 12, 30), Decimal('200.00'), 2), + (False, datetime(1970, 1, 2, 8, 45), Decimal('12.00'), 3) ] df = spark.createDataFrame(data, schema=columns) df.repartition(1).write.format('delta').partitionBy( @@ -194,24 +195,25 @@ def with_schema_change(case: TestCaseInfo, spark: SparkSession): case.delta_root) save_expected(case) + @reference_table( name='all_primitive_types', description='Table containing all non-nested types', ) def create_all_primitive_types(case: TestCaseInfo, spark: SparkSession): schema = types.StructType([ - types.StructField("utf8", types.StringType()), - types.StructField("int64", types.LongType()), - types.StructField("int32", types.IntegerType()), - types.StructField("int16", types.ShortType()), - types.StructField("int8", types.ByteType()), - types.StructField("float32", types.FloatType()), - types.StructField("float64", types.DoubleType()), - types.StructField("bool", types.BooleanType()), - types.StructField("binary", types.BinaryType()), - types.StructField("decimal", types.DecimalType(5, 3)), - types.StructField("date32", types.DateType()), - types.StructField("timestamp", types.TimestampType()), + types.StructField('utf8', types.StringType()), + types.StructField('int64', types.LongType()), + types.StructField('int32', types.IntegerType()), + types.StructField('int16', types.ShortType()), + types.StructField('int8', types.ByteType()), + types.StructField('float32', types.FloatType()), + types.StructField('float64', types.DoubleType()), + types.StructField('bool', types.BooleanType()), + types.StructField('binary', types.BinaryType()), + types.StructField('decimal', types.DecimalType(5, 3)), + types.StructField('date32', types.DateType()), + types.StructField('timestamp', types.TimestampType()), ]) df = spark.createDataFrame([ @@ -225,7 +227,7 @@ def create_all_primitive_types(case: TestCaseInfo, spark: SparkSession): float(i), i % 2 == 0, bytes(i), - Decimal("10.000") + i, + Decimal('10.000') + i, date(1970, 1, 1) + timedelta(days=i), datetime(1970, 1, 1) + timedelta(hours=i) ) @@ -240,20 +242,25 @@ def create_all_primitive_types(case: TestCaseInfo, spark: SparkSession): description='Table containing various nested types', ) def create_nested_types(case: TestCaseInfo, spark: SparkSession): - schema = types.StructType([ - types.StructField("struct", types.StructType([ - types.StructField("float64", types.DoubleType()), - types.StructField("bool", types.BooleanType()), - ])), - types.StructField("array", types.ArrayType(types.ShortType())), - types.StructField("map", types.MapType(types.StringType(), types.IntegerType())), - ]) + schema = types.StructType([types.StructField( + 'struct', types.StructType( + [types.StructField( + 'float64', types.DoubleType()), + types.StructField( + 'bool', types.BooleanType()), ])), + types.StructField( + 'array', types.ArrayType( + types.ShortType())), + types.StructField( + 'map', types.MapType( + types.StringType(), + types.IntegerType())), ]) df = spark.createDataFrame([ ( - { "float64": float(i), "bool": i % 2 == 0 }, + {'float64': float(i), 'bool': i % 2 == 0}, list(range(i + 1)), - { str(i): i for i in range(i) } + {str(i): i for i in range(i)} ) for i in range(5) ], schema=schema) @@ -261,17 +268,18 @@ def create_nested_types(case: TestCaseInfo, spark: SparkSession): df.repartition(1).write.format('delta').save(case.delta_root) -def get_sample_data(spark: SparkSession, seed: int=42, nrows: int=5) -> pyspark.sql.DataFrame: +def get_sample_data( + spark: SparkSession, seed: int = 42, nrows: int = 5) -> pyspark.sql.DataFrame: # Use seed to get consistent data between runs, for reproducibility random.seed(seed) return spark.createDataFrame([ ( - random.choice(["a", "b", "c", None]), + random.choice(['a', 'b', 'c', None]), random.randint(0, 1000), date(random.randint(1970, 2020), random.randint(1, 12), 1) ) for i in range(nrows) - ], schema=["letter", "int", "date"]) + ], schema=['letter', 'int', 'date']) @reference_table( @@ -279,24 +287,95 @@ def get_sample_data(spark: SparkSession, seed: int=42, nrows: int=5) -> pyspark. description='Table with a checkpoint', ) def create_with_checkpoint(case: TestCaseInfo, spark: SparkSession): - spark.conf.set("spark.databricks.delta.retentionDurationCheck.enabled", "false") + df = get_sample_data(spark) + + (DeltaTable.create(spark) + .location(str(Path(case.delta_root).absolute())) + .addColumns(df.schema) + .property('delta.checkpointInterval', '2') + .execute()) + + for i in range(3): + df = get_sample_data(spark, seed=i, nrows=5) + df.repartition(1).write.format('delta').mode( + 'overwrite').save(case.delta_root) + + assert any(path.suffixes == ['.checkpoint', '.parquet'] + for path in (Path(case.delta_root) / '_delta_log').iterdir()) + + +def remove_log_file(delta_root: str, version: int): + os.remove(os.path.join(delta_root, '_delta_log', f'{version:0>20}.json')) + + +@reference_table( + name='no_replay', + description='Table with a checkpoint and prior commits cleaned up', +) +def create_no_replay(case: TestCaseInfo, spark: SparkSession): + spark.conf.set( + 'spark.databricks.delta.retentionDurationCheck.enabled', 'false') df = get_sample_data(spark) - + table = (DeltaTable.create(spark) + .location(str(Path(case.delta_root).absolute())) + .addColumns(df.schema) + .property('delta.checkpointInterval', '2') + .execute()) + + for i in range(3): + df = get_sample_data(spark, seed=i, nrows=5) + df.repartition(1).write.format('delta').mode( + 'overwrite').save(case.delta_root) + + table.vacuum(retentionHours=0) + + remove_log_file(case.delta_root, version=0) + remove_log_file(case.delta_root, version=1) + + files_in_log = list((Path(case.delta_root) / '_delta_log').iterdir()) + assert any(path.suffixes == ['.checkpoint', '.parquet'] + for path in files_in_log) + assert not any(path.name == f'{0:0>20}.json' for path in files_in_log) + + +@reference_table( + name='stats_as_struct', + description='Table with stats only written as struct (not JSON) with Checkpoint', +) +def create_stats_as_struct(case: TestCaseInfo, spark: SparkSession): + df = get_sample_data(spark) + (DeltaTable.create(spark) .location(str(Path(case.delta_root).absolute())) .addColumns(df.schema) - .property("delta.checkpointInterval", "2") - .property("delta.logRetentionDuration", "0 days") + .property('delta.checkpointInterval', '2') + .property('delta.checkpoint.writeStatsAsStruct', 'true') + .property('delta.checkpoint.writeStatsAsJson', 'false') .execute()) - for i in range(5): + for i in range(3): df = get_sample_data(spark, seed=i, nrows=5) - df.repartition(1).write.format('delta').mode('overwrite').save(case.delta_root) - - assert any(path.suffixes == [".checkpoint", ".parquet"] - for path in (Path(case.delta_root) / "_delta_log").iterdir()) + df.repartition(1).write.format('delta').mode( + 'overwrite').save(case.delta_root) - table.vacuum(retentionHours=0) +@reference_table( + name='no_stats', + description='Table with no stats', +) +def create_no_stats(case: TestCaseInfo, spark: SparkSession): + df = get_sample_data(spark) + (DeltaTable.create(spark) + .location(str(Path(case.delta_root).absolute())) + .addColumns(df.schema) + .property('delta.checkpointInterval', '2') + .property('delta.checkpoint.writeStatsAsStruct', 'false') + .property('delta.checkpoint.writeStatsAsJson', 'false') + .property('delta.dataSkippingNumIndexedCols', '0') + .execute()) + for i in range(3): + df = get_sample_data(spark, seed=i, nrows=5) + df.repartition(1).write.format('delta').mode( + 'overwrite').save(case.delta_root) diff --git a/dat/main.py b/dat/main.py index 1980db7..8d30b73 100644 --- a/dat/main.py +++ b/dat/main.py @@ -45,7 +45,8 @@ def write_generated_reference_tables(table_name: Optional[str]): create_table() break else: - raise ValueError(f"Could not find generated table named '{table_name}'") + raise ValueError( + f"Could not find generated table named '{table_name}'") else: out_base = Path('out/reader_tests/generated') shutil.rmtree(out_base) diff --git a/setup.cfg b/setup.cfg index f6b2ac6..9175508 100644 --- a/setup.cfg +++ b/setup.cfg @@ -9,4 +9,5 @@ per-file-ignores = # WPS202 Found too many module members tests/*: S101 WPS114 WPS226 WPS202 dat/external_tables.py: WPS226 WPS114 - dat/generated_tables.py: WPS226 WPS114 \ No newline at end of file + dat/generated_tables.py: WPS226 WPS114 +max-line-length = 90 \ No newline at end of file