Skip to content

Commit

Permalink
(fix): correct default fill values for dask-sparse (#1719)
Browse files Browse the repository at this point in the history
* (fix): correct default fill values for dask-sparse

* (chore): release note

* (chore): link to `concat`

* (fix): further clarify message
  • Loading branch information
ilan-gold authored Oct 17, 2024
1 parent 097377c commit 96ccce7
Show file tree
Hide file tree
Showing 3 changed files with 29 additions and 1 deletion.
1 change: 1 addition & 0 deletions docs/release-notes/1719.bugfix.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Ensure {func}`anndata.concat` of {class}`~anndata.AnnData` object with {class}`scipy.sparse.spmatrix` and {class}`scipy.sparse.sparray` dask arrays uses the correct fill value of 0. {user}`ilan-gold`
9 changes: 8 additions & 1 deletion src/anndata/_core/merge.py
Original file line number Diff line number Diff line change
Expand Up @@ -723,7 +723,14 @@ def default_fill_value(els):
This is largely due to backwards compat, and might not be the ideal solution.
"""
if any(isinstance(el, sparse.spmatrix | SpArray) for el in els):
if any(
isinstance(el, sparse.spmatrix | SpArray)
or (
isinstance(el, DaskArray)
and isinstance(el._meta, sparse.spmatrix | SpArray)
)
for el in els
):
return 0
else:
return np.nan
Expand Down
20 changes: 20 additions & 0 deletions tests/test_concatenate.py
Original file line number Diff line number Diff line change
Expand Up @@ -1635,3 +1635,23 @@ def test_concat_on_var_outer_join(array_type):
# This shouldn't error
# TODO: specify expected result while accounting for null value
_ = concat([a, b], join="outer", axis=1)


def test_concat_dask_sparse_matches_memory(join_type, merge_strategy):
import dask.array as da

X = sparse.random(50, 20, density=0.5, format="csr")
X_dask = da.from_array(X, chunks=(5, 20))
var_names_1 = [f"gene_{i}" for i in range(20)]
var_names_2 = [f"gene_{i}{'_foo' if (i%2) else ''}" for i in range(20, 40)]

ad1 = AnnData(X=X, var=pd.DataFrame(index=var_names_1))
ad2 = AnnData(X=X, var=pd.DataFrame(index=var_names_2))

ad1_dask = AnnData(X=X_dask, var=pd.DataFrame(index=var_names_1))
ad2_dask = AnnData(X=X_dask, var=pd.DataFrame(index=var_names_2))

res_in_memory = concat([ad1, ad2], join=join_type, merge=merge_strategy)
res_dask = concat([ad1_dask, ad2_dask], join=join_type, merge=merge_strategy)

assert_equal(res_in_memory, res_dask)

0 comments on commit 96ccce7

Please sign in to comment.