From 31589c677bc7cb982467159d275436a35cb69904 Mon Sep 17 00:00:00 2001 From: Gregor Sturm Date: Wed, 1 Nov 2023 18:21:38 +0100 Subject: [PATCH] Speed up index_chains (#444) * Add MWE for index_chains in numba * Try different approach, but it seems worse * Another (buggy) MWE for a better implementation * Apparently fully working vectorized implementation of index_chains * Tolerate missing sort key * Document code * fix tests * Update changelog * Update changelog --- CHANGELOG.md | 9 ++ docs/tutorials/tutorial_3k_tcr.ipynb | 4 +- src/scirpy/pp/_index_chains.py | 173 ++++++++++++++++++++++++++- 3 files changed, 181 insertions(+), 5 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 51d4cdf8f..ca342e0c5 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -10,6 +10,15 @@ and this project adheres to [Semantic Versioning][]. ## [Unreleased] +### Breaking changes + +- Reimplement `pp.index_chains` using numba and awkward array functions, achieving a significant speedup. This function + behaves exactly like the previous version _except_ that callback functions passed to the `filter` arguments + must now be vectorized over an awkward array, e.g. to check if a `junction_aa` field is present you could + previously pass `lambda x: x['junction_aa'] is not None`, now an accepted version would be + `lambda x: ~ak.is_none(x["junction_aa"], axis=-1)`. To learn more about native awkward array functions, please + refer to the [awkward array documentation](https://awkward-array.org/doc/main/reference/index.html). ([#444](https://github.com/scverse/scirpy/pull/444)) + ### Fixes - Fix that `define_clonotype_clusters` could not retreive `within_group` columns from MuData ([#459](https://github.com/scverse/scirpy/pull/459)) diff --git a/docs/tutorials/tutorial_3k_tcr.ipynb b/docs/tutorials/tutorial_3k_tcr.ipynb index 9d90bb352..13f61a577 100644 --- a/docs/tutorials/tutorial_3k_tcr.ipynb +++ b/docs/tutorials/tutorial_3k_tcr.ipynb @@ -3537,7 +3537,7 @@ "notebook_metadata_filter": "-kernelspec" }, "kernelspec": { - "display_name": "scirpy_dev2", + "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, @@ -3551,7 +3551,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.10.10" + "version": "3.11.4" } }, "nbformat": 4, diff --git a/src/scirpy/pp/_index_chains.py b/src/scirpy/pp/_index_chains.py index 143b678e7..d4fc1e4e9 100644 --- a/src/scirpy/pp/_index_chains.py +++ b/src/scirpy/pp/_index_chains.py @@ -1,19 +1,25 @@ +import operator from collections.abc import Mapping, Sequence -from functools import partial +from functools import partial, reduce from types import MappingProxyType from typing import Any, Callable, Union import awkward as ak +import numba as nb +import numpy as np from scanpy import logging from scirpy.io._datastructures import AirrCell from scirpy.util import DataHandler, _is_na2, tqdm SCIRPY_DUAL_IR_MODEL = "scirpy_dual_ir_v0.13" +# make these constants available to numba +_VJ_LOCI = tuple(AirrCell.VJ_LOCI) +_VDJ_LOCI = tuple(AirrCell.VDJ_LOCI) @DataHandler.inject_param_docs() -def index_chains( +def index_chains_legacy( adata: DataHandler.TYPE, *, filter: Union[Callable[[Mapping], bool], Sequence[Union[str, Callable[[Mapping], bool]]]] = ( @@ -135,7 +141,168 @@ def index_chains( } -def _key_sort_chains(chains: list[Mapping], sort_chains_by: Mapping[str, Any], idx: int) -> Sequence: +@DataHandler.inject_param_docs() +def index_chains( + adata: DataHandler.TYPE, + *, + filter: Union[Callable[[ak.Array], bool], Sequence[Union[str, Callable[[ak.Array], bool]]]] = ( + "productive", + "require_junction_aa", + ), + sort_chains_by: Mapping[str, Any] = MappingProxyType( + {"duplicate_count": 0, "consensus_count": 0, "junction": "", "junction_aa": ""} + ), + airr_mod: str = "airr", + airr_key: str = "airr", + key_added: str = "chain_indices", +) -> None: + """\ + Selects primary/secondary VJ/VDJ cells per chain according to the :ref:`receptor-model`. + + This function iterates through all chains stored in the :term:`awkward array` in + `adata.obsm[airr_key]` and + + * labels chains as primary/secondary VJ/VDJ chains + * labels cells as multichain cells + + based on the expression level of the chains and the specified filtering option. + By default, non-productive chains and chains without a valid CDR3 amino acid sequence are filtered out. + + Additionally, chains without a valid IMGT locus are always filtered out. + + For more details, please refer to the :ref:`receptor-model` and the :ref:`data structure `. + + Parameters + ---------- + {adata} + filter + Option to filter chains. Can be either + * a callback function that takes the full awkward array with AIRR chains as input and returns + another awkward array that is a boolean mask which can be used to index the former. + (True to keep, False to discard) + * a list of "filtering presets". Possible values are `"productive"` and `"require_junction_aa"`. + `"productive"` removes non-productive chains and `"require_junction_aa"` removes chains that don't have + a CDR3 amino acid sequence. + * a list with a combination of both. + + Multiple presets/functions are combined using `and`. Filtered chains do not count towards calling "multichain" cells. + sort_chains_by + A list of sort keys used to determine an ordering of chains. The chain with the highest value + of this tuple will be the primary chain, second-highest the secondary chain. If there are more chains, they + will not be indexed, and the cell receives the "multichain" flag. + {airr_mod} + {airr_key} + key_added + Key under which the chain indicies will be stored in `adata.obsm` and metadata will be stored in `adata.uns`. + + Returns + ------- + Nothing, but adds a dataframe to `adata.obsm[chain_indices]` + """ + params = DataHandler(adata, airr_mod, airr_key) + + # prepare filter functions + if isinstance(filter, Callable): + filter = [filter] + filter_presets = { + "productive": lambda x: x["productive"], + "require_junction_aa": lambda x: ~ak.is_none(x["junction_aa"], axis=-1), + } + filter = [filter_presets[f] if isinstance(f, str) else f for f in filter] + + # only warn if those fields are in the key (i.e. this should give a warning if those are missing with + # default settings. If the user specifies their own dictionary, they are on their own) + if "duplicate_count" in sort_chains_by and "consensus_count" in sort_chains_by: + if "duplicate_count" not in params.airr.fields and "consensus_count" not in params.airr.fields: + logging.warning("No expression information available. Cannot rank chains by expression. ") # type: ignore + + if "locus" not in params.airr.fields: + raise ValueError("The scirpy receptor model requires a `locus` field to be specified in the AIRR data.") + + airr = params.airr + logging.info("Filtering chains...") + # Get the numeric indices pre-filtering - these are the indices we need in the final output as + # .obsm["airr"] is and remains unfiltered. + airr_idx = ak.local_index(airr, axis=1) + # Filter out chains that do not match the filter criteria + # we need an initial value that selects all chains in case filter is an empty list + airr_idx = airr_idx[reduce(operator.and_, (f(airr) for f in filter), ak.ones_like(airr_idx, dtype=bool))] + + res = {} + is_multichain = np.zeros(len(airr), dtype=bool) + for chain_type, locus_names in {"VJ": AirrCell.VJ_LOCI, "VDJ": AirrCell.VDJ_LOCI}.items(): + logging.info(f"Indexing {chain_type} chains...") + # get the indices for all VJ / VDJ chains, respectively + idx = airr_idx[_awkward_isin(airr["locus"][airr_idx], locus_names)] + + # Now we need to sort the chains by the keys specified in `sort_chains_by`. + # since `argsort` doesn't support composite keys, we take advantage of the + # fact that the sorting algorithm is stable and sort the same array several times, + # starting with the lowest priority key up to the highest priority key. + for k, default in reversed(sort_chains_by.items()): + # skip this round of sorting altogether if field not present + if k in airr.fields: + logging.debug(f"Sorting chains by {k}") + tmp_idx = ak.argsort(ak.fill_none(airr[k][idx], default), stable=True, axis=-1, ascending=False) + idx = idx[tmp_idx] + else: + logging.debug(f"Skip sorting by {k} because field not present") + + # We want the result to be lists of exactly 2 - clip if longer, pad with None if shorter. + res[chain_type] = ak.pad_none(idx, 2, axis=1, clip=True) + is_multichain |= ak.to_numpy(_awkward_len(idx)) > 2 + + # build results + logging.info("build result array") + res["multichain"] = is_multichain + + params.adata.obsm[key_added] = ak.zip(res, depth_limit=1) # type: ignore + + # store metadata in .uns + params.adata.uns[key_added] = { + "model": SCIRPY_DUAL_IR_MODEL, # can be used to distinguish different receptor models that may be added in the future. + "filter": str(filter), + "airr_key": airr_key, + "sort_chains_by": str(sort_chains_by), + } + + +@nb.njit +def _awkward_len_inner(arr, ab): + for row in arr: + ab.append(len(row)) + return ab + + +def _awkward_len(arr): + return _awkward_len_inner(arr, ak.ArrayBuilder()).snapshot() + + +@nb.njit() +def _awkward_isin_inner(arr, haystack, ab): + for row in arr: + ab.begin_list() + for v in row: + ab.append(v in haystack) + ab.end_list() + return ab + + +def _awkward_isin(arr, haystack): + haystack = tuple(haystack) + return _awkward_isin_inner(arr, haystack, ak.ArrayBuilder()).snapshot() + + +# For future reference, here would be two alternative implementations that are a bit +# slower, but work without the need for numba. +# def _awkward_len(arr): +# return ak.max(ak.local_index(arr, axis=1), axis=1) +# +# def _awkward_isin(arr, haystack): +# return reduce(operator.or_, (arr == el for el in haystack)) + + +def _key_sort_chains(chains, sort_chains_by: Mapping[str, Any], idx: int) -> Sequence: """Get key to sort chains by expression. Parameters