From cf0699b1b3e24b381a2a617a2ccd949a4863f856 Mon Sep 17 00:00:00 2001 From: Kyle Barron Date: Tue, 4 Jun 2024 12:46:45 -0400 Subject: [PATCH] Refactor interally to `RawBatch` and `CleanBatch` wrapper types (#57) * Move json equality logic outside of test_arrow * Refactor to RawBatch and CleanBatch wrapper types * Move _from_arrow functions to _api * Update imports * fix circular import * keep deprecated api * Add write-read test and fix typing * add parquet tests * fix ci * Rename wrapper types --- stac_geoparquet/arrow/__init__.py | 10 +- stac_geoparquet/arrow/_api.py | 136 +++++++++++ stac_geoparquet/arrow/_batch.py | 190 +++++++++++++++ stac_geoparquet/arrow/_from_arrow.py | 140 +---------- stac_geoparquet/arrow/_schema/models.py | 4 +- stac_geoparquet/arrow/_to_arrow.py | 265 ++++++++++++++------- stac_geoparquet/arrow/_to_parquet.py | 4 +- stac_geoparquet/arrow/_util.py | 303 ++++-------------------- stac_geoparquet/from_arrow.py | 5 +- stac_geoparquet/to_arrow.py | 10 +- tests/__init__.py | 0 tests/json_equals.py | 167 +++++++++++++ tests/test_arrow.py | 215 +++-------------- tests/test_parquet.py | 53 +++++ 14 files changed, 833 insertions(+), 669 deletions(-) create mode 100644 stac_geoparquet/arrow/_api.py create mode 100644 stac_geoparquet/arrow/_batch.py create mode 100644 tests/__init__.py create mode 100644 tests/json_equals.py create mode 100644 tests/test_parquet.py diff --git a/stac_geoparquet/arrow/__init__.py b/stac_geoparquet/arrow/__init__.py index ee781a3..c88deb2 100644 --- a/stac_geoparquet/arrow/__init__.py +++ b/stac_geoparquet/arrow/__init__.py @@ -1,3 +1,7 @@ -from ._from_arrow import stac_table_to_items, stac_table_to_ndjson -from ._to_arrow import parse_stac_items_to_arrow, parse_stac_ndjson_to_arrow -from ._to_parquet import to_parquet +from ._api import ( + parse_stac_items_to_arrow, + parse_stac_ndjson_to_arrow, + stac_table_to_items, + stac_table_to_ndjson, +) +from ._to_parquet import parse_stac_ndjson_to_parquet, to_parquet diff --git a/stac_geoparquet/arrow/_api.py b/stac_geoparquet/arrow/_api.py new file mode 100644 index 0000000..2fee475 --- /dev/null +++ b/stac_geoparquet/arrow/_api.py @@ -0,0 +1,136 @@ +import os +from pathlib import Path +from typing import Any, Dict, Iterable, Iterator, Optional, Union + +import pyarrow as pa + +from stac_geoparquet.arrow._batch import StacArrowBatch, StacJsonBatch +from stac_geoparquet.arrow._schema.models import InferredSchema +from stac_geoparquet.arrow._util import batched_iter +from stac_geoparquet.json_reader import read_json_chunked + + +def parse_stac_items_to_arrow( + items: Iterable[Dict[str, Any]], + *, + chunk_size: int = 8192, + schema: Optional[Union[pa.Schema, InferredSchema]] = None, +) -> Iterable[pa.RecordBatch]: + """Parse a collection of STAC Items to an iterable of :class:`pyarrow.RecordBatch`. + + The objects under `properties` are moved up to the top-level of the + Table, similar to :meth:`geopandas.GeoDataFrame.from_features`. + + Args: + items: the STAC Items to convert + chunk_size: The chunk size to use for Arrow record batches. This only takes + effect if `schema` is not None. When `schema` is None, the input will be + parsed into a single contiguous record batch. Defaults to 8192. + schema: The schema of the input data. If provided, can improve memory use; + otherwise all items need to be parsed into a single array for schema + inference. Defaults to None. + + Returns: + an iterable of pyarrow RecordBatches with the STAC-GeoParquet representation of items. + """ + if schema is not None: + if isinstance(schema, InferredSchema): + schema = schema.inner + + # If schema is provided, then for better memory usage we parse input STAC items + # to Arrow batches in chunks. + for chunk in batched_iter(items, chunk_size): + yield stac_items_to_arrow(chunk, schema=schema) + + else: + # If schema is _not_ provided, then we must convert to Arrow all at once, or + # else it would be possible for a STAC item late in the collection (after the + # first chunk) to have a different schema and not match the schema inferred for + # the first chunk. + yield stac_items_to_arrow(items) + + +def parse_stac_ndjson_to_arrow( + path: Union[str, Path, Iterable[Union[str, Path]]], + *, + chunk_size: int = 65536, + schema: Optional[pa.Schema] = None, + limit: Optional[int] = None, +) -> Iterator[pa.RecordBatch]: + """ + Convert one or more newline-delimited JSON STAC files to a generator of Arrow + RecordBatches. + + Each RecordBatch in the returned iterator is guaranteed to have an identical schema, + and can be used to write to one or more Parquet files. + + Args: + path: One or more paths to files with STAC items. + chunk_size: The chunk size. Defaults to 65536. + schema: The schema to represent the input STAC data. Defaults to None, in which + case the schema will first be inferred via a full pass over the input data. + In this case, there will be two full passes over the input data: one to + infer a common schema across all data and another to read the data. + + Other args: + limit: The maximum number of JSON Items to use for schema inference + + Yields: + Arrow RecordBatch with a single chunk of Item data. + """ + # If the schema was not provided, then we need to load all data into memory at once + # to perform schema resolution. + if schema is None: + inferred_schema = InferredSchema() + inferred_schema.update_from_json(path, chunk_size=chunk_size, limit=limit) + yield from parse_stac_ndjson_to_arrow( + path, chunk_size=chunk_size, schema=inferred_schema + ) + return + + if isinstance(schema, InferredSchema): + schema = schema.inner + + for batch in read_json_chunked(path, chunk_size=chunk_size): + yield stac_items_to_arrow(batch, schema=schema) + + +def stac_table_to_items(table: pa.Table) -> Iterable[dict]: + """Convert a STAC Table to a generator of STAC Item `dict`s""" + for batch in table.to_batches(): + clean_batch = StacArrowBatch(batch) + yield from clean_batch.to_raw_batch().iter_dicts() + + +def stac_table_to_ndjson( + table: pa.Table, dest: Union[str, Path, os.PathLike[bytes]] +) -> None: + """Write a STAC Table to a newline-delimited JSON file.""" + for batch in table.to_batches(): + clean_batch = StacArrowBatch(batch) + clean_batch.to_raw_batch().to_ndjson(dest) + + +def stac_items_to_arrow( + items: Iterable[Dict[str, Any]], *, schema: Optional[pa.Schema] = None +) -> pa.RecordBatch: + """Convert dicts representing STAC Items to Arrow + + This converts GeoJSON geometries to WKB before Arrow conversion to allow multiple + geometry types. + + All items will be parsed into a single RecordBatch, meaning that each internal array + is fully contiguous in memory for the length of `items`. + + Args: + items: STAC Items to convert to Arrow + + Kwargs: + schema: An optional schema that describes the format of the data. Note that this + must represent the geometry column as binary type. + + Returns: + Arrow RecordBatch with items in Arrow + """ + raw_batch = StacJsonBatch.from_dicts(items, schema=schema) + return raw_batch.to_clean_batch().inner diff --git a/stac_geoparquet/arrow/_batch.py b/stac_geoparquet/arrow/_batch.py new file mode 100644 index 0000000..7130cb2 --- /dev/null +++ b/stac_geoparquet/arrow/_batch.py @@ -0,0 +1,190 @@ +from __future__ import annotations + +import os +from copy import deepcopy +from pathlib import Path +from typing import Any, Iterable + +import numpy as np +import orjson +import pyarrow as pa +import pyarrow.compute as pc +import shapely +import shapely.geometry +from numpy.typing import NDArray +from typing_extensions import Self + +from stac_geoparquet.arrow._from_arrow import ( + convert_bbox_to_array, + convert_timestamp_columns_to_string, + lower_properties_from_top_level, +) +from stac_geoparquet.arrow._to_arrow import ( + assign_geoarrow_metadata, + bring_properties_to_top_level, + convert_bbox_to_struct, + convert_timestamp_columns, +) +from stac_geoparquet.arrow._util import convert_tuples_to_lists, set_by_path + + +class StacJsonBatch: + """ + An Arrow RecordBatch of STAC Items that has been **minimally converted** to Arrow. + That is, it aligns as much as possible to the raw STAC JSON representation. + + The **only** transformations that have already been applied here are those that are + necessary to represent the core STAC items in Arrow. + + - `geometry` has been converted to WKB binary + - `properties.proj:geometry`, if it exists, has been converted to WKB binary + ISO encoding + - The `proj:geometry` in any asset properties, if it exists, has been converted to + WKB binary. + + No other transformations have yet been applied. I.e. all properties are still in a + top-level `properties` struct column. + """ + + inner: pa.RecordBatch + """The underlying pyarrow RecordBatch""" + + def __init__(self, batch: pa.RecordBatch) -> None: + self.inner = batch + + @classmethod + def from_dicts( + cls, items: Iterable[dict[str, Any]], *, schema: pa.Schema | None = None + ) -> Self: + """Construct a StacJsonBatch from an iterable of dicts representing STAC items. + + All items will be parsed into a single RecordBatch, meaning that each internal + array is fully contiguous in memory for the length of `items`. + + Args: + items: STAC Items to convert to Arrow + + Kwargs: + schema: An optional schema that describes the format of the data. Note that + this must represent the geometry column and any `proj:geometry` columns + as binary type. + + Returns: + a new StacJsonBatch of data. + """ + # Preprocess GeoJSON to WKB in each STAC item + # Otherwise, pyarrow will try to parse coordinates into a native geometry type + # and if you have multiple geometry types pyarrow will error with + # `ArrowInvalid: cannot mix list and non-list, non-null values` + wkb_items = [] + for item in items: + wkb_item = deepcopy(item) + wkb_item["geometry"] = shapely.to_wkb( + shapely.geometry.shape(wkb_item["geometry"]), flavor="iso" + ) + + # If a proj:geometry key exists in top-level properties, convert that to WKB + if "proj:geometry" in wkb_item["properties"]: + wkb_item["properties"]["proj:geometry"] = shapely.to_wkb( + shapely.geometry.shape(wkb_item["properties"]["proj:geometry"]), + flavor="iso", + ) + + # If a proj:geometry key exists in any asset properties, convert that to WKB + for asset_value in wkb_item["assets"].values(): + if "proj:geometry" in asset_value: + asset_value["proj:geometry"] = shapely.to_wkb( + shapely.geometry.shape(asset_value["proj:geometry"]), + flavor="iso", + ) + + wkb_items.append(wkb_item) + + if schema is not None: + array = pa.array(wkb_items, type=pa.struct(schema)) + else: + array = pa.array(wkb_items) + + return cls(pa.RecordBatch.from_struct_array(array)) + + def iter_dicts(self) -> Iterable[dict]: + batch = self.inner + + # Find all paths in the schema that have a WKB geometry + geometry_paths = [["geometry"]] + try: + batch.schema.field("properties").type.field("proj:geometry") + geometry_paths.append(["properties", "proj:geometry"]) + except KeyError: + pass + + assets_struct = batch.schema.field("assets").type + for asset_idx in range(assets_struct.num_fields): + asset_field = assets_struct.field(asset_idx) + if "proj:geometry" in pa.schema(asset_field).names: + geometry_paths.append(["assets", asset_field.name, "proj:geometry"]) + + # Convert each geometry column to a Shapely geometry, and then assign the + # geojson geometry when converting each row to a dictionary. + geometries: list[NDArray[np.object_]] = [] + for geometry_path in geometry_paths: + col = batch + for path_segment in geometry_path: + if isinstance(col, pa.RecordBatch): + col = col[path_segment] + elif pa.types.is_struct(col.type): + col = pc.struct_field(col, path_segment) # type: ignore + else: + raise AssertionError(f"unexpected type {type(col)}") + + geometries.append(shapely.from_wkb(col)) + + struct_batch = batch.to_struct_array() + for row_idx in range(len(struct_batch)): + row_dict = struct_batch[row_idx].as_py() + for geometry_path, geometry_column in zip(geometry_paths, geometries): + geojson_g = geometry_column[row_idx].__geo_interface__ + geojson_g["coordinates"] = convert_tuples_to_lists( + geojson_g["coordinates"] + ) + set_by_path(row_dict, geometry_path, geojson_g) + + yield row_dict + + def to_clean_batch(self) -> StacArrowBatch: + batch = self.inner + + batch = bring_properties_to_top_level(batch) + batch = convert_timestamp_columns(batch) + batch = convert_bbox_to_struct(batch) + batch = assign_geoarrow_metadata(batch) + + return StacArrowBatch(batch) + + def to_ndjson(self, dest: str | Path | os.PathLike[bytes]) -> None: + with open(dest, "ab") as f: + for item_dict in self.iter_dicts(): + f.write(orjson.dumps(item_dict)) + f.write(b"\n") + + +class StacArrowBatch: + """ + An Arrow RecordBatch of STAC Items that has been processed to match the + STAC-GeoParquet specification. + """ + + inner: pa.RecordBatch + """The underlying pyarrow RecordBatch""" + + def __init__(self, batch: pa.RecordBatch) -> None: + self.inner = batch + + def to_raw_batch(self) -> StacJsonBatch: + batch = self.inner + + batch = convert_timestamp_columns_to_string(batch) + batch = lower_properties_from_top_level(batch) + batch = convert_bbox_to_array(batch) + + return StacJsonBatch(batch) diff --git a/stac_geoparquet/arrow/_from_arrow.py b/stac_geoparquet/arrow/_from_arrow.py index c06b0ad..a0dbfbe 100644 --- a/stac_geoparquet/arrow/_from_arrow.py +++ b/stac_geoparquet/arrow/_from_arrow.py @@ -1,89 +1,13 @@ """Convert STAC Items in Arrow Table format to JSON Lines or Python dicts.""" -import orjson -import operator -import os -from functools import reduce -from typing import Any, Dict, Iterable, List, Sequence, Tuple, Union +from typing import List import numpy as np import pyarrow as pa import pyarrow.compute as pc -import shapely -from numpy.typing import NDArray -import shapely.geometry - - -def stac_batch_to_items(batch: pa.RecordBatch) -> Iterable[dict]: - """Convert a stac arrow recordbatch to item dicts.""" - batch = _undo_stac_transformations(batch) - # Find all paths in the schema that have a WKB geometry - geometry_paths = [["geometry"]] - try: - batch.schema.field("properties").type.field("proj:geometry") - geometry_paths.append(["properties", "proj:geometry"]) - except KeyError: - pass - - assets_struct = batch.schema.field("assets").type - for asset_idx in range(assets_struct.num_fields): - asset_field = assets_struct.field(asset_idx) - if "proj:geometry" in pa.schema(asset_field).names: - geometry_paths.append(["assets", asset_field.name, "proj:geometry"]) - - # Convert each geometry column to a Shapely geometry, and then assign the - # geojson geometry when converting each row to a dictionary. - geometries: List[NDArray[np.object_]] = [] - for geometry_path in geometry_paths: - col = batch - for path_segment in geometry_path: - if isinstance(col, pa.RecordBatch): - col = col[path_segment] - elif pa.types.is_struct(col.type): - col = pc.struct_field(col, path_segment) - else: - raise AssertionError(f"unexpected type {type(col)}") - - geometries.append(shapely.from_wkb(col)) - - struct_batch = batch.to_struct_array() - for row_idx in range(len(struct_batch)): - row_dict = struct_batch[row_idx].as_py() - for geometry_path, geometry_column in zip(geometry_paths, geometries): - geojson_g = geometry_column[row_idx].__geo_interface__ - geojson_g["coordinates"] = convert_tuples_to_lists(geojson_g["coordinates"]) - set_by_path(row_dict, geometry_path, geojson_g) - - yield row_dict - - -def stac_table_to_ndjson(table: pa.Table, dest: Union[str, os.PathLike[str]]) -> None: - """Write a STAC Table to a newline-delimited JSON file.""" - with open(dest, "wb") as f: - for item_dict in stac_table_to_items(table): - f.write(orjson.dumps(item_dict)) - f.write(b"\n") - - -def stac_table_to_items(table: pa.Table) -> Iterable[dict]: - """Convert a STAC Table to a generator of STAC Item `dict`s""" - for batch in table.to_batches(): - yield from stac_batch_to_items(batch) - - -def _undo_stac_transformations(batch: pa.RecordBatch) -> pa.RecordBatch: - """Undo the transformations done to convert STAC Json into an Arrow Table - - Note that this function does _not_ undo the GeoJSON -> WKB geometry transformation, - as that is easier to do when converting each item in the table to a dict. - """ - batch = _convert_timestamp_columns_to_string(batch) - batch = _lower_properties_from_top_level(batch) - batch = _convert_bbox_to_array(batch) - return batch -def _convert_timestamp_columns_to_string(batch: pa.RecordBatch) -> pa.RecordBatch: +def convert_timestamp_columns_to_string(batch: pa.RecordBatch) -> pa.RecordBatch: """Convert any datetime columns in the table to a string representation""" allowed_column_names = { "datetime", # common metadata @@ -102,13 +26,14 @@ def _convert_timestamp_columns_to_string(batch: pa.RecordBatch) -> pa.RecordBatc continue batch = batch.drop_columns((column_name,)).append_column( - column_name, pc.strftime(column, format="%Y-%m-%dT%H:%M:%SZ") + column_name, + pc.strftime(column, format="%Y-%m-%dT%H:%M:%SZ"), # type: ignore ) return batch -def _lower_properties_from_top_level(batch: pa.RecordBatch) -> pa.RecordBatch: +def lower_properties_from_top_level(batch: pa.RecordBatch) -> pa.RecordBatch: """Take properties columns from the top level and wrap them in a struct column""" stac_top_level_keys = { "stac_version", @@ -141,7 +66,7 @@ def _lower_properties_from_top_level(batch: pa.RecordBatch) -> pa.RecordBatch: ) -def _convert_bbox_to_array(batch: pa.RecordBatch) -> pa.RecordBatch: +def convert_bbox_to_array(batch: pa.RecordBatch) -> pa.RecordBatch: """Convert the struct bbox column back to a list column for writing to JSON""" bbox_col_idx = batch.schema.get_field_index("bbox") @@ -191,56 +116,3 @@ def _convert_bbox_to_array(batch: pa.RecordBatch) -> pa.RecordBatch: raise ValueError("Expected 4 or 6 fields in bbox struct.") return batch.set_column(bbox_col_idx, "bbox", list_arr) - - -def convert_tuples_to_lists(t: List | Tuple) -> List[Any]: - """Convert tuples to lists, recursively - - For example, converts: - ``` - ( - ( - (-112.4820566, 38.1261015), - (-112.4816283, 38.1331311), - (-112.4833551, 38.1338897), - (-112.4832919, 38.1307687), - (-112.4855415, 38.1291793), - (-112.4820566, 38.1261015), - ), - ) - ``` - - to - - ```py - [ - [ - [-112.4820566, 38.1261015], - [-112.4816283, 38.1331311], - [-112.4833551, 38.1338897], - [-112.4832919, 38.1307687], - [-112.4855415, 38.1291793], - [-112.4820566, 38.1261015], - ] - ] - ``` - - From https://stackoverflow.com/a/1014669. - """ - return list(map(convert_tuples_to_lists, t)) if isinstance(t, (list, tuple)) else t - - -def get_by_path(root: Dict[str, Any], keys: Sequence[str]) -> Any: - """Access a nested object in root by item sequence. - - From https://stackoverflow.com/a/14692747 - """ - return reduce(operator.getitem, keys, root) - - -def set_by_path(root: Dict[str, Any], keys: Sequence[str], value: Any) -> None: - """Set a value in a nested object in root by item sequence. - - From https://stackoverflow.com/a/14692747 - """ - get_by_path(root, keys[:-1])[keys[-1]] = value # type: ignore diff --git a/stac_geoparquet/arrow/_schema/models.py b/stac_geoparquet/arrow/_schema/models.py index 06fcbd2..17fd169 100644 --- a/stac_geoparquet/arrow/_schema/models.py +++ b/stac_geoparquet/arrow/_schema/models.py @@ -3,7 +3,7 @@ import pyarrow as pa -from stac_geoparquet.arrow._util import stac_items_to_arrow +from stac_geoparquet.arrow._batch import StacJsonBatch from stac_geoparquet.json_reader import read_json_chunked @@ -48,7 +48,7 @@ def update_from_json( def update_from_items(self, items: Sequence[Dict[str, Any]]) -> None: """Update this inferred schema from a sequence of STAC Items.""" self.count += len(items) - current_schema = stac_items_to_arrow(items, schema=None).schema + current_schema = StacJsonBatch.from_dicts(items, schema=None).inner.schema new_schema = pa.unify_schemas( [self.inner, current_schema], promote_options="permissive" ) diff --git a/stac_geoparquet/arrow/_to_arrow.py b/stac_geoparquet/arrow/_to_arrow.py index b99dcec..38e1511 100644 --- a/stac_geoparquet/arrow/_to_arrow.py +++ b/stac_geoparquet/arrow/_to_arrow.py @@ -1,102 +1,197 @@ """Convert STAC data into Arrow tables""" -from pathlib import Path -from typing import ( - Any, - Dict, - Iterable, - Iterator, - Optional, - Union, -) - +import ciso8601 +import numpy as np +import orjson import pyarrow as pa +import pyarrow.compute as pc -from stac_geoparquet.arrow._schema.models import InferredSchema -from stac_geoparquet.json_reader import read_json_chunked -from stac_geoparquet.arrow._util import stac_items_to_arrow, batched_iter - +from stac_geoparquet.arrow._crs import WGS84_CRS_JSON -def parse_stac_items_to_arrow( - items: Iterable[Dict[str, Any]], - *, - chunk_size: int = 8192, - schema: Optional[Union[pa.Schema, InferredSchema]] = None, -) -> Iterable[pa.RecordBatch]: - """Parse a collection of STAC Items to an iterable of :class:`pyarrow.RecordBatch`. - The objects under `properties` are moved up to the top-level of the - Table, similar to :meth:`geopandas.GeoDataFrame.from_features`. +def bring_properties_to_top_level( + batch: pa.RecordBatch, +) -> pa.RecordBatch: + """Bring all the fields inside of the nested "properties" struct to the top level""" + properties_field = batch.schema.field("properties") + properties_column = batch["properties"] - Args: - items: the STAC Items to convert - chunk_size: The chunk size to use for Arrow record batches. This only takes - effect if `schema` is not None. When `schema` is None, the input will be - parsed into a single contiguous record batch. Defaults to 8192. - schema: The schema of the input data. If provided, can improve memory use; - otherwise all items need to be parsed into a single array for schema - inference. Defaults to None. + for field_idx in range(properties_field.type.num_fields): + inner_prop_field = properties_field.type.field(field_idx) + batch = batch.append_column( + inner_prop_field, + pc.struct_field(properties_column, field_idx), # type: ignore + ) - Returns: - an iterable of pyarrow RecordBatches with the STAC-GeoParquet representation of items. - """ - if schema is not None: - if isinstance(schema, InferredSchema): - schema = schema.inner + batch = batch.drop_columns( + [ + "properties", + ] + ) + return batch + + +def convert_timestamp_columns( + batch: pa.RecordBatch, +) -> pa.RecordBatch: + """Convert all timestamp columns from a string to an Arrow Timestamp data type""" + allowed_column_names = { + "datetime", # common metadata + "start_datetime", + "end_datetime", + "created", + "updated", + "expires", # timestamps extension + "published", + "unpublished", + } + for column_name in allowed_column_names: + try: + column = batch[column_name] + except KeyError: + continue + + field_index = batch.schema.get_field_index(column_name) + + if pa.types.is_timestamp(column.type): + continue + + # STAC allows datetimes to be null. If all rows are null, the column type may be + # inferred as null. We cast this to a timestamp column. + elif pa.types.is_null(column.type): + batch = batch.set_column( + field_index, column_name, column.cast(pa.timestamp("us")) + ) + + elif pa.types.is_string(column.type): + batch = batch.set_column( + field_index, column_name, _convert_single_timestamp_column(column) + ) + else: + raise ValueError( + f"Inferred time column '{column_name}' was expected to be a string or" + f" timestamp data type but got {column.type}" + ) + + return batch + + +def _convert_single_timestamp_column(column: pa.Array) -> pa.TimestampArray: + """Convert an individual timestamp column from string to a Timestamp type""" + return pa.array( + [ciso8601.parse_rfc3339(str(t)) for t in column], pa.timestamp("us", tz="UTC") + ) + + +def _is_bbox_3d(bbox_col: pa.Array) -> bool: + """Infer whether the bounding box column represents 2d or 3d bounding boxes.""" + offsets_set = set() + offsets = bbox_col.offsets.to_numpy() + offsets_set.update(np.unique(offsets[1:] - offsets[:-1])) + + if len(offsets_set) > 1: + raise ValueError("Mixed 2d-3d bounding boxes not yet supported") + + offset = list(offsets_set)[0] + if offset == 6: + return True + elif offset == 4: + return False + else: + raise ValueError(f"Unexpected bbox offset: {offset=}") - # If schema is provided, then for better memory usage we parse input STAC items - # to Arrow batches in chunks. - for chunk in batched_iter(items, chunk_size): - yield stac_items_to_arrow(chunk, schema=schema) - else: - # If schema is _not_ provided, then we must convert to Arrow all at once, or - # else it would be possible for a STAC item late in the collection (after the - # first chunk) to have a different schema and not match the schema inferred for - # the first chunk. - yield stac_items_to_arrow(items) - - -def parse_stac_ndjson_to_arrow( - path: Union[str, Path, Iterable[Union[str, Path]]], - *, - chunk_size: int = 65536, - schema: Optional[pa.Schema] = None, - limit: Optional[int] = None, -) -> Iterator[pa.RecordBatch]: - """ - Convert one or more newline-delimited JSON STAC files to a generator of Arrow - RecordBatches. +def convert_bbox_to_struct(batch: pa.RecordBatch) -> pa.RecordBatch: + """Convert bbox column to a struct representation - Each RecordBatch in the returned iterator is guaranteed to have an identical schema, - and can be used to write to one or more Parquet files. + Since the bbox in JSON is stored as an array, pyarrow automatically converts the + bbox column to a ListArray. But according to GeoParquet 1.1, we should save the bbox + column as a StructArray, which allows for Parquet statistics to infer any spatial + partitioning in the dataset. Args: - path: One or more paths to files with STAC items. - chunk_size: The chunk size. Defaults to 65536. - schema: The schema to represent the input STAC data. Defaults to None, in which - case the schema will first be inferred via a full pass over the input data. - In this case, there will be two full passes over the input data: one to - infer a common schema across all data and another to read the data. - - Other args: - limit: The maximum number of JSON Items to use for schema inference - - Yields: - Arrow RecordBatch with a single chunk of Item data. + batch: _description_ + + Returns: + New record batch """ - # If the schema was not provided, then we need to load all data into memory at once - # to perform schema resolution. - if schema is None: - inferred_schema = InferredSchema() - inferred_schema.update_from_json(path, chunk_size=chunk_size, limit=limit) - yield from parse_stac_ndjson_to_arrow( - path, chunk_size=chunk_size, schema=inferred_schema + bbox_col_idx = batch.schema.get_field_index("bbox") + bbox_col = batch.column(bbox_col_idx) + bbox_3d = _is_bbox_3d(bbox_col) + + assert ( + pa.types.is_list(bbox_col.type) + or pa.types.is_large_list(bbox_col.type) + or pa.types.is_fixed_size_list(bbox_col.type) + ) + if bbox_3d: + coords = bbox_col.flatten().to_numpy().reshape(-1, 6) + else: + coords = bbox_col.flatten().to_numpy().reshape(-1, 4) + + if bbox_3d: + xmin = coords[:, 0] + ymin = coords[:, 1] + zmin = coords[:, 2] + xmax = coords[:, 3] + ymax = coords[:, 4] + zmax = coords[:, 5] + + struct_arr = pa.StructArray.from_arrays( + [ + xmin, + ymin, + zmin, + xmax, + ymax, + zmax, + ], + names=[ + "xmin", + "ymin", + "zmin", + "xmax", + "ymax", + "zmax", + ], ) - return - if isinstance(schema, InferredSchema): - schema = schema.inner + else: + xmin = coords[:, 0] + ymin = coords[:, 1] + xmax = coords[:, 2] + ymax = coords[:, 3] + + struct_arr = pa.StructArray.from_arrays( + [ + xmin, + ymin, + xmax, + ymax, + ], + names=[ + "xmin", + "ymin", + "xmax", + "ymax", + ], + ) - for batch in read_json_chunked(path, chunk_size=chunk_size): - yield stac_items_to_arrow(batch, schema=schema) + return batch.set_column(bbox_col_idx, "bbox", struct_arr) + + +def assign_geoarrow_metadata( + batch: pa.RecordBatch, +) -> pa.RecordBatch: + """Tag the primary geometry column with `geoarrow.wkb` on the field metadata.""" + existing_field_idx = batch.schema.get_field_index("geometry") + existing_field = batch.schema.field(existing_field_idx) + ext_metadata = {"crs": WGS84_CRS_JSON} + field_metadata = { + b"ARROW:extension:name": b"geoarrow.wkb", + b"ARROW:extension:metadata": orjson.dumps(ext_metadata), + } + new_field = existing_field.with_metadata(field_metadata) + return batch.set_column( + existing_field_idx, new_field, batch.column(existing_field_idx) + ) diff --git a/stac_geoparquet/arrow/_to_parquet.py b/stac_geoparquet/arrow/_to_parquet.py index 294e216..7197d16 100644 --- a/stac_geoparquet/arrow/_to_parquet.py +++ b/stac_geoparquet/arrow/_to_parquet.py @@ -5,9 +5,9 @@ import pyarrow as pa import pyarrow.parquet as pq -from stac_geoparquet.arrow._schema.models import InferredSchema -from stac_geoparquet.arrow._to_arrow import parse_stac_ndjson_to_arrow +from stac_geoparquet.arrow._api import parse_stac_ndjson_to_arrow from stac_geoparquet.arrow._crs import WGS84_CRS_JSON +from stac_geoparquet.arrow._schema.models import InferredSchema def parse_stac_ndjson_to_parquet( diff --git a/stac_geoparquet/arrow/_util.py b/stac_geoparquet/arrow/_util.py index 5390af5..6b73dac 100644 --- a/stac_geoparquet/arrow/_util.py +++ b/stac_geoparquet/arrow/_util.py @@ -1,23 +1,18 @@ -from copy import deepcopy +import operator +from functools import reduce from typing import ( Any, Dict, Iterable, + List, Optional, Sequence, + Union, ) -import ciso8601 -import numpy as np import pyarrow as pa -import pyarrow.compute as pc -import shapely -import shapely.geometry -import orjson from itertools import islice -from stac_geoparquet.arrow._crs import WGS84_CRS_JSON - def update_batch_schema( batch: pa.RecordBatch, @@ -42,268 +37,54 @@ def batched_iter( return -def stac_items_to_arrow( - items: Iterable[Dict[str, Any]], *, schema: Optional[pa.Schema] = None -) -> pa.RecordBatch: - """Convert dicts representing STAC Items to Arrow - - This converts GeoJSON geometries to WKB before Arrow conversion to allow multiple - geometry types. - - All items will be parsed into a single RecordBatch, meaning that each internal array - is fully contiguous in memory for the length of `items`. - - Args: - items: STAC Items to convert to Arrow - - Kwargs: - schema: An optional schema that describes the format of the data. Note that this - must represent the geometry column as binary type. - - Returns: - Arrow RecordBatch with items in Arrow - """ - # Preprocess GeoJSON to WKB in each STAC item - # Otherwise, pyarrow will try to parse coordinates into a native geometry type and - # if you have multiple geometry types pyarrow will error with - # `ArrowInvalid: cannot mix list and non-list, non-null values` - wkb_items = [] - for item in items: - wkb_item = deepcopy(item) - wkb_item["geometry"] = shapely.to_wkb( - shapely.geometry.shape(wkb_item["geometry"]), flavor="iso" - ) - - # If a proj:geometry key exists in top-level properties, convert that to WKB - if "proj:geometry" in wkb_item["properties"]: - wkb_item["properties"]["proj:geometry"] = shapely.to_wkb( - shapely.geometry.shape(wkb_item["properties"]["proj:geometry"]), - flavor="iso", - ) - - # If a proj:geometry key exists in any asset properties, convert that to WKB - for asset_value in wkb_item["assets"].values(): - if "proj:geometry" in asset_value: - asset_value["proj:geometry"] = shapely.to_wkb( - shapely.geometry.shape(asset_value["proj:geometry"]), - flavor="iso", - ) - - wkb_items.append(wkb_item) - - if schema is not None: - array = pa.array(wkb_items, type=pa.struct(schema)) - else: - array = pa.array(wkb_items) - - return _process_arrow_batch(pa.RecordBatch.from_struct_array(array)) - - -def _bring_properties_to_top_level( - batch: pa.RecordBatch, -) -> pa.RecordBatch: - """Bring all the fields inside of the nested "properties" struct to the top level""" - properties_field = batch.schema.field("properties") - properties_column = batch["properties"] - - for field_idx in range(properties_field.type.num_fields): - inner_prop_field = properties_field.type.field(field_idx) - batch = batch.append_column( - inner_prop_field, pc.struct_field(properties_column, field_idx) - ) - - batch = batch.drop_columns( - [ - "properties", - ] +def convert_tuples_to_lists(t: Union[list, tuple]) -> List[Any]: + """Convert tuples to lists, recursively + + For example, converts: + ``` + ( + ( + (-112.4820566, 38.1261015), + (-112.4816283, 38.1331311), + (-112.4833551, 38.1338897), + (-112.4832919, 38.1307687), + (-112.4855415, 38.1291793), + (-112.4820566, 38.1261015), + ), ) - return batch + ``` + to -def _convert_geometry_to_wkb( - batch: pa.RecordBatch, -) -> pa.RecordBatch: - """Convert the geometry column in the table to WKB""" - geoms = shapely.from_geojson( - [orjson.dumps(item) for item in batch["geometry"].to_pylist()] - ) - wkb_geoms = shapely.to_wkb(geoms) - return batch.drop_columns( + ```py + [ [ - "geometry", + [-112.4820566, 38.1261015], + [-112.4816283, 38.1331311], + [-112.4833551, 38.1338897], + [-112.4832919, 38.1307687], + [-112.4855415, 38.1291793], + [-112.4820566, 38.1261015], ] - ).append_column("geometry", pa.array(wkb_geoms)) - - -def _convert_timestamp_columns( - batch: pa.RecordBatch, -) -> pa.RecordBatch: - """Convert all timestamp columns from a string to an Arrow Timestamp data type""" - allowed_column_names = { - "datetime", # common metadata - "start_datetime", - "end_datetime", - "created", - "updated", - "expires", # timestamps extension - "published", - "unpublished", - } - for column_name in allowed_column_names: - try: - column = batch[column_name] - except KeyError: - continue - - field_index = batch.schema.get_field_index(column_name) - - if pa.types.is_timestamp(column.type): - continue - - # STAC allows datetimes to be null. If all rows are null, the column type may be - # inferred as null. We cast this to a timestamp column. - elif pa.types.is_null(column.type): - batch = batch.set_column( - field_index, column_name, column.cast(pa.timestamp("us")) - ) - - elif pa.types.is_string(column.type): - batch = batch.set_column( - field_index, column_name, _convert_timestamp_column(column) - ) - else: - raise ValueError( - f"Inferred time column '{column_name}' was expected to be a string or" - f" timestamp data type but got {column.type}" - ) - - return batch - - -def _convert_timestamp_column(column: pa.Array) -> pa.TimestampArray: - """Convert an individual timestamp column from string to a Timestamp type""" - return pa.array( - [ciso8601.parse_rfc3339(str(t)) for t in column], pa.timestamp("us", tz="UTC") - ) - - -def _is_bbox_3d(bbox_col: pa.Array) -> bool: - """Infer whether the bounding box column represents 2d or 3d bounding boxes.""" - offsets_set = set() - offsets = bbox_col.offsets.to_numpy() - offsets_set.update(np.unique(offsets[1:] - offsets[:-1])) - - if len(offsets_set) > 1: - raise ValueError("Mixed 2d-3d bounding boxes not yet supported") - - offset = list(offsets_set)[0] - if offset == 6: - return True - elif offset == 4: - return False - else: - raise ValueError(f"Unexpected bbox offset: {offset=}") - + ] + ``` -def _convert_bbox_to_struct(batch: pa.RecordBatch) -> pa.RecordBatch: - """Convert bbox column to a struct representation - - Since the bbox in JSON is stored as an array, pyarrow automatically converts the - bbox column to a ListArray. But according to GeoParquet 1.1, we should save the bbox - column as a StructArray, which allows for Parquet statistics to infer any spatial - partitioning in the dataset. - - Args: - batch: _description_ - - Returns: - New record batch + From https://stackoverflow.com/a/1014669. """ - bbox_col_idx = batch.schema.get_field_index("bbox") - bbox_col = batch.column(bbox_col_idx) - bbox_3d = _is_bbox_3d(bbox_col) - - assert ( - pa.types.is_list(bbox_col.type) - or pa.types.is_large_list(bbox_col.type) - or pa.types.is_fixed_size_list(bbox_col.type) - ) - if bbox_3d: - coords = bbox_col.flatten().to_numpy().reshape(-1, 6) - else: - coords = bbox_col.flatten().to_numpy().reshape(-1, 4) - - if bbox_3d: - xmin = coords[:, 0] - ymin = coords[:, 1] - zmin = coords[:, 2] - xmax = coords[:, 3] - ymax = coords[:, 4] - zmax = coords[:, 5] - - struct_arr = pa.StructArray.from_arrays( - [ - xmin, - ymin, - zmin, - xmax, - ymax, - zmax, - ], - names=[ - "xmin", - "ymin", - "zmin", - "xmax", - "ymax", - "zmax", - ], - ) + return list(map(convert_tuples_to_lists, t)) if isinstance(t, (list, tuple)) else t - else: - xmin = coords[:, 0] - ymin = coords[:, 1] - xmax = coords[:, 2] - ymax = coords[:, 3] - struct_arr = pa.StructArray.from_arrays( - [ - xmin, - ymin, - xmax, - ymax, - ], - names=[ - "xmin", - "ymin", - "xmax", - "ymax", - ], - ) - - return batch.set_column(bbox_col_idx, "bbox", struct_arr) +def get_by_path(root: Dict[str, Any], keys: Sequence[str]) -> Any: + """Access a nested object in root by item sequence. + From https://stackoverflow.com/a/14692747 + """ + return reduce(operator.getitem, keys, root) -def _assign_geoarrow_metadata( - batch: pa.RecordBatch, -) -> pa.RecordBatch: - """Tag the primary geometry column with `geoarrow.wkb` on the field metadata.""" - existing_field_idx = batch.schema.get_field_index("geometry") - existing_field = batch.schema.field(existing_field_idx) - ext_metadata = {"crs": WGS84_CRS_JSON} - field_metadata = { - b"ARROW:extension:name": b"geoarrow.wkb", - b"ARROW:extension:metadata": orjson.dumps(ext_metadata), - } - new_field = existing_field.with_metadata(field_metadata) - return batch.set_column( - existing_field_idx, new_field, batch.column(existing_field_idx) - ) +def set_by_path(root: Dict[str, Any], keys: Sequence[str], value: Any) -> None: + """Set a value in a nested object in root by item sequence. -def _process_arrow_batch(batch: pa.RecordBatch) -> pa.RecordBatch: - batch = _bring_properties_to_top_level(batch) - batch = _convert_timestamp_columns(batch) - batch = _convert_bbox_to_struct(batch) - batch = _assign_geoarrow_metadata(batch) - return batch + From https://stackoverflow.com/a/14692747 + """ + get_by_path(root, keys[:-1])[keys[-1]] = value # type: ignore diff --git a/stac_geoparquet/from_arrow.py b/stac_geoparquet/from_arrow.py index dc19fca..2af5920 100644 --- a/stac_geoparquet/from_arrow.py +++ b/stac_geoparquet/from_arrow.py @@ -5,4 +5,7 @@ FutureWarning, ) -from stac_geoparquet.arrow._from_arrow import * # noqa + +from stac_geoparquet.arrow._api import stac_items_to_arrow # noqa +from stac_geoparquet.arrow._api import stac_table_to_items # noqa +from stac_geoparquet.arrow._api import stac_table_to_ndjson # noqa diff --git a/stac_geoparquet/to_arrow.py b/stac_geoparquet/to_arrow.py index 9b3f81d..2802d6e 100644 --- a/stac_geoparquet/to_arrow.py +++ b/stac_geoparquet/to_arrow.py @@ -1,8 +1,14 @@ +# This doesn't work inline on these imports for some reason +# flake8: noqa: F401 + import warnings +from stac_geoparquet.arrow._api import ( + parse_stac_items_to_arrow, + parse_stac_ndjson_to_arrow, +) + warnings.warn( "stac_geoparquet.to_arrow is deprecated. Please use stac_geoparquet.arrow instead.", FutureWarning, ) - -from stac_geoparquet.arrow._to_arrow import * # noqa diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/json_equals.py b/tests/json_equals.py new file mode 100644 index 0000000..79a3ad5 --- /dev/null +++ b/tests/json_equals.py @@ -0,0 +1,167 @@ +import math +from typing import Any, Dict, Sequence, Union + +from ciso8601 import parse_rfc3339 + + +JsonValue = Union[list, tuple, int, float, dict, str, bool, None] + + +def assert_json_value_equal( + result: JsonValue, + expected: JsonValue, + *, + key_name: str = "root", + precision: float = 0.0001, +) -> None: + """Assert that the JSON value in `result` and `expected` are equal for our purposes. + + We allow these variations between result and expected: + + - We allow numbers to vary up to `precision`. + - We consider `key: None` and a missing key to be equivalent. + - We allow RFC3339 date strings with varying precision levels, as long as they + represent the same parsed datetime. + + Args: + result: The result to assert against. + expected: The expected item to compare against. + key_name: The key name of the current path in the JSON. Used for error messages. + precision: The precision to use for comparing integers and floats. + + Raises: + AssertionError: If the two values are not equal + """ + if isinstance(result, list) and isinstance(expected, list): + assert_sequence_equal(result, expected, key_name=key_name, precision=precision) + + elif isinstance(result, tuple) and isinstance(expected, tuple): + assert_sequence_equal(result, expected, key_name=key_name, precision=precision) + + elif isinstance(result, (int, float)) and isinstance(expected, (int, float)): + assert_number_equal(result, expected, key_name=key_name, precision=precision) + + elif isinstance(result, dict) and isinstance(expected, dict): + assert_dict_equal(result, expected, key_name=key_name, precision=precision) + + elif isinstance(result, str) and isinstance(expected, str): + assert_string_equal(result, expected, key_name=key_name) + + elif isinstance(result, bool) and isinstance(expected, bool): + assert_bool_equal(result, expected, key_name=key_name) + + elif result is None and expected is None: + pass + + else: + raise AssertionError( + f"Mismatched types at {key_name}. {type(result)=}, {type(expected)=}" + ) + + +def assert_sequence_equal( + result: Sequence, expected: Sequence, *, key_name: str, precision: float +) -> None: + """Compare two JSON arrays, recursively""" + assert len(result) == len(expected), ( + f"List at {key_name} has different lengths." f"{len(result)=}, {len(expected)=}" + ) + + for i in range(len(result)): + assert_json_value_equal( + result[i], expected[i], key_name=f"{key_name}.[{i}]", precision=precision + ) + + +def assert_number_equal( + result: Union[int, float], + expected: Union[int, float], + *, + precision: float, + key_name: str, +) -> None: + """Compare two JSON numbers""" + # Allow NaN equality + if math.isnan(result) and math.isnan(expected): + return + + assert abs(result - expected) <= precision, ( + f"Number at {key_name} not within precision. " + f"{result=}, {expected=}, {precision=}." + ) + + +def assert_string_equal( + result: str, + expected: str, + *, + key_name: str, +) -> None: + """Compare two JSON strings. + + We attempt to parse each string to a datetime. If this succeeds, then we compare the + datetime.datetime representations instead of the bare strings. + """ + + # Check if both strings are dates, then assert the parsed datetimes are equal + try: + result_datetime = parse_rfc3339(result) + expected_datetime = parse_rfc3339(expected) + + assert result_datetime == expected_datetime, ( + f"Date string at {key_name} not equal. " + f"{result=}, {expected=}." + f"{result_datetime=}, {expected_datetime=}." + ) + + except ValueError: + assert ( + result == expected + ), f"String at {key_name} not equal. {result=}, {expected=}." + + +def assert_bool_equal( + result: bool, + expected: bool, + *, + key_name: str, +) -> None: + """Compare two JSON booleans.""" + assert result == expected, f"Bool at {key_name} not equal. {result=}, {expected=}." + + +def assert_dict_equal( + result: Dict[str, Any], + expected: Dict[str, Any], + *, + key_name: str, + precision: float, +) -> None: + """ + Assert that two JSON dicts are equal, recursively, allowing missing keys to equal + None. + """ + result_keys = set(result.keys()) + expected_keys = set(expected.keys()) + + # For any keys that exist in result but not expected, assert that the result value + # is None + for key in result_keys - expected_keys: + assert ( + result[key] is None + ), f"Expected key at {key_name} to be None in result. Got {result['key']}" + + # And vice versa + for key in expected_keys - result_keys: + assert ( + expected[key] is None + ), f"Expected key at {key_name} to be None in expected. Got {expected['key']}" + + # For any overlapping keys, assert that their values are equal + for key in result_keys & expected_keys: + assert_json_value_equal( + result[key], + expected[key], + key_name=f"{key_name}.{key}", + precision=precision, + ) diff --git a/tests/test_arrow.py b/tests/test_arrow.py index 2b9bca4..f51787b 100644 --- a/tests/test_arrow.py +++ b/tests/test_arrow.py @@ -1,178 +1,19 @@ import json -import math from pathlib import Path -from typing import Any, Dict, Sequence, Union import pyarrow as pa import pytest -from ciso8601 import parse_rfc3339 -from stac_geoparquet.arrow import parse_stac_items_to_arrow, stac_table_to_items - -HERE = Path(__file__).parent +from stac_geoparquet.arrow import ( + parse_stac_items_to_arrow, + parse_stac_ndjson_to_arrow, + stac_table_to_items, + stac_table_to_ndjson, +) -JsonValue = Union[list, tuple, int, float, dict, str, bool, None] - - -def assert_json_value_equal( - result: JsonValue, - expected: JsonValue, - *, - key_name: str = "root", - precision: float = 0.0001, -) -> None: - """Assert that the JSON value in `result` and `expected` are equal for our purposes. - - We allow these variations between result and expected: - - - We allow numbers to vary up to `precision`. - - We consider `key: None` and a missing key to be equivalent. - - We allow RFC3339 date strings with varying precision levels, as long as they - represent the same parsed datetime. - - Args: - result: The result to assert against. - expected: The expected item to compare against. - key_name: The key name of the current path in the JSON. Used for error messages. - precision: The precision to use for comparing integers and floats. - - Raises: - AssertionError: If the two values are not equal - """ - if isinstance(result, list) and isinstance(expected, list): - assert_sequence_equal(result, expected, key_name=key_name, precision=precision) - - elif isinstance(result, tuple) and isinstance(expected, tuple): - assert_sequence_equal(result, expected, key_name=key_name, precision=precision) - - elif isinstance(result, (int, float)) and isinstance(expected, (int, float)): - assert_number_equal(result, expected, key_name=key_name, precision=precision) - - elif isinstance(result, dict) and isinstance(expected, dict): - assert_dict_equal(result, expected, key_name=key_name, precision=precision) - - elif isinstance(result, str) and isinstance(expected, str): - assert_string_equal(result, expected, key_name=key_name) - - elif isinstance(result, bool) and isinstance(expected, bool): - assert_bool_equal(result, expected, key_name=key_name) - - elif result is None and expected is None: - pass - - else: - raise AssertionError( - f"Mismatched types at {key_name}. {type(result)=}, {type(expected)=}" - ) - - -def assert_sequence_equal( - result: Sequence, expected: Sequence, *, key_name: str, precision: float -) -> None: - """Compare two JSON arrays, recursively""" - assert len(result) == len(expected), ( - f"List at {key_name} has different lengths." f"{len(result)=}, {len(expected)=}" - ) - - for i in range(len(result)): - assert_json_value_equal( - result[i], expected[i], key_name=f"{key_name}.[{i}]", precision=precision - ) - - -def assert_number_equal( - result: Union[int, float], - expected: Union[int, float], - *, - precision: float, - key_name: str, -) -> None: - """Compare two JSON numbers""" - # Allow NaN equality - if math.isnan(result) and math.isnan(expected): - return - - assert abs(result - expected) <= precision, ( - f"Number at {key_name} not within precision. " - f"{result=}, {expected=}, {precision=}." - ) - - -def assert_string_equal( - result: str, - expected: str, - *, - key_name: str, -) -> None: - """Compare two JSON strings. - - We attempt to parse each string to a datetime. If this succeeds, then we compare the - datetime.datetime representations instead of the bare strings. - """ - - # Check if both strings are dates, then assert the parsed datetimes are equal - try: - result_datetime = parse_rfc3339(result) - expected_datetime = parse_rfc3339(expected) - - assert result_datetime == expected_datetime, ( - f"Date string at {key_name} not equal. " - f"{result=}, {expected=}." - f"{result_datetime=}, {expected_datetime=}." - ) - - except ValueError: - assert ( - result == expected - ), f"String at {key_name} not equal. {result=}, {expected=}." - - -def assert_bool_equal( - result: bool, - expected: bool, - *, - key_name: str, -) -> None: - """Compare two JSON booleans.""" - assert result == expected, f"Bool at {key_name} not equal. {result=}, {expected=}." - - -def assert_dict_equal( - result: Dict[str, Any], - expected: Dict[str, Any], - *, - key_name: str, - precision: float, -) -> None: - """ - Assert that two JSON dicts are equal, recursively, allowing missing keys to equal - None. - """ - result_keys = set(result.keys()) - expected_keys = set(expected.keys()) - - # For any keys that exist in result but not expected, assert that the result value - # is None - for key in result_keys - expected_keys: - assert ( - result[key] is None - ), f"Expected key at {key_name} to be None in result. Got {result['key']}" - - # And vice versa - for key in expected_keys - result_keys: - assert ( - expected[key] is None - ), f"Expected key at {key_name} to be None in expected. Got {expected['key']}" - - # For any overlapping keys, assert that their values are equal - for key in result_keys & expected_keys: - assert_json_value_equal( - result[key], - expected[key], - key_name=f"{key_name}.{key}", - precision=precision, - ) +from .json_equals import assert_json_value_equal +HERE = Path(__file__).parent TEST_COLLECTIONS = [ "3dep-lidar-copc", @@ -190,11 +31,8 @@ def assert_dict_equal( ] -@pytest.mark.parametrize( - "collection_id", - TEST_COLLECTIONS, -) -def test_round_trip(collection_id: str): +@pytest.mark.parametrize("collection_id", TEST_COLLECTIONS) +def test_round_trip_read_write(collection_id: str): with open(HERE / "data" / f"{collection_id}-pc.json") as f: items = json.load(f) @@ -205,6 +43,19 @@ def test_round_trip(collection_id: str): assert_json_value_equal(result, expected, precision=0) +@pytest.mark.parametrize("collection_id", TEST_COLLECTIONS) +def test_round_trip_write_read_ndjson(collection_id: str, tmp_path: Path): + # First load into a STAC-GeoParquet table + path = HERE / "data" / f"{collection_id}-pc.json" + table = pa.Table.from_batches(parse_stac_ndjson_to_arrow(path)) + + # Then write to disk + stac_table_to_ndjson(table, tmp_path / "tmp.ndjson") + + # Then read back and assert tables match + table = pa.Table.from_batches(parse_stac_ndjson_to_arrow(tmp_path / "tmp.ndjson")) + + def test_table_contains_geoarrow_metadata(): collection_id = "naip" with open(HERE / "data" / f"{collection_id}-pc.json") as f: @@ -219,19 +70,25 @@ def test_table_contains_geoarrow_metadata(): } +@pytest.mark.parametrize("collection_id", TEST_COLLECTIONS) +def test_parse_json_to_arrow(collection_id: str): + path = HERE / "data" / f"{collection_id}-pc.json" + table = pa.Table.from_batches(parse_stac_ndjson_to_arrow(path)) + items_result = list(stac_table_to_items(table)) + + with open(HERE / "data" / f"{collection_id}-pc.json") as f: + items = json.load(f) + + for result, expected in zip(items_result, items): + assert_json_value_equal(result, expected, precision=0) + + def test_to_arrow_deprecated(): with pytest.warns(FutureWarning): import stac_geoparquet.to_arrow stac_geoparquet.to_arrow.parse_stac_items_to_arrow -def test_to_parquet_deprecated(): - with pytest.warns(FutureWarning): - import stac_geoparquet.to_parquet - - stac_geoparquet.to_parquet.to_parquet - - def test_from_arrow_deprecated(): with pytest.warns(FutureWarning): import stac_geoparquet.from_arrow diff --git a/tests/test_parquet.py b/tests/test_parquet.py new file mode 100644 index 0000000..10dd938 --- /dev/null +++ b/tests/test_parquet.py @@ -0,0 +1,53 @@ +import json +from pathlib import Path + +import pyarrow.parquet as pq +import pytest + +from stac_geoparquet.arrow import parse_stac_ndjson_to_parquet, stac_table_to_items + +from .json_equals import assert_json_value_equal + +HERE = Path(__file__).parent + + +def test_to_parquet_deprecated(): + with pytest.warns(FutureWarning): + import stac_geoparquet.to_parquet + + stac_geoparquet.to_parquet.to_parquet + + +TEST_COLLECTIONS = [ + "3dep-lidar-copc", + "3dep-lidar-dsm", + "cop-dem-glo-30", + "io-lulc-annual-v02", + "io-lulc", + "landsat-c2-l1", + "landsat-c2-l2", + "naip", + "planet-nicfi-analytic", + "sentinel-1-rtc", + "sentinel-2-l2a", + "us-census", +] + + +@pytest.mark.parametrize("collection_id", TEST_COLLECTIONS) +def test_round_trip_via_parquet(collection_id: str, tmp_path: Path): + path = HERE / "data" / f"{collection_id}-pc.json" + out_path = tmp_path / "file.parquet" + # Convert to Parquet + parse_stac_ndjson_to_parquet(path, out_path) + + # Read back into table and convert to json + table = pq.read_table(out_path) + items_result = list(stac_table_to_items(table)) + + # Compare with original json + with open(HERE / "data" / f"{collection_id}-pc.json") as f: + items = json.load(f) + + for result, expected in zip(items_result, items): + assert_json_value_equal(result, expected, precision=0)