Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Resolve Issue 161 #165

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft

Resolve Issue 161 #165

wants to merge 1 commit into from

Conversation

pawel-czyz
Copy link
Member

@pawel-czyz pawel-czyz commented Jun 28, 2024

This PR aims to resolve #161.

Tasks:

  • Update the JointDistribution class.
  • Update multivariate normal constructor.
  • Update multivariate Student constructor.
  • Update product distribution constructor.
  • Check if BMMSampler has to be adjusted.
  • Update the transform (bending) function.
  • Update the mixture function.
  • Update the tutorial using discrete-continuous variables. Casting the Bernoulli distribution to float is not further needed.
  • Further check if we manually create JointDistribution somewhere. Then, update it.

Help highly appreciated! 🙂

@pawel-czyz
Copy link
Member Author

pawel-czyz commented Jun 28, 2024

I've got stuck at creating mixtures of joint distributions (which, recall, are now tuples/lists, rather than single integers):

TypeError: Dimension value must be integer or None or have an __index__ method, got value 'TensorShape([2])' with type '<class 'tensorflow_probability.python.internal.backend.jax.gen.tensor_shape.TensorShape'>'

To reproduce, use the following code:

import tensorflow_probability as tfp
import jax
import jax.numpy as jnp

tfd = tfp.substrates.jax.distributions
tfb = tfp.substrates.jax.bijectors

key = jax.random.PRNGKey(42)

# This works
mix = tfd.Mixture(
    cat=tfd.Categorical(probs=jnp.asarray([0.3, 0.7])),
    components=[tfd.Normal(0.0, 1.0), tfd.Normal(1., 2.)],
)
print("Here it works. This is a sample: ", mix.sample(3, key))


mean = jnp.zeros(5)
covariance_matrix = jnp.eye(5)
dist = tfd.MultivariateNormalFullCovariance(loc=mean, covariance_matrix=covariance_matrix)

split_dist = tfd.TransformedDistribution(
    distribution=dist,
    bijector=tfb.Split((2, 3)),
)

bijectors = [tfb.Exp(), tfb.Sigmoid()]  # Note: has to be a list, for tuple doesn't work

bij = tfb.JointMap(bijectors)
tr_dist = tfd.TransformedDistribution(distribution=split_dist, bijector=bij)

print("There is an error:")
tfd.Mixture(
    cat=tfd.Categorical(probs=jnp.asarray([0.3, 0.7])),
    components=[split_dist, tr_dist],
)

I'm feeling a bit lost here. I've submitted an issue to the TFP repository.

Perhaps a workaround would be to define a custom mixture distribution, performing less checks at initialisation – in our case allowed array shapes can be made more strict.
Some links which seem useful if we are going to use this approach:

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

JointDistribution wraps and unwraps X and Y
1 participant