Skip to content

Commit

Permalink
Refactor interally to RawBatch and CleanBatch wrapper types (#57)
Browse files Browse the repository at this point in the history
* Move json equality logic outside of test_arrow

* Refactor to RawBatch and CleanBatch wrapper types

* Move _from_arrow functions to _api

* Update imports

* fix circular import

* keep deprecated api

* Add write-read test and fix typing

* add parquet tests

* fix ci

* Rename wrapper types
  • Loading branch information
kylebarron authored Jun 4, 2024
1 parent c968f15 commit cf0699b
Show file tree
Hide file tree
Showing 14 changed files with 833 additions and 669 deletions.
10 changes: 7 additions & 3 deletions stac_geoparquet/arrow/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
from ._from_arrow import stac_table_to_items, stac_table_to_ndjson
from ._to_arrow import parse_stac_items_to_arrow, parse_stac_ndjson_to_arrow
from ._to_parquet import to_parquet
from ._api import (
parse_stac_items_to_arrow,
parse_stac_ndjson_to_arrow,
stac_table_to_items,
stac_table_to_ndjson,
)
from ._to_parquet import parse_stac_ndjson_to_parquet, to_parquet
136 changes: 136 additions & 0 deletions stac_geoparquet/arrow/_api.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,136 @@
import os
from pathlib import Path
from typing import Any, Dict, Iterable, Iterator, Optional, Union

import pyarrow as pa

from stac_geoparquet.arrow._batch import StacArrowBatch, StacJsonBatch
from stac_geoparquet.arrow._schema.models import InferredSchema
from stac_geoparquet.arrow._util import batched_iter
from stac_geoparquet.json_reader import read_json_chunked


def parse_stac_items_to_arrow(
items: Iterable[Dict[str, Any]],
*,
chunk_size: int = 8192,
schema: Optional[Union[pa.Schema, InferredSchema]] = None,
) -> Iterable[pa.RecordBatch]:
"""Parse a collection of STAC Items to an iterable of :class:`pyarrow.RecordBatch`.
The objects under `properties` are moved up to the top-level of the
Table, similar to :meth:`geopandas.GeoDataFrame.from_features`.
Args:
items: the STAC Items to convert
chunk_size: The chunk size to use for Arrow record batches. This only takes
effect if `schema` is not None. When `schema` is None, the input will be
parsed into a single contiguous record batch. Defaults to 8192.
schema: The schema of the input data. If provided, can improve memory use;
otherwise all items need to be parsed into a single array for schema
inference. Defaults to None.
Returns:
an iterable of pyarrow RecordBatches with the STAC-GeoParquet representation of items.
"""
if schema is not None:
if isinstance(schema, InferredSchema):
schema = schema.inner

# If schema is provided, then for better memory usage we parse input STAC items
# to Arrow batches in chunks.
for chunk in batched_iter(items, chunk_size):
yield stac_items_to_arrow(chunk, schema=schema)

else:
# If schema is _not_ provided, then we must convert to Arrow all at once, or
# else it would be possible for a STAC item late in the collection (after the
# first chunk) to have a different schema and not match the schema inferred for
# the first chunk.
yield stac_items_to_arrow(items)


def parse_stac_ndjson_to_arrow(
path: Union[str, Path, Iterable[Union[str, Path]]],
*,
chunk_size: int = 65536,
schema: Optional[pa.Schema] = None,
limit: Optional[int] = None,
) -> Iterator[pa.RecordBatch]:
"""
Convert one or more newline-delimited JSON STAC files to a generator of Arrow
RecordBatches.
Each RecordBatch in the returned iterator is guaranteed to have an identical schema,
and can be used to write to one or more Parquet files.
Args:
path: One or more paths to files with STAC items.
chunk_size: The chunk size. Defaults to 65536.
schema: The schema to represent the input STAC data. Defaults to None, in which
case the schema will first be inferred via a full pass over the input data.
In this case, there will be two full passes over the input data: one to
infer a common schema across all data and another to read the data.
Other args:
limit: The maximum number of JSON Items to use for schema inference
Yields:
Arrow RecordBatch with a single chunk of Item data.
"""
# If the schema was not provided, then we need to load all data into memory at once
# to perform schema resolution.
if schema is None:
inferred_schema = InferredSchema()
inferred_schema.update_from_json(path, chunk_size=chunk_size, limit=limit)
yield from parse_stac_ndjson_to_arrow(
path, chunk_size=chunk_size, schema=inferred_schema
)
return

if isinstance(schema, InferredSchema):
schema = schema.inner

for batch in read_json_chunked(path, chunk_size=chunk_size):
yield stac_items_to_arrow(batch, schema=schema)


def stac_table_to_items(table: pa.Table) -> Iterable[dict]:
"""Convert a STAC Table to a generator of STAC Item `dict`s"""
for batch in table.to_batches():
clean_batch = StacArrowBatch(batch)
yield from clean_batch.to_raw_batch().iter_dicts()


def stac_table_to_ndjson(
table: pa.Table, dest: Union[str, Path, os.PathLike[bytes]]
) -> None:
"""Write a STAC Table to a newline-delimited JSON file."""
for batch in table.to_batches():
clean_batch = StacArrowBatch(batch)
clean_batch.to_raw_batch().to_ndjson(dest)


def stac_items_to_arrow(
items: Iterable[Dict[str, Any]], *, schema: Optional[pa.Schema] = None
) -> pa.RecordBatch:
"""Convert dicts representing STAC Items to Arrow
This converts GeoJSON geometries to WKB before Arrow conversion to allow multiple
geometry types.
All items will be parsed into a single RecordBatch, meaning that each internal array
is fully contiguous in memory for the length of `items`.
Args:
items: STAC Items to convert to Arrow
Kwargs:
schema: An optional schema that describes the format of the data. Note that this
must represent the geometry column as binary type.
Returns:
Arrow RecordBatch with items in Arrow
"""
raw_batch = StacJsonBatch.from_dicts(items, schema=schema)
return raw_batch.to_clean_batch().inner
190 changes: 190 additions & 0 deletions stac_geoparquet/arrow/_batch.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,190 @@
from __future__ import annotations

import os
from copy import deepcopy
from pathlib import Path
from typing import Any, Iterable

import numpy as np
import orjson
import pyarrow as pa
import pyarrow.compute as pc
import shapely
import shapely.geometry
from numpy.typing import NDArray
from typing_extensions import Self

from stac_geoparquet.arrow._from_arrow import (
convert_bbox_to_array,
convert_timestamp_columns_to_string,
lower_properties_from_top_level,
)
from stac_geoparquet.arrow._to_arrow import (
assign_geoarrow_metadata,
bring_properties_to_top_level,
convert_bbox_to_struct,
convert_timestamp_columns,
)
from stac_geoparquet.arrow._util import convert_tuples_to_lists, set_by_path


class StacJsonBatch:
"""
An Arrow RecordBatch of STAC Items that has been **minimally converted** to Arrow.
That is, it aligns as much as possible to the raw STAC JSON representation.
The **only** transformations that have already been applied here are those that are
necessary to represent the core STAC items in Arrow.
- `geometry` has been converted to WKB binary
- `properties.proj:geometry`, if it exists, has been converted to WKB binary
ISO encoding
- The `proj:geometry` in any asset properties, if it exists, has been converted to
WKB binary.
No other transformations have yet been applied. I.e. all properties are still in a
top-level `properties` struct column.
"""

inner: pa.RecordBatch
"""The underlying pyarrow RecordBatch"""

def __init__(self, batch: pa.RecordBatch) -> None:
self.inner = batch

@classmethod
def from_dicts(
cls, items: Iterable[dict[str, Any]], *, schema: pa.Schema | None = None
) -> Self:
"""Construct a StacJsonBatch from an iterable of dicts representing STAC items.
All items will be parsed into a single RecordBatch, meaning that each internal
array is fully contiguous in memory for the length of `items`.
Args:
items: STAC Items to convert to Arrow
Kwargs:
schema: An optional schema that describes the format of the data. Note that
this must represent the geometry column and any `proj:geometry` columns
as binary type.
Returns:
a new StacJsonBatch of data.
"""
# Preprocess GeoJSON to WKB in each STAC item
# Otherwise, pyarrow will try to parse coordinates into a native geometry type
# and if you have multiple geometry types pyarrow will error with
# `ArrowInvalid: cannot mix list and non-list, non-null values`
wkb_items = []
for item in items:
wkb_item = deepcopy(item)
wkb_item["geometry"] = shapely.to_wkb(
shapely.geometry.shape(wkb_item["geometry"]), flavor="iso"
)

# If a proj:geometry key exists in top-level properties, convert that to WKB
if "proj:geometry" in wkb_item["properties"]:
wkb_item["properties"]["proj:geometry"] = shapely.to_wkb(
shapely.geometry.shape(wkb_item["properties"]["proj:geometry"]),
flavor="iso",
)

# If a proj:geometry key exists in any asset properties, convert that to WKB
for asset_value in wkb_item["assets"].values():
if "proj:geometry" in asset_value:
asset_value["proj:geometry"] = shapely.to_wkb(
shapely.geometry.shape(asset_value["proj:geometry"]),
flavor="iso",
)

wkb_items.append(wkb_item)

if schema is not None:
array = pa.array(wkb_items, type=pa.struct(schema))
else:
array = pa.array(wkb_items)

return cls(pa.RecordBatch.from_struct_array(array))

def iter_dicts(self) -> Iterable[dict]:
batch = self.inner

# Find all paths in the schema that have a WKB geometry
geometry_paths = [["geometry"]]
try:
batch.schema.field("properties").type.field("proj:geometry")
geometry_paths.append(["properties", "proj:geometry"])
except KeyError:
pass

assets_struct = batch.schema.field("assets").type
for asset_idx in range(assets_struct.num_fields):
asset_field = assets_struct.field(asset_idx)
if "proj:geometry" in pa.schema(asset_field).names:
geometry_paths.append(["assets", asset_field.name, "proj:geometry"])

# Convert each geometry column to a Shapely geometry, and then assign the
# geojson geometry when converting each row to a dictionary.
geometries: list[NDArray[np.object_]] = []
for geometry_path in geometry_paths:
col = batch
for path_segment in geometry_path:
if isinstance(col, pa.RecordBatch):
col = col[path_segment]
elif pa.types.is_struct(col.type):
col = pc.struct_field(col, path_segment) # type: ignore
else:
raise AssertionError(f"unexpected type {type(col)}")

geometries.append(shapely.from_wkb(col))

struct_batch = batch.to_struct_array()
for row_idx in range(len(struct_batch)):
row_dict = struct_batch[row_idx].as_py()
for geometry_path, geometry_column in zip(geometry_paths, geometries):
geojson_g = geometry_column[row_idx].__geo_interface__
geojson_g["coordinates"] = convert_tuples_to_lists(
geojson_g["coordinates"]
)
set_by_path(row_dict, geometry_path, geojson_g)

yield row_dict

def to_clean_batch(self) -> StacArrowBatch:
batch = self.inner

batch = bring_properties_to_top_level(batch)
batch = convert_timestamp_columns(batch)
batch = convert_bbox_to_struct(batch)
batch = assign_geoarrow_metadata(batch)

return StacArrowBatch(batch)

def to_ndjson(self, dest: str | Path | os.PathLike[bytes]) -> None:
with open(dest, "ab") as f:
for item_dict in self.iter_dicts():
f.write(orjson.dumps(item_dict))
f.write(b"\n")


class StacArrowBatch:
"""
An Arrow RecordBatch of STAC Items that has been processed to match the
STAC-GeoParquet specification.
"""

inner: pa.RecordBatch
"""The underlying pyarrow RecordBatch"""

def __init__(self, batch: pa.RecordBatch) -> None:
self.inner = batch

def to_raw_batch(self) -> StacJsonBatch:
batch = self.inner

batch = convert_timestamp_columns_to_string(batch)
batch = lower_properties_from_top_level(batch)
batch = convert_bbox_to_array(batch)

return StacJsonBatch(batch)
Loading

0 comments on commit cf0699b

Please sign in to comment.