Skip to content

Commit

Permalink
Use ParquetWriter rather than Dataset.writedataset
Browse files Browse the repository at this point in the history
The pyarrow write dataset function has shown to consume a lot of memory
while also taking longer to execute. It has features we use around
partitioning and file visitor functions, but those features can be
replicated without too much difficulty given the way we use them.
  • Loading branch information
mzappitello committed Jan 4, 2024
1 parent 6ba05d5 commit f145166
Showing 1 changed file with 34 additions and 21 deletions.
55 changes: 34 additions & 21 deletions python_src/src/lamp_py/aws/s3.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
import boto3
import pandas
import pyarrow
import pyarrow.dataset as ds
import pyarrow.compute as pc
import pyarrow.parquet as pq
from pyarrow import Table, fs
from pyarrow.util import guid
Expand Down Expand Up @@ -336,7 +336,8 @@ def write_parquet_file(
@s3_dir - the s3 bucket plus prefix "subdirectory" path where the
parquet files should be written
@partition_cols - column names in the table to partition out into the
filepath.
filepath. NOTE: the assumption is that the values in the partition
columns of the incoming table are uniform.
@visitor_func - if set, this function will be called with a WrittenFile
instance for each file created during the call. a WrittenFile has
path and metadata attributes.
Expand All @@ -349,27 +350,39 @@ def write_parquet_file(
)
process_logger.log_start()

# generate partitioning for this table write based on what columns
# we expect to be able to partition out for this input type
partitioning = ds.partitioning(
table.select(partition_cols).schema, flavor="hive"
)
try:
# pull out the partition information into a list of strings.
partition_strings = []
for col in partition_cols:
unique_list = pc.unique(table.column(col)).to_pylist()
print(unique_list)
assert len(unique_list) == 1
partition_strings.append(f"{col}={unique_list[0]}")

table = table.drop(partition_cols)

# generate an s3 path to write this file to
if basename_template is None:
basename_template = guid() + "-{i}.parquet"
write_path = os.path.join(
s3_dir, *partition_strings, basename_template.format(i=0)
)

if basename_template is None:
basename_template = guid() + "-{i}.parquet"

ds.write_dataset(
data=table,
base_dir=s3_dir,
filesystem=fs.S3FileSystem(),
format=ds.ParquetFileFormat(),
partitioning=partitioning,
file_visitor=visitor_func,
basename_template=basename_template,
existing_data_behavior="overwrite_or_ignore",
)
process_logger.add_metadata(write_path=write_path)

process_logger.log_complete()
# write teh parquet file to the partitioned path
with pq.ParquetWriter(
where=write_path, schema=table.schema, filesystem=fs.S3FileSystem()
) as pq_writer:
pq_writer.write(table)

# call the visitor function if it exists
if visitor_func is not None:
visitor_func(write_path)

process_logger.log_complete()
except Exception as exception:
process_logger.log_failure(exception)


# pylint: enable=R0913
Expand Down

0 comments on commit f145166

Please sign in to comment.