Skip to content

Commit

Permalink
more tables
Browse files Browse the repository at this point in the history
  • Loading branch information
wjones127 committed Dec 4, 2022
1 parent c4d2292 commit 1876315
Show file tree
Hide file tree
Showing 4 changed files with 130 additions and 48 deletions.
3 changes: 2 additions & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,8 @@ lint-bandit: ## Run bandit
@echo "\n${BLUE}Running bandit...${NC}\n"
@${POETRY_RUN} bandit -r ${PROJ}

lint-base: lint-flake8 lint-bandit ## Just run the linters without autolinting
#lint-base: lint-flake8 lint-bandit ## Just run the linters without autolinting
lint-base: lint-flake8 # TODO: Can we drop bandit?

lint: autolint lint-base lint-mypy ## Autolint and code linting

Expand Down
169 changes: 124 additions & 45 deletions dat/generated_tables.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
from decimal import Decimal
import os
import random
from datetime import date, datetime, timedelta
from decimal import Decimal
from pathlib import Path
import random
from typing import Callable, List, Tuple

from delta.tables import DeltaTable
import pyspark.sql
from pyspark.sql import SparkSession
import pyspark.sql.types as types
from delta.tables import DeltaTable
from pyspark.sql import SparkSession

from dat.models import TableVersionMetadata, TestCaseInfo
from dat.spark_builder import get_spark_session
Expand Down Expand Up @@ -158,16 +158,17 @@ def create_multi_partitioned(case: TestCaseInfo, spark: SparkSession):


@reference_table(
name="multi_partitioned_2",
description="Multiple levels of partitioning, with boolean, timestamp, and decimal partition columns"
name='multi_partitioned_2',
description=('Multiple levels of partitioning, with boolean, timestamp, and '
'decimal partition columns')
)
def create_multi_partitioned_2(case: TestCaseInfo, spark: SparkSession):
columns = ['bool', 'time', 'amount', 'int']
partition_columns = ['bool', 'time', 'amount']
data = [
(True, datetime(1970, 1, 1), Decimal("200.00"), 1),
(True, datetime(1970, 1, 1, 12, 30), Decimal("200.00"), 2),
(False, datetime(1970, 1, 2, 8, 45), Decimal("12.00"), 3)
(True, datetime(1970, 1, 1), Decimal('200.00'), 1),
(True, datetime(1970, 1, 1, 12, 30), Decimal('200.00'), 2),
(False, datetime(1970, 1, 2, 8, 45), Decimal('12.00'), 3)
]
df = spark.createDataFrame(data, schema=columns)
df.repartition(1).write.format('delta').partitionBy(
Expand All @@ -194,24 +195,25 @@ def with_schema_change(case: TestCaseInfo, spark: SparkSession):
case.delta_root)
save_expected(case)


@reference_table(
name='all_primitive_types',
description='Table containing all non-nested types',
)
def create_all_primitive_types(case: TestCaseInfo, spark: SparkSession):
schema = types.StructType([
types.StructField("utf8", types.StringType()),
types.StructField("int64", types.LongType()),
types.StructField("int32", types.IntegerType()),
types.StructField("int16", types.ShortType()),
types.StructField("int8", types.ByteType()),
types.StructField("float32", types.FloatType()),
types.StructField("float64", types.DoubleType()),
types.StructField("bool", types.BooleanType()),
types.StructField("binary", types.BinaryType()),
types.StructField("decimal", types.DecimalType(5, 3)),
types.StructField("date32", types.DateType()),
types.StructField("timestamp", types.TimestampType()),
types.StructField('utf8', types.StringType()),
types.StructField('int64', types.LongType()),
types.StructField('int32', types.IntegerType()),
types.StructField('int16', types.ShortType()),
types.StructField('int8', types.ByteType()),
types.StructField('float32', types.FloatType()),
types.StructField('float64', types.DoubleType()),
types.StructField('bool', types.BooleanType()),
types.StructField('binary', types.BinaryType()),
types.StructField('decimal', types.DecimalType(5, 3)),
types.StructField('date32', types.DateType()),
types.StructField('timestamp', types.TimestampType()),
])

df = spark.createDataFrame([
Expand All @@ -225,7 +227,7 @@ def create_all_primitive_types(case: TestCaseInfo, spark: SparkSession):
float(i),
i % 2 == 0,
bytes(i),
Decimal("10.000") + i,
Decimal('10.000') + i,
date(1970, 1, 1) + timedelta(days=i),
datetime(1970, 1, 1) + timedelta(hours=i)
)
Expand All @@ -240,63 +242,140 @@ def create_all_primitive_types(case: TestCaseInfo, spark: SparkSession):
description='Table containing various nested types',
)
def create_nested_types(case: TestCaseInfo, spark: SparkSession):
schema = types.StructType([
types.StructField("struct", types.StructType([
types.StructField("float64", types.DoubleType()),
types.StructField("bool", types.BooleanType()),
])),
types.StructField("array", types.ArrayType(types.ShortType())),
types.StructField("map", types.MapType(types.StringType(), types.IntegerType())),
])
schema = types.StructType([types.StructField(
'struct', types.StructType(
[types.StructField(
'float64', types.DoubleType()),
types.StructField(
'bool', types.BooleanType()), ])),
types.StructField(
'array', types.ArrayType(
types.ShortType())),
types.StructField(
'map', types.MapType(
types.StringType(),
types.IntegerType())), ])

df = spark.createDataFrame([
(
{ "float64": float(i), "bool": i % 2 == 0 },
{'float64': float(i), 'bool': i % 2 == 0},
list(range(i + 1)),
{ str(i): i for i in range(i) }
{str(i): i for i in range(i)}
)
for i in range(5)
], schema=schema)

df.repartition(1).write.format('delta').save(case.delta_root)


def get_sample_data(spark: SparkSession, seed: int=42, nrows: int=5) -> pyspark.sql.DataFrame:
def get_sample_data(
spark: SparkSession, seed: int = 42, nrows: int = 5) -> pyspark.sql.DataFrame:
# Use seed to get consistent data between runs, for reproducibility
random.seed(seed)
return spark.createDataFrame([
(
random.choice(["a", "b", "c", None]),
random.choice(['a', 'b', 'c', None]),
random.randint(0, 1000),
date(random.randint(1970, 2020), random.randint(1, 12), 1)
)
for i in range(nrows)
], schema=["letter", "int", "date"])
], schema=['letter', 'int', 'date'])


@reference_table(
name='with_checkpoint',
description='Table with a checkpoint',
)
def create_with_checkpoint(case: TestCaseInfo, spark: SparkSession):
spark.conf.set("spark.databricks.delta.retentionDurationCheck.enabled", "false")
df = get_sample_data(spark)

(DeltaTable.create(spark)
.location(str(Path(case.delta_root).absolute()))
.addColumns(df.schema)
.property('delta.checkpointInterval', '2')
.execute())

for i in range(3):
df = get_sample_data(spark, seed=i, nrows=5)
df.repartition(1).write.format('delta').mode(
'overwrite').save(case.delta_root)

assert any(path.suffixes == ['.checkpoint', '.parquet']
for path in (Path(case.delta_root) / '_delta_log').iterdir())


def remove_log_file(delta_root: str, version: int):
os.remove(os.path.join(delta_root, '_delta_log', f'{version:0>20}.json'))


@reference_table(
name='no_replay',
description='Table with a checkpoint and prior commits cleaned up',
)
def create_no_replay(case: TestCaseInfo, spark: SparkSession):
spark.conf.set(
'spark.databricks.delta.retentionDurationCheck.enabled', 'false')

df = get_sample_data(spark)

table = (DeltaTable.create(spark)
.location(str(Path(case.delta_root).absolute()))
.addColumns(df.schema)
.property('delta.checkpointInterval', '2')
.execute())

for i in range(3):
df = get_sample_data(spark, seed=i, nrows=5)
df.repartition(1).write.format('delta').mode(
'overwrite').save(case.delta_root)

table.vacuum(retentionHours=0)

remove_log_file(case.delta_root, version=0)
remove_log_file(case.delta_root, version=1)

files_in_log = list((Path(case.delta_root) / '_delta_log').iterdir())
assert any(path.suffixes == ['.checkpoint', '.parquet']
for path in files_in_log)
assert not any(path.name == f'{0:0>20}.json' for path in files_in_log)


@reference_table(
name='stats_as_struct',
description='Table with stats only written as struct (not JSON) with Checkpoint',
)
def create_stats_as_struct(case: TestCaseInfo, spark: SparkSession):
df = get_sample_data(spark)
(DeltaTable.create(spark)
.location(str(Path(case.delta_root).absolute()))
.addColumns(df.schema)
.property("delta.checkpointInterval", "2")
.property("delta.logRetentionDuration", "0 days")
.property('delta.checkpointInterval', '2')
.property('delta.checkpoint.writeStatsAsStruct', 'true')
.property('delta.checkpoint.writeStatsAsJson', 'false')
.execute())

for i in range(5):
for i in range(3):
df = get_sample_data(spark, seed=i, nrows=5)
df.repartition(1).write.format('delta').mode('overwrite').save(case.delta_root)

assert any(path.suffixes == [".checkpoint", ".parquet"]
for path in (Path(case.delta_root) / "_delta_log").iterdir())
df.repartition(1).write.format('delta').mode(
'overwrite').save(case.delta_root)

table.vacuum(retentionHours=0)

@reference_table(
name='no_stats',
description='Table with no stats',
)
def create_no_stats(case: TestCaseInfo, spark: SparkSession):
df = get_sample_data(spark)
(DeltaTable.create(spark)
.location(str(Path(case.delta_root).absolute()))
.addColumns(df.schema)
.property('delta.checkpointInterval', '2')
.property('delta.checkpoint.writeStatsAsStruct', 'false')
.property('delta.checkpoint.writeStatsAsJson', 'false')
.property('delta.dataSkippingNumIndexedCols', '0')
.execute())

for i in range(3):
df = get_sample_data(spark, seed=i, nrows=5)
df.repartition(1).write.format('delta').mode(
'overwrite').save(case.delta_root)
3 changes: 2 additions & 1 deletion dat/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,8 @@ def write_generated_reference_tables(table_name: Optional[str]):
create_table()
break
else:
raise ValueError(f"Could not find generated table named '{table_name}'")
raise ValueError(
f"Could not find generated table named '{table_name}'")
else:
out_base = Path('out/reader_tests/generated')
shutil.rmtree(out_base)
Expand Down
3 changes: 2 additions & 1 deletion setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -9,4 +9,5 @@ per-file-ignores =
# WPS202 Found too many module members
tests/*: S101 WPS114 WPS226 WPS202
dat/external_tables.py: WPS226 WPS114
dat/generated_tables.py: WPS226 WPS114
dat/generated_tables.py: WPS226 WPS114
max-line-length = 90

0 comments on commit 1876315

Please sign in to comment.