Skip to content

Commit

Permalink
Exhaustive schema inference (#50)
Browse files Browse the repository at this point in the history
* exhaustive schema inference

* Update docstrings

* fix dict accessor

* fix setter

* Fix Arrow -> STAC conversion for nested proj:geometry

* fix types

* fix circular import
  • Loading branch information
kylebarron authored May 15, 2024
1 parent 61282d1 commit 2eaf22d
Show file tree
Hide file tree
Showing 6 changed files with 345 additions and 84 deletions.
3 changes: 3 additions & 0 deletions stac_geoparquet/arrow/_crs.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from pyproj import CRS

WGS84_CRS_JSON = CRS.from_epsg(4326).to_json_dict()
107 changes: 95 additions & 12 deletions stac_geoparquet/arrow/_from_arrow.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,17 @@
"""Convert STAC Items in Arrow Table format to JSON Lines or Python dicts."""

import os
import json
from typing import Iterable, List, Union
import operator
import os
from functools import reduce
from typing import Any, Dict, Iterable, List, Sequence, Tuple, Union

import numpy as np
import pyarrow as pa
import pyarrow.compute as pc
import shapely
from numpy.typing import NDArray
import shapely.geometry


def stac_table_to_ndjson(table: pa.Table, dest: Union[str, os.PathLike[str]]) -> None:
Expand All @@ -22,20 +26,46 @@ def stac_table_to_items(table: pa.Table) -> Iterable[dict]:
"""Convert a STAC Table to a generator of STAC Item `dict`s"""
table = _undo_stac_table_transformations(table)

# Convert WKB geometry column to GeoJSON, and then assign the geojson geometry when
# converting each row to a dictionary.
for batch in table.to_batches():
geoms = shapely.from_wkb(batch["geometry"])
geojson_strings = shapely.to_geojson(geoms)
# Find all paths in the schema that have a WKB geometry
geometry_paths = [["geometry"]]
try:
table.schema.field("properties").type.field("proj:geometry")
geometry_paths.append(["properties", "proj:geometry"])
except KeyError:
pass

# RecordBatch is missing a `drop()` method, so we keep all columns other than
# geometry instead
keep_column_names = [name for name in batch.column_names if name != "geometry"]
struct_batch = batch.select(keep_column_names).to_struct_array()
assets_struct = table.schema.field("assets").type
for asset_idx in range(assets_struct.num_fields):
asset_field = assets_struct.field(asset_idx)
if "proj:geometry" in pa.schema(asset_field).names:
geometry_paths.append(["assets", asset_field.name, "proj:geometry"])

for batch in table.to_batches():
# Convert each geometry column to a Shapely geometry, and then assign the
# geojson geometry when converting each row to a dictionary.
geometries: List[NDArray[np.object_]] = []
for geometry_path in geometry_paths:
col = batch
for path_segment in geometry_path:
if isinstance(col, pa.RecordBatch):
col = col[path_segment]
elif pa.types.is_struct(col.type):
col = pc.struct_field(col, path_segment)
else:
raise AssertionError(f"unexpected type {type(col)}")

geometries.append(shapely.from_wkb(col))

struct_batch = batch.to_struct_array()
for row_idx in range(len(struct_batch)):
row_dict = struct_batch[row_idx].as_py()
row_dict["geometry"] = json.loads(geojson_strings[row_idx])
for geometry_path, geometry_column in zip(geometry_paths, geometries):
geojson_g = geometry_column[row_idx].__geo_interface__
geojson_g["coordinates"] = convert_tuples_to_lists(
geojson_g["coordinates"]
)
set_by_path(row_dict, geometry_path, geojson_g)

yield row_dict


Expand Down Expand Up @@ -164,3 +194,56 @@ def _convert_bbox_to_array(table: pa.Table) -> pa.Table:
new_chunks.append(list_arr)

return table.set_column(bbox_col_idx, "bbox", new_chunks)


def convert_tuples_to_lists(t: List | Tuple) -> List[Any]:
"""Convert tuples to lists, recursively
For example, converts:
```
(
(
(-112.4820566, 38.1261015),
(-112.4816283, 38.1331311),
(-112.4833551, 38.1338897),
(-112.4832919, 38.1307687),
(-112.4855415, 38.1291793),
(-112.4820566, 38.1261015),
),
)
```
to
```py
[
[
[-112.4820566, 38.1261015],
[-112.4816283, 38.1331311],
[-112.4833551, 38.1338897],
[-112.4832919, 38.1307687],
[-112.4855415, 38.1291793],
[-112.4820566, 38.1261015],
]
]
```
From https://stackoverflow.com/a/1014669.
"""
return list(map(convert_tuples_to_lists, t)) if isinstance(t, (list, tuple)) else t


def get_by_path(root: Dict[str, Any], keys: Sequence[str]) -> Any:
"""Access a nested object in root by item sequence.
From https://stackoverflow.com/a/14692747
"""
return reduce(operator.getitem, keys, root)


def set_by_path(root: Dict[str, Any], keys: Sequence[str], value: Any) -> None:
"""Set a value in a nested object in root by item sequence.
From https://stackoverflow.com/a/14692747
"""
get_by_path(root, keys[:-1])[keys[-1]] = value # type: ignore
70 changes: 70 additions & 0 deletions stac_geoparquet/arrow/_schema/models.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
import json
from pathlib import Path
from typing import Any, Dict, Iterable, Sequence, Union

import pyarrow as pa

from stac_geoparquet.arrow._util import stac_items_to_arrow


class InferredSchema:
"""
A schema representing the original STAC JSON with absolutely minimal modifications.
The only modification from the data is converting any geometry fields from GeoJSON
to WKB.
"""

inner: pa.Schema
"""The underlying Arrow schema."""

count: int
"""The total number of items scanned."""

def __init__(self) -> None:
self.inner = pa.schema([])
self.count = 0

def update_from_ndjson(
self,
path: Union[Union[str, Path], Iterable[Union[str, Path]]],
*,
chunk_size: int = 65536,
) -> None:
"""
Update this inferred schema from one or more newline-delimited JSON STAC files.
Args:
path: One or more paths to files with STAC items.
chunk_size: The chunk size to load into memory at a time. Defaults to 65536.
"""
# Handle multi-path input
if not isinstance(path, (str, Path)):
for p in path:
self.update_from_ndjson(p)

return

# Handle single-path input
with open(path) as f:
items = []
for line in f:
item = json.loads(line)
items.append(item)

if len(items) >= chunk_size:
self.update_from_items(items)
items = []

# Handle remainder
if len(items) > 0:
self.update_from_items(items)

def update_from_items(self, items: Sequence[Dict[str, Any]]) -> None:
"""Update this inferred schema from a sequence of STAC Items."""
self.count += len(items)
current_schema = stac_items_to_arrow(items, schema=None).schema
new_schema = pa.unify_schemas(
[self.inner, current_schema], promote_options="permissive"
)
self.inner = new_schema
Loading

0 comments on commit 2eaf22d

Please sign in to comment.