Skip to content

Commit

Permalink
added support for multiple targets in plot_colocalization_diff_heatmap
Browse files Browse the repository at this point in the history
  • Loading branch information
ptajvar committed Aug 22, 2024
1 parent a6d7db5 commit 4812fba
Showing 1 changed file with 51 additions and 35 deletions.
86 changes: 51 additions & 35 deletions src/pixelator/plot/spatial_analysis_plots.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,8 +117,8 @@ def _get_top_marker_pairs(

def plot_colocalization_diff_heatmap(
colocalization_data: pd.DataFrame,
target: str,
reference: str,
targets: str | list[str] | None = None,
contrast_column: str = "sample",
markers: Union[list, None] = None,
n_top_marker_pairs: Union[int, None] = None,
Expand All @@ -130,8 +130,8 @@ def plot_colocalization_diff_heatmap(
Example usage: plot_colocalization_diff_heatmap(pxl.colocalization, target:"stimulated", reference:"control", contrast_column="sample").
:param colocalization_data: The colocalization data frame that can be found in a pixel variable "pxl" through pxl.colocalization. The data frame should contain the columns "marker_1", "marker_2", "pearson", "pearson_z", and the contrast_column.
:param target: The label for target components in the contrast_column.
:param reference: The label for reference components in the contrast_column.
:param targets: label or list of labels for target components in the contrast_column.
:param contrast_column: The column to use for the contrast. Defaults to "sample".
:param markers: The markers to include in the heatmap. Defaults to None. At most only one of n_top_marker_pairs or markers should be provided.
:param n_top_marker_pairs: The number of top marker pairs to include in the heatmap. Defaults to None. At most only one of n_top_marker_pairs or markers should be provided.
Expand All @@ -150,6 +150,18 @@ def plot_colocalization_diff_heatmap(
else:
value_col = "pearson"

if isinstance(targets, str):
targets = [targets]
elif targets is None:
targets = colocalization_data[contrast_column].unique()
targets = list(set(targets) - {reference})

if len(targets) > 5:
raise ValueError(
"Only up to 5 target components can be visualized. "
"Number of requested targets is {len(targets)}"
)

if markers is not None:
filter_mask = (colocalization_data["marker_1"].isin(markers)) & (
colocalization_data["marker_2"].isin(markers)
Expand All @@ -158,45 +170,49 @@ def plot_colocalization_diff_heatmap(

differential_colocalization = get_differential_colocalization(
colocalization_data,
target=target,
targets=targets,
reference=reference,
contrast_column=contrast_column,
use_z_score=use_z_score,
)

differential_colocalization = differential_colocalization.fillna(0).reset_index()

if n_top_marker_pairs is not None:
top_markers = _get_top_marker_pairs(
differential_colocalization, n_top_marker_pairs, "median_difference"
if len(targets) == 1:
differential_colocalization = {targets[0]: differential_colocalization}
fig, axes = plt.subplots(1, len(targets), figsize=(5 * len(targets), 5))
for ind, target in enumerate(targets):
ax = axes if len(targets) == 1 else axes[ind]
target_diff = differential_colocalization[target]
target_diff = target_diff.fillna(0).reset_index()

if n_top_marker_pairs is not None:
top_markers = _get_top_marker_pairs(
target_diff, n_top_marker_pairs, "median_difference"
)
else:
top_markers = None

# Making the differential colocalization symmetric
target_diff = _make_colocalization_symmetric(target_diff, "median_difference")

pivoted_target_diff = _pivot_colocalization_data(
target_diff,
"median_difference",
markers=top_markers,
)
else:
top_markers = None

# Making the differential colocalization symmetric
differential_colocalization = _make_colocalization_symmetric(
differential_colocalization, "median_difference"
)

pivoted_differential_colocalization = _pivot_colocalization_data(
differential_colocalization,
"median_difference",
markers=top_markers,
)

max_value = np.max(np.abs(pivoted_differential_colocalization.to_numpy().flatten()))
sns.clustermap(
pivoted_differential_colocalization,
yticklabels=True,
xticklabels=True,
method="complete",
linewidths=0.1,
vmin=-max_value,
vmax=max_value,
cmap=cmap,
)
max_value = np.max(np.abs(pivoted_target_diff.to_numpy().flatten()))
sns.clustermap(
pivoted_target_diff,
yticklabels=True,
xticklabels=True,
method="complete",
linewidths=0.1,
vmin=-max_value,
vmax=max_value,
cmap=cmap,
ax=ax,
)

return plt.gcf(), plt.gca()
return fig, axes


def _add_top_marker_labels(
Expand Down Expand Up @@ -286,7 +302,7 @@ def plot_colocalization_diff_volcano(

differential_colocalization = get_differential_colocalization(
colocalization_data,
target=target,
targets=target,
reference=reference,
contrast_column=contrast_column,
use_z_score=use_z_score,
Expand Down

0 comments on commit 4812fba

Please sign in to comment.