diff --git a/README.md b/README.md index e8faa7a..f9859de 100644 --- a/README.md +++ b/README.md @@ -19,11 +19,12 @@ Note that `stac_geoparquet` lifts the keys in the item `properties` up to the to >>> import requests >>> import stac_geoparquet.arrow >>> import pyarrow.parquet +>>> import pyarrow as pa >>> items = requests.get( ... "https://planetarycomputer.microsoft.com/api/stac/v1/collections/sentinel-2-l2a/items" ... ).json()["features"] ->>> table = stac_geoparquet.arrow.parse_stac_items_to_arrow(items) +>>> table = pa.Table.from_batches(stac_geoparquet.arrow.parse_stac_items_to_arrow(items)) >>> stac_geoparquet.arrow.to_parquet(table, "items.parquet") >>> table2 = pyarrow.parquet.read_table("items.parquet") >>> items2 = list(stac_geoparquet.arrow.stac_table_to_items(table2)) diff --git a/pyproject.toml b/pyproject.toml index 0003844..319914c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -15,10 +15,12 @@ dependencies = [ "geopandas", "packaging", "pandas", - "pyarrow", + # Needed for RecordBatch.append_column + "pyarrow>=16", "pyproj", "pystac", "shapely", + "orjson", ] [tool.hatch.version] diff --git a/stac_geoparquet/arrow/_from_arrow.py b/stac_geoparquet/arrow/_from_arrow.py index ec717ae..c06b0ad 100644 --- a/stac_geoparquet/arrow/_from_arrow.py +++ b/stac_geoparquet/arrow/_from_arrow.py @@ -1,6 +1,6 @@ """Convert STAC Items in Arrow Table format to JSON Lines or Python dicts.""" -import json +import orjson import operator import os from functools import reduce @@ -14,74 +14,76 @@ import shapely.geometry -def stac_table_to_ndjson(table: pa.Table, dest: Union[str, os.PathLike[str]]) -> None: - """Write a STAC Table to a newline-delimited JSON file.""" - with open(dest, "w") as f: - for item_dict in stac_table_to_items(table): - json.dump(item_dict, f, separators=(",", ":")) - f.write("\n") - - -def stac_table_to_items(table: pa.Table) -> Iterable[dict]: - """Convert a STAC Table to a generator of STAC Item `dict`s""" - table = _undo_stac_table_transformations(table) - +def stac_batch_to_items(batch: pa.RecordBatch) -> Iterable[dict]: + """Convert a stac arrow recordbatch to item dicts.""" + batch = _undo_stac_transformations(batch) # Find all paths in the schema that have a WKB geometry geometry_paths = [["geometry"]] try: - table.schema.field("properties").type.field("proj:geometry") + batch.schema.field("properties").type.field("proj:geometry") geometry_paths.append(["properties", "proj:geometry"]) except KeyError: pass - assets_struct = table.schema.field("assets").type + assets_struct = batch.schema.field("assets").type for asset_idx in range(assets_struct.num_fields): asset_field = assets_struct.field(asset_idx) if "proj:geometry" in pa.schema(asset_field).names: geometry_paths.append(["assets", asset_field.name, "proj:geometry"]) + # Convert each geometry column to a Shapely geometry, and then assign the + # geojson geometry when converting each row to a dictionary. + geometries: List[NDArray[np.object_]] = [] + for geometry_path in geometry_paths: + col = batch + for path_segment in geometry_path: + if isinstance(col, pa.RecordBatch): + col = col[path_segment] + elif pa.types.is_struct(col.type): + col = pc.struct_field(col, path_segment) + else: + raise AssertionError(f"unexpected type {type(col)}") + + geometries.append(shapely.from_wkb(col)) + + struct_batch = batch.to_struct_array() + for row_idx in range(len(struct_batch)): + row_dict = struct_batch[row_idx].as_py() + for geometry_path, geometry_column in zip(geometry_paths, geometries): + geojson_g = geometry_column[row_idx].__geo_interface__ + geojson_g["coordinates"] = convert_tuples_to_lists(geojson_g["coordinates"]) + set_by_path(row_dict, geometry_path, geojson_g) + + yield row_dict + + +def stac_table_to_ndjson(table: pa.Table, dest: Union[str, os.PathLike[str]]) -> None: + """Write a STAC Table to a newline-delimited JSON file.""" + with open(dest, "wb") as f: + for item_dict in stac_table_to_items(table): + f.write(orjson.dumps(item_dict)) + f.write(b"\n") + + +def stac_table_to_items(table: pa.Table) -> Iterable[dict]: + """Convert a STAC Table to a generator of STAC Item `dict`s""" for batch in table.to_batches(): - # Convert each geometry column to a Shapely geometry, and then assign the - # geojson geometry when converting each row to a dictionary. - geometries: List[NDArray[np.object_]] = [] - for geometry_path in geometry_paths: - col = batch - for path_segment in geometry_path: - if isinstance(col, pa.RecordBatch): - col = col[path_segment] - elif pa.types.is_struct(col.type): - col = pc.struct_field(col, path_segment) - else: - raise AssertionError(f"unexpected type {type(col)}") - - geometries.append(shapely.from_wkb(col)) - - struct_batch = batch.to_struct_array() - for row_idx in range(len(struct_batch)): - row_dict = struct_batch[row_idx].as_py() - for geometry_path, geometry_column in zip(geometry_paths, geometries): - geojson_g = geometry_column[row_idx].__geo_interface__ - geojson_g["coordinates"] = convert_tuples_to_lists( - geojson_g["coordinates"] - ) - set_by_path(row_dict, geometry_path, geojson_g) - - yield row_dict - - -def _undo_stac_table_transformations(table: pa.Table) -> pa.Table: + yield from stac_batch_to_items(batch) + + +def _undo_stac_transformations(batch: pa.RecordBatch) -> pa.RecordBatch: """Undo the transformations done to convert STAC Json into an Arrow Table Note that this function does _not_ undo the GeoJSON -> WKB geometry transformation, as that is easier to do when converting each item in the table to a dict. """ - table = _convert_timestamp_columns_to_string(table) - table = _lower_properties_from_top_level(table) - table = _convert_bbox_to_array(table) - return table + batch = _convert_timestamp_columns_to_string(batch) + batch = _lower_properties_from_top_level(batch) + batch = _convert_bbox_to_array(batch) + return batch -def _convert_timestamp_columns_to_string(table: pa.Table) -> pa.Table: +def _convert_timestamp_columns_to_string(batch: pa.RecordBatch) -> pa.RecordBatch: """Convert any datetime columns in the table to a string representation""" allowed_column_names = { "datetime", # common metadata @@ -95,18 +97,18 @@ def _convert_timestamp_columns_to_string(table: pa.Table) -> pa.Table: } for column_name in allowed_column_names: try: - column = table[column_name] + column = batch[column_name] except KeyError: continue - table = table.drop(column_name).append_column( + batch = batch.drop_columns((column_name,)).append_column( column_name, pc.strftime(column, format="%Y-%m-%dT%H:%M:%SZ") ) - return table + return batch -def _lower_properties_from_top_level(table: pa.Table) -> pa.Table: +def _lower_properties_from_top_level(batch: pa.RecordBatch) -> pa.RecordBatch: """Take properties columns from the top level and wrap them in a struct column""" stac_top_level_keys = { "stac_version", @@ -122,78 +124,73 @@ def _lower_properties_from_top_level(table: pa.Table) -> pa.Table: properties_column_names: List[str] = [] properties_column_fields: List[pa.Field] = [] - for column_idx in range(table.num_columns): - column_name = table.column_names[column_idx] + for column_idx in range(batch.num_columns): + column_name = batch.column_names[column_idx] if column_name in stac_top_level_keys: continue properties_column_names.append(column_name) - properties_column_fields.append(table.schema.field(column_idx)) + properties_column_fields.append(batch.schema.field(column_idx)) - properties_array_chunks = [] - for batch in table.select(properties_column_names).to_batches(): - struct_arr = pa.StructArray.from_arrays( - batch.columns, fields=properties_column_fields - ) - properties_array_chunks.append(struct_arr) + struct_arr = pa.StructArray.from_arrays( + batch.select(properties_column_names).columns, fields=properties_column_fields + ) - return table.drop_columns(properties_column_names).append_column( - "properties", pa.chunked_array(properties_array_chunks) + return batch.drop_columns(properties_column_names).append_column( + "properties", struct_arr ) -def _convert_bbox_to_array(table: pa.Table) -> pa.Table: +def _convert_bbox_to_array(batch: pa.RecordBatch) -> pa.RecordBatch: """Convert the struct bbox column back to a list column for writing to JSON""" - bbox_col_idx = table.schema.get_field_index("bbox") - bbox_col = table.column(bbox_col_idx) - - new_chunks = [] - for chunk in bbox_col.chunks: - assert pa.types.is_struct(chunk.type) - - if bbox_col.type.num_fields == 4: - xmin = chunk.field("xmin").to_numpy() - ymin = chunk.field("ymin").to_numpy() - xmax = chunk.field("xmax").to_numpy() - ymax = chunk.field("ymax").to_numpy() - coords = np.column_stack( - [ - xmin, - ymin, - xmax, - ymax, - ] - ) - - list_arr = pa.FixedSizeListArray.from_arrays(coords.flatten("C"), 4) - - elif bbox_col.type.num_fields == 6: - xmin = chunk.field("xmin").to_numpy() - ymin = chunk.field("ymin").to_numpy() - zmin = chunk.field("zmin").to_numpy() - xmax = chunk.field("xmax").to_numpy() - ymax = chunk.field("ymax").to_numpy() - zmax = chunk.field("zmax").to_numpy() - coords = np.column_stack( - [ - xmin, - ymin, - zmin, - xmax, - ymax, - zmax, - ] - ) - - list_arr = pa.FixedSizeListArray.from_arrays(coords.flatten("C"), 6) - - else: - raise ValueError("Expected 4 or 6 fields in bbox struct.") - - new_chunks.append(list_arr) - - return table.set_column(bbox_col_idx, "bbox", new_chunks) + bbox_col_idx = batch.schema.get_field_index("bbox") + bbox_col = batch.column(bbox_col_idx) + + # new_chunks = [] + # for chunk in bbox_col.chunks: + assert pa.types.is_struct(bbox_col.type) + + if bbox_col.type.num_fields == 4: + xmin = bbox_col.field("xmin").to_numpy() + ymin = bbox_col.field("ymin").to_numpy() + xmax = bbox_col.field("xmax").to_numpy() + ymax = bbox_col.field("ymax").to_numpy() + coords = np.column_stack( + [ + xmin, + ymin, + xmax, + ymax, + ] + ) + + list_arr = pa.FixedSizeListArray.from_arrays(coords.flatten("C"), 4) + + elif bbox_col.type.num_fields == 6: + xmin = bbox_col.field("xmin").to_numpy() + ymin = bbox_col.field("ymin").to_numpy() + zmin = bbox_col.field("zmin").to_numpy() + xmax = bbox_col.field("xmax").to_numpy() + ymax = bbox_col.field("ymax").to_numpy() + zmax = bbox_col.field("zmax").to_numpy() + coords = np.column_stack( + [ + xmin, + ymin, + zmin, + xmax, + ymax, + zmax, + ] + ) + + list_arr = pa.FixedSizeListArray.from_arrays(coords.flatten("C"), 6) + + else: + raise ValueError("Expected 4 or 6 fields in bbox struct.") + + return batch.set_column(bbox_col_idx, "bbox", list_arr) def convert_tuples_to_lists(t: List | Tuple) -> List[Any]: diff --git a/stac_geoparquet/arrow/_schema/models.py b/stac_geoparquet/arrow/_schema/models.py index c3c0ecb..06fcbd2 100644 --- a/stac_geoparquet/arrow/_schema/models.py +++ b/stac_geoparquet/arrow/_schema/models.py @@ -1,10 +1,10 @@ -import json from pathlib import Path -from typing import Any, Dict, Iterable, Sequence, Union +from typing import Any, Dict, Iterable, Optional, Sequence, Union import pyarrow as pa from stac_geoparquet.arrow._util import stac_items_to_arrow +from stac_geoparquet.json_reader import read_json_chunked class InferredSchema: @@ -25,11 +25,12 @@ def __init__(self) -> None: self.inner = pa.schema([]) self.count = 0 - def update_from_ndjson( + def update_from_json( self, - path: Union[Union[str, Path], Iterable[Union[str, Path]]], + path: Union[str, Path, Iterable[Union[str, Path]]], *, chunk_size: int = 65536, + limit: Optional[int] = None, ) -> None: """ Update this inferred schema from one or more newline-delimited JSON STAC files. @@ -37,28 +38,12 @@ def update_from_ndjson( Args: path: One or more paths to files with STAC items. chunk_size: The chunk size to load into memory at a time. Defaults to 65536. - """ - # Handle multi-path input - if not isinstance(path, (str, Path)): - for p in path: - self.update_from_ndjson(p) - - return - - # Handle single-path input - with open(path) as f: - items = [] - for line in f: - item = json.loads(line) - items.append(item) - if len(items) >= chunk_size: - self.update_from_items(items) - items = [] - - # Handle remainder - if len(items) > 0: - self.update_from_items(items) + Other args: + limit: The maximum number of JSON Items to use for schema inference + """ + for batch in read_json_chunked(path, chunk_size=chunk_size, limit=limit): + self.update_from_items(batch) def update_from_items(self, items: Sequence[Dict[str, Any]]) -> None: """Update this inferred schema from a sequence of STAC Items.""" diff --git a/stac_geoparquet/arrow/_to_arrow.py b/stac_geoparquet/arrow/_to_arrow.py index 0918d34..b99dcec 100644 --- a/stac_geoparquet/arrow/_to_arrow.py +++ b/stac_geoparquet/arrow/_to_arrow.py @@ -1,48 +1,29 @@ """Convert STAC data into Arrow tables""" -import json -from datetime import datetime from pathlib import Path from typing import ( Any, Dict, - Generator, Iterable, Iterator, - List, Optional, - Sequence, Union, ) -import ciso8601 -import numpy as np import pyarrow as pa -import pyarrow.compute as pc -import shapely -import shapely.geometry from stac_geoparquet.arrow._schema.models import InferredSchema -from stac_geoparquet.arrow._crs import WGS84_CRS_JSON -from stac_geoparquet.arrow._util import stac_items_to_arrow - - -def _chunks( - lst: Sequence[Dict[str, Any]], n: int -) -> Generator[Sequence[Dict[str, Any]], None, None]: - """Yield successive n-sized chunks from lst.""" - for i in range(0, len(lst), n): - yield lst[i : i + n] +from stac_geoparquet.json_reader import read_json_chunked +from stac_geoparquet.arrow._util import stac_items_to_arrow, batched_iter def parse_stac_items_to_arrow( - items: Sequence[Dict[str, Any]], + items: Iterable[Dict[str, Any]], *, chunk_size: int = 8192, schema: Optional[Union[pa.Schema, InferredSchema]] = None, - downcast: bool = False, -) -> pa.Table: - """Parse a collection of STAC Items to a :class:`pyarrow.Table`. +) -> Iterable[pa.RecordBatch]: + """Parse a collection of STAC Items to an iterable of :class:`pyarrow.RecordBatch`. The objects under `properties` are moved up to the top-level of the Table, similar to :meth:`geopandas.GeoDataFrame.from_features`. @@ -55,38 +36,33 @@ def parse_stac_items_to_arrow( schema: The schema of the input data. If provided, can improve memory use; otherwise all items need to be parsed into a single array for schema inference. Defaults to None. - downcast: if True, store bbox as float32 for memory and disk saving. Returns: - a pyarrow Table with the STAC-GeoParquet representation of items. + an iterable of pyarrow RecordBatches with the STAC-GeoParquet representation of items. """ - if schema is not None: if isinstance(schema, InferredSchema): schema = schema.inner # If schema is provided, then for better memory usage we parse input STAC items # to Arrow batches in chunks. - batches = [] - for chunk in _chunks(items, chunk_size): - batches.append(stac_items_to_arrow(chunk, schema=schema)) + for chunk in batched_iter(items, chunk_size): + yield stac_items_to_arrow(chunk, schema=schema) - table = pa.Table.from_batches(batches, schema=schema) else: # If schema is _not_ provided, then we must convert to Arrow all at once, or # else it would be possible for a STAC item late in the collection (after the # first chunk) to have a different schema and not match the schema inferred for # the first chunk. - table = pa.Table.from_batches([stac_items_to_arrow(items)]) - - return _process_arrow_table(table, downcast=downcast) + yield stac_items_to_arrow(items) def parse_stac_ndjson_to_arrow( - path: Union[Union[str, Path], Iterable[Union[str, Path]]], + path: Union[str, Path, Iterable[Union[str, Path]]], *, chunk_size: int = 65536, - schema: Optional[Union[pa.Schema, InferredSchema]] = None, + schema: Optional[pa.Schema] = None, + limit: Optional[int] = None, ) -> Iterator[pa.RecordBatch]: """ Convert one or more newline-delimited JSON STAC files to a generator of Arrow @@ -103,286 +79,24 @@ def parse_stac_ndjson_to_arrow( In this case, there will be two full passes over the input data: one to infer a common schema across all data and another to read the data. + Other args: + limit: The maximum number of JSON Items to use for schema inference + Yields: Arrow RecordBatch with a single chunk of Item data. """ - # Define outside of if/else to make mypy happy - items: List[dict] = [] - # If the schema was not provided, then we need to load all data into memory at once # to perform schema resolution. if schema is None: inferred_schema = InferredSchema() - inferred_schema.update_from_ndjson(path, chunk_size=chunk_size) + inferred_schema.update_from_json(path, chunk_size=chunk_size, limit=limit) yield from parse_stac_ndjson_to_arrow( path, chunk_size=chunk_size, schema=inferred_schema ) return - # Check if path is an iterable - # If so, recursively call this function on each item in the iterable - if not isinstance(path, (str, Path)): - for p in path: - yield from parse_stac_ndjson_to_arrow( - p, chunk_size=chunk_size, schema=schema - ) - - return - if isinstance(schema, InferredSchema): schema = schema.inner - # Otherwise, we can stream over the input, converting each batch of `chunk_size` - # into an Arrow RecordBatch at a time. This is much more memory efficient. - with open(path) as f: - for line in f: - items.append(json.loads(line)) - - if len(items) >= chunk_size: - batch = stac_items_to_arrow(items, schema=schema) - yield from _process_arrow_table( - pa.Table.from_batches([batch]), downcast=False - ).to_batches() - items = [] - - # Don't forget the last chunk in case the total number of items is not a multiple of - # chunk_size. - if len(items) > 0: - batch = stac_items_to_arrow(items, schema=schema) - yield from _process_arrow_table( - pa.Table.from_batches([batch]), downcast=False - ).to_batches() - - -def _process_arrow_table(table: pa.Table, *, downcast: bool = True) -> pa.Table: - table = _bring_properties_to_top_level(table) - table = _convert_timestamp_columns(table) - table = _convert_bbox_to_struct(table, downcast=downcast) - table = _assign_geoarrow_metadata(table) - return table - - -def _bring_properties_to_top_level(table: pa.Table) -> pa.Table: - """Bring all the fields inside of the nested "properties" struct to the top level""" - properties_field = table.schema.field("properties") - properties_column = table["properties"] - - for field_idx in range(properties_field.type.num_fields): - inner_prop_field = properties_field.type.field(field_idx) - table = table.append_column( - inner_prop_field, pc.struct_field(properties_column, field_idx) - ) - - table = table.drop("properties") - return table - - -def _convert_geometry_to_wkb(table: pa.Table) -> pa.Table: - """Convert the geometry column in the table to WKB""" - geoms = shapely.from_geojson( - [json.dumps(item) for item in table["geometry"].to_pylist()] - ) - wkb_geoms = shapely.to_wkb(geoms) - return table.drop("geometry").append_column("geometry", pa.array(wkb_geoms)) - - -def _convert_timestamp_columns(table: pa.Table) -> pa.Table: - """Convert all timestamp columns from a string to an Arrow Timestamp data type""" - allowed_column_names = { - "datetime", # common metadata - "start_datetime", - "end_datetime", - "created", - "updated", - "expires", # timestamps extension - "published", - "unpublished", - } - for column_name in allowed_column_names: - try: - column = table[column_name] - except KeyError: - continue - - field_index = table.schema.get_field_index(column_name) - - if pa.types.is_timestamp(column.type): - continue - - # STAC allows datetimes to be null. If all rows are null, the column type may be - # inferred as null. We cast this to a timestamp column. - elif pa.types.is_null(column.type): - table = table.set_column( - field_index, column_name, column.cast(pa.timestamp("us")) - ) - - elif pa.types.is_string(column.type): - table = table.set_column( - field_index, column_name, _convert_timestamp_column(column) - ) - else: - raise ValueError( - f"Inferred time column '{column_name}' was expected to be a string or" - f" timestamp data type but got {column.type}" - ) - - return table - - -def _convert_timestamp_column(column: pa.ChunkedArray) -> pa.ChunkedArray: - """Convert an individual timestamp column from string to a Timestamp type""" - chunks = [] - for chunk in column.chunks: - parsed_chunk: List[Optional[datetime]] = [] - for item in chunk: - if not item.is_valid: - parsed_chunk.append(None) - else: - parsed_chunk.append(ciso8601.parse_rfc3339(item.as_py())) - - pyarrow_chunk = pa.array(parsed_chunk) - chunks.append(pyarrow_chunk) - - return pa.chunked_array(chunks) - - -def _is_bbox_3d(bbox_col: pa.ChunkedArray) -> bool: - """Infer whether the bounding box column represents 2d or 3d bounding boxes.""" - offsets_set = set() - for chunk in bbox_col.chunks: - offsets = chunk.offsets.to_numpy() - offsets_set.update(np.unique(offsets[1:] - offsets[:-1])) - - if len(offsets_set) > 1: - raise ValueError("Mixed 2d-3d bounding boxes not yet supported") - - offset = list(offsets_set)[0] - if offset == 6: - return True - elif offset == 4: - return False - else: - raise ValueError(f"Unexpected bbox offset: {offset=}") - - -def _convert_bbox_to_struct(table: pa.Table, *, downcast: bool) -> pa.Table: - """Convert bbox column to a struct representation - - Since the bbox in JSON is stored as an array, pyarrow automatically converts the - bbox column to a ListArray. But according to GeoParquet 1.1, we should save the bbox - column as a StructArray, which allows for Parquet statistics to infer any spatial - partitioning in the dataset. - - Args: - table: _description_ - downcast: if True, will use float32 coordinates for the bounding boxes instead - of float64. Float rounding is applied to ensure the float32 bounding box - strictly contains the original float64 box. This is recommended when - possible to minimize file size. - - Returns: - New table - """ - bbox_col_idx = table.schema.get_field_index("bbox") - bbox_col = table.column(bbox_col_idx) - bbox_3d = _is_bbox_3d(bbox_col) - - new_chunks = [] - for chunk in bbox_col.chunks: - assert ( - pa.types.is_list(chunk.type) - or pa.types.is_large_list(chunk.type) - or pa.types.is_fixed_size_list(chunk.type) - ) - if bbox_3d: - coords = chunk.flatten().to_numpy().reshape(-1, 6) - else: - coords = chunk.flatten().to_numpy().reshape(-1, 4) - - if downcast: - coords = coords.astype(np.float32) - - if bbox_3d: - xmin = coords[:, 0] - ymin = coords[:, 1] - zmin = coords[:, 2] - xmax = coords[:, 3] - ymax = coords[:, 4] - zmax = coords[:, 5] - - if downcast: - # Round min values down to the next float32 value - # Round max values up to the next float32 value - xmin = np.nextafter(xmin, -np.Infinity) - ymin = np.nextafter(ymin, -np.Infinity) - zmin = np.nextafter(zmin, -np.Infinity) - xmax = np.nextafter(xmax, np.Infinity) - ymax = np.nextafter(ymax, np.Infinity) - zmax = np.nextafter(zmax, np.Infinity) - - struct_arr = pa.StructArray.from_arrays( - [ - xmin, - ymin, - zmin, - xmax, - ymax, - zmax, - ], - names=[ - "xmin", - "ymin", - "zmin", - "xmax", - "ymax", - "zmax", - ], - ) - - else: - xmin = coords[:, 0] - ymin = coords[:, 1] - xmax = coords[:, 2] - ymax = coords[:, 3] - - if downcast: - # Round min values down to the next float32 value - # Round max values up to the next float32 value - xmin = np.nextafter(xmin, -np.Infinity) - ymin = np.nextafter(ymin, -np.Infinity) - xmax = np.nextafter(xmax, np.Infinity) - ymax = np.nextafter(ymax, np.Infinity) - - struct_arr = pa.StructArray.from_arrays( - [ - xmin, - ymin, - xmax, - ymax, - ], - names=[ - "xmin", - "ymin", - "xmax", - "ymax", - ], - ) - - new_chunks.append(struct_arr) - - return table.set_column(bbox_col_idx, "bbox", new_chunks) - - -def _assign_geoarrow_metadata(table: pa.Table) -> pa.Table: - """Tag the primary geometry column with `geoarrow.wkb` on the field metadata.""" - existing_field_idx = table.schema.get_field_index("geometry") - existing_field = table.schema.field(existing_field_idx) - ext_metadata = {"crs": WGS84_CRS_JSON} - field_metadata = { - b"ARROW:extension:name": b"geoarrow.wkb", - b"ARROW:extension:metadata": json.dumps(ext_metadata).encode("utf-8"), - } - new_field = existing_field.with_metadata(field_metadata) - return table.set_column( - existing_field_idx, new_field, table.column(existing_field_idx) - ) + for batch in read_json_chunked(path, chunk_size=chunk_size): + yield stac_items_to_arrow(batch, schema=schema) diff --git a/stac_geoparquet/arrow/_to_parquet.py b/stac_geoparquet/arrow/_to_parquet.py index 94da692..294e216 100644 --- a/stac_geoparquet/arrow/_to_parquet.py +++ b/stac_geoparquet/arrow/_to_parquet.py @@ -11,7 +11,7 @@ def parse_stac_ndjson_to_parquet( - input_path: Union[Union[str, Path], Iterable[Union[str, Path]]], + input_path: Union[str, Path, Iterable[Union[str, Path]]], output_path: Union[str, Path], *, chunk_size: int = 65536, @@ -30,6 +30,7 @@ def parse_stac_ndjson_to_parquet( infer a common schema across all data and another to read the data and iteratively convert to GeoParquet. """ + batches_iter = parse_stac_ndjson_to_arrow( input_path, chunk_size=chunk_size, schema=schema ) diff --git a/stac_geoparquet/arrow/_util.py b/stac_geoparquet/arrow/_util.py index 8714064..5390af5 100644 --- a/stac_geoparquet/arrow/_util.py +++ b/stac_geoparquet/arrow/_util.py @@ -1,13 +1,49 @@ from copy import deepcopy -from typing import Any, Dict, Optional, Sequence +from typing import ( + Any, + Dict, + Iterable, + Optional, + Sequence, +) +import ciso8601 +import numpy as np import pyarrow as pa +import pyarrow.compute as pc import shapely import shapely.geometry +import orjson +from itertools import islice + +from stac_geoparquet.arrow._crs import WGS84_CRS_JSON + + +def update_batch_schema( + batch: pa.RecordBatch, + schema: pa.Schema, +) -> pa.RecordBatch: + """Update a batch with new schema.""" + return pa.record_batch(batch.to_pydict(), schema=schema) + + +def batched_iter( + lst: Iterable[Dict[str, Any]], n: int, *, limit: Optional[int] = None +) -> Iterable[Sequence[Dict[str, Any]]]: + """Yield successive n-sized chunks from iterable.""" + if n < 1: + raise ValueError("n must be at least one") + it = iter(lst) + count = 0 + while batch := tuple(islice(it, n)): + yield batch + count += len(batch) + if limit and count >= limit: + return def stac_items_to_arrow( - items: Sequence[Dict[str, Any]], *, schema: Optional[pa.Schema] = None + items: Iterable[Dict[str, Any]], *, schema: Optional[pa.Schema] = None ) -> pa.RecordBatch: """Convert dicts representing STAC Items to Arrow @@ -60,4 +96,214 @@ def stac_items_to_arrow( else: array = pa.array(wkb_items) - return pa.RecordBatch.from_struct_array(array) + return _process_arrow_batch(pa.RecordBatch.from_struct_array(array)) + + +def _bring_properties_to_top_level( + batch: pa.RecordBatch, +) -> pa.RecordBatch: + """Bring all the fields inside of the nested "properties" struct to the top level""" + properties_field = batch.schema.field("properties") + properties_column = batch["properties"] + + for field_idx in range(properties_field.type.num_fields): + inner_prop_field = properties_field.type.field(field_idx) + batch = batch.append_column( + inner_prop_field, pc.struct_field(properties_column, field_idx) + ) + + batch = batch.drop_columns( + [ + "properties", + ] + ) + return batch + + +def _convert_geometry_to_wkb( + batch: pa.RecordBatch, +) -> pa.RecordBatch: + """Convert the geometry column in the table to WKB""" + geoms = shapely.from_geojson( + [orjson.dumps(item) for item in batch["geometry"].to_pylist()] + ) + wkb_geoms = shapely.to_wkb(geoms) + return batch.drop_columns( + [ + "geometry", + ] + ).append_column("geometry", pa.array(wkb_geoms)) + + +def _convert_timestamp_columns( + batch: pa.RecordBatch, +) -> pa.RecordBatch: + """Convert all timestamp columns from a string to an Arrow Timestamp data type""" + allowed_column_names = { + "datetime", # common metadata + "start_datetime", + "end_datetime", + "created", + "updated", + "expires", # timestamps extension + "published", + "unpublished", + } + for column_name in allowed_column_names: + try: + column = batch[column_name] + except KeyError: + continue + + field_index = batch.schema.get_field_index(column_name) + + if pa.types.is_timestamp(column.type): + continue + + # STAC allows datetimes to be null. If all rows are null, the column type may be + # inferred as null. We cast this to a timestamp column. + elif pa.types.is_null(column.type): + batch = batch.set_column( + field_index, column_name, column.cast(pa.timestamp("us")) + ) + + elif pa.types.is_string(column.type): + batch = batch.set_column( + field_index, column_name, _convert_timestamp_column(column) + ) + else: + raise ValueError( + f"Inferred time column '{column_name}' was expected to be a string or" + f" timestamp data type but got {column.type}" + ) + + return batch + + +def _convert_timestamp_column(column: pa.Array) -> pa.TimestampArray: + """Convert an individual timestamp column from string to a Timestamp type""" + return pa.array( + [ciso8601.parse_rfc3339(str(t)) for t in column], pa.timestamp("us", tz="UTC") + ) + + +def _is_bbox_3d(bbox_col: pa.Array) -> bool: + """Infer whether the bounding box column represents 2d or 3d bounding boxes.""" + offsets_set = set() + offsets = bbox_col.offsets.to_numpy() + offsets_set.update(np.unique(offsets[1:] - offsets[:-1])) + + if len(offsets_set) > 1: + raise ValueError("Mixed 2d-3d bounding boxes not yet supported") + + offset = list(offsets_set)[0] + if offset == 6: + return True + elif offset == 4: + return False + else: + raise ValueError(f"Unexpected bbox offset: {offset=}") + + +def _convert_bbox_to_struct(batch: pa.RecordBatch) -> pa.RecordBatch: + """Convert bbox column to a struct representation + + Since the bbox in JSON is stored as an array, pyarrow automatically converts the + bbox column to a ListArray. But according to GeoParquet 1.1, we should save the bbox + column as a StructArray, which allows for Parquet statistics to infer any spatial + partitioning in the dataset. + + Args: + batch: _description_ + + Returns: + New record batch + """ + bbox_col_idx = batch.schema.get_field_index("bbox") + bbox_col = batch.column(bbox_col_idx) + bbox_3d = _is_bbox_3d(bbox_col) + + assert ( + pa.types.is_list(bbox_col.type) + or pa.types.is_large_list(bbox_col.type) + or pa.types.is_fixed_size_list(bbox_col.type) + ) + if bbox_3d: + coords = bbox_col.flatten().to_numpy().reshape(-1, 6) + else: + coords = bbox_col.flatten().to_numpy().reshape(-1, 4) + + if bbox_3d: + xmin = coords[:, 0] + ymin = coords[:, 1] + zmin = coords[:, 2] + xmax = coords[:, 3] + ymax = coords[:, 4] + zmax = coords[:, 5] + + struct_arr = pa.StructArray.from_arrays( + [ + xmin, + ymin, + zmin, + xmax, + ymax, + zmax, + ], + names=[ + "xmin", + "ymin", + "zmin", + "xmax", + "ymax", + "zmax", + ], + ) + + else: + xmin = coords[:, 0] + ymin = coords[:, 1] + xmax = coords[:, 2] + ymax = coords[:, 3] + + struct_arr = pa.StructArray.from_arrays( + [ + xmin, + ymin, + xmax, + ymax, + ], + names=[ + "xmin", + "ymin", + "xmax", + "ymax", + ], + ) + + return batch.set_column(bbox_col_idx, "bbox", struct_arr) + + +def _assign_geoarrow_metadata( + batch: pa.RecordBatch, +) -> pa.RecordBatch: + """Tag the primary geometry column with `geoarrow.wkb` on the field metadata.""" + existing_field_idx = batch.schema.get_field_index("geometry") + existing_field = batch.schema.field(existing_field_idx) + ext_metadata = {"crs": WGS84_CRS_JSON} + field_metadata = { + b"ARROW:extension:name": b"geoarrow.wkb", + b"ARROW:extension:metadata": orjson.dumps(ext_metadata), + } + new_field = existing_field.with_metadata(field_metadata) + return batch.set_column( + existing_field_idx, new_field, batch.column(existing_field_idx) + ) + + +def _process_arrow_batch(batch: pa.RecordBatch) -> pa.RecordBatch: + batch = _bring_properties_to_top_level(batch) + batch = _convert_timestamp_columns(batch) + batch = _convert_bbox_to_struct(batch) + batch = _assign_geoarrow_metadata(batch) + return batch diff --git a/stac_geoparquet/json_reader.py b/stac_geoparquet/json_reader.py new file mode 100644 index 0000000..62589d7 --- /dev/null +++ b/stac_geoparquet/json_reader.py @@ -0,0 +1,48 @@ +"""Return an iterator of items from an ndjson, a json array of items, or a featurecollection of items.""" + +from pathlib import Path +from typing import Any, Dict, Iterable, Optional, Sequence, Union + +import orjson + +from stac_geoparquet.arrow._util import batched_iter + + +def read_json( + path: Union[str, Path, Iterable[Union[str, Path]]], +) -> Iterable[Dict[str, Any]]: + """Read a json or ndjson file.""" + if isinstance(path, (str, Path)): + path = [path] + + for p in path: + with open(p) as f: + try: + # Support ndjson or json list/FeatureCollection without any whitespace + # (all on first line) + for line in f: + item = orjson.loads(line.strip()) + if isinstance(item, list): + yield from item + elif "features" in item: + yield from item["features"] + else: + yield item + except orjson.JSONDecodeError: + f.seek(0) + # read full json file as either a list or FeatureCollection + json = orjson.loads(f.read()) + if isinstance(json, list): + yield from json + else: + yield from json["features"] + + +def read_json_chunked( + path: Union[str, Path, Iterable[Union[str, Path]]], + chunk_size: int, + *, + limit: Optional[int] = None, +) -> Iterable[Sequence[Dict[str, Any]]]: + """Read from a JSON or NDJSON file in chunks of `chunk_size`.""" + return batched_iter(read_json(path), chunk_size, limit=limit) diff --git a/tests/test_arrow.py b/tests/test_arrow.py index e3f1291..2b9bca4 100644 --- a/tests/test_arrow.py +++ b/tests/test_arrow.py @@ -3,6 +3,7 @@ from pathlib import Path from typing import Any, Dict, Sequence, Union +import pyarrow as pa import pytest from ciso8601 import parse_rfc3339 @@ -197,13 +198,7 @@ def test_round_trip(collection_id: str): with open(HERE / "data" / f"{collection_id}-pc.json") as f: items = json.load(f) - table = parse_stac_items_to_arrow(items, downcast=True) - items_result = list(stac_table_to_items(table)) - - for result, expected in zip(items_result, items): - assert_json_value_equal(result, expected, precision=0.001) - - table = parse_stac_items_to_arrow(items, downcast=False) + table = pa.Table.from_batches(parse_stac_items_to_arrow(items)) items_result = list(stac_table_to_items(table)) for result, expected in zip(items_result, items): @@ -215,7 +210,7 @@ def test_table_contains_geoarrow_metadata(): with open(HERE / "data" / f"{collection_id}-pc.json") as f: items = json.load(f) - table = parse_stac_items_to_arrow(items) + table = pa.Table.from_batches(parse_stac_items_to_arrow(items)) field_meta = table.schema.field("geometry").metadata assert field_meta[b"ARROW:extension:name"] == b"geoarrow.wkb" assert json.loads(field_meta[b"ARROW:extension:metadata"])["crs"]["id"] == {