Skip to content

Commit

Permalink
Use Arrow stream interface for public API (#69)
Browse files Browse the repository at this point in the history
* Use Arrow stream interface for public API
  • Loading branch information
kylebarron authored Jun 25, 2024
1 parent ab2701f commit 7df15b3
Show file tree
Hide file tree
Showing 4 changed files with 105 additions and 47 deletions.
82 changes: 63 additions & 19 deletions stac_geoparquet/arrow/_api.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,17 @@
from __future__ import annotations

import itertools
import os
from pathlib import Path
from typing import Any, Iterable, Iterator
from typing import Any, Iterable

import pyarrow as pa

from stac_geoparquet.arrow._batch import StacArrowBatch, StacJsonBatch
from stac_geoparquet.arrow._constants import DEFAULT_JSON_CHUNK_SIZE
from stac_geoparquet.arrow._schema.models import InferredSchema
from stac_geoparquet.arrow._util import batched_iter
from stac_geoparquet.arrow.types import ArrowStreamExportable
from stac_geoparquet.json_reader import read_json_chunked


Expand All @@ -18,7 +20,7 @@ def parse_stac_items_to_arrow(
*,
chunk_size: int = 8192,
schema: pa.Schema | InferredSchema | None = None,
) -> Iterable[pa.RecordBatch]:
) -> pa.RecordBatchReader:
"""
Parse a collection of STAC Items to an iterable of
[`pyarrow.RecordBatch`][pyarrow.RecordBatch].
Expand All @@ -37,23 +39,27 @@ def parse_stac_items_to_arrow(
inference. Defaults to None.
Returns:
an iterable of pyarrow RecordBatches with the STAC-GeoParquet representation of items.
pyarrow RecordBatchReader with a stream of STAC Arrow RecordBatches.
"""
if schema is not None:
if isinstance(schema, InferredSchema):
schema = schema.inner

# If schema is provided, then for better memory usage we parse input STAC items
# to Arrow batches in chunks.
for chunk in batched_iter(items, chunk_size):
yield stac_items_to_arrow(chunk, schema=schema)
batches = (
stac_items_to_arrow(batch, schema=schema)
for batch in batched_iter(items, chunk_size)
)
return pa.RecordBatchReader.from_batches(schema, batches)

else:
# If schema is _not_ provided, then we must convert to Arrow all at once, or
# else it would be possible for a STAC item late in the collection (after the
# first chunk) to have a different schema and not match the schema inferred for
# the first chunk.
yield stac_items_to_arrow(items)
batch = stac_items_to_arrow(items)
return pa.RecordBatchReader.from_batches(batch.schema, [batch])


def parse_stac_ndjson_to_arrow(
Expand All @@ -62,7 +68,7 @@ def parse_stac_ndjson_to_arrow(
chunk_size: int = DEFAULT_JSON_CHUNK_SIZE,
schema: pa.Schema | None = None,
limit: int | None = None,
) -> Iterator[pa.RecordBatch]:
) -> pa.RecordBatchReader:
"""
Convert one or more newline-delimited JSON STAC files to a generator of Arrow
RecordBatches.
Expand All @@ -81,39 +87,77 @@ def parse_stac_ndjson_to_arrow(
Keyword Args:
limit: The maximum number of JSON Items to use for schema inference
Yields:
Arrow RecordBatch with a single chunk of Item data.
Returns:
pyarrow RecordBatchReader with a stream of STAC Arrow RecordBatches.
"""
# If the schema was not provided, then we need to load all data into memory at once
# to perform schema resolution.
if schema is None:
inferred_schema = InferredSchema()
inferred_schema.update_from_json(path, chunk_size=chunk_size, limit=limit)
inferred_schema.manual_updates()
yield from parse_stac_ndjson_to_arrow(
return parse_stac_ndjson_to_arrow(
path, chunk_size=chunk_size, schema=inferred_schema
)
return

if isinstance(schema, InferredSchema):
schema = schema.inner

for batch in read_json_chunked(path, chunk_size=chunk_size):
yield stac_items_to_arrow(batch, schema=schema)
batches_iter = (
stac_items_to_arrow(batch, schema=schema)
for batch in read_json_chunked(path, chunk_size=chunk_size)
)
first_batch = next(batches_iter)
# Need to take this schema from the iterator; the existing `schema` is the schema of
# JSON batch
resolved_schema = first_batch.schema
return pa.RecordBatchReader.from_batches(
resolved_schema, itertools.chain([first_batch], batches_iter)
)


def stac_table_to_items(
table: pa.Table | pa.RecordBatchReader | ArrowStreamExportable,
) -> Iterable[dict]:
"""Convert STAC Arrow to a generator of STAC Item `dict`s.
Args:
table: STAC in Arrow form. This can be a pyarrow Table, a pyarrow
RecordBatchReader, or any other Arrow stream object exposed through the
[Arrow PyCapsule
Interface](https://arrow.apache.org/docs/format/CDataInterface/PyCapsuleInterface.html).
A RecordBatchReader or stream object will not be materialized in memory.
Yields:
A STAC `dict` for each input row.
"""
# Coerce to record batch reader to avoid materializing entire stream
reader = pa.RecordBatchReader.from_stream(table)

def stac_table_to_items(table: pa.Table) -> Iterable[dict]:
"""Convert a STAC Table to a generator of STAC Item `dict`s"""
for batch in table.to_batches():
for batch in reader:
clean_batch = StacArrowBatch(batch)
yield from clean_batch.to_json_batch().iter_dicts()


def stac_table_to_ndjson(
table: pa.Table, dest: str | Path | os.PathLike[bytes]
table: pa.Table | pa.RecordBatchReader | ArrowStreamExportable,
dest: str | Path | os.PathLike[bytes],
) -> None:
"""Write a STAC Table to a newline-delimited JSON file."""
for batch in table.to_batches():
"""Write STAC Arrow to a newline-delimited JSON file.
Args:
table: STAC in Arrow form. This can be a pyarrow Table, a pyarrow
RecordBatchReader, or any other Arrow stream object exposed through the
[Arrow PyCapsule
Interface](https://arrow.apache.org/docs/format/CDataInterface/PyCapsuleInterface.html).
A RecordBatchReader or stream object will not be materialized in memory.
dest: The destination where newline-delimited JSON should be written.
"""

# Coerce to record batch reader to avoid materializing entire stream
reader = pa.RecordBatchReader.from_stream(table)

for batch in reader:
clean_batch = StacArrowBatch(batch)
clean_batch.to_json_batch().to_ndjson(dest)

Expand Down
13 changes: 6 additions & 7 deletions stac_geoparquet/arrow/_delta_lake.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from __future__ import annotations

import itertools
from pathlib import Path
from typing import TYPE_CHECKING, Any, Iterable

Expand Down Expand Up @@ -47,14 +46,14 @@ def parse_stac_ndjson_to_delta_lake(
schema_version: GeoParquet specification version; if not provided will default
to latest supported version.
"""
batches_iter = parse_stac_ndjson_to_arrow(
record_batch_reader = parse_stac_ndjson_to_arrow(
input_path, chunk_size=chunk_size, schema=schema, limit=limit
)
first_batch = next(batches_iter)
schema = first_batch.schema.with_metadata(
schema = record_batch_reader.schema.with_metadata(
create_geoparquet_metadata(
pa.Table.from_batches([first_batch]), schema_version=schema_version
record_batch_reader.schema, schema_version=schema_version
)
)
combined_iter = itertools.chain([first_batch], batches_iter)
write_deltalake(table_or_uri, combined_iter, schema=schema, engine="rust", **kwargs)
write_deltalake(
table_or_uri, record_batch_reader, schema=schema, engine="rust", **kwargs
)
52 changes: 31 additions & 21 deletions stac_geoparquet/arrow/_to_parquet.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
)
from stac_geoparquet.arrow._crs import WGS84_CRS_JSON
from stac_geoparquet.arrow._schema.models import InferredSchema
from stac_geoparquet.arrow.types import ArrowStreamExportable


def parse_stac_ndjson_to_parquet(
Expand Down Expand Up @@ -43,26 +44,24 @@ def parse_stac_ndjson_to_parquet(
limit: The maximum number of JSON records to convert.
schema_version: GeoParquet specification version; if not provided will default
to latest supported version.
"""
batches_iter = parse_stac_ndjson_to_arrow(
All other keyword args are passed on to
[`pyarrow.parquet.ParquetWriter`][pyarrow.parquet.ParquetWriter].
"""
record_batch_reader = parse_stac_ndjson_to_arrow(
input_path, chunk_size=chunk_size, schema=schema, limit=limit
)
first_batch = next(batches_iter)
schema = first_batch.schema.with_metadata(
create_geoparquet_metadata(
pa.Table.from_batches([first_batch]), schema_version=schema_version
)
to_parquet(
record_batch_reader,
output_path=output_path,
schema_version=schema_version,
**kwargs,
)
with pq.ParquetWriter(output_path, schema, **kwargs) as writer:
writer.write_batch(first_batch)
for batch in batches_iter:
writer.write_batch(batch)


def to_parquet(
table: pa.Table,
where: Any,
table: pa.Table | pa.RecordBatchReader | ArrowStreamExportable,
output_path: str | Path,
*,
schema_version: SUPPORTED_PARQUET_SCHEMA_VERSIONS = DEFAULT_PARQUET_SCHEMA_VERSION,
**kwargs: Any,
Expand All @@ -72,22 +71,33 @@ def to_parquet(
This writes metadata compliant with either GeoParquet 1.0 or 1.1.
Args:
table: The table to write to Parquet
where: The destination for saving.
table: STAC in Arrow form. This can be a pyarrow Table, a pyarrow
RecordBatchReader, or any other Arrow stream object exposed through the
[Arrow PyCapsule
Interface](https://arrow.apache.org/docs/format/CDataInterface/PyCapsuleInterface.html).
A RecordBatchReader or stream object will not be materialized in memory.
output_path: The destination for saving.
Keyword Args:
schema_version: GeoParquet specification version; if not provided will default
to latest supported version.
All other keyword args are passed on to
[`pyarrow.parquet.ParquetWriter`][pyarrow.parquet.ParquetWriter].
"""
metadata = table.schema.metadata or {}
metadata.update(create_geoparquet_metadata(table, schema_version=schema_version))
table = table.replace_schema_metadata(metadata)
# Coerce to record batch reader to avoid materializing entire stream
reader = pa.RecordBatchReader.from_stream(table)

pq.write_table(table, where, **kwargs)
schema = reader.schema.with_metadata(
create_geoparquet_metadata(reader.schema, schema_version=schema_version)
)
with pq.ParquetWriter(output_path, schema, **kwargs) as writer:
for batch in reader:
writer.write_batch(batch)


def create_geoparquet_metadata(
table: pa.Table,
schema: pa.Schema,
*,
schema_version: SUPPORTED_PARQUET_SCHEMA_VERSIONS,
) -> dict[bytes, bytes]:
Expand Down Expand Up @@ -116,7 +126,7 @@ def create_geoparquet_metadata(
"primary_column": "geometry",
}

if "proj:geometry" in table.schema.names:
if "proj:geometry" in schema.names:
# Note we don't include proj:bbox as a covering here for a couple different
# reasons. For one, it's very common for the projected geometries to have a
# different CRS in each row, so having statistics for proj:bbox wouldn't be
Expand Down
5 changes: 5 additions & 0 deletions stac_geoparquet/arrow/types.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
from typing import Protocol


class ArrowStreamExportable(Protocol):
def __arrow_c_stream__(self, requested_schema: object | None = None) -> object: ... # noqa

0 comments on commit 7df15b3

Please sign in to comment.