Skip to content

Commit

Permalink
Use umi_count instead of duplicate_count in compliance with AIRR v1.4 (
Browse files Browse the repository at this point in the history
…#487)

* Remove index_chains_legacy because the numba implemenation seems to work fine

* Use 'umi_count' instead of 'duplicate_count'

* Fix test
  • Loading branch information
grst authored Jan 29, 2024
1 parent 84f390f commit d776d01
Show file tree
Hide file tree
Showing 7 changed files with 84 additions and 168 deletions.
14 changes: 14 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,20 @@ and this project adheres to [Semantic Versioning][].
[keep a changelog]: https://keepachangelog.com/en/1.0.0/
[semantic versioning]: https://semver.org/spec/v2.0.0.html

## Unreleased

### Backwards-incompatible changes

- Use the `umi_count` field instead of `duplicate_count` to store UMI counts. The field `umi_count` has been added to
the AIRR Rearrangement standard in [version 1.4](https://docs.airr-community.org/en/latest/news.html#version-1-4-1-august-27-2022).
Use of `duplicate_count` for UMI counts is now discouraged. Scirpy will use `umi_count` in all `scirpy.io` functions.
It will _not_ change AIRR data that is read through `scirpy.io.read_airr` that still uses the `duplicate_count` column.
Scirpy remains compatible with datasets that still use `duplicate_count`. You can update your dataset using

```python
adata.obsm["airr"]["umi_count"] = adata.obsm["airr"]["duplicate_count"]
```

## v0.15.0

### Fixes
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ dependencies = [
'igraph != 0.10.0,!=0.10.1',
'networkx>=2.5',
'squarify',
'airr>=1.2',
'airr>=1.4.1',
'tqdm>=4.63', # https://github.com/tqdm/tqdm/issues/1082
'adjustText>=0.7',
'numba>=0.41.0',
Expand Down
30 changes: 8 additions & 22 deletions src/scirpy/io/_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from collections.abc import Collection, Iterable, Sequence
from glob import iglob
from pathlib import Path
from typing import Any, Literal, Union
from typing import Any, Union

import numpy as np
import pandas as pd
Expand Down Expand Up @@ -125,7 +125,7 @@ def _read_10x_vdj_json(
chain["locus"] = chain_type
chain["junction"] = contig["cdr3_seq"]
chain["junction_aa"] = contig["cdr3"]
chain["duplicate_count"] = contig["umi_count"]
chain["umi_count"] = contig["umi_count"]
chain["consensus_count"] = contig["read_count"]
chain["productive"] = contig["productive"]
chain["is_cell"] = contig["is_cell"]
Expand Down Expand Up @@ -166,7 +166,7 @@ def _read_10x_vdj_csv(
locus=chain_series["chain"],
junction_aa=chain_series["cdr3"],
junction=chain_series["cdr3_nt"],
duplicate_count=chain_series["umis"],
umi_count=chain_series["umis"],
consensus_count=chain_series["reads"],
productive=_is_true2(chain_series["productive"]),
v_call=chain_series["v_gene"],
Expand Down Expand Up @@ -352,7 +352,7 @@ def _process_chains(chains, chain_type):
)
def read_airr(
path: Union[str, Sequence[str], Path, Sequence[Path], pd.DataFrame, Sequence[pd.DataFrame]],
use_umi_count_col: Union[bool, Literal["auto"]] = "auto",
use_umi_count_col: None = None, # deprecated, kept for backwards-compatibility
infer_locus: bool = True,
cell_attributes: Collection[str] = DEFAULT_AIRR_CELL_ATTRIBUTES,
include_fields: Any = None,
Expand Down Expand Up @@ -380,10 +380,9 @@ def read_airr(
as a List, e.g. `["path/to/tcr_alpha.tsv", "path/to/tcr_beta.tsv"]`.
Alternatively, this can be a pandas data frame.
use_umi_count_col
Whether to add UMI counts from the non-strandard (but common) `umi_count`
column. When this column is used, the UMI counts are moved over to the
standard `duplicate_count` column. Default: Use `umi_count` if there is
no `duplicate_count` column present.
Deprecated, has no effect as of v0.16. Since v1.4 of the AIRR standard, `umi_count`
is an official field in the Rearrangement schema and preferred over `duplicate_count`.
`umi_count` now always takes precedence over `duplicate_count`.
infer_locus
Try to infer the `locus` column from gene names, in case it is not specified.
cell_attributes
Expand All @@ -409,16 +408,6 @@ def read_airr(
if isinstance(path, (str, Path, pd.DataFrame)):
path: list[Union[str, Path, pd.DataFrame]] = [path] # type: ignore

def _decide_use_umi_count_col(chain_dict):
"""Logic to decide whether or not to use counts form the `umi_counts` column."""
if "umi_count" in chain_dict and use_umi_count_col == "auto" and "duplicate_count" not in chain_dict:
logger.warning("Renaming the non-standard `umi_count` column to `duplicate_count`. ") # type: ignore
return True
elif use_umi_count_col is True:
return True
else:
return False

for tmp_path_or_df in path:
if isinstance(tmp_path_or_df, pd.DataFrame):
iterator = _read_airr_rearrangement_df(tmp_path_or_df)
Expand All @@ -438,9 +427,6 @@ def _decide_use_umi_count_col(chain_dict):
)
airr_cells[cell_id] = tmp_cell

if _decide_use_umi_count_col(chain_dict):
chain_dict["duplicate_count"] = get_rearrangement_schema().to_int(chain_dict.pop("umi_count"))

if infer_locus and "locus" not in chain_dict:
logger.warning(
"`locus` column not found in input data. The locus is being inferred from the {v,d,j,c}_call columns."
Expand Down Expand Up @@ -742,7 +728,7 @@ def _get(row, field):
"junction_aa": _get(row, "CDR3_Translation"),
"productive": row["Productive"],
"consensus_count": row["Read_Count"],
"duplicate_count": row["Molecule_Count"],
"umi_count": row["Molecule_Count"],
}
)
tmp_cell.add_chain(tmp_chain)
Expand Down
139 changes: 11 additions & 128 deletions src/scirpy/pp/_index_chains.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import operator
from collections.abc import Mapping, Sequence
from functools import partial, reduce
from functools import reduce
from types import MappingProxyType
from typing import Any, Callable, Union

Expand All @@ -10,137 +10,14 @@
from scanpy import logging

from scirpy.io._datastructures import AirrCell
from scirpy.util import DataHandler, _is_na2, tqdm
from scirpy.util import DataHandler

SCIRPY_DUAL_IR_MODEL = "scirpy_dual_ir_v0.13"
# make these constants available to numba
_VJ_LOCI = tuple(AirrCell.VJ_LOCI)
_VDJ_LOCI = tuple(AirrCell.VDJ_LOCI)


@DataHandler.inject_param_docs()
def index_chains_legacy(
adata: DataHandler.TYPE,
*,
filter: Union[Callable[[Mapping], bool], Sequence[Union[str, Callable[[Mapping], bool]]]] = (
"productive",
"require_junction_aa",
),
sort_chains_by: Mapping[str, Any] = MappingProxyType(
{"duplicate_count": 0, "consensus_count": 0, "junction": "", "junction_aa": ""}
),
airr_mod: str = "airr",
airr_key: str = "airr",
key_added: str = "chain_indices",
) -> None:
"""\
Selects primary/secondary VJ/VDJ cells per chain according to the :ref:`receptor-model`.
This function iterates through all chains stored in the :term:`awkward array` in
`adata.obsm[airr_key]` and
* labels chains as primary/secondary VJ/VDJ chains
* labels cells as multichain cells
based on the expression level of the chains and the specified filtering option.
By default, non-productive chains and chains without a valid CDR3 amino acid sequence are filtered out.
Additionally, chains without a valid IMGT locus are always filtered out.
For more details, please refer to the :ref:`receptor-model` and the :ref:`data structure <data-structure>`.
Parameters
----------
{adata}
filter
Option to filter chains. Can be either
* a callback function that takes a chain-dictionary as input and returns a boolean (True to keep, False to discard)
* a list of "filtering presets". Possible values are `"productive"` and `"require_junction_aa"`.
`"productive"` removes non-productive chains and `"require_junction_aa"` removes chains that don't have
a CDR3 amino acid sequence.
* a list with a combination of both.
Multiple presets/functions are combined using `and`. Filtered chains do not count towards calling "multichain" cells.
sort_chains_by
A list of sort keys used to determine an ordering of chains. The chain with the highest value
of this tuple will be the primary chain, second-highest the secondary chain. If there are more chains, they
will not be indexed, and the cell receives the "multichain" flag.
{airr_mod}
{airr_key}
key_added
Key under which the chain indicies will be stored in `adata.obsm` and metadata will be stored in `adata.uns`.
Returns
-------
Nothing, but adds a dataframe to `adata.obsm[chain_indices]`
"""
chain_index_list = []
params = DataHandler(adata, airr_mod, airr_key)

# prepare filter functions
if isinstance(filter, Callable):
filter = [filter]
filter_presets = {
"productive": lambda x: x["productive"],
"require_junction_aa": lambda x: not _is_na2(x["junction_aa"]),
}
filter = [filter_presets[f] if isinstance(f, str) else f for f in filter]

# only warn if those fields are in the key (i.e. this should give a warning if those are missing with
# default settings. If the user specifies their own dictionary, they are on their own)
if "duplicate_count" in sort_chains_by and "consensus_count" in sort_chains_by:
if "duplicate_count" not in params.airr.fields and "consensus_count" not in params.airr.fields:
logging.warning("No expression information available. Cannot rank chains by expression. ") # type: ignore

# in chunks of 5000-10000 this is fastest. Not sure why there is additional
# overhead when running `to_list` on the full array. It's anyway friendlier to memory this way.
CHUNKSIZE = 5000
for i in tqdm(range(0, len(params.airr), CHUNKSIZE)):
cells = ak.to_list(params.airr[i : i + CHUNKSIZE])
for cell_chains in cells:
# cell_chains = cast(List[ak.Record], cell_chains)

# Split chains into VJ and VDJ chains
chain_indices: dict[str, Any] = {"VJ": [], "VDJ": []}
for i, tmp_chain in enumerate(cell_chains):
if all(f(tmp_chain) for f in filter) and "locus" in params.airr.fields:
if tmp_chain["locus"] in AirrCell.VJ_LOCI:
chain_indices["VJ"].append(i)
elif tmp_chain["locus"] in AirrCell.VDJ_LOCI:
chain_indices["VDJ"].append(i)

# Order chains by expression (or whatever was specified in sort_chains_by)
for junction_type in ["VJ", "VDJ"]:
chain_indices[junction_type] = sorted(
chain_indices[junction_type],
key=partial(_key_sort_chains, cell_chains, sort_chains_by), # type: ignore
reverse=True,
)

chain_indices["multichain"] = len(chain_indices["VJ"]) > 2 or len(chain_indices["VDJ"]) > 2
chain_index_list.append(chain_indices)

chain_index_awk = ak.Array(chain_index_list)
for k in ["VJ", "VDJ"]:
# ensure the length for VJ and VDJ is exactly 2 (such that it can be sliced later)
# and ensure that the type is always ?int (important if all values are None)
chain_index_awk[k] = ak.values_astype(
ak.pad_none(chain_index_awk[k], 2, axis=1, clip=True),
int,
including_unknown=True,
)

params.adata.obsm[key_added] = chain_index_awk # type: ignore

# store metadata in .uns
params.adata.uns[key_added] = {
"model": SCIRPY_DUAL_IR_MODEL, # can be used to distinguish different receptor models that may be added in the future.
"filter": str(filter),
"airr_key": airr_key,
"sort_chains_by": str(sort_chains_by),
}


@DataHandler.inject_param_docs()
def index_chains(
adata: DataHandler.TYPE,
Expand All @@ -150,7 +27,9 @@ def index_chains(
"require_junction_aa",
),
sort_chains_by: Mapping[str, Any] = MappingProxyType(
{"duplicate_count": 0, "consensus_count": 0, "junction": "", "junction_aa": ""}
# Since AIRR version v1.4.1, `duplicate_count` is deprecated in favor of `umi_count`.
# We still keep it as sort key for backwards compatibility
{"umi_count": 0, "duplicate_count": 0, "consensus_count": 0, "junction": "", "junction_aa": ""}
),
airr_mod: str = "airr",
airr_key: str = "airr",
Expand Down Expand Up @@ -212,8 +91,12 @@ def index_chains(

# only warn if those fields are in the key (i.e. this should give a warning if those are missing with
# default settings. If the user specifies their own dictionary, they are on their own)
if "duplicate_count" in sort_chains_by and "consensus_count" in sort_chains_by:
if "duplicate_count" not in params.airr.fields and "consensus_count" not in params.airr.fields:
if "duplicate_count" in sort_chains_by and "consensus_count" in sort_chains_by and "umi_count" in sort_chains_by:
if (
"duplicate_count" not in params.airr.fields
and "consensus_count" not in params.airr.fields
and "umi_count" not in sort_chains_by
):
logging.warning("No expression information available. Cannot rank chains by expression. ") # type: ignore

if "locus" not in params.airr.fields:
Expand Down
Loading

0 comments on commit d776d01

Please sign in to comment.