Skip to content

Commit

Permalink
Add "depth" column to the discarded_edgelist
Browse files Browse the repository at this point in the history
  • Loading branch information
ptajvar committed Oct 22, 2024
1 parent 98da20f commit d080994
Show file tree
Hide file tree
Showing 3 changed files with 38 additions and 52 deletions.
50 changes: 14 additions & 36 deletions src/pixelator/graph/community_detection.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,28 +148,23 @@ def connect_components(
logger.debug("Calculating raw edgelist metrics")

if multiplet_recovery:
recovered_node_component_map, component_refinement_history = (
recover_technical_multiplets(
edgelist=edges,
node_component_map=node_component_map.astype(np.int64),
max_refinement_recursion_depth=max_refinement_recursion_depth,
max_edges_to_split=max_edges_to_split,
)
)

# save the recovered components info to a file
component_refinement_history.to_csv(
Path(output) / f"{sample_name}.components_recovered.csv"
recovered_node_component_map, node_depth_map = recover_technical_multiplets(
edgelist=edges,
node_component_map=node_component_map.astype(np.int64),
max_refinement_recursion_depth=max_refinement_recursion_depth,
max_edges_to_split=max_edges_to_split,
)
else:
recovered_node_component_map = node_component_map
node_depth_map = pd.Series(index=node_component_map.index, data=0)

del edges

# assign component column to edge list
edgelist_with_component_info = map_upis_to_components(
edgelist=edgelist,
node_component_map=recovered_node_component_map.astype(np.int64),
node_depth_map=node_depth_map.astype(np.int64),
)
remaining_edgelist, removed_edgelist = split_remaining_and_removed_edgelist(
edgelist_with_component_info
Expand Down Expand Up @@ -249,7 +244,7 @@ def recover_technical_multiplets(
node_component_map: pd.Series,
max_refinement_recursion_depth: int = 5,
max_edges_to_split: int = 5,
) -> Tuple[pd.Series, pd.DataFrame]:
) -> Tuple[pd.Series, pd.Series]:
"""Perform component recovery by deleting spurious edges.
The molecular pixelation assay may under some conditions introduce spurious
Expand Down Expand Up @@ -277,9 +272,9 @@ def recover_technical_multiplets(
into smaller components during the recovery process.
:param max_edges_to_split: The maximum number of edges between the product components
when splitting during multiplet recovery.
:return: A tuple with the updated node component map and the history of component
breakdowns.
:rtype: Tuple[pd.Series, pd.DataFrame]
:return: A tuple with the updated node component map and the iteration depth at which each
node is re-assigned to a component.
:rtype: Tuple[pd.Series, pd.Series]
"""
logger.debug(
"Starting multiplets recovery in edge list with %i rows",
Expand All @@ -296,8 +291,8 @@ def id_generator(start=0):
comp_sizes = node_component_map.groupby(node_component_map).count()

n_edges_to_remove = 0
community_annotation_history = []
to_be_refined_next = comp_sizes[comp_sizes > MIN_PIXELS_TO_REFINE].index
node_depth_map = pd.Series(index=node_component_map.index, data=0)
for depth in range(max_refinement_recursion_depth):
edgelist["component_a"] = node_component_map[edgelist["upia"]].values
edgelist["component_b"] = node_component_map[edgelist["upib"]].values
Expand All @@ -312,10 +307,6 @@ def id_generator(start=0):
component_edgelist["component_b"] == component
].sort_values(["upia", "upib"])

component_nodes = list(
set(component_edgelist["upia"]).union(set(component_edgelist["upib"]))
)

edgelist_tuple = list(
map(tuple, np.array(component_edgelist[["upia", "upib", "len"]]))
)
Expand Down Expand Up @@ -343,6 +334,7 @@ def id_generator(start=0):
!= component_edgelist["upib_community"]
).sum()
community_size_map = community_serie.groupby(community_serie).count()
node_depth_map[community_serie.index] = depth + 1

if (community_size_map > MIN_PIXELS_TO_REFINE).sum() > 1:
further_refinement = True
Expand All @@ -354,32 +346,18 @@ def id_generator(start=0):
node_component_map[
community_serie[community_serie == new_community].index
] = new_id
community_annotation_history.append(
(
component,
new_id,
len(component_nodes),
community_size_map[new_community],
depth,
)
)
if (
further_refinement
and community_size_map[new_community] > MIN_PIXELS_TO_REFINE
):
to_be_refined_next.append(new_id)

component_refinement_history = pd.DataFrame(
community_annotation_history,
columns=["old", "new", "old_size", "new_size", "depth"],
).astype(int)

logger.info(
"Obtained %i components after removing %i edges",
node_component_map.nunique(),
n_edges_to_remove,
)
return node_component_map, component_refinement_history
return node_component_map, node_depth_map


def write_recovered_components(
Expand Down
20 changes: 20 additions & 0 deletions src/pixelator/graph/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -378,6 +378,7 @@ def edgelist_metrics(
def map_upis_to_components(
edgelist: pl.LazyFrame,
node_component_map: pd.Series,
node_depth_map: Optional[pd.Series] = None,
) -> pl.LazyFrame:
"""Update the edgelist with component names corresponding to upia/upib.
Expand Down Expand Up @@ -419,6 +420,22 @@ def map_upis_to_components(
default="",
),
)
if node_depth_map is not None:
node_depth_dict = node_depth_map.to_dict()
edgelist_with_component_info = (
edgelist_with_component_info.with_columns(
upia_depth=pl.col("upia")
.cast(pl.String)
.replace_strict(node_depth_dict, default=0),
upib_depth=pl.col("upib")
.cast(pl.String)
.replace_strict(node_depth_dict, default=0),
)
.with_columns(
depth=pl.min_horizontal([pl.col("upia_depth"), pl.col("upib_depth")])
)
.drop(["upia_depth", "upib_depth"])
)

return edgelist_with_component_info

Expand Down Expand Up @@ -507,6 +524,9 @@ def split_remaining_and_removed_edgelist(
.rename({"component_a": "component"})
.drop("component_b")
)
if "depth" in remaining_edgelist.columns:
remaining_edgelist = remaining_edgelist.drop("depth")

removed_edgelist = edgelist.filter(
(pl.col("component_a") == "") | (pl.col("component_a") != pl.col("component_b"))
)
Expand Down
20 changes: 4 additions & 16 deletions tests/graph/test_community_detection.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,18 +87,12 @@ def test_recovery_technical_multiplets(
.reset_index()
.rename(columns={"count": "len"})
)
result, info = recover_technical_multiplets(
result, depth_info = recover_technical_multiplets(
edgelist=edges,
node_component_map=node_component_map,
)
assert result.nunique() == 2
expected_info = pd.DataFrame(
{
0: {"old": 0, "new": 1, "old_size": 200, "new_size": 100, "depth": 0},
1: {"old": 0, "new": 2, "old_size": 200, "new_size": 100, "depth": 0},
}
).T
assert_frame_equal(info, expected_info)
assert set(depth_info.unique()) == {1}


def test_recovery_technical_multiplets_benchmark(
Expand All @@ -119,16 +113,10 @@ def test_recovery_technical_multiplets_benchmark(
)
)
node_component_map[:] = 0
result, info = benchmark(
result, depth_info = benchmark(
recover_technical_multiplets,
edgelist=edges,
node_component_map=node_component_map,
)
assert result.nunique() == 2
expected_info = pd.DataFrame(
{
0: {"old": 0, "new": 1, "old_size": 200, "new_size": 100, "depth": 0},
1: {"old": 0, "new": 2, "old_size": 200, "new_size": 100, "depth": 0},
}
).T
assert_frame_equal(info, expected_info)
assert set(depth_info.unique()) == {1}

0 comments on commit d080994

Please sign in to comment.