Skip to content

Commit

Permalink
Merge branch 'main' into pawel/docs-fine-distributions
Browse files Browse the repository at this point in the history
  • Loading branch information
pawel-czyz committed Jan 14, 2024
2 parents d0e8ffd + 62ddea2 commit e4083bd
Show file tree
Hide file tree
Showing 17 changed files with 164 additions and 78 deletions.
7 changes: 7 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -65,3 +65,10 @@ If you find this code useful in your research, consider citing [our manuscript](
}
```

## List of estimators

- The neighborhood-based KSG estimator proposed in [Estimating Mutual Information](https://arxiv.org/abs/cond-mat/0305641) by Kraskov et al. (2003).
- Donsker-Varadhan and MINE estimators proposed in [MINE: Mutual Information Neural Estimation](https://arxiv.org/abs/1801.04062) by Belghazi et al. (2018).
- InfoNCE estimator proposed in [Representation Learning with Contrastive Predictive Coding](https://arxiv.org/abs/1807.03748) by Oord et al. (2018).
- NWJ estimator proposed in [Estimating divergence functionals and the likelihood ratio by convex risk minimization](https://arxiv.org/abs/0809.0853) by Nguyen et al. (2008).
- Estimator based on canonical correlation analysis described in [Feature discovery under contextual supervision using mutual information](https://ieeexplore.ieee.org/document/227286) by Kay (1992) and in [Some data analyses using mutual information](https://www.jstor.org/stable/43601047) by Brillinger (2004).
25 changes: 25 additions & 0 deletions docs/api/fine-distributions.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
# Fine distributions

## Core utilities

::: bmi.samplers._tfp.JointDistribution

::: bmi.samplers._tfp.monte_carlo_mi_estimate

::: bmi.samplers._tfp.pmi_profile

::: bmi.samplers._tfp.transform

::: bmi.samplers._tfp.ProductDistribution

::: bmi.samplers._tfp.FineSampler

## Basic distributions

::: bmi.samplers._tfp.construct_multivariate_normal_distribution

::: bmi.samplers._tfp.MultivariateNormalDistribution

::: bmi.samplers._tfp.construct_multivariate_student_distribution

::: bmi.samplers._tfp.MultivariateStudentDistribution
21 changes: 17 additions & 4 deletions docs/api/index.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,19 @@
# API

- [Tasks](tasks.md) represent named probability distributions which are used in the benchmark.
- [Estimators](estimators.md) are the implemented mutual information estimators.
- [Samplers](samplers.md) represent joint probability distributions with known mutual information from which one can sample. They are lower level than `Tasks` and can be used to define new tasks by transformations which preserve mutual information.
- [Interfaces](interfaces.md) defines the main interfaces used in the package.
## Tasks

[Tasks](tasks.md) represent named probability distributions which are used in the benchmark.

## Estimators

[Estimators](estimators.md) are the implemented mutual information estimators.

## Samplers

[Samplers](samplers.md) represent joint probability distributions with known mutual information from which one can sample. They are lower level than `Tasks` and can be used to define new tasks by transformations which preserve mutual information.

### Fine distributions
[Subpackage](fine-distributions.md) implementing distributions in which the ground-truth mutual information may not be known analytically, but can be efficiently approximated using Monte Carlo methods.

## Interfaces
[Interfaces](interfaces.md) defines the main interfaces used in the package.
3 changes: 3 additions & 0 deletions docs/api/interfaces.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,3 +2,6 @@
This section lists the most important interfaces used in the package.

::: bmi.interface.IMutualInformationPointEstimator


::: bmi.interface.ISampler
15 changes: 15 additions & 0 deletions docs/api/samplers.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@

Samplers represent probability distributions with known mutual information.

## Simple distributions

::: bmi.samplers.SplitMultinormal

::: bmi.samplers.SplitStudentT
Expand All @@ -10,11 +12,24 @@ Samplers represent probability distributions with known mutual information.

::: bmi.samplers.BivariateNormalSampler

## Combining and transforming samplers

::: bmi.samplers.IndependentConcatenationSampler

::: bmi.samplers.TransformedSampler

## Discrete random variables

::: bmi.samplers.DiscreteUniformMixtureSampler

::: bmi.samplers.MultivariateDiscreteUniformMixtureSampler

::: bmi.samplers.ZeroInflatedPoissonizationSampler

## Fine distributions

See the [fine distributions subpackage API](fine-distributions.md) for more information.

### Auxiliary

::: bmi.samplers.BaseSampler
5 changes: 4 additions & 1 deletion docs/estimators.md
Original file line number Diff line number Diff line change
Expand Up @@ -75,10 +75,13 @@ The API is [here](api/estimators.md).

### How can I add a new estimator?
Thank you for considering contributing to this project! Please, consult [contributing guidelines](contributing.md) and reach out to us on [GitHub](https://github.com/cbg-ethz/bmi/issues), so we can discuss the best way of adding the estimator to the package.

Generally, the following steps are required:

1. Implement the interface [`IMutualInformationPointEstimator`](api/interfaces.md#bmi.interface.IMutualInformationPointEstimator) in a new file inside `src/bmi/estimators` directory. The unit tests should be added in `tests/estimators` directory.
2. Export the new estimator to the public API by adding an entry in `src/bmi/estimators/__init__.py`.
3. Export the docstring of new estimator to `docs/api/estimators.md`.
4. Add the estimator to the [list of estimators](#list-of-estimators).
4. Add the estimator to the [list of estimators](#list-of-estimators) and [ReadMe](index.md#list-of-estimators)


\bibliography
3 changes: 3 additions & 0 deletions mkdocs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,9 @@ theme:
name: material
features:
- navigation.tabs
- navigation.indexes
- navigation.sections
- navigation.top
- toc.integrate
- search.suggest
- search.highlight
Expand Down Expand Up @@ -49,6 +51,7 @@ nav:
- Contributing: contributing.md
- API: api/index.md


watch:
- src/bmi

Expand Down
25 changes: 10 additions & 15 deletions src/bmi/estimators/_kde.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,17 +42,12 @@ class DifferentialEntropies(BaseModel):
class KDEMutualInformationEstimator(IMutualInformationPointEstimator):
"""The kernel density mutual information estimator based on
.. math::
$I(X; Y) = h(X) + h(Y) - h(X, Y)$,
I(X; Y) = h(X) + h(Y) - h(X, Y),
where $h(X)$ is the differential entropy
$h(X) = -\\mathbb{E}[ \\log p(X) ]$.
where :math:`h(X)` is the differential entropy
.. math::
h(X) = -\\mathbb{E}[ \\log p(X) ].
The logarithm of probability density function :math:`\\log p(X)`
The logarithm of probability density function $\\log p(X)$
is estimated via a kernel density estimator (KDE) using SciKit-Learn.
Note:
Expand All @@ -74,16 +69,16 @@ def __init__(
Args:
kernel_xy: kernel to be used for joint distribution
PDF :math:`p_{XY}` estimation.
See SciKit-Learn's ``KernelDensity`` object for more information.
PDF $p_{XY}$ estimation.
See SciKit-Learn's ``KernelDensity`` object for more information.
kernel_x: kernel to be used for the :math:`p_X` estimation.
If ``None`` (default), ``kernel_xy`` will be used.
If ``None`` (default), ``kernel_xy`` will be used.
kernel_y: similarly to ``kernel_x``.
bandwidth_xy: kernel bandwidth to be used for joint distribution
estimation.
estimation.
bandwidth_x: kernel bandwidth to be used
for :math:`p_X` estimation.
If set to None (default), then ``bandwidth_xy`` is used.
for $p_X$ estimation.
If set to None (default), then ``bandwidth_xy`` is used.
bandwidth_y: similar to ``bandwidth_x``
standardize: whether to standardize the data points
"""
Expand Down
8 changes: 4 additions & 4 deletions src/bmi/estimators/ksg.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,17 +50,17 @@ def __init__(
Args:
neighborhoods: sequence of positive integers,
specifying the size of neighborhood for MI calculation
specifying the size of neighborhood for MI calculation
standardize: whether to standardize the data before MI calculation, by default true
metric_x: metric on the X space
metric_y: metric on the Y space. If None, `metric_x` will be used
n_jobs: number of jobs to be launched to compute distances.
Use -1 to use all processors.
Use -1 to use all processors.
chunk_size: internal batch size, used to speed up the computations while fitting
into the memory
into the memory
Note:
If you use Chebyshev (l-infinity) distance for both X and Y,
If you use Chebyshev ($\\l_\\infty$) distance for both $X$ and $Y$ spaces,
`KSGChebyshevEstimator` may be faster.
"""

Expand Down
2 changes: 1 addition & 1 deletion src/bmi/samplers/_split_student_t.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ def covariance(self) -> np.ndarray:
array, shape `(dim_x+dim_y, dim_x+dim_y)`
Raises:
ValueError, if covariance is not defined (for `df` $\\le 2$)
ValueError: if covariance is not defined (for `df` $\\le 2$)
"""
if self.df <= 2:
raise ValueError(
Expand Down
2 changes: 1 addition & 1 deletion src/bmi/samplers/_splitmultinormal.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ def __init__(
dim_x: dimension of the X space
dim_y: dimension of the Y space
mean: mean vector, shape `(n,)` where `n = dim_x + dim_y`.
Default: zero vector
Default: zero vector
covariance: covariance matrix, shape (n, n)
"""
super().__init__(dim_x=dim_x, dim_y=dim_y)
Expand Down
39 changes: 20 additions & 19 deletions src/bmi/samplers/_tfp/_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,17 +11,18 @@

@dataclasses.dataclass
class JointDistribution:
"""Represents a joint distribution of X and Y with known marginals.
This is the main object of this package.
"""The main object of this package.
Represents a joint distribution $P_{XY}$ together
with the marginal distributions $P_X$ and $P_Y$.
Attributes:
dist: tfd.Distribution, joint distribution of X and Y
dist_x: tfd.Distribution, marginal distribution of X
dist_y: tfd.Distribution, marginal distribution of Y
dim_x: dimension of the support of X
dim_y: dimension of the support of Y
dist: $P_{XY}$
dist_x: $P_X$
dist_y: $P_Y$
dim_x: dimension of the support of $X$
dim_y: dimension of the support of $Y$
analytic_mi: analytical mutual information.
Use `None` if unknown (in most cases)
Use `None` if unknown (in most cases)
"""

dist_joint: tfd.Distribution
Expand All @@ -34,7 +35,7 @@ class JointDistribution:
def sample(
self, n_points: int, key: jax.random.PRNGKeyArray
) -> tuple[jnp.ndarray, jnp.ndarray]:
"""Sample from the joint distribution.
"""Sample from the joint distribution $P_{XY}$.
Args:
n_points: number of samples to draw
Expand All @@ -55,7 +56,7 @@ def pmi(self, x: jnp.ndarray, y: jnp.ndarray) -> jnp.ndarray:
Returns:
pointwise mutual information evaluated at (x, y) points,
shape `(n_points,)`
shape `(n_points,)`
Note:
This function is vectorized, i.e. it can calculate PMI for multiple points at once.
Expand All @@ -75,8 +76,8 @@ def mixture(
Args:
proportions: mixture proportions should be positive and sum up to 1,
shape `(n_components,)`
components: sequence of `JointDistribution` objects which will be mixed
shape `(n_components,)`
components: sequence of distributions to be mixed
Returns:
mixture distribution
Expand Down Expand Up @@ -120,13 +121,13 @@ def transform(
x_transform: Optional[tfb.Bijector] = None,
y_transform: Optional[tfb.Bijector] = None,
) -> JointDistribution:
"""For given diffeomorphisms `f` and `g` transforms the joint distribution P_{XY}
into P_{f(X)g(Y)}.
"""For given diffeomorphisms $f$ and $g$ transforms the joint distribution $P_{XY}$
into $P_{f(X)g(Y)}$.
Args:
dist: distribution to be transformed
x_transform: diffeomorphism to transform X. Defaults to identity.
y_transform: diffeomorphism to transform Y. Defaults to identity.
x_transform: diffeomorphism $f$ to transform $X$. Defaults to identity.
y_transform: diffeomorphism $g$ to transform $Y$. Defaults to identity.
Returns:
transformed distribution
Expand Down Expand Up @@ -169,11 +170,11 @@ def pmi_profile(key: jax.random.PRNGKeyArray, dist: JointDistribution, n: int) -
def monte_carlo_mi_estimate(
key: jax.random.PRNGKeyArray, dist: JointDistribution, n: int
) -> tuple[float, float]:
"""Estimates the mutual information between X and Y using Monte Carlo sampling.
"""Estimates the mutual information $I(X; Y)$ using Monte Carlo sampling.
Returns:
float, mutual information estimate
float, standard error estimate
mutual information estimate
standard error estimate
Note:
It is worth to run this procedure multiple times and see whether
Expand Down
12 changes: 9 additions & 3 deletions src/bmi/samplers/_tfp/_normal.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
def construct_multivariate_normal_distribution(
mean: jnp.ndarray, covariance: jnp.ndarray
) -> tfd.MultivariateNormalLinearOperator:
"""Constructs a multivariate normal distribution."""
# Lower triangular matrix such that `covariance = scale @ scale^T`
scale = jnp.linalg.cholesky(covariance)
return tfd.MultivariateNormalLinearOperator(
Expand All @@ -23,16 +24,21 @@ def construct_multivariate_normal_distribution(


class MultivariateNormalDistribution(JointDistribution):
"""Multivariate normal distribution $P_{XY}$,
such that $P_X$ is a multivariate normal distribution on the space
of dimension `dim_x` and $P_Y$ is a multivariate normal distribution
on the space of dimension `dim_y`."""

def __init__(
self, *, dim_x: int, dim_y: int, covariance: ArrayLike, mean: Optional[ArrayLike] = None
) -> None:
"""
Args:
dim_x: dimension of the X space
dim_y: dimension of the Y space
dim_x: dimension of the $X$ support
dim_y: dimension of the $Y$ support
mean: mean vector, shape `(n,)` where `n = dim_x + dim_y`.
Default: zero vector
Default: zero vector
covariance: covariance matrix, shape (n, n)
"""
# The default mean vector is zero
Expand Down
19 changes: 9 additions & 10 deletions src/bmi/samplers/_tfp/_product.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,20 +7,19 @@


class ProductDistribution(JointDistribution):
"""From distributions P_X and P_Y creates a distribution
"""From distributions $P_X$ and $P_Y$ creates a distribution $P_{XY} = P_X \\otimes P_Y$,
so that $X$ and $Y$ are independent.
P_{XY} = P_X x P_Y
in which the variables X and Y are independent.
In particular,
I(X; Y) = 0
under this distribution.
In particular, $I(X; Y) = 0$.
"""

def __init__(self, dist_x: tfd.Distribution, dist_y: tfd.Distribution) -> None:
"""Creates a product distribution.
Args:
dist_x: distribution $P_X$
dist_y: distribution $P_Y$
"""
dims_x = dist_x.event_shape_tensor()
dims_y = dist_y.event_shape_tensor()

Expand Down
Loading

0 comments on commit e4083bd

Please sign in to comment.