diff --git a/src/sweets/_missing_data.py b/src/sweets/_missing_data.py index 7083f12..11e667b 100644 --- a/src/sweets/_missing_data.py +++ b/src/sweets/_missing_data.py @@ -25,7 +25,10 @@ def get_geodataframe( - gslc_files: Iterable[Filename], max_workers: int = 5, one_per_burst: bool = True + gslc_files: Iterable[Filename], + max_workers: int = 5, + one_per_burst: bool = True, + polygons: Optional[Sequence[geometry.Polygon]] = None, ) -> gpd.GeoDataFrame: """Get a GeoDataFrame of the CSLC footprints. @@ -37,24 +40,29 @@ def get_geodataframe( Number of threads to use. one_per_burst : bool, default=True If True, only keep one footprint per burst ID. + polygons : Sequence[shapely.geometry.Polygon], optional + If provided, skips computing them from the CSLCs. + Otherwise, will read them in with `get_cslc_polygon`. """ gslc_files = list(gslc_files) # make sure generator doesn't deplete after first run - if one_per_burst: - from opera_utils import group_by_burst - - burst_to_file_list = group_by_burst(gslc_files) - slc_files = [file_list[0] for file_list in burst_to_file_list.values()] - unique_polygons = thread_map( - get_cslc_polygon, slc_files, max_workers=max_workers - ) - assert len(unique_polygons) == len(burst_to_file_list) - # Repeat the polygons for each burst - polygons: list[geometry.Polygon] = [] - for burst_id, p in zip(burst_to_file_list, unique_polygons): - for _ in range(len(burst_to_file_list[burst_id])): - polygons.append(p) - else: - polygons = thread_map(get_cslc_polygon, gslc_files, max_workers=max_workers) + if not polygons: + polys: list[geometry.Polygon] = [] + if one_per_burst: + from opera_utils import group_by_burst + + burst_to_file_list = group_by_burst(gslc_files) + slc_files = [file_list[0] for file_list in burst_to_file_list.values()] + unique_polygons = thread_map( + get_cslc_polygon, slc_files, max_workers=max_workers + ) + assert len(unique_polygons) == len(burst_to_file_list) + # Repeat the polys for each burst + for burst_id, p in zip(burst_to_file_list, unique_polygons): + for _ in range(len(burst_to_file_list[burst_id])): + polys.append(p) + else: + polys = thread_map(get_cslc_polygon, gslc_files, max_workers=max_workers) + polygons = polys gdf = gpd.GeoDataFrame(geometry=polygons, crs="EPSG:4326") gdf["count"] = 1