From 4812fbae9c661be06c4fb2d38518a1a6d92be574 Mon Sep 17 00:00:00 2001 From: ptajvar Date: Thu, 22 Aug 2024 15:19:01 +0200 Subject: [PATCH] added support for multiple targets in plot_colocalization_diff_heatmap --- src/pixelator/plot/spatial_analysis_plots.py | 86 ++++++++++++-------- 1 file changed, 51 insertions(+), 35 deletions(-) diff --git a/src/pixelator/plot/spatial_analysis_plots.py b/src/pixelator/plot/spatial_analysis_plots.py index 16296fb8..73eb82d5 100644 --- a/src/pixelator/plot/spatial_analysis_plots.py +++ b/src/pixelator/plot/spatial_analysis_plots.py @@ -117,8 +117,8 @@ def _get_top_marker_pairs( def plot_colocalization_diff_heatmap( colocalization_data: pd.DataFrame, - target: str, reference: str, + targets: str | list[str] | None = None, contrast_column: str = "sample", markers: Union[list, None] = None, n_top_marker_pairs: Union[int, None] = None, @@ -130,8 +130,8 @@ def plot_colocalization_diff_heatmap( Example usage: plot_colocalization_diff_heatmap(pxl.colocalization, target:"stimulated", reference:"control", contrast_column="sample"). :param colocalization_data: The colocalization data frame that can be found in a pixel variable "pxl" through pxl.colocalization. The data frame should contain the columns "marker_1", "marker_2", "pearson", "pearson_z", and the contrast_column. - :param target: The label for target components in the contrast_column. :param reference: The label for reference components in the contrast_column. + :param targets: label or list of labels for target components in the contrast_column. :param contrast_column: The column to use for the contrast. Defaults to "sample". :param markers: The markers to include in the heatmap. Defaults to None. At most only one of n_top_marker_pairs or markers should be provided. :param n_top_marker_pairs: The number of top marker pairs to include in the heatmap. Defaults to None. At most only one of n_top_marker_pairs or markers should be provided. @@ -150,6 +150,18 @@ def plot_colocalization_diff_heatmap( else: value_col = "pearson" + if isinstance(targets, str): + targets = [targets] + elif targets is None: + targets = colocalization_data[contrast_column].unique() + targets = list(set(targets) - {reference}) + + if len(targets) > 5: + raise ValueError( + "Only up to 5 target components can be visualized. " + "Number of requested targets is {len(targets)}" + ) + if markers is not None: filter_mask = (colocalization_data["marker_1"].isin(markers)) & ( colocalization_data["marker_2"].isin(markers) @@ -158,45 +170,49 @@ def plot_colocalization_diff_heatmap( differential_colocalization = get_differential_colocalization( colocalization_data, - target=target, + targets=targets, reference=reference, contrast_column=contrast_column, use_z_score=use_z_score, ) - - differential_colocalization = differential_colocalization.fillna(0).reset_index() - - if n_top_marker_pairs is not None: - top_markers = _get_top_marker_pairs( - differential_colocalization, n_top_marker_pairs, "median_difference" + if len(targets) == 1: + differential_colocalization = {targets[0]: differential_colocalization} + fig, axes = plt.subplots(1, len(targets), figsize=(5 * len(targets), 5)) + for ind, target in enumerate(targets): + ax = axes if len(targets) == 1 else axes[ind] + target_diff = differential_colocalization[target] + target_diff = target_diff.fillna(0).reset_index() + + if n_top_marker_pairs is not None: + top_markers = _get_top_marker_pairs( + target_diff, n_top_marker_pairs, "median_difference" + ) + else: + top_markers = None + + # Making the differential colocalization symmetric + target_diff = _make_colocalization_symmetric(target_diff, "median_difference") + + pivoted_target_diff = _pivot_colocalization_data( + target_diff, + "median_difference", + markers=top_markers, ) - else: - top_markers = None - # Making the differential colocalization symmetric - differential_colocalization = _make_colocalization_symmetric( - differential_colocalization, "median_difference" - ) - - pivoted_differential_colocalization = _pivot_colocalization_data( - differential_colocalization, - "median_difference", - markers=top_markers, - ) - - max_value = np.max(np.abs(pivoted_differential_colocalization.to_numpy().flatten())) - sns.clustermap( - pivoted_differential_colocalization, - yticklabels=True, - xticklabels=True, - method="complete", - linewidths=0.1, - vmin=-max_value, - vmax=max_value, - cmap=cmap, - ) + max_value = np.max(np.abs(pivoted_target_diff.to_numpy().flatten())) + sns.clustermap( + pivoted_target_diff, + yticklabels=True, + xticklabels=True, + method="complete", + linewidths=0.1, + vmin=-max_value, + vmax=max_value, + cmap=cmap, + ax=ax, + ) - return plt.gcf(), plt.gca() + return fig, axes def _add_top_marker_labels( @@ -286,7 +302,7 @@ def plot_colocalization_diff_volcano( differential_colocalization = get_differential_colocalization( colocalization_data, - target=target, + targets=target, reference=reference, contrast_column=contrast_column, use_z_score=use_z_score,