Skip to content

Commit

Permalink
Add discrete-continuous samplers (#124)
Browse files Browse the repository at this point in the history
  • Loading branch information
pawel-czyz authored Oct 7, 2023
1 parent 8d4800f commit 970788a
Show file tree
Hide file tree
Showing 7 changed files with 350 additions and 0 deletions.
8 changes: 8 additions & 0 deletions docs/api/samplers.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,3 +10,11 @@ Samplers represent probability distributions with known mutual information.

::: bmi.samplers.BivariateNormalSampler

::: bmi.samplers.IndependentConcatenationSampler

::: bmi.samplers.DiscreteUniformMixtureSampler

::: bmi.samplers.MultivariateDiscreteUniformMixtureSampler

::: bmi.samplers.ZeroInflatedPoissonizationSampler

2 changes: 2 additions & 0 deletions src/bmi/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import bmi.estimators as estimators
import bmi.samplers as samplers
import bmi.transforms as transforms
import bmi.utils as utils

# ISort doesn't want to split these into several lines, conflicting with Black
# isort: off
Expand Down Expand Up @@ -31,6 +32,7 @@
"estimators",
"samplers",
"transforms",
"utils",
"IMutualInformationPointEstimator",
"ISampler",
"Pathlike",
Expand Down
14 changes: 14 additions & 0 deletions src/bmi/samplers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,21 @@

# isort: on
import bmi.samplers._tfp as fine
from bmi.samplers._independent_coordinates import IndependentConcatenationSampler
from bmi.samplers._split_student_t import SplitStudentT
from bmi.samplers._splitmultinormal import BivariateNormalSampler, SplitMultinormal
from bmi.samplers._transformed import TransformedSampler
from bmi.samplers.base import BaseSampler

# isort: off
from bmi.samplers._discrete_continuous import (
DiscreteUniformMixtureSampler,
MultivariateDiscreteUniformMixtureSampler,
ZeroInflatedPoissonizationSampler,
)

# isort: on

__all__ = [
"AdditiveUniformSampler",
"BaseSampler",
Expand All @@ -31,4 +41,8 @@
"DenseLVMParametrization",
"SparseLVMParametrization",
"GaussianLVMParametrization",
"IndependentConcatenationSampler",
"DiscreteUniformMixtureSampler",
"MultivariateDiscreteUniformMixtureSampler",
"ZeroInflatedPoissonizationSampler",
]
192 changes: 192 additions & 0 deletions src/bmi/samplers/_discrete_continuous.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,192 @@
from typing import Optional, Sequence, Union

import jax.numpy as jnp
import numpy as np
from jax import random

from bmi.interface import KeyArray
from bmi.samplers._independent_coordinates import IndependentConcatenationSampler
from bmi.samplers.base import BaseSampler, cast_to_rng


def _cite_gao_2017() -> str:
return (
"@inproceedings{Gao2017-DiscreteContinuousMI,\n"
+ " author = {Gao, Weihao and Kannan, Sreeram and Oh, Sewoong and Viswanath, Pramod},\n"
+ " booktitle = {Advances in Neural Information Processing Systems},\n"
+ " editor = {I. Guyon and U. Von Luxburg and S. Bengio and H. Wallach and R. Fergus and S. Vishwanathan and R. Garnett},\n" # noqa: E501
+ " pages = {},\n"
+ " publisher = {Curran Associates, Inc.},\n"
+ " title = {Estimating Mutual Information for Discrete-Continuous Mixtures},\n"
+ " url = {https://proceedings.neurips.cc/paper_files/paper/2017/file/ef72d53990bc4805684c9b61fa64a102-Paper.pdf},\n" # noqa: E501
+ " volume = {30},\n"
+ " year = {2017}n"
+ "}"
)


class DiscreteUniformMixtureSampler(BaseSampler):
"""Sampler from Gao et al. (2017) for the discrete-continuous mixture model:
$$X \\sim \\mathrm{Categorical}(1/m, ..., 1/m)$$
is sampled from the set $\\{0, ..., m-1\\}$.
Then,
$$Y | X \\sim \\mathrm{Uniform}(X, X+2).$$
It holds that
$$I(X; Y) = \\log m - \\frac{m-1}{m} \\log 2.$$
"""

def __init__(self, *, n_discrete: int = 5, use_discrete_x: bool = True) -> None:
"""
Args:
n_discrete: number of discrete values to sample X from
use_discrete_x: if True, X will be an integer. If False, X will be casted to a float.
"""
super().__init__(dim_x=1, dim_y=1)

if n_discrete <= 0:
raise ValueError(f"n_discrete must be positive, was {n_discrete}.")

self._n_discrete = n_discrete

if use_discrete_x:
self._x_factor = 1
else:
self._x_factor = 1.0

def sample(self, n_points: int, rng: Union[int, KeyArray]) -> tuple[jnp.ndarray, jnp.ndarray]:
rng = cast_to_rng(rng)
key_x, key_y = random.split(rng)
xs = random.randint(key_x, shape=(n_points, 1), minval=0, maxval=self._n_discrete)
uniforms = random.uniform(key_y, shape=(n_points, 1), minval=0.0, maxval=2.0)
ys = xs + uniforms

xs = xs * self._x_factor
return xs, ys

def mutual_information(self) -> float:
m = self._n_discrete
return jnp.log(m) - (m - 1) * jnp.log(2) / m

@staticmethod
def cite() -> str:
"""Returns the BibTeX citation."""
return _cite_gao_2017()


class MultivariateDiscreteUniformMixtureSampler(IndependentConcatenationSampler):
"""Multivariate alternative for `DiscreteUniformMixtureSampler`,
which is a concatenation of several independent samplers.
Namely, for a sequence of integers $(m_k)$,
the variables $X = (X_1, ..., X_k)$ and $Y = (Y_1, ..., Y_k)$ are sampled coordinate-wise.
Each coordinate
$$X_k \\sim \\mathrm{Categorical}(1/m_k, ..., 1/m_k)$$
is from the set $\\{0, ..., m_k-1\\}$.
Then,
$$Y_k | X_k \\sim \\mathrm{Uniform}(X_k, X_k + 2).$$
Mutual information can be calculated as
$$I(X; Y) = \\sum_k I(X_k; Y_k),$$
where
$$I(X_k; Y_k) = \\log m_k - \\frac{m_k-1}{m_k} \\log 2.$$
"""

def __init__(self, *, ns_discrete: Sequence[int], use_discrete_x: bool = True) -> None:
samplers = [
DiscreteUniformMixtureSampler(n_discrete=n, use_discrete_x=use_discrete_x)
for n in ns_discrete
]
super().__init__(samplers=samplers)

@staticmethod
def cite() -> str:
"""Returns the BibTeX citation."""
return _cite_gao_2017()


class ZeroInflatedPoissonizationSampler(BaseSampler):
"""Sampler from Gao et al. (2017) modelling zero-inflated Poissonization
of the exponential distribution.
Let $X \\sim \\mathrm{Exponential}(1)$. Then, $Y$ is sampled from the mixture distribution
$$Y \\mid X = p\\, \\delta_0 + (1-p) \\, \\mathrm{Poisson}(X) $$
"""

def __init__(self, p: float = 0.15, use_discrete_y: bool = True) -> None:
"""
Args:
p: zero-inflation parameter. Must be in [0, 1).
"""
if p < 0 or p >= 1:
raise ValueError(f"p must be in [0, 1), was {p}.")
self._p = float(p)

if use_discrete_y:
self._y_factor = 1
else:
self._y_factor = 1.0

def sample(self, n_points: int, rng: Union[int, KeyArray]) -> tuple[jnp.ndarray, jnp.ndarray]:
rng = cast_to_rng(rng)
key_x, key_zeros, key_poisson = random.split(rng, 3)
xs = random.exponential(key_x, shape=(n_points, 1))
# With probability p we have 0, with probability 1-p we have 1
zeros = random.bernoulli(key_zeros, p=1 - self._p, shape=(n_points, 1))
poissons = random.poisson(key_poisson, lam=xs)
# Note that this corresponds to a mixture model: with probability p we sample
# from Dirac delta at 0 and with probability 1-p we sample from Poisson
ys = zeros * poissons * self._y_factor

return xs, ys

def mutual_information(self, truncation: Optional[int] = None) -> float:
"""Ground-truth mutual information is equal to
$$I(X; Y) = (1-p) \\cdot (2 \\log 2 - \\gamma - S)$$
where
$$S = \\sum_{k=1}^{\\infty} \\log k \\cdot 2^{-k},$$
so that the approximation
$$I(X; Y) \\approx (1-p) \\cdot 0.3012 $$
holds.
Args:
truncation: if set to None, the above approximation will be used.
Otherwise, the sum will be truncated at the given value.
"""
assert truncation is None or truncation > 0

if truncation is None:
bracket = 0.3012
else:
i_arr = 1.0 * jnp.arange(1, truncation + 1)
s = jnp.sum(jnp.log(i_arr) * jnp.exp2(-i_arr))
bracket = 2 * jnp.log(2) - np.euler_gamma - s
bracket = float(bracket)

return (1 - self._p) * bracket

@staticmethod
def cite() -> str:
"""Returns the BibTeX citation."""
return _cite_gao_2017()
50 changes: 50 additions & 0 deletions src/bmi/samplers/_independent_coordinates.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
from typing import Sequence, Union

import jax.numpy as jnp
from jax import random

from bmi.interface import ISampler, KeyArray
from bmi.samplers.base import BaseSampler, cast_to_rng


class IndependentConcatenationSampler(BaseSampler):
"""Consider a sequence of samplers $S_k$, where $k \\in \\{1, \\dotsc, m \\}$
and variables
$$(X_k, Y_k) \\sim S_k.$$
If the variables are sampled independently, we can concatenate them
to $X = (X_1, \\dotsc, X_m)$ and $Y = (Y_1, \\dotsc, Y_m)$
and have
$$I(X; Y) = I(X_1; Y_1) + \\dotsc + I(X_m; Y_m).$$
"""

def __init__(self, samplers: Sequence[ISampler]) -> None:
"""
Args:
samplers: sequence of samplers to concatenate
"""
self._samplers = list(samplers)
dim_x = sum(sampler.dim_x for sampler in self._samplers)
dim_y = sum(sampler.dim_y for sampler in self._samplers)

super().__init__(dim_x=dim_x, dim_y=dim_y)

def sample(self, n_points: int, rng: Union[int, KeyArray]) -> tuple[jnp.ndarray, jnp.ndarray]:
rng = cast_to_rng(rng)
keys = random.split(rng, len(self._samplers))
xs = []
ys = []
for key, sampler in zip(keys, self._samplers):
x, y = sampler.sample(n_points, key)
xs.append(x)
ys.append(y)

return jnp.hstack(xs), jnp.hstack(ys)

def mutual_information(self) -> float:
return float(
jnp.sum(jnp.asarray([sampler.mutual_information() for sampler in self._samplers]))
)
15 changes: 15 additions & 0 deletions src/bmi/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,21 @@
from bmi.interface import Pathlike


def add_noise(points: ArrayLike, noise_std: float = 1e-5, rng_key: int = 0) -> np.ndarray:
"""Adds small noise.
Function useful when discrete random variables are involved.
Args:
points: array of points, shape (n_points, dim)
noise_std: standard deviation of the noise
rng_key: random number generator seed, used for reproducibility
"""
rng = np.random.default_rng(rng_key)
points = np.asarray(points)
noise = rng.normal(scale=noise_std, size=points.shape)
return points + noise


def save_sample(path: Pathlike, samples_x: ArrayLike, samples_y: ArrayLike):
samples_x = np.asarray(samples_x)
samples_y = np.asarray(samples_y)
Expand Down
69 changes: 69 additions & 0 deletions tests/samplers/test_discrete_continuous.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
import jax.numpy as jnp
import pytest

import bmi
from bmi.utils import add_noise

# isort: off
from bmi.samplers._discrete_continuous import (
DiscreteUniformMixtureSampler,
MultivariateDiscreteUniformMixtureSampler,
ZeroInflatedPoissonizationSampler,
)

# isort: on


@pytest.mark.parametrize("truncation", [100, 1000])
def test_truncation_is_robust(truncation: int) -> None:
sampler = ZeroInflatedPoissonizationSampler()
approximation_default = sampler.mutual_information()
approximation_truncated = sampler.mutual_information(truncation=truncation)

assert pytest.approx(approximation_default, abs=1e-3) == approximation_truncated


def get_samplers():
yield DiscreteUniformMixtureSampler(n_discrete=5, use_discrete_x=True)
yield DiscreteUniformMixtureSampler(n_discrete=3, use_discrete_x=False)
yield MultivariateDiscreteUniformMixtureSampler(ns_discrete=[4, 5], use_discrete_x=True)
yield ZeroInflatedPoissonizationSampler(p=0)
yield ZeroInflatedPoissonizationSampler(p=0.15, use_discrete_y=True)
yield ZeroInflatedPoissonizationSampler(p=0.7, use_discrete_y=False)


@pytest.mark.parametrize("sampler", get_samplers())
def test_mutual_information_is_right(sampler) -> None:
xs, ys = sampler.sample(n_points=7_000, rng=0)

xs = add_noise(xs, 1e-7)
ys = add_noise(ys, 1e-7)

mi_estimate = bmi.estimators.KSGEnsembleFirstEstimator(neighborhoods=(5,)).estimate(xs, ys)
mi_ground_truth = sampler.mutual_information()

assert pytest.approx(mi_estimate, abs=0.05, rel=0.05) == mi_ground_truth


def test_discrete_x():
discrete_sampler = DiscreteUniformMixtureSampler(n_discrete=5, use_discrete_x=True)
continuous_sampler = DiscreteUniformMixtureSampler(n_discrete=5, use_discrete_x=False)

xs_discrete, _ = discrete_sampler.sample(n_points=10, rng=0)
xs_continuous, _ = continuous_sampler.sample(n_points=10, rng=0)

assert (xs_discrete == xs_continuous).all()
assert xs_discrete.dtype == jnp.int32
assert xs_continuous.dtype == jnp.float32


def test_discrete_y():
discrete_sampler = ZeroInflatedPoissonizationSampler(p=0.15, use_discrete_y=True)
continuous_sampler = ZeroInflatedPoissonizationSampler(p=0.15, use_discrete_y=False)

_, ys_discrete = discrete_sampler.sample(n_points=10, rng=0)
_, ys_continuous = continuous_sampler.sample(n_points=10, rng=0)

assert (ys_discrete == ys_continuous).all()
assert ys_discrete.dtype == jnp.int32
assert ys_continuous.dtype == jnp.float32

0 comments on commit 970788a

Please sign in to comment.