Skip to content

Commit

Permalink
Update var_by_distance (#868)
Browse files Browse the repository at this point in the history
* Fix docstring; Update helper function

* Update var_by_distance plot to work with .obs columns

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Remove print

* Update docstring

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
LLehner and pre-commit-ci[bot] authored Aug 14, 2024
1 parent 93ee854 commit 787b90e
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 12 deletions.
18 changes: 13 additions & 5 deletions src/squidpy/pl/_var_by_distance.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,11 +94,6 @@ def var_by_distance(
regplot_kwargs = dict(regplot_kwargs)
scatterplot_kwargs = dict(scatterplot_kwargs)

df = adata.obsm[design_matrix_key] # get design matrix
df[var] = (
np.array(adata[:, var].X.toarray()) if issparse(adata[:, var].X) else np.array(adata[:, var].X)
) # add var column

# if several variables are plotted, make a panel grid
if isinstance(var, list):
fig, grid = _panel_grid(
Expand All @@ -111,6 +106,19 @@ def var_by_distance(
else:
var = [var]

df = adata.obsm[design_matrix_key] # get design matrix

# add var column to design matrix
for name in var:
if name in adata.var_names:
df[name] = (
np.array(adata[:, name].X.toarray()) if issparse(adata[:, name].X) else np.array(adata[:, name].X)
)
elif name in adata.obs:
df[name] = adata.obs[name].values
else:
raise ValueError(f"Variable {name} not found in `adata.var` or `adata.obs`.")

# iterate over the variables to plot
for i, v in enumerate(var):
if len(var) > 1:
Expand Down
13 changes: 6 additions & 7 deletions src/squidpy/tl/_var_by_distance.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,10 +38,10 @@ def var_by_distance(
----------
%(adata)s
groups
Anchor points to calculate distances from, can be a single gene,
a list of genes or a set of coordinates.
Anchor point(s) to calculate distances to. Can be a single or multiple observations as long as they are annotated in an .obs column with name given by `cluster_key`.
Alternatively, a numpy array of coordinates can be passed.
cluster_key
Annotation column in `.obs` that is used as anchor.
Name of annotation column in `.obs` where the observation used as anchor points are located.
%(library_key)s
design_matrix_key
Name of the design matrix saved to `.obsm`.
Expand Down Expand Up @@ -238,10 +238,9 @@ def _get_coordinates(adata: AnnData, anchor: str, annotation: str, spatial_key:

if isinstance(anchor, np.ndarray):
anchor_coord = anchor[~np.isnan(anchor).any(axis=1)]
return anchor_coord, batch_coord, nan_ids

anchor_arr = np.array(adata[adata.obs[annotation] == anchor].obsm["spatial"])
anchor_coord = anchor_arr[~np.isnan(anchor_arr).any(axis=1)]
else:
anchor_arr = np.array(adata[adata.obs[annotation] == anchor].obsm["spatial"])
anchor_coord = anchor_arr[~np.isnan(anchor_arr).any(axis=1)]
return anchor_coord, batch_coord, nan_ids


Expand Down

0 comments on commit 787b90e

Please sign in to comment.