Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Change defautl latent distribution to logit normal #4

Merged
merged 7 commits into from
Mar 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
433 changes: 0 additions & 433 deletions docs/_static/model.svg

This file was deleted.

1 change: 1 addition & 0 deletions docs/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,7 @@
# a list of builtin themes.
#
html_theme = "sphinx_book_theme"
html_logo = "_static/img/logo.png"
html_static_path = ["_static"]
html_css_files = ["css/custom.css"]

Expand Down
2 changes: 1 addition & 1 deletion docs/contributing.md
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ Please write documentation for new or changed features and use-cases. This proje
- [Numpy-style docstrings][numpydoc] (through the [napoloen][numpydoc-napoleon] extension).
- Jupyter notebooks as tutorials through [myst-nb][] (See [Tutorials with myst-nb](#tutorials-with-myst-nb-and-jupyter-notebooks))
- [Sphinx autodoc typehints][], to automatically reference annotated input and output types
- Citations (like {cite:p}`Virshup_2023`) can be included with [sphinxcontrib-bibtex](https://sphinxcontrib-bibtex.readthedocs.io/)
- Citations (like {cite:p}`virshup2023`) can be included with [sphinxcontrib-bibtex](https://sphinxcontrib-bibtex.readthedocs.io/)

See the [scanpy developer docs](https://scanpy.readthedocs.io/en/latest/dev/documentation.html) for more information
on how to write documentation.
Expand Down
6 changes: 4 additions & 2 deletions sccoral/module/_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,7 +200,9 @@ def __init__(
n_levels = 1

name = f"encoder_{cat_name}"
model = LinearEncoder(n_levels, 1, mean_bias=True, var_bias=True)
model = LinearEncoder(
n_levels, 1, latent_distribution=latent_distribution, mean_bias=True, var_bias=True
)

# Register encoder in class
setattr(self, name, model)
Expand All @@ -215,7 +217,7 @@ def __init__(
if continuous_names is not None:
for con_name, dim in zip(continuous_names, range(n_latent + n_cat, n_latent + n_cat + n_con)):
name = f"encoder_{con_name}"
model = LinearEncoder(1, 1)
model = LinearEncoder(1, 1, latent_distribution=latent_distribution)

# Register encoder in class
setattr(self, name, model)
Expand Down
6 changes: 3 additions & 3 deletions sccoral/nn/_components.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ class LinearEncoder(nn.Module):
Number of input dimensions
n_output
Number of output dimensions
distributions
latent_distribution
Normal distribution `normal` or lognormal `ln` (:cite:Svensson2020)
return_dist
Whether to return the distribution or samples
Expand All @@ -30,7 +30,7 @@ def __init__(
self,
n_input: int,
n_output: int,
distribution: Literal["ln", "normal"] = "normal",
latent_distribution: Literal["ln", "normal"] = "ln",
return_dist: bool = False,
mean_bias: bool = True,
var_bias: bool = True,
Expand All @@ -43,7 +43,7 @@ def __init__(

self.var_eps = var_eps

if distribution == "ln":
if latent_distribution == "ln":
self.z_transformation = nn.Softmax(dim=-1)
else:
# Identity function
Expand Down
6 changes: 3 additions & 3 deletions tests/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,12 +78,12 @@ def test_pretraining_early_stopping(adata, pretraining_early_stopping, requires_
assert model.module.z_encoder.encoder.fc_layers[0][0].weight.requires_grad == requires_grad


@pytest.fixture(scope="module")
def basic_train(adata):
@pytest.fixture(scope="module", params=["normal", "ln"])
def basic_train(adata, request):
SCCORAL.setup_anndata(
adata, categorical_covariates="categorical_covariate", continuous_covariates="continuous_covariate"
)
model = SCCORAL(adata, n_latent=5)
model = SCCORAL(adata, n_latent=5, latent_distribution=request.param)
model.train(max_epochs=20, accelerator="cpu")

return model
Expand Down
Loading