Skip to content

Commit

Permalink
Fitting Gaussian mixture models and plot styling (#121)
Browse files Browse the repository at this point in the history
  • Loading branch information
pawel-czyz authored Sep 23, 2023
1 parent c9b117a commit d45509a
Show file tree
Hide file tree
Showing 5 changed files with 532 additions and 132 deletions.
152 changes: 30 additions & 122 deletions workflows/Mixtures/cool_tasks.smk
Original file line number Diff line number Diff line change
Expand Up @@ -10,112 +10,7 @@ import jax.numpy as jnp
import bmi
from bmi.samplers import fine

# --- Define samplers ---

# The X distribution
x_dist = fine.mixture(
proportions=jnp.array([0.5, 0.5]),
components=[
fine.MultivariateNormalDistribution(
covariance=0.3 * bmi.samplers.canonical_correlation([x * 0.9]),
mean=jnp.zeros(2),
dim_x=1, dim_y=1,
) for x in [-1, 1]
]
)
x_sampler = fine.FineSampler(x_dist)

# The fence distribution
n_components = 12

fence_base_dist = fine.mixture(
proportions=jnp.ones(n_components) / n_components,
components=[
fine.MultivariateNormalDistribution(
covariance=jnp.diag(jnp.array([0.1, 1.0, 0.1])),
mean=jnp.array([x, 0, x%4]) * 1.5,
dim_x=2, dim_y=1,
) for x in range(n_components)
]
)
base_sampler = fine.FineSampler(fence_base_dist)
fence_aux_sampler = bmi.samplers.TransformedSampler(
base_sampler,
transform_x=lambda x: x + jnp.array([5., 0.]) * jnp.sin(3 * x[1]),
)
fence_sampler = bmi.samplers.TransformedSampler(
fence_aux_sampler,
transform_x=lambda x: jnp.array([0.1 * x[0]-0.8, 0.5 * x[1]])
)

# The AI distribution
corr = 0.95
var_x = 0.04

ai_dist = fine.mixture(
proportions=np.full(6, fill_value=1/6),
components=[
# I components
fine.MultivariateNormalDistribution(
dim_x=1, dim_y=1,
mean=np.array([1., 0.]),
covariance=np.diag([0.01, 0.2]),
),
fine.MultivariateNormalDistribution(
dim_x=1, dim_y=1,
mean=np.array([1., 1]),
covariance=np.diag([0.05, 0.001]),
),
fine.MultivariateNormalDistribution(
dim_x=1, dim_y=1,
mean=np.array([1., -1]),
covariance=np.diag([0.05, 0.001]),
),
# A components
fine.MultivariateNormalDistribution(
dim_x=1, dim_y=1,
mean=np.array([-0.8, -0.2]),
covariance=np.diag([0.03, 0.001]),
),
fine.MultivariateNormalDistribution(
dim_x=1, dim_y=1,
mean=np.array([-1.2, 0.]),
covariance=np.array([[var_x, np.sqrt(var_x * 0.2) * corr], [np.sqrt(var_x * 0.2) * corr, 0.2]]),
),
fine.MultivariateNormalDistribution(
dim_x=1, dim_y=1,
mean=np.array([-0.4, 0.]),
covariance=np.array([[var_x, -np.sqrt(var_x * 0.2) * corr], [-np.sqrt(var_x * 0.2) * corr, 0.2]]),
),
]
)
ai_sampler = fine.FineSampler(ai_dist)

# Balls mixed with spiral

balls_mixt = fine.mixture(
proportions=jnp.array([0.5, 0.5]),
components=[
fine.MultivariateNormalDistribution(
covariance=bmi.samplers.canonical_correlation([0.0], additional_y=1),
mean=jnp.array([x, x, x]) * 1.5,
dim_x=2, dim_y=1,
) for x in [-1, 1]
]
)

base_balls_sampler = fine.FineSampler(balls_mixt)
a = jnp.array([[0, -1], [1, 0]])
spiral = bmi.transforms.Spiral(a, speed=0.5)

sampler_balls_aux = bmi.samplers.TransformedSampler(
base_balls_sampler,
transform_x=spiral
)
sampler_balls_transformed = bmi.samplers.TransformedSampler(
sampler_balls_aux,
transform_x=lambda x: 0.3 * x,
)
import example_distributions as ed


N_SAMPLES = [1_000, 5_000 ]
Expand All @@ -139,6 +34,13 @@ ESTIMATOR_NAMES = {
}
assert set(ESTIMATOR_NAMES.keys()) == set(ESTIMATORS.keys())

_SAMPLE_ESTIMATE: int = 200_000

x_sampler = ed.create_x_distribution(_sample=_SAMPLE_ESTIMATE).sampler
ai_sampler = ed.create_ai_distribution(_sample=_SAMPLE_ESTIMATE).sampler
waves_sampler = ed.create_waves_distribution(_sample=_SAMPLE_ESTIMATE).sampler
galaxy_sampler = ed.create_galaxy_distribution(_sample=_SAMPLE_ESTIMATE).sampler

UNSCALED_TASKS = {
"X": bmi.benchmark.Task(
sampler=x_sampler,
Expand All @@ -151,12 +53,12 @@ UNSCALED_TASKS = {
task_name="AI",
),
"Fence": bmi.benchmark.Task(
sampler=fence_sampler,
sampler=waves_sampler,
task_id="Fence",
task_name="Fence",
),
"Balls": bmi.benchmark.Task(
sampler=sampler_balls_transformed,
sampler=galaxy_sampler,
task_id="Balls",
task_name="Balls",
),
Expand All @@ -170,49 +72,55 @@ rule all:
input:
'cool_tasks.pdf',
'results.csv',
'results.pdf',
'cool_tasks-results.pdf',
'profiles.pdf'

rule plot_distributions:
output: "cool_tasks.pdf"
run:
fig, axs = subplots_from_axsize(1, 4, axsize=(3, 3))
fig, axs = subplots_from_axsize(1, 4, axsize=(1.5, 1.5), wspace=0.4)

# Plot the X distribution
ax = axs[0]
xs, ys = x_sampler.sample(1000, 0)

ax.scatter(xs[:, 0], ys[:, 0], s=4**2, alpha=0.3, color="k", rasterized=True)
size = 2**2

ax.scatter(xs[:, 0], ys[:, 0], s=size, alpha=0.3, color="k", rasterized=True)
ax.set_xlabel("$X$")
ax.set_ylabel("$Y$")

# Plot the AI distribution
ax = axs[1]
xs, ys = ai_sampler.sample(2000, 0)
ax.scatter(xs[:, 0], ys[:, 0], s=4**2, alpha=0.3, color="k", rasterized=True)
ax.scatter(xs[:, 0], ys[:, 0], s=size, alpha=0.3, color="k", rasterized=True)
ax.set_xlabel("$X$")
ax.set_ylabel("$Y$")

# Plot the fence distribution
ax = axs[2]
xs, ys = fence_sampler.sample(2000, 0)
xs, ys = waves_sampler.sample(2000, 0)

ax.scatter(xs[:, 0], xs[:, 1], c=ys[:, 0], s=4**2, alpha=0.3, rasterized=True)
ax.scatter(xs[:, 0], xs[:, 1], c=ys[:, 0], s=size, alpha=0.3, rasterized=True)
ax.set_xlabel("$X_1$")
ax.set_ylabel("$X_2$")

# Plot transformed balls distribution
ax = axs[3]
xs, ys = sampler_balls_transformed.sample(2000, 0)
ax.scatter(xs[:, 0], xs[:, 1], c=ys[:, 0], s=4**2, alpha=0.3, rasterized=True)
xs, ys = galaxy_sampler.sample(2000, 0)
ax.scatter(xs[:, 0], xs[:, 1], c=ys[:, 0], s=size, alpha=0.3, rasterized=True)
ax.set_xlabel("$X_1$")
ax.set_ylabel("$X_2$")

for ax in axs:
ticks = [-1, 0, 1]
ax.set_xticks(ticks, ticks)
ax.set_yticks(ticks, ticks)
ax.set_xlim(-2., 2.)
ax.set_ylim(-2., 2.)
ax.spines[['right', 'top']].set_visible(False)

fig.savefig(str(output))
fig.savefig(str(output), dpi=300)

rule plot_pmi_profiles:
output: "profiles.pdf"
Expand All @@ -232,11 +140,11 @@ rule plot_pmi_profiles:


rule plot_results:
output: 'results.pdf'
output: 'cool_tasks-results.pdf'
input: 'results.csv'
run:
data = pd.read_csv(str(input))
fig, ax = subplots_from_axsize(1, 1, (4, 3))
fig, ax = subplots_from_axsize(1, 1, (2, 1.5), right=1.3)

data_5k = data[data['n_samples'] == 5000]
tasks = ['X', 'AI', 'Fence', 'Balls']
Expand All @@ -247,7 +155,7 @@ rule plot_results:
data_est['task_id'].apply(lambda e: tasks.index(e)) + 0.05 * np.random.normal(size=len(data_est)),
data_est['mi_estimate'],
label=ESTIMATOR_NAMES[estimator_id],
alpha=0.4, s=5**2,
alpha=0.4, s=3**2,
)

for task_id, data_task in data_5k.groupby('task_id'):
Expand All @@ -257,10 +165,10 @@ rule plot_results:

ax.set_xticks(range(len(tasks)), tasks_official)

ax.legend(frameon=False, loc='upper left')
ax.legend(frameon=False, loc='upper left', bbox_to_anchor=(1, 1))
ax.spines[['top', 'right']].set_visible(False)
ax.set_ylim(-0.1, 1.4)
ax.set_ylabel('Mutual information [nats]')
ax.set_ylabel('MI')
fig.savefig(str(output))


Expand Down
12 changes: 9 additions & 3 deletions workflows/Mixtures/distinct_profiles.smk
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@ import jax.numpy as jnp
import bmi.samplers._tfp as bmi_tfp
from bmi.transforms import invert_cdf, normal_cdf

from subplots_from_axsize import subplots_from_axsize

mpl.use("Agg")


Expand Down Expand Up @@ -114,6 +116,8 @@ def hide_ticks(ax):
ax.set_yticks([])
ax.set_xlabel("$X$")
ax.set_ylabel("$Y$")
ax.spines[['right', 'top']].set_visible(False)


rule plot_samples:
input:
Expand All @@ -123,7 +127,7 @@ rule plot_samples:
output:
"figure_distinct_profiles.pdf"
run:
fig, axs = plt.subplots(1, 4, figsize=(8, 2))
fig, axs = plt.subplots(1, 4, figsize=(7, 2))

color1 = "navy"
color2 = "salmon"
Expand Down Expand Up @@ -159,15 +163,17 @@ rule plot_samples:
ax.hist(pmi_u, bins=bins, density=True, color=color2, alpha=0.5, label="Mixture")
ax.set_title("PMI profiles")
ax.set_xlabel("PMI")
ax.set_ylabel("Density")
ax.set_ylabel("")
ax.set_yticks([])
ax.spines[['right', 'top', 'left']].set_visible(False)

mi_1 = jnp.mean(pmi_normal)
mi_2 = jnp.mean(pmi_u)

if abs(mi_1 - mi_2) > 0.01:
raise ValueError(f"MI different: {mi_1:.2f} != {mi_2:.2f}")

ax.axvline(mi_1, c="k", linewidth=0.5, linestyle="--")
ax.axvline(mi_1, c="k", linewidth=1, linestyle="--")

fig.tight_layout()
fig.savefig(str(output))
Loading

0 comments on commit d45509a

Please sign in to comment.