diff --git a/README.md b/README.md index 790ed165..8829f510 100644 --- a/README.md +++ b/README.md @@ -65,3 +65,10 @@ If you find this code useful in your research, consider citing [our manuscript]( } ``` +## List of estimators + +- The neighborhood-based KSG estimator proposed in [Estimating Mutual Information](https://arxiv.org/abs/cond-mat/0305641) by Kraskov et al. (2003). +- Donsker-Varadhan and MINE estimators proposed in [MINE: Mutual Information Neural Estimation](https://arxiv.org/abs/1801.04062) by Belghazi et al. (2018). +- InfoNCE estimator proposed in [Representation Learning with Contrastive Predictive Coding](https://arxiv.org/abs/1807.03748) by Oord et al. (2018). +- NWJ estimator proposed in [Estimating divergence functionals and the likelihood ratio by convex risk minimization](https://arxiv.org/abs/0809.0853) by Nguyen et al. (2008). +- Estimator based on canonical correlation analysis described in [Feature discovery under contextual supervision using mutual information](https://ieeexplore.ieee.org/document/227286) by Kay (1992) and in [Some data analyses using mutual information](https://www.jstor.org/stable/43601047) by Brillinger (2004). diff --git a/docs/api/fine-distributions.md b/docs/api/fine-distributions.md new file mode 100644 index 00000000..fca737c0 --- /dev/null +++ b/docs/api/fine-distributions.md @@ -0,0 +1,25 @@ +# Fine distributions + +## Core utilities + +::: bmi.samplers._tfp.JointDistribution + +::: bmi.samplers._tfp.monte_carlo_mi_estimate + +::: bmi.samplers._tfp.pmi_profile + +::: bmi.samplers._tfp.transform + +::: bmi.samplers._tfp.ProductDistribution + +::: bmi.samplers._tfp.FineSampler + +## Basic distributions + +::: bmi.samplers._tfp.construct_multivariate_normal_distribution + +::: bmi.samplers._tfp.MultivariateNormalDistribution + +::: bmi.samplers._tfp.construct_multivariate_student_distribution + +::: bmi.samplers._tfp.MultivariateStudentDistribution \ No newline at end of file diff --git a/docs/api/index.md b/docs/api/index.md index f63dccc3..6ccbc15c 100644 --- a/docs/api/index.md +++ b/docs/api/index.md @@ -1,6 +1,19 @@ # API -- [Tasks](tasks.md) represent named probability distributions which are used in the benchmark. -- [Estimators](estimators.md) are the implemented mutual information estimators. -- [Samplers](samplers.md) represent joint probability distributions with known mutual information from which one can sample. They are lower level than `Tasks` and can be used to define new tasks by transformations which preserve mutual information. -- [Interfaces](interfaces.md) defines the main interfaces used in the package. +## Tasks + +[Tasks](tasks.md) represent named probability distributions which are used in the benchmark. + +## Estimators + +[Estimators](estimators.md) are the implemented mutual information estimators. + +## Samplers + +[Samplers](samplers.md) represent joint probability distributions with known mutual information from which one can sample. They are lower level than `Tasks` and can be used to define new tasks by transformations which preserve mutual information. + +### Fine distributions +[Subpackage](fine-distributions.md) implementing distributions in which the ground-truth mutual information may not be known analytically, but can be efficiently approximated using Monte Carlo methods. + +## Interfaces +[Interfaces](interfaces.md) defines the main interfaces used in the package. diff --git a/docs/api/interfaces.md b/docs/api/interfaces.md index 33a254a1..8808e1dc 100644 --- a/docs/api/interfaces.md +++ b/docs/api/interfaces.md @@ -2,3 +2,6 @@ This section lists the most important interfaces used in the package. ::: bmi.interface.IMutualInformationPointEstimator + + +::: bmi.interface.ISampler \ No newline at end of file diff --git a/docs/api/samplers.md b/docs/api/samplers.md index 08fe72b2..10c001e9 100644 --- a/docs/api/samplers.md +++ b/docs/api/samplers.md @@ -2,6 +2,8 @@ Samplers represent probability distributions with known mutual information. +## Simple distributions + ::: bmi.samplers.SplitMultinormal ::: bmi.samplers.SplitStudentT @@ -10,11 +12,24 @@ Samplers represent probability distributions with known mutual information. ::: bmi.samplers.BivariateNormalSampler +## Combining and transforming samplers + ::: bmi.samplers.IndependentConcatenationSampler +::: bmi.samplers.TransformedSampler + +## Discrete random variables + ::: bmi.samplers.DiscreteUniformMixtureSampler ::: bmi.samplers.MultivariateDiscreteUniformMixtureSampler ::: bmi.samplers.ZeroInflatedPoissonizationSampler +## Fine distributions + +See the [fine distributions subpackage API](fine-distributions.md) for more information. + +### Auxiliary + +::: bmi.samplers.BaseSampler diff --git a/docs/estimators.md b/docs/estimators.md index f4859814..2c4a36eb 100644 --- a/docs/estimators.md +++ b/docs/estimators.md @@ -75,10 +75,13 @@ The API is [here](api/estimators.md). ### How can I add a new estimator? Thank you for considering contributing to this project! Please, consult [contributing guidelines](contributing.md) and reach out to us on [GitHub](https://github.com/cbg-ethz/bmi/issues), so we can discuss the best way of adding the estimator to the package. + Generally, the following steps are required: + 1. Implement the interface [`IMutualInformationPointEstimator`](api/interfaces.md#bmi.interface.IMutualInformationPointEstimator) in a new file inside `src/bmi/estimators` directory. The unit tests should be added in `tests/estimators` directory. 2. Export the new estimator to the public API by adding an entry in `src/bmi/estimators/__init__.py`. 3. Export the docstring of new estimator to `docs/api/estimators.md`. -4. Add the estimator to the [list of estimators](#list-of-estimators). +4. Add the estimator to the [list of estimators](#list-of-estimators) and [ReadMe](index.md#list-of-estimators) + \bibliography \ No newline at end of file diff --git a/mkdocs.yml b/mkdocs.yml index 3c3d54b2..85b1c694 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -5,7 +5,9 @@ theme: name: material features: - navigation.tabs + - navigation.indexes - navigation.sections + - navigation.top - toc.integrate - search.suggest - search.highlight @@ -49,6 +51,7 @@ nav: - Contributing: contributing.md - API: api/index.md + watch: - src/bmi diff --git a/src/bmi/estimators/_kde.py b/src/bmi/estimators/_kde.py index e7b5abb8..088d7cc6 100644 --- a/src/bmi/estimators/_kde.py +++ b/src/bmi/estimators/_kde.py @@ -42,17 +42,12 @@ class DifferentialEntropies(BaseModel): class KDEMutualInformationEstimator(IMutualInformationPointEstimator): """The kernel density mutual information estimator based on - .. math:: + $I(X; Y) = h(X) + h(Y) - h(X, Y)$, - I(X; Y) = h(X) + h(Y) - h(X, Y), + where $h(X)$ is the differential entropy + $h(X) = -\\mathbb{E}[ \\log p(X) ]$. - where :math:`h(X)` is the differential entropy - - .. math:: - - h(X) = -\\mathbb{E}[ \\log p(X) ]. - - The logarithm of probability density function :math:`\\log p(X)` + The logarithm of probability density function $\\log p(X)$ is estimated via a kernel density estimator (KDE) using SciKit-Learn. Note: @@ -74,16 +69,16 @@ def __init__( Args: kernel_xy: kernel to be used for joint distribution - PDF :math:`p_{XY}` estimation. - See SciKit-Learn's ``KernelDensity`` object for more information. + PDF $p_{XY}$ estimation. + See SciKit-Learn's ``KernelDensity`` object for more information. kernel_x: kernel to be used for the :math:`p_X` estimation. - If ``None`` (default), ``kernel_xy`` will be used. + If ``None`` (default), ``kernel_xy`` will be used. kernel_y: similarly to ``kernel_x``. bandwidth_xy: kernel bandwidth to be used for joint distribution - estimation. + estimation. bandwidth_x: kernel bandwidth to be used - for :math:`p_X` estimation. - If set to None (default), then ``bandwidth_xy`` is used. + for $p_X$ estimation. + If set to None (default), then ``bandwidth_xy`` is used. bandwidth_y: similar to ``bandwidth_x`` standardize: whether to standardize the data points """ diff --git a/src/bmi/estimators/ksg.py b/src/bmi/estimators/ksg.py index 601ad760..a722b3b1 100644 --- a/src/bmi/estimators/ksg.py +++ b/src/bmi/estimators/ksg.py @@ -50,17 +50,17 @@ def __init__( Args: neighborhoods: sequence of positive integers, - specifying the size of neighborhood for MI calculation + specifying the size of neighborhood for MI calculation standardize: whether to standardize the data before MI calculation, by default true metric_x: metric on the X space metric_y: metric on the Y space. If None, `metric_x` will be used n_jobs: number of jobs to be launched to compute distances. - Use -1 to use all processors. + Use -1 to use all processors. chunk_size: internal batch size, used to speed up the computations while fitting - into the memory + into the memory Note: - If you use Chebyshev (l-infinity) distance for both X and Y, + If you use Chebyshev ($\\l_\\infty$) distance for both $X$ and $Y$ spaces, `KSGChebyshevEstimator` may be faster. """ diff --git a/src/bmi/samplers/_split_student_t.py b/src/bmi/samplers/_split_student_t.py index f62f80e6..acb8ff1d 100644 --- a/src/bmi/samplers/_split_student_t.py +++ b/src/bmi/samplers/_split_student_t.py @@ -118,7 +118,7 @@ def covariance(self) -> np.ndarray: array, shape `(dim_x+dim_y, dim_x+dim_y)` Raises: - ValueError, if covariance is not defined (for `df` $\\le 2$) + ValueError: if covariance is not defined (for `df` $\\le 2$) """ if self.df <= 2: raise ValueError( diff --git a/src/bmi/samplers/_splitmultinormal.py b/src/bmi/samplers/_splitmultinormal.py index 7bbf6d11..14659231 100644 --- a/src/bmi/samplers/_splitmultinormal.py +++ b/src/bmi/samplers/_splitmultinormal.py @@ -99,7 +99,7 @@ def __init__( dim_x: dimension of the X space dim_y: dimension of the Y space mean: mean vector, shape `(n,)` where `n = dim_x + dim_y`. - Default: zero vector + Default: zero vector covariance: covariance matrix, shape (n, n) """ super().__init__(dim_x=dim_x, dim_y=dim_y) diff --git a/src/bmi/samplers/_tfp/_core.py b/src/bmi/samplers/_tfp/_core.py index 9ff6f256..a1f2e6f1 100644 --- a/src/bmi/samplers/_tfp/_core.py +++ b/src/bmi/samplers/_tfp/_core.py @@ -11,17 +11,18 @@ @dataclasses.dataclass class JointDistribution: - """Represents a joint distribution of X and Y with known marginals. - This is the main object of this package. + """The main object of this package. + Represents a joint distribution $P_{XY}$ together + with the marginal distributions $P_X$ and $P_Y$. Attributes: - dist: tfd.Distribution, joint distribution of X and Y - dist_x: tfd.Distribution, marginal distribution of X - dist_y: tfd.Distribution, marginal distribution of Y - dim_x: dimension of the support of X - dim_y: dimension of the support of Y + dist: $P_{XY}$ + dist_x: $P_X$ + dist_y: $P_Y$ + dim_x: dimension of the support of $X$ + dim_y: dimension of the support of $Y$ analytic_mi: analytical mutual information. - Use `None` if unknown (in most cases) + Use `None` if unknown (in most cases) """ dist_joint: tfd.Distribution @@ -34,7 +35,7 @@ class JointDistribution: def sample( self, n_points: int, key: jax.random.PRNGKeyArray ) -> tuple[jnp.ndarray, jnp.ndarray]: - """Sample from the joint distribution. + """Sample from the joint distribution $P_{XY}$. Args: n_points: number of samples to draw @@ -55,7 +56,7 @@ def pmi(self, x: jnp.ndarray, y: jnp.ndarray) -> jnp.ndarray: Returns: pointwise mutual information evaluated at (x, y) points, - shape `(n_points,)` + shape `(n_points,)` Note: This function is vectorized, i.e. it can calculate PMI for multiple points at once. @@ -75,8 +76,8 @@ def mixture( Args: proportions: mixture proportions should be positive and sum up to 1, - shape `(n_components,)` - components: sequence of `JointDistribution` objects which will be mixed + shape `(n_components,)` + components: sequence of distributions to be mixed Returns: mixture distribution @@ -120,13 +121,13 @@ def transform( x_transform: Optional[tfb.Bijector] = None, y_transform: Optional[tfb.Bijector] = None, ) -> JointDistribution: - """For given diffeomorphisms `f` and `g` transforms the joint distribution P_{XY} - into P_{f(X)g(Y)}. + """For given diffeomorphisms $f$ and $g$ transforms the joint distribution $P_{XY}$ + into $P_{f(X)g(Y)}$. Args: dist: distribution to be transformed - x_transform: diffeomorphism to transform X. Defaults to identity. - y_transform: diffeomorphism to transform Y. Defaults to identity. + x_transform: diffeomorphism $f$ to transform $X$. Defaults to identity. + y_transform: diffeomorphism $g$ to transform $Y$. Defaults to identity. Returns: transformed distribution @@ -169,11 +170,11 @@ def pmi_profile(key: jax.random.PRNGKeyArray, dist: JointDistribution, n: int) - def monte_carlo_mi_estimate( key: jax.random.PRNGKeyArray, dist: JointDistribution, n: int ) -> tuple[float, float]: - """Estimates the mutual information between X and Y using Monte Carlo sampling. + """Estimates the mutual information $I(X; Y)$ using Monte Carlo sampling. Returns: - float, mutual information estimate - float, standard error estimate + mutual information estimate + standard error estimate Note: It is worth to run this procedure multiple times and see whether diff --git a/src/bmi/samplers/_tfp/_normal.py b/src/bmi/samplers/_tfp/_normal.py index 6666013f..a41bfef8 100644 --- a/src/bmi/samplers/_tfp/_normal.py +++ b/src/bmi/samplers/_tfp/_normal.py @@ -14,6 +14,7 @@ def construct_multivariate_normal_distribution( mean: jnp.ndarray, covariance: jnp.ndarray ) -> tfd.MultivariateNormalLinearOperator: + """Constructs a multivariate normal distribution.""" # Lower triangular matrix such that `covariance = scale @ scale^T` scale = jnp.linalg.cholesky(covariance) return tfd.MultivariateNormalLinearOperator( @@ -23,16 +24,21 @@ def construct_multivariate_normal_distribution( class MultivariateNormalDistribution(JointDistribution): + """Multivariate normal distribution $P_{XY}$, + such that $P_X$ is a multivariate normal distribution on the space + of dimension `dim_x` and $P_Y$ is a multivariate normal distribution + on the space of dimension `dim_y`.""" + def __init__( self, *, dim_x: int, dim_y: int, covariance: ArrayLike, mean: Optional[ArrayLike] = None ) -> None: """ Args: - dim_x: dimension of the X space - dim_y: dimension of the Y space + dim_x: dimension of the $X$ support + dim_y: dimension of the $Y$ support mean: mean vector, shape `(n,)` where `n = dim_x + dim_y`. - Default: zero vector + Default: zero vector covariance: covariance matrix, shape (n, n) """ # The default mean vector is zero diff --git a/src/bmi/samplers/_tfp/_product.py b/src/bmi/samplers/_tfp/_product.py index c4adfef0..0b09d575 100644 --- a/src/bmi/samplers/_tfp/_product.py +++ b/src/bmi/samplers/_tfp/_product.py @@ -7,20 +7,19 @@ class ProductDistribution(JointDistribution): - """From distributions P_X and P_Y creates a distribution + """From distributions $P_X$ and $P_Y$ creates a distribution $P_{XY} = P_X \\otimes P_Y$, + so that $X$ and $Y$ are independent. - P_{XY} = P_X x P_Y - - in which the variables X and Y are independent. - - In particular, - - I(X; Y) = 0 - - under this distribution. + In particular, $I(X; Y) = 0$. """ def __init__(self, dist_x: tfd.Distribution, dist_y: tfd.Distribution) -> None: + """Creates a product distribution. + + Args: + dist_x: distribution $P_X$ + dist_y: distribution $P_Y$ + """ dims_x = dist_x.event_shape_tensor() dims_y = dist_y.event_shape_tensor() diff --git a/src/bmi/samplers/_tfp/_student.py b/src/bmi/samplers/_tfp/_student.py index fca47227..31172530 100644 --- a/src/bmi/samplers/_tfp/_student.py +++ b/src/bmi/samplers/_tfp/_student.py @@ -16,6 +16,13 @@ def construct_multivariate_student_distribution( dispersion: jnp.ndarray, df: Union[int, float], ) -> tfd.MultivariateStudentTLinearOperator: + """Constructs a multivariate Student distribution. + + Args: + mean: location vector, shape `(dim,)` + dispersion: dispersion matrix, shape `(dim, dim)` + df: degrees of freedom + """ # Lower triangular matrix such that `dispersion = scale @ scale^T` scale = jnp.linalg.cholesky(dispersion) return tfd.MultivariateStudentTLinearOperator( @@ -26,6 +33,14 @@ def construct_multivariate_student_distribution( class MultivariateStudentDistribution(JointDistribution): + """Multivariate Student distribution $P_{XY}$, + such that $P_X$ is a multivariate Student distribution on the space + of dimension `dim_x` and $P_Y$ is a multivariate Student distribution + on the space of dimension `dim_y`. + + Note that the degrees of freedom `df` are the same for all distributions. + """ + def __init__( self, *, @@ -38,12 +53,12 @@ def __init__( """ Args: - dim_x: dimension of the X space - dim_y: dimension of the Y space + dim_x: dimension of the $X$ support + dim_y: dimension of the $Y$ support df: degrees of freedom mean: mean vector, shape `(n,)` where `n = dim_x + dim_y`. - Default: zero vector - dispersion: dispersion matrix, shape (n, n) + Default: zero vector + dispersion: dispersion matrix, shape `(n, n)` """ # The default mean vector is zero if mean is None: diff --git a/src/bmi/samplers/_tfp/_wrapper.py b/src/bmi/samplers/_tfp/_wrapper.py index 3cc386ee..6183d95d 100644 --- a/src/bmi/samplers/_tfp/_wrapper.py +++ b/src/bmi/samplers/_tfp/_wrapper.py @@ -8,7 +8,7 @@ class FineSampler(BaseSampler): - """Wrapper around a fine distribution.""" + """Wraps a given fine distribution into a sampler.""" def __init__( self, diff --git a/src/bmi/samplers/_transformed.py b/src/bmi/samplers/_transformed.py index c746f58b..1d608390 100644 --- a/src/bmi/samplers/_transformed.py +++ b/src/bmi/samplers/_transformed.py @@ -19,7 +19,8 @@ def identity(x: _T) -> _T: class TransformedSampler(base.BaseSampler): - """Pushforward of a distribution $P_{XY}$ + """ + Pushforward of a distribution $P_{XY}$ via a product mapping $f \\times g$. @@ -45,21 +46,21 @@ def __init__( ) -> None: """ Args: - base_sampler: allows sampling from P(X, Y) - transform_x: diffeomorphism f, so that we have variable f(X). - By default the identity mapping. - transform_y: diffeomorphism g, so that we have variable g(Y). - By default the identity mapping. - add_dim_x: the difference in dimensions of the output of `f` and its input. - Note that for any diffeomorphism it must be zero - add_dim_y: similarly as `add_dim_x`, but for `g` + base_sampler: allows sampling from $P(X, Y)$ + transform_x: diffeomorphism $f$, + so that we have variable $f(X)$. By default the identity mapping. + transform_y: diffeomorphism $g$, + so that we have variable $g(Y)$. By default the identity mapping. + add_dim_x: the difference in dimensions of the output of $f$ and its input. + Note that for any diffeomorphism it must be zero + add_dim_y: similarly as `add_dim_x`, but for $g$. vectorise: whether to use `jax.vmap` to vectorise transforms. If False, provided `transform_X` and `transform_Y` need to already be vectorized. Note: - If you don't use diffeomorphisms (in particular, - non-default `add_dim_x` or `add_dim_y`), overwrite the - `mutual_information()` method + If you don't use diffeomorphisms (in particular, + non-default `add_dim_x` or `add_dim_y`), overwrite the + `mutual_information()` method """ if add_dim_x < 0 or add_dim_y < 0: raise ValueError("Transformed samplers cannot decrease dimensionality.") @@ -92,12 +93,12 @@ def transform(self, x: SomeArray, y: SomeArray) -> tuple[jnp.ndarray, jnp.ndarra return self._vectorized_transform_x(x), self._vectorized_transform_y(y) def sample(self, n_points: int, rng: Union[int, KeyArray]) -> tuple[jnp.ndarray, jnp.ndarray]: - """Samples from P(f(X), g(Y)). + """Samples from the distribution $P(f(X), g(Y))$. Returns: paired samples - from f(X), shape (n_points, dim(X) + add_dim_x) and - from g(Y), shape (n_points, dim(Y) + add_dim_y) + from $f(X)$, shape `(n_points, dim(X) + add_dim_x)` and + from $g(Y)$, shape `(n_points, dim(Y) + add_dim_y)` """ x, y = self._base_sampler.sample(n_points=n_points, rng=rng) return self.transform(x, y)