Skip to content

Commit

Permalink
Replace clip_at with breakpoints in clonal_expansion. (#439)
Browse files Browse the repository at this point in the history
* Define API for breakpoints argument in clonal_expansion.

* update changelog

* update plotting module

* Add testcase

* Implement breakpoints in clip_and_count

* Update workflow tests

* Fix test

* Update tutorial

* Fix docs
  • Loading branch information
grst authored Nov 11, 2023
1 parent 3120a9d commit d862cf3
Show file tree
Hide file tree
Showing 8 changed files with 297 additions and 245 deletions.
7 changes: 6 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ and this project adheres to [Semantic Versioning][].
[keep a changelog]: https://keepachangelog.com/en/1.0.0/
[semantic versioning]: https://semver.org/spec/v2.0.0.html

## [Unreleased]
## v0.14.0

### Breaking changes

Expand All @@ -19,6 +19,11 @@ and this project adheres to [Semantic Versioning][].
`lambda x: ~ak.is_none(x["junction_aa"], axis=-1)`. To learn more about native awkward array functions, please
refer to the [awkward array documentation](https://awkward-array.org/doc/main/reference/index.html). ([#444](https://github.com/scverse/scirpy/pull/444))

### Additions

- The `clonal_expansion` function now supports a `breakpoints` argument for more flexible "expansion categories".
The `breakpoints` argument supersedes the `clip_at` parameter, which is now deprecated. ([#439](https://github.com/scverse/scirpy/pull/439))

### Fixes

- Fix that `define_clonotype_clusters` could not retreive `within_group` columns from MuData ([#459](https://github.com/scverse/scirpy/pull/459))
Expand Down
368 changes: 196 additions & 172 deletions docs/tutorials/tutorial_3k_tcr.ipynb

Large diffs are not rendered by default.

24 changes: 18 additions & 6 deletions src/scirpy/pl/_clonal_expansion.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from typing import Literal, Union
from collections.abc import Sequence
from typing import Literal, Optional, Union

from scirpy import tl
from scirpy.util import DataHandler
Expand All @@ -12,8 +13,9 @@ def clonal_expansion(
groupby: str,
*,
target_col: str = "clone_id",
clip_at: int = 3,
expanded_in: Union[str, None] = None,
breakpoints: Sequence[int] = (1, 2),
clip_at: Optional[int] = None,
summarize_by: Literal["cell", "clone_id"] = "cell",
normalize: bool = True,
show_nonexpanded: bool = True,
Expand All @@ -39,14 +41,23 @@ def clonal_expansion(
Group by this categorical variable in `adata.obs`.
target_col
Column in `adata.obs` containing the clonotype information.
clip_at
All entries in `target_col` with more copies than `clip_at`
will be summarized into a single group.
expanded_in
Calculate clonal expansion within groups. To calculate expansion
within patients, set this to the column containing patient annotation.
If set to None, a clonotype counts as expanded if there's any cell of the
same clonotype across the entire dataset. See also :term:`Public clonotype`.
breakpoints
summarize clonotypes with a size smaller or equal than the specified numbers
into groups. For instance, if this is (1, 2, 5), there will be four categories:
* all clonotypes with a size of 1 (singletons)
* all clonotypes with a size of 2
* all clonotypes with a size between 3 and 5 (inclusive)
* all clonotypes with a size > 5
clip_at
This argument is superseded by `breakpoints` and is only kept for backwards-compatibility.
Specifying a value of `clip_at = N` equals to specifying `breakpoints = (1, 2, 3, ..., N)`
Specifying both `clip_at` overrides `breakpoints`.
summarize_by
Can be either `cell` to count cells belonging to a clonotype (the default),
or `clone_id` to count clonotypes. The former leads to a over-representation
Expand All @@ -70,9 +81,10 @@ def clonal_expansion(
summarize_by=summarize_by,
normalize=normalize,
expanded_in=expanded_in,
breakpoints=breakpoints,
clip_at=clip_at,
)
if not show_nonexpanded:
plot_df.drop("1", axis="columns", inplace=True)
plot_df.drop("<= 1", axis="columns", inplace=True)

return {"bar": base.bar, "barh": base.barh}[viztype](plot_df, **kwargs)
Binary file not shown.
Binary file not shown.
7 changes: 7 additions & 0 deletions src/scirpy/tests/test_plotting.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,13 @@ def test_clonal_expansion(adata_clonotype):
assert isinstance(p, plt.Axes)


@pytest.mark.parametrize("adata_clonotype", [True], indirect=["adata_clonotype"], ids=["MuData"])
def test_clonal_expansion_mudata_prefix(adata_clonotype):
"""Regression test for #445"""
p = pl.clonal_expansion(adata_clonotype, groupby="group", target_col="airr:clone_id")
assert isinstance(p, plt.Axes)


def test_alpha_diversity(adata_diversity):
p = pl.alpha_diversity(adata_diversity, groupby="group", target_col="clonotype_")
assert isinstance(p, plt.Axes)
Expand Down
77 changes: 28 additions & 49 deletions src/scirpy/tests/test_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import pytest
import scanpy as sc
from mudata import MuData
from pytest import approx

import scirpy as ir
from scirpy.util import DataHandler
Expand Down Expand Up @@ -106,13 +107,13 @@ def test_clip_and_count_clonotypes(adata_clonotype):
adata = adata_clonotype

res = ir.tl._clonal_expansion._clip_and_count(
adata, groupby="group", target_col="clone_id", clip_at=2, inplace=False
adata, groupby="group", target_col="clone_id", breakpoints=(1,), inplace=False
)
npt.assert_equal(res, np.array([">= 2"] * 3 + ["nan"] * 2 + ["1"] * 3 + [">= 2"] * 2))
npt.assert_equal(res, np.array(["> 1"] * 3 + ["nan"] * 2 + ["<= 1"] * 3 + ["> 1"] * 2))

# check without group
res = ir.tl._clonal_expansion._clip_and_count(adata, target_col="clone_id", clip_at=5, inplace=False)
npt.assert_equal(res, np.array(["4"] * 3 + ["nan"] * 2 + ["4"] + ["1"] * 2 + ["2"] * 2))
res = ir.tl._clonal_expansion._clip_and_count(adata, target_col="clone_id", breakpoints=(1, 2, 4), inplace=False)
npt.assert_equal(res, np.array(["<= 4"] * 3 + ["nan"] * 2 + ["<= 4"] + ["<= 1"] * 2 + ["<= 2"] * 2))

# check if target_col works
params = DataHandler.default(adata)
Expand All @@ -123,45 +124,35 @@ def test_clip_and_count_clonotypes(adata_clonotype):
adata,
groupby="group",
target_col="new_col",
clip_at=2,
breakpoints=(1,),
)
npt.assert_equal(
params.adata.obs["new_col_clipped_count"],
np.array([">= 2"] * 3 + ["nan"] * 2 + ["1"] * 3 + [">= 2"] * 2),
np.array(["> 1"] * 3 + ["nan"] * 2 + ["<= 1"] * 3 + ["> 1"] * 2),
)

# check if it raises value error if target_col does not exist
with pytest.raises(ValueError):
ir.tl._clonal_expansion._clip_and_count(
adata,
groupby="group",
target_col="clone_id",
clip_at=2,
fraction=False,
)


@pytest.mark.parametrize(
"expanded_in,expected",
[
("group", [">= 2"] * 3 + ["nan"] * 2 + ["1"] * 3 + [">= 2"] * 2),
(None, [">= 2"] * 3 + ["nan"] * 2 + [">= 2"] + ["1"] * 2 + [">= 2"] * 2),
("group", ["> 1"] * 3 + ["nan"] * 2 + ["<= 1"] * 3 + ["> 1"] * 2),
(None, ["> 1"] * 3 + ["nan"] * 2 + ["> 1"] + ["<= 1"] * 2 + ["> 1"] * 2),
],
)
def test_clonal_expansion(adata_clonotype, expanded_in, expected):
res = ir.tl.clonal_expansion(adata_clonotype, expanded_in=expanded_in, clip_at=2, inplace=False)
res = ir.tl.clonal_expansion(adata_clonotype, expanded_in=expanded_in, breakpoints=(1,), inplace=False)
npt.assert_equal(res, np.array(expected))


def test_clonal_expansion_summary(adata_clonotype):
res = ir.tl.summarize_clonal_expansion(adata_clonotype, "group", target_col="clone_id", clip_at=2, normalize=True)
pdt.assert_frame_equal(
res,
pd.DataFrame.from_dict({"group": ["A", "B"], "1": [0, 2 / 5], ">= 2": [1.0, 3 / 5]}).set_index("group"),
check_names=False,
check_index_type=False,
check_categorical=False,
res = ir.tl.summarize_clonal_expansion(
adata_clonotype, "group", target_col="clone_id", breakpoints=(1,), normalize=True
)
assert res.reset_index().to_dict(orient="list") == {
"group": ["A", "B"],
"<= 1": [0, approx(0.4)],
"> 1": [1.0, approx(0.6)],
}

# test the `expanded_in` parameter.
res = ir.tl.summarize_clonal_expansion(
Expand All @@ -172,13 +163,11 @@ def test_clonal_expansion_summary(adata_clonotype):
normalize=True,
expanded_in="group",
)
pdt.assert_frame_equal(
res,
pd.DataFrame.from_dict({"group": ["A", "B"], "1": [0, 3 / 5], ">= 2": [1.0, 2 / 5]}).set_index("group"),
check_names=False,
check_index_type=False,
check_categorical=False,
)
assert res.reset_index().to_dict(orient="list") == {
"group": ["A", "B"],
"<= 1": [0, approx(0.6)],
"> 1": [1.0, approx(0.4)],
}

# test the `summarize_by` parameter.
res = ir.tl.summarize_clonal_expansion(
Expand All @@ -189,26 +178,16 @@ def test_clonal_expansion_summary(adata_clonotype):
normalize=True,
summarize_by="clone_id",
)
pdt.assert_frame_equal(
res,
pd.DataFrame.from_dict({"group": ["A", "B"], "1": [0, 2 / 4], ">= 2": [1.0, 2 / 4]}).set_index("group"),
check_names=False,
check_index_type=False,
check_categorical=False,
)
assert res.reset_index().to_dict(orient="list") == {
"group": ["A", "B"],
"<= 1": [0, approx(0.5)],
"> 1": [1.0, approx(0.5)],
}

res_counts = ir.tl.summarize_clonal_expansion(
adata_clonotype, "group", target_col="clone_id", clip_at=2, normalize=False
)
print(res_counts)
pdt.assert_frame_equal(
res_counts,
pd.DataFrame.from_dict({"group": ["A", "B"], "1": [0, 2], ">= 2": [3, 3]}).set_index("group"),
check_names=False,
check_dtype=False,
check_index_type=False,
check_categorical=False,
)
assert res_counts.reset_index().to_dict(orient="list") == {"group": ["A", "B"], "<= 1": [0, 2], "> 1": [3, 3]}


@pytest.mark.extra
Expand Down
59 changes: 42 additions & 17 deletions src/scirpy/tl/_clonal_expansion.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
from typing import Literal, Union
import warnings
from collections.abc import Sequence
from typing import Literal, Optional, Union

import numpy as np
import pandas as pd

from scirpy.util import DataHandler, _is_na, _normalize_counts
Expand All @@ -10,10 +13,9 @@ def _clip_and_count(
target_col: str,
*,
groupby: Union[str, None, list[str]] = None,
clip_at: int = 3,
breakpoints: Sequence[int] = (1, 2, 3),
inplace: bool = True,
key_added: Union[str, None] = None,
fraction: bool = True,
airr_mod="airr",
) -> Union[None, pd.Series]:
"""Counts the number of identical entries in `target_col`
Expand All @@ -22,22 +24,32 @@ def _clip_and_count(
`nan`s in the input remain `nan` in the output.
"""
params = DataHandler(adata, airr_mod)
if target_col not in params.adata.obs.columns:
raise ValueError("`target_col` not found in obs.")
if not len(breakpoints):
raise ValueError("Need to specify at least one breakpoint.")

categories = [f"<= {b}" for b in breakpoints] + [f"> {breakpoints[-1]}", "nan"]

@np.vectorize
def _get_interval(value: int) -> str:
"""Return the interval of `value`, given breakpoints."""
for b in breakpoints:
if value <= b:
return f"<= {b}"
return f"> {b}"

groupby = [groupby] if isinstance(groupby, str) else groupby
groupby_cols = [target_col] if groupby is None else groupby + [target_col]
obs = params.get_obs(groupby_cols)

clonotype_counts = (
params.adata.obs.groupby(groupby_cols, observed=True)
obs.groupby(groupby_cols, observed=True)
.size()
.reset_index(name="tmp_count")
.assign(
tmp_count=lambda X: [f">= {min(n, clip_at)}" if n >= clip_at else str(n) for n in X["tmp_count"].values]
)
.assign(tmp_count=lambda X: pd.Categorical(_get_interval(X["tmp_count"].values), categories=categories))
)
clipped_count = params.adata.obs.merge(clonotype_counts, how="left", on=groupby_cols)["tmp_count"]
clipped_count[_is_na(params.adata.obs[target_col])] = "nan"
clipped_count.index = params.adata.obs.index
clipped_count = obs.merge(clonotype_counts, how="left", on=groupby_cols)["tmp_count"]
clipped_count[_is_na(obs[target_col])] = "nan"
clipped_count.index = obs.index

if inplace:
key_added = f"{target_col}_clipped_count" if key_added is None else key_added
Expand All @@ -52,7 +64,8 @@ def clonal_expansion(
*,
target_col: str = "clone_id",
expanded_in: Union[str, None] = None,
clip_at: int = 3,
breakpoints: Sequence[int] = (1, 2),
clip_at: Optional[int] = None,
key_added: str = "clonal_expansion",
inplace: bool = True,
**kwargs,
Expand All @@ -72,9 +85,18 @@ def clonal_expansion(
this to the column containing sample annotation. If set to None,
a clonotype counts as expanded if there's any cell of the same clonotype
across the entire dataset.
clip_at:
All clonotypes with more than `clip_at` clones will be summarized into
a single category
breakpoints
summarize clonotypes with a size smaller or equal than the specified numbers
into groups. For instance, if this is (1, 2, 5), there will be four categories:
* all clonotypes with a size of 1 (singletons)
* all clonotypes with a size of 2
* all clonotypes with a size between 3 and 5 (inclusive)
* all clonotypes with a size > 5
clip_at
This argument is superseded by `breakpoints` and is only kept for backwards-compatibility.
Specifying a value of `clip_at = N` equals to specifying `breakpoints = (1, 2, 3, ..., N)`
Specifying both `clip_at` overrides `breakpoints`.
{key_added}
{inplace}
{airr_mod}
Expand All @@ -84,11 +106,14 @@ def clonal_expansion(
Depending on the value of inplace, adds a column to adata or returns
a Series with the clipped count per cell.
"""
if clip_at is not None:
breakpoints = list(range(1, clip_at))
warnings.warn("The argument `clip_at` is deprecated. Please use `brekpoints` instead.", category=FutureWarning)
return _clip_and_count(
adata,
target_col,
groupby=expanded_in,
clip_at=clip_at,
breakpoints=breakpoints,
key_added=key_added,
inplace=inplace,
**kwargs,
Expand Down

0 comments on commit d862cf3

Please sign in to comment.