Skip to content

Commit

Permalink
User iterables / and record batches in to_arrow (#53)
Browse files Browse the repository at this point in the history
* add json reader that can accomodate json and ndjson, user iterable of items and record batches for converting to arrow

* only use batch, not union

* remove downcast parameter

* pass along schema

* Fix reading json

* restore inferredSchema class

* simplify chunked json reading

* Fixed schema inference

* handle iterable output

* fix error in refactoring

* Add schema inference with limit

---------

Co-authored-by: Kyle Barron <kyle@developmentseed.org>
  • Loading branch information
bitner and kylebarron authored May 22, 2024
1 parent 2eaf22d commit 22abad6
Show file tree
Hide file tree
Showing 9 changed files with 446 additions and 457 deletions.
3 changes: 2 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,12 @@ Note that `stac_geoparquet` lifts the keys in the item `properties` up to the to
>>> import requests
>>> import stac_geoparquet.arrow
>>> import pyarrow.parquet
>>> import pyarrow as pa

>>> items = requests.get(
... "https://planetarycomputer.microsoft.com/api/stac/v1/collections/sentinel-2-l2a/items"
... ).json()["features"]
>>> table = stac_geoparquet.arrow.parse_stac_items_to_arrow(items)
>>> table = pa.Table.from_batches(stac_geoparquet.arrow.parse_stac_items_to_arrow(items))
>>> stac_geoparquet.arrow.to_parquet(table, "items.parquet")
>>> table2 = pyarrow.parquet.read_table("items.parquet")
>>> items2 = list(stac_geoparquet.arrow.stac_table_to_items(table2))
Expand Down
4 changes: 3 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,12 @@ dependencies = [
"geopandas",
"packaging",
"pandas",
"pyarrow",
# Needed for RecordBatch.append_column
"pyarrow>=16",
"pyproj",
"pystac",
"shapely",
"orjson",
]

[tool.hatch.version]
Expand Down
225 changes: 111 additions & 114 deletions stac_geoparquet/arrow/_from_arrow.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
"""Convert STAC Items in Arrow Table format to JSON Lines or Python dicts."""

import json
import orjson
import operator
import os
from functools import reduce
Expand All @@ -14,74 +14,76 @@
import shapely.geometry


def stac_table_to_ndjson(table: pa.Table, dest: Union[str, os.PathLike[str]]) -> None:
"""Write a STAC Table to a newline-delimited JSON file."""
with open(dest, "w") as f:
for item_dict in stac_table_to_items(table):
json.dump(item_dict, f, separators=(",", ":"))
f.write("\n")


def stac_table_to_items(table: pa.Table) -> Iterable[dict]:
"""Convert a STAC Table to a generator of STAC Item `dict`s"""
table = _undo_stac_table_transformations(table)

def stac_batch_to_items(batch: pa.RecordBatch) -> Iterable[dict]:
"""Convert a stac arrow recordbatch to item dicts."""
batch = _undo_stac_transformations(batch)
# Find all paths in the schema that have a WKB geometry
geometry_paths = [["geometry"]]
try:
table.schema.field("properties").type.field("proj:geometry")
batch.schema.field("properties").type.field("proj:geometry")
geometry_paths.append(["properties", "proj:geometry"])
except KeyError:
pass

assets_struct = table.schema.field("assets").type
assets_struct = batch.schema.field("assets").type
for asset_idx in range(assets_struct.num_fields):
asset_field = assets_struct.field(asset_idx)
if "proj:geometry" in pa.schema(asset_field).names:
geometry_paths.append(["assets", asset_field.name, "proj:geometry"])

# Convert each geometry column to a Shapely geometry, and then assign the
# geojson geometry when converting each row to a dictionary.
geometries: List[NDArray[np.object_]] = []
for geometry_path in geometry_paths:
col = batch
for path_segment in geometry_path:
if isinstance(col, pa.RecordBatch):
col = col[path_segment]
elif pa.types.is_struct(col.type):
col = pc.struct_field(col, path_segment)
else:
raise AssertionError(f"unexpected type {type(col)}")

geometries.append(shapely.from_wkb(col))

struct_batch = batch.to_struct_array()
for row_idx in range(len(struct_batch)):
row_dict = struct_batch[row_idx].as_py()
for geometry_path, geometry_column in zip(geometry_paths, geometries):
geojson_g = geometry_column[row_idx].__geo_interface__
geojson_g["coordinates"] = convert_tuples_to_lists(geojson_g["coordinates"])
set_by_path(row_dict, geometry_path, geojson_g)

yield row_dict


def stac_table_to_ndjson(table: pa.Table, dest: Union[str, os.PathLike[str]]) -> None:
"""Write a STAC Table to a newline-delimited JSON file."""
with open(dest, "wb") as f:
for item_dict in stac_table_to_items(table):
f.write(orjson.dumps(item_dict))
f.write(b"\n")


def stac_table_to_items(table: pa.Table) -> Iterable[dict]:
"""Convert a STAC Table to a generator of STAC Item `dict`s"""
for batch in table.to_batches():
# Convert each geometry column to a Shapely geometry, and then assign the
# geojson geometry when converting each row to a dictionary.
geometries: List[NDArray[np.object_]] = []
for geometry_path in geometry_paths:
col = batch
for path_segment in geometry_path:
if isinstance(col, pa.RecordBatch):
col = col[path_segment]
elif pa.types.is_struct(col.type):
col = pc.struct_field(col, path_segment)
else:
raise AssertionError(f"unexpected type {type(col)}")

geometries.append(shapely.from_wkb(col))

struct_batch = batch.to_struct_array()
for row_idx in range(len(struct_batch)):
row_dict = struct_batch[row_idx].as_py()
for geometry_path, geometry_column in zip(geometry_paths, geometries):
geojson_g = geometry_column[row_idx].__geo_interface__
geojson_g["coordinates"] = convert_tuples_to_lists(
geojson_g["coordinates"]
)
set_by_path(row_dict, geometry_path, geojson_g)

yield row_dict


def _undo_stac_table_transformations(table: pa.Table) -> pa.Table:
yield from stac_batch_to_items(batch)


def _undo_stac_transformations(batch: pa.RecordBatch) -> pa.RecordBatch:
"""Undo the transformations done to convert STAC Json into an Arrow Table
Note that this function does _not_ undo the GeoJSON -> WKB geometry transformation,
as that is easier to do when converting each item in the table to a dict.
"""
table = _convert_timestamp_columns_to_string(table)
table = _lower_properties_from_top_level(table)
table = _convert_bbox_to_array(table)
return table
batch = _convert_timestamp_columns_to_string(batch)
batch = _lower_properties_from_top_level(batch)
batch = _convert_bbox_to_array(batch)
return batch


def _convert_timestamp_columns_to_string(table: pa.Table) -> pa.Table:
def _convert_timestamp_columns_to_string(batch: pa.RecordBatch) -> pa.RecordBatch:
"""Convert any datetime columns in the table to a string representation"""
allowed_column_names = {
"datetime", # common metadata
Expand All @@ -95,18 +97,18 @@ def _convert_timestamp_columns_to_string(table: pa.Table) -> pa.Table:
}
for column_name in allowed_column_names:
try:
column = table[column_name]
column = batch[column_name]
except KeyError:
continue

table = table.drop(column_name).append_column(
batch = batch.drop_columns((column_name,)).append_column(
column_name, pc.strftime(column, format="%Y-%m-%dT%H:%M:%SZ")
)

return table
return batch


def _lower_properties_from_top_level(table: pa.Table) -> pa.Table:
def _lower_properties_from_top_level(batch: pa.RecordBatch) -> pa.RecordBatch:
"""Take properties columns from the top level and wrap them in a struct column"""
stac_top_level_keys = {
"stac_version",
Expand All @@ -122,78 +124,73 @@ def _lower_properties_from_top_level(table: pa.Table) -> pa.Table:

properties_column_names: List[str] = []
properties_column_fields: List[pa.Field] = []
for column_idx in range(table.num_columns):
column_name = table.column_names[column_idx]
for column_idx in range(batch.num_columns):
column_name = batch.column_names[column_idx]
if column_name in stac_top_level_keys:
continue

properties_column_names.append(column_name)
properties_column_fields.append(table.schema.field(column_idx))
properties_column_fields.append(batch.schema.field(column_idx))

properties_array_chunks = []
for batch in table.select(properties_column_names).to_batches():
struct_arr = pa.StructArray.from_arrays(
batch.columns, fields=properties_column_fields
)
properties_array_chunks.append(struct_arr)
struct_arr = pa.StructArray.from_arrays(
batch.select(properties_column_names).columns, fields=properties_column_fields
)

return table.drop_columns(properties_column_names).append_column(
"properties", pa.chunked_array(properties_array_chunks)
return batch.drop_columns(properties_column_names).append_column(
"properties", struct_arr
)


def _convert_bbox_to_array(table: pa.Table) -> pa.Table:
def _convert_bbox_to_array(batch: pa.RecordBatch) -> pa.RecordBatch:
"""Convert the struct bbox column back to a list column for writing to JSON"""

bbox_col_idx = table.schema.get_field_index("bbox")
bbox_col = table.column(bbox_col_idx)

new_chunks = []
for chunk in bbox_col.chunks:
assert pa.types.is_struct(chunk.type)

if bbox_col.type.num_fields == 4:
xmin = chunk.field("xmin").to_numpy()
ymin = chunk.field("ymin").to_numpy()
xmax = chunk.field("xmax").to_numpy()
ymax = chunk.field("ymax").to_numpy()
coords = np.column_stack(
[
xmin,
ymin,
xmax,
ymax,
]
)

list_arr = pa.FixedSizeListArray.from_arrays(coords.flatten("C"), 4)

elif bbox_col.type.num_fields == 6:
xmin = chunk.field("xmin").to_numpy()
ymin = chunk.field("ymin").to_numpy()
zmin = chunk.field("zmin").to_numpy()
xmax = chunk.field("xmax").to_numpy()
ymax = chunk.field("ymax").to_numpy()
zmax = chunk.field("zmax").to_numpy()
coords = np.column_stack(
[
xmin,
ymin,
zmin,
xmax,
ymax,
zmax,
]
)

list_arr = pa.FixedSizeListArray.from_arrays(coords.flatten("C"), 6)

else:
raise ValueError("Expected 4 or 6 fields in bbox struct.")

new_chunks.append(list_arr)

return table.set_column(bbox_col_idx, "bbox", new_chunks)
bbox_col_idx = batch.schema.get_field_index("bbox")
bbox_col = batch.column(bbox_col_idx)

# new_chunks = []
# for chunk in bbox_col.chunks:
assert pa.types.is_struct(bbox_col.type)

if bbox_col.type.num_fields == 4:
xmin = bbox_col.field("xmin").to_numpy()
ymin = bbox_col.field("ymin").to_numpy()
xmax = bbox_col.field("xmax").to_numpy()
ymax = bbox_col.field("ymax").to_numpy()
coords = np.column_stack(
[
xmin,
ymin,
xmax,
ymax,
]
)

list_arr = pa.FixedSizeListArray.from_arrays(coords.flatten("C"), 4)

elif bbox_col.type.num_fields == 6:
xmin = bbox_col.field("xmin").to_numpy()
ymin = bbox_col.field("ymin").to_numpy()
zmin = bbox_col.field("zmin").to_numpy()
xmax = bbox_col.field("xmax").to_numpy()
ymax = bbox_col.field("ymax").to_numpy()
zmax = bbox_col.field("zmax").to_numpy()
coords = np.column_stack(
[
xmin,
ymin,
zmin,
xmax,
ymax,
zmax,
]
)

list_arr = pa.FixedSizeListArray.from_arrays(coords.flatten("C"), 6)

else:
raise ValueError("Expected 4 or 6 fields in bbox struct.")

return batch.set_column(bbox_col_idx, "bbox", list_arr)


def convert_tuples_to_lists(t: List | Tuple) -> List[Any]:
Expand Down
35 changes: 10 additions & 25 deletions stac_geoparquet/arrow/_schema/models.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
import json
from pathlib import Path
from typing import Any, Dict, Iterable, Sequence, Union
from typing import Any, Dict, Iterable, Optional, Sequence, Union

import pyarrow as pa

from stac_geoparquet.arrow._util import stac_items_to_arrow
from stac_geoparquet.json_reader import read_json_chunked


class InferredSchema:
Expand All @@ -25,40 +25,25 @@ def __init__(self) -> None:
self.inner = pa.schema([])
self.count = 0

def update_from_ndjson(
def update_from_json(
self,
path: Union[Union[str, Path], Iterable[Union[str, Path]]],
path: Union[str, Path, Iterable[Union[str, Path]]],
*,
chunk_size: int = 65536,
limit: Optional[int] = None,
) -> None:
"""
Update this inferred schema from one or more newline-delimited JSON STAC files.
Args:
path: One or more paths to files with STAC items.
chunk_size: The chunk size to load into memory at a time. Defaults to 65536.
"""
# Handle multi-path input
if not isinstance(path, (str, Path)):
for p in path:
self.update_from_ndjson(p)

return

# Handle single-path input
with open(path) as f:
items = []
for line in f:
item = json.loads(line)
items.append(item)
if len(items) >= chunk_size:
self.update_from_items(items)
items = []

# Handle remainder
if len(items) > 0:
self.update_from_items(items)
Other args:
limit: The maximum number of JSON Items to use for schema inference
"""
for batch in read_json_chunked(path, chunk_size=chunk_size, limit=limit):
self.update_from_items(batch)

def update_from_items(self, items: Sequence[Dict[str, Any]]) -> None:
"""Update this inferred schema from a sequence of STAC Items."""
Expand Down
Loading

0 comments on commit 22abad6

Please sign in to comment.