-
Notifications
You must be signed in to change notification settings - Fork 5
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add discrete-continuous samplers (#124)
- Loading branch information
1 parent
8d4800f
commit 970788a
Showing
7 changed files
with
350 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,192 @@ | ||
from typing import Optional, Sequence, Union | ||
|
||
import jax.numpy as jnp | ||
import numpy as np | ||
from jax import random | ||
|
||
from bmi.interface import KeyArray | ||
from bmi.samplers._independent_coordinates import IndependentConcatenationSampler | ||
from bmi.samplers.base import BaseSampler, cast_to_rng | ||
|
||
|
||
def _cite_gao_2017() -> str: | ||
return ( | ||
"@inproceedings{Gao2017-DiscreteContinuousMI,\n" | ||
+ " author = {Gao, Weihao and Kannan, Sreeram and Oh, Sewoong and Viswanath, Pramod},\n" | ||
+ " booktitle = {Advances in Neural Information Processing Systems},\n" | ||
+ " editor = {I. Guyon and U. Von Luxburg and S. Bengio and H. Wallach and R. Fergus and S. Vishwanathan and R. Garnett},\n" # noqa: E501 | ||
+ " pages = {},\n" | ||
+ " publisher = {Curran Associates, Inc.},\n" | ||
+ " title = {Estimating Mutual Information for Discrete-Continuous Mixtures},\n" | ||
+ " url = {https://proceedings.neurips.cc/paper_files/paper/2017/file/ef72d53990bc4805684c9b61fa64a102-Paper.pdf},\n" # noqa: E501 | ||
+ " volume = {30},\n" | ||
+ " year = {2017}n" | ||
+ "}" | ||
) | ||
|
||
|
||
class DiscreteUniformMixtureSampler(BaseSampler): | ||
"""Sampler from Gao et al. (2017) for the discrete-continuous mixture model: | ||
$$X \\sim \\mathrm{Categorical}(1/m, ..., 1/m)$$ | ||
is sampled from the set $\\{0, ..., m-1\\}$. | ||
Then, | ||
$$Y | X \\sim \\mathrm{Uniform}(X, X+2).$$ | ||
It holds that | ||
$$I(X; Y) = \\log m - \\frac{m-1}{m} \\log 2.$$ | ||
""" | ||
|
||
def __init__(self, *, n_discrete: int = 5, use_discrete_x: bool = True) -> None: | ||
""" | ||
Args: | ||
n_discrete: number of discrete values to sample X from | ||
use_discrete_x: if True, X will be an integer. If False, X will be casted to a float. | ||
""" | ||
super().__init__(dim_x=1, dim_y=1) | ||
|
||
if n_discrete <= 0: | ||
raise ValueError(f"n_discrete must be positive, was {n_discrete}.") | ||
|
||
self._n_discrete = n_discrete | ||
|
||
if use_discrete_x: | ||
self._x_factor = 1 | ||
else: | ||
self._x_factor = 1.0 | ||
|
||
def sample(self, n_points: int, rng: Union[int, KeyArray]) -> tuple[jnp.ndarray, jnp.ndarray]: | ||
rng = cast_to_rng(rng) | ||
key_x, key_y = random.split(rng) | ||
xs = random.randint(key_x, shape=(n_points, 1), minval=0, maxval=self._n_discrete) | ||
uniforms = random.uniform(key_y, shape=(n_points, 1), minval=0.0, maxval=2.0) | ||
ys = xs + uniforms | ||
|
||
xs = xs * self._x_factor | ||
return xs, ys | ||
|
||
def mutual_information(self) -> float: | ||
m = self._n_discrete | ||
return jnp.log(m) - (m - 1) * jnp.log(2) / m | ||
|
||
@staticmethod | ||
def cite() -> str: | ||
"""Returns the BibTeX citation.""" | ||
return _cite_gao_2017() | ||
|
||
|
||
class MultivariateDiscreteUniformMixtureSampler(IndependentConcatenationSampler): | ||
"""Multivariate alternative for `DiscreteUniformMixtureSampler`, | ||
which is a concatenation of several independent samplers. | ||
Namely, for a sequence of integers $(m_k)$, | ||
the variables $X = (X_1, ..., X_k)$ and $Y = (Y_1, ..., Y_k)$ are sampled coordinate-wise. | ||
Each coordinate | ||
$$X_k \\sim \\mathrm{Categorical}(1/m_k, ..., 1/m_k)$$ | ||
is from the set $\\{0, ..., m_k-1\\}$. | ||
Then, | ||
$$Y_k | X_k \\sim \\mathrm{Uniform}(X_k, X_k + 2).$$ | ||
Mutual information can be calculated as | ||
$$I(X; Y) = \\sum_k I(X_k; Y_k),$$ | ||
where | ||
$$I(X_k; Y_k) = \\log m_k - \\frac{m_k-1}{m_k} \\log 2.$$ | ||
""" | ||
|
||
def __init__(self, *, ns_discrete: Sequence[int], use_discrete_x: bool = True) -> None: | ||
samplers = [ | ||
DiscreteUniformMixtureSampler(n_discrete=n, use_discrete_x=use_discrete_x) | ||
for n in ns_discrete | ||
] | ||
super().__init__(samplers=samplers) | ||
|
||
@staticmethod | ||
def cite() -> str: | ||
"""Returns the BibTeX citation.""" | ||
return _cite_gao_2017() | ||
|
||
|
||
class ZeroInflatedPoissonizationSampler(BaseSampler): | ||
"""Sampler from Gao et al. (2017) modelling zero-inflated Poissonization | ||
of the exponential distribution. | ||
Let $X \\sim \\mathrm{Exponential}(1)$. Then, $Y$ is sampled from the mixture distribution | ||
$$Y \\mid X = p\\, \\delta_0 + (1-p) \\, \\mathrm{Poisson}(X) $$ | ||
""" | ||
|
||
def __init__(self, p: float = 0.15, use_discrete_y: bool = True) -> None: | ||
""" | ||
Args: | ||
p: zero-inflation parameter. Must be in [0, 1). | ||
""" | ||
if p < 0 or p >= 1: | ||
raise ValueError(f"p must be in [0, 1), was {p}.") | ||
self._p = float(p) | ||
|
||
if use_discrete_y: | ||
self._y_factor = 1 | ||
else: | ||
self._y_factor = 1.0 | ||
|
||
def sample(self, n_points: int, rng: Union[int, KeyArray]) -> tuple[jnp.ndarray, jnp.ndarray]: | ||
rng = cast_to_rng(rng) | ||
key_x, key_zeros, key_poisson = random.split(rng, 3) | ||
xs = random.exponential(key_x, shape=(n_points, 1)) | ||
# With probability p we have 0, with probability 1-p we have 1 | ||
zeros = random.bernoulli(key_zeros, p=1 - self._p, shape=(n_points, 1)) | ||
poissons = random.poisson(key_poisson, lam=xs) | ||
# Note that this corresponds to a mixture model: with probability p we sample | ||
# from Dirac delta at 0 and with probability 1-p we sample from Poisson | ||
ys = zeros * poissons * self._y_factor | ||
|
||
return xs, ys | ||
|
||
def mutual_information(self, truncation: Optional[int] = None) -> float: | ||
"""Ground-truth mutual information is equal to | ||
$$I(X; Y) = (1-p) \\cdot (2 \\log 2 - \\gamma - S)$$ | ||
where | ||
$$S = \\sum_{k=1}^{\\infty} \\log k \\cdot 2^{-k},$$ | ||
so that the approximation | ||
$$I(X; Y) \\approx (1-p) \\cdot 0.3012 $$ | ||
holds. | ||
Args: | ||
truncation: if set to None, the above approximation will be used. | ||
Otherwise, the sum will be truncated at the given value. | ||
""" | ||
assert truncation is None or truncation > 0 | ||
|
||
if truncation is None: | ||
bracket = 0.3012 | ||
else: | ||
i_arr = 1.0 * jnp.arange(1, truncation + 1) | ||
s = jnp.sum(jnp.log(i_arr) * jnp.exp2(-i_arr)) | ||
bracket = 2 * jnp.log(2) - np.euler_gamma - s | ||
bracket = float(bracket) | ||
|
||
return (1 - self._p) * bracket | ||
|
||
@staticmethod | ||
def cite() -> str: | ||
"""Returns the BibTeX citation.""" | ||
return _cite_gao_2017() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,50 @@ | ||
from typing import Sequence, Union | ||
|
||
import jax.numpy as jnp | ||
from jax import random | ||
|
||
from bmi.interface import ISampler, KeyArray | ||
from bmi.samplers.base import BaseSampler, cast_to_rng | ||
|
||
|
||
class IndependentConcatenationSampler(BaseSampler): | ||
"""Consider a sequence of samplers $S_k$, where $k \\in \\{1, \\dotsc, m \\}$ | ||
and variables | ||
$$(X_k, Y_k) \\sim S_k.$$ | ||
If the variables are sampled independently, we can concatenate them | ||
to $X = (X_1, \\dotsc, X_m)$ and $Y = (Y_1, \\dotsc, Y_m)$ | ||
and have | ||
$$I(X; Y) = I(X_1; Y_1) + \\dotsc + I(X_m; Y_m).$$ | ||
""" | ||
|
||
def __init__(self, samplers: Sequence[ISampler]) -> None: | ||
""" | ||
Args: | ||
samplers: sequence of samplers to concatenate | ||
""" | ||
self._samplers = list(samplers) | ||
dim_x = sum(sampler.dim_x for sampler in self._samplers) | ||
dim_y = sum(sampler.dim_y for sampler in self._samplers) | ||
|
||
super().__init__(dim_x=dim_x, dim_y=dim_y) | ||
|
||
def sample(self, n_points: int, rng: Union[int, KeyArray]) -> tuple[jnp.ndarray, jnp.ndarray]: | ||
rng = cast_to_rng(rng) | ||
keys = random.split(rng, len(self._samplers)) | ||
xs = [] | ||
ys = [] | ||
for key, sampler in zip(keys, self._samplers): | ||
x, y = sampler.sample(n_points, key) | ||
xs.append(x) | ||
ys.append(y) | ||
|
||
return jnp.hstack(xs), jnp.hstack(ys) | ||
|
||
def mutual_information(self) -> float: | ||
return float( | ||
jnp.sum(jnp.asarray([sampler.mutual_information() for sampler in self._samplers])) | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,69 @@ | ||
import jax.numpy as jnp | ||
import pytest | ||
|
||
import bmi | ||
from bmi.utils import add_noise | ||
|
||
# isort: off | ||
from bmi.samplers._discrete_continuous import ( | ||
DiscreteUniformMixtureSampler, | ||
MultivariateDiscreteUniformMixtureSampler, | ||
ZeroInflatedPoissonizationSampler, | ||
) | ||
|
||
# isort: on | ||
|
||
|
||
@pytest.mark.parametrize("truncation", [100, 1000]) | ||
def test_truncation_is_robust(truncation: int) -> None: | ||
sampler = ZeroInflatedPoissonizationSampler() | ||
approximation_default = sampler.mutual_information() | ||
approximation_truncated = sampler.mutual_information(truncation=truncation) | ||
|
||
assert pytest.approx(approximation_default, abs=1e-3) == approximation_truncated | ||
|
||
|
||
def get_samplers(): | ||
yield DiscreteUniformMixtureSampler(n_discrete=5, use_discrete_x=True) | ||
yield DiscreteUniformMixtureSampler(n_discrete=3, use_discrete_x=False) | ||
yield MultivariateDiscreteUniformMixtureSampler(ns_discrete=[4, 5], use_discrete_x=True) | ||
yield ZeroInflatedPoissonizationSampler(p=0) | ||
yield ZeroInflatedPoissonizationSampler(p=0.15, use_discrete_y=True) | ||
yield ZeroInflatedPoissonizationSampler(p=0.7, use_discrete_y=False) | ||
|
||
|
||
@pytest.mark.parametrize("sampler", get_samplers()) | ||
def test_mutual_information_is_right(sampler) -> None: | ||
xs, ys = sampler.sample(n_points=7_000, rng=0) | ||
|
||
xs = add_noise(xs, 1e-7) | ||
ys = add_noise(ys, 1e-7) | ||
|
||
mi_estimate = bmi.estimators.KSGEnsembleFirstEstimator(neighborhoods=(5,)).estimate(xs, ys) | ||
mi_ground_truth = sampler.mutual_information() | ||
|
||
assert pytest.approx(mi_estimate, abs=0.05, rel=0.05) == mi_ground_truth | ||
|
||
|
||
def test_discrete_x(): | ||
discrete_sampler = DiscreteUniformMixtureSampler(n_discrete=5, use_discrete_x=True) | ||
continuous_sampler = DiscreteUniformMixtureSampler(n_discrete=5, use_discrete_x=False) | ||
|
||
xs_discrete, _ = discrete_sampler.sample(n_points=10, rng=0) | ||
xs_continuous, _ = continuous_sampler.sample(n_points=10, rng=0) | ||
|
||
assert (xs_discrete == xs_continuous).all() | ||
assert xs_discrete.dtype == jnp.int32 | ||
assert xs_continuous.dtype == jnp.float32 | ||
|
||
|
||
def test_discrete_y(): | ||
discrete_sampler = ZeroInflatedPoissonizationSampler(p=0.15, use_discrete_y=True) | ||
continuous_sampler = ZeroInflatedPoissonizationSampler(p=0.15, use_discrete_y=False) | ||
|
||
_, ys_discrete = discrete_sampler.sample(n_points=10, rng=0) | ||
_, ys_continuous = continuous_sampler.sample(n_points=10, rng=0) | ||
|
||
assert (ys_discrete == ys_continuous).all() | ||
assert ys_discrete.dtype == jnp.int32 | ||
assert ys_continuous.dtype == jnp.float32 |