Skip to content

Commit

Permalink
Use dataclasses instead of NamedTuple for displacement, stitched …
Browse files Browse the repository at this point in the history
…outputs (#449)

* Use `dataclasses` instead of `NamedTuple` for displacement, stitched outputs

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
scottstanie and pre-commit-ci[bot] authored Oct 14, 2024
1 parent 3571819 commit 215929b
Show file tree
Hide file tree
Showing 4 changed files with 36 additions and 41 deletions.
2 changes: 1 addition & 1 deletion src/dolphin/unwrap/_post_process.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ def get_2pi_ambiguities(


def interpolate_masked_gaps(
unw: NDArray[np.float_], ifg: NDArray[np.complex64]
unw: NDArray[np.float64], ifg: NDArray[np.complex64]
) -> None:
"""Perform phase unwrapping using nearest neighbor interpolation of ambiguities.
Expand Down
2 changes: 1 addition & 1 deletion src/dolphin/unwrap/_unwrap_3d.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,7 +210,7 @@ def filled_masked_unw_regions(

def _reform_wrapped_phase(
unw_filename: PathOrStr, ifg_filenames: Sequence[PathOrStr]
) -> tuple[NDArray[np.float_], NDArray[np.complex64]]:
) -> tuple[NDArray[np.float64], NDArray[np.complex64]]:
"""Load unwrapped phase, and re-calculate the corresponding wrapped phase.
Finds the matching ifg to `unw_filename`, or uses 2 to compute the correct
Expand Down
55 changes: 24 additions & 31 deletions src/dolphin/workflows/displacement.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@
import multiprocessing as mp
from collections import defaultdict
from concurrent.futures import ProcessPoolExecutor
from dataclasses import dataclass
from pathlib import Path
from typing import NamedTuple

from opera_utils import group_by_burst, group_by_date # , get_dates
from tqdm.auto import tqdm
Expand All @@ -25,8 +25,9 @@
logger = logging.getLogger(__name__)


class OutputPaths(NamedTuple):
"""Named tuple of `DisplacementWorkflow` outputs."""
@dataclass
class OutputPaths:
"""Output files of the `DisplacementWorkflow`."""

comp_slc_dict: dict[str, list[Path]]
stitched_ifg_paths: list[Path]
Expand Down Expand Up @@ -188,15 +189,7 @@ def run(
# Is there one best size? dependent on `half_window` or resolution?
# For now, just pick a reasonable size
corr_window_size = (11, 11)
(
stitched_ifg_paths,
stitched_cor_paths,
stitched_temp_coh_file,
stitched_ps_file,
stitched_amp_dispersion_file,
stitched_shp_count_file,
stitched_similarity_file,
) = stitching_bursts.run(
stitched_paths = stitching_bursts.run(
ifg_file_list=ifg_file_list,
temp_coh_file_list=temp_coh_file_list,
ps_file_list=ps_file_list,
Expand All @@ -217,13 +210,13 @@ def run(
_print_summary(cfg)
return OutputPaths(
comp_slc_dict=comp_slc_dict,
stitched_ifg_paths=stitched_ifg_paths,
stitched_cor_paths=stitched_cor_paths,
stitched_temp_coh_file=stitched_temp_coh_file,
stitched_ps_file=stitched_ps_file,
stitched_amp_dispersion_file=stitched_amp_dispersion_file,
stitched_shp_count_file=stitched_shp_count_file,
stitched_similarity_file=stitched_similarity_file,
stitched_ifg_paths=stitched_paths.ifg_paths,
stitched_cor_paths=stitched_paths.interferometric_corr_paths,
stitched_temp_coh_file=stitched_paths.temp_coh_file,
stitched_ps_file=stitched_paths.ps_file,
stitched_amp_dispersion_file=stitched_paths.amp_dispersion_file,
stitched_shp_count_file=stitched_paths.shp_count_file,
stitched_similarity_file=stitched_paths.similarity_file,
unwrapped_paths=None,
conncomp_paths=None,
timeseries_paths=None,
Expand All @@ -235,9 +228,9 @@ def run(
row_looks, col_looks = cfg.phase_linking.half_window.to_looks()
nlooks = row_looks * col_looks
unwrapped_paths, conncomp_paths = unwrapping.run(
ifg_file_list=stitched_ifg_paths,
cor_file_list=stitched_cor_paths,
temporal_coherence_file=stitched_temp_coh_file,
ifg_file_list=stitched_paths.ifg_paths,
cor_file_list=stitched_paths.interferometric_corr_paths,
temporal_coherence_file=stitched_paths.temp_coh_file,
nlooks=nlooks,
unwrap_options=cfg.unwrap_options,
mask_file=cfg.mask_file,
Expand All @@ -258,8 +251,8 @@ def run(
timeseries_paths, reference_point = timeseries.run(
unwrapped_paths=unwrapped_paths,
conncomp_paths=conncomp_paths,
corr_paths=stitched_cor_paths,
condition_file=stitched_temp_coh_file,
corr_paths=stitched_paths.interferometric_corr_paths,
condition_file=stitched_paths.temp_coh_file,
condition=CallFunc.MAX,
output_dir=ts_opts._directory,
method=timeseries.InversionMethod(ts_opts.method),
Expand Down Expand Up @@ -361,13 +354,13 @@ def run(
_print_summary(cfg)
return OutputPaths(
comp_slc_dict=comp_slc_dict,
stitched_ifg_paths=stitched_ifg_paths,
stitched_cor_paths=stitched_cor_paths,
stitched_temp_coh_file=stitched_temp_coh_file,
stitched_ps_file=stitched_ps_file,
stitched_amp_dispersion_file=stitched_amp_dispersion_file,
stitched_shp_count_file=stitched_shp_count_file,
stitched_similarity_file=stitched_similarity_file,
stitched_ifg_paths=stitched_paths.ifg_paths,
stitched_cor_paths=stitched_paths.interferometric_corr_paths,
stitched_temp_coh_file=stitched_paths.temp_coh_file,
stitched_ps_file=stitched_paths.ps_file,
stitched_amp_dispersion_file=stitched_paths.amp_dispersion_file,
stitched_shp_count_file=stitched_paths.shp_count_file,
stitched_similarity_file=stitched_paths.similarity_file,
unwrapped_paths=unwrapped_paths,
# TODO: Let's keep the unwrapped_paths since all the outputs are
# corresponding to those and if we have a network unwrapping, the
Expand Down
18 changes: 10 additions & 8 deletions src/dolphin/workflows/stitching_bursts.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,9 @@
from __future__ import annotations

import logging
from dataclasses import dataclass
from pathlib import Path
from typing import NamedTuple, Sequence
from typing import Sequence

from dolphin import stitching
from dolphin._log import log_runtime
Expand All @@ -18,22 +19,23 @@
logger = logging.getLogger(__name__)


class StitchedOutputs(NamedTuple):
@dataclass
class StitchedOutputs:
"""Output rasters from stitching step."""

stitched_ifg_paths: list[Path]
ifg_paths: list[Path]
"""List of Paths to the stitched interferograms."""
interferometric_corr_paths: list[Path]
"""List of Paths to interferometric correlation files created."""
stitched_temp_coh_file: Path
temp_coh_file: Path
"""Path to temporal correlation file created."""
stitched_ps_file: Path
ps_file: Path
"""Path to ps mask file created."""
stitched_amp_disp_file: Path
amp_dispersion_file: Path
"""Path to amplitude dispersion file created."""
stitched_shp_count_file: Path
shp_count_file: Path
"""Path to SHP count file created."""
stitched_similarity_file: Path
similarity_file: Path
"""Path to cosine similarity file created."""


Expand Down

0 comments on commit 215929b

Please sign in to comment.