Skip to content

Commit

Permalink
Merge pull request #14 from lucas-diedrich/latent-distribution
Browse files Browse the repository at this point in the history
[REVERSE] Reverse to logistic-normal distribution in latent space
  • Loading branch information
lucas-diedrich authored May 17, 2024
2 parents 78c5144 + 58460b1 commit 496b4d7
Showing 1 changed file with 9 additions and 8 deletions.
17 changes: 9 additions & 8 deletions sccoral/module/_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,7 @@ def __init__(
n_layers=n_layers,
n_hidden=n_hidden,
dropout_rate=dropout_rate,
distribution=latent_distribution,
distribution="normal",
use_batch_norm=self.use_batch_norm_encoder,
use_layer_norm=self.use_layer_norm_encoder,
return_dist=False,
Expand Down Expand Up @@ -201,7 +201,9 @@ def __init__(

name = f"encoder_{cat_name}"

model = LinearEncoder(n_levels, 1, latent_distribution="normal", mean_bias=True, var_bias=True)
model = LinearEncoder(
n_levels, 1, latent_distribution="normal", mean_bias=True, var_bias=True
)

# Register encoder in class
setattr(self, name, model)
Expand Down Expand Up @@ -336,17 +338,16 @@ def inference(self, x, batch_index, continuous_covariates, categorical_covariate
var_z = torch.cat([var_counts, *var_ca, *var_cc], dim=1)
z = torch.cat([latent_counts, *latent_ca, *latent_cc], dim=1)

if self.latent_distribution == "ln":
z = F.softmax(z, dim=-1)

qz = Normal(loc=mean_z, scale=torch.sqrt(var_z))

if n_samples > 1:
# Sample n samples from normal distribution
# if logistic normal, transform z
# if lognormal, transform z
z_untransformed = qz.sample((n_samples,))
z_untransformed_counts = z_untransformed[:, : len(mean_z)]
z_covariates = z_untransformed[:, len(mean_z) :]

z_counts = self.z_encoder.z_transformation(z_untransformed_counts)
z = torch.cat([z_counts, z_covariates], dim=1)
z = self.z_encoder.z_transformation(z_untransformed)

if self.use_observed_lib_size:
library = library.unsqueeze(0).expand((n_samples, library.size(0), library.size(1)))
Expand Down

0 comments on commit 496b4d7

Please sign in to comment.