Skip to content

Commit

Permalink
Improve plots (#123)
Browse files Browse the repository at this point in the history
  • Loading branch information
pawel-czyz authored Sep 24, 2023
1 parent d45509a commit dc6100d
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 8 deletions.
19 changes: 13 additions & 6 deletions workflows/Mixtures/how_good_integration_is.smk
Original file line number Diff line number Diff line change
Expand Up @@ -205,7 +205,7 @@ rule estimate_ground_truth_single_seed:
indent=4,
)

def plot_estimates(ax: plt.Axes, estimates_path, ground_truth_path) -> None:
def plot_estimates(ax: plt.Axes, estimates_path, ground_truth_path, alpha: float = 0.1) -> None:
df = pd.read_csv(estimates_path)
with open(ground_truth_path) as fh:
ground_truth = json.load(fh)
Expand Down Expand Up @@ -233,7 +233,7 @@ def plot_estimates(ax: plt.Axes, estimates_path, ground_truth_path) -> None:
color = ESTIMATOR_COLORS[estimator]

ax.plot(points, mean, color=color, label=estimator)
ax.fill_between(points, mean - std, mean + std, alpha=0.1, color=color)
ax.fill_between(points, mean - std, mean + std, alpha=alpha, color=color)


rule plot_performance_all:
Expand All @@ -251,21 +251,28 @@ rule plot_performance_all:
run:
fig, axs = subplots_from_axsize(1, 4, axsize=(2.5, 1.5), right=1.2, top=0.3)

y_min = 0.0
y_max = 1.0
alpha = 0.2

ax = axs[0]
ax.set_title("Mixture")
plot_estimates(ax, input.simple_estimates, input.simple_ground_truth)
ax.set_ylim(y_min, y_max)
plot_estimates(ax, input.simple_estimates, input.simple_ground_truth, alpha=alpha)

ax = axs[1]
ax.set_title("Constant bias")
plot_estimates(ax, input.biased_estimates, input.biased_ground_truth)
ax.set_ylim(y_min, y_max)
plot_estimates(ax, input.biased_estimates, input.biased_ground_truth, alpha=alpha)

ax = axs[2]
ax.set_title("Functional bias")
plot_estimates(ax, input.func_estimates, input.func_ground_truth)
ax.set_ylim(y_min, y_max)
plot_estimates(ax, input.func_estimates, input.func_ground_truth, alpha=alpha)

ax = axs[3]
ax.set_title("High-dimensional")
plot_estimates(ax, input.highdim_estimates, input.highdim_ground_truth)
plot_estimates(ax, input.highdim_estimates, input.highdim_ground_truth, alpha=alpha)


for ax in axs:
Expand Down
23 changes: 21 additions & 2 deletions workflows/Mixtures/outliers.smk
Original file line number Diff line number Diff line change
Expand Up @@ -40,11 +40,13 @@ class ChangeMixingSetup:
proportions=jnp.asarray([1.0 - alpha, alpha]),
)

signal_cov_parametrization = bmi.samplers.SparseLVMParametrization(dim_x=2, dim_y=2, n_interacting=2, beta=0.1, lambd=2.0)
signal_cov_matrix = signal_cov_parametrization.correlation

dist_signal_gauss = fine.MultivariateNormalDistribution(
dim_x=2,
dim_y=2,
covariance=bmi.samplers.SparseLVMParametrization(dim_x=2, dim_y=2, n_interacting=2, beta=0.1, lambd=2.0).correlation
covariance=signal_cov_matrix,
)

covariance_inlier = jnp.eye(dist_signal_gauss.dim_y)
Expand Down Expand Up @@ -132,7 +134,24 @@ rule all:
input:
mixing_ground_truths = expand("{setup}/mixing_ground_truths.done", setup=CHANGE_MIXING_SETUPS.keys()),
estimates = "results.csv",
outliers_plot = "outliers_plot.pdf"
outliers_plot = "outliers_plot.pdf",
parameters = "parameters.json",
covariance_heatmap = "covariance_heatmap.pdf"

rule plot_parameters:
output:
params_json = "parameters.json",
covariance_heatmap = "covariance_heatmap.pdf"
run:
with open(str(output.params_json), "w") as fh:
json.dump({
"signal_covariance": signal_cov_matrix.tolist(),
"signal_mutual_information": dist_signal_gauss.analytic_mi,
}, fp=fh, indent=4)
fig, ax = plt.subplots()
sns.heatmap(signal_cov_matrix, ax=ax, annot=True, fmt=".2f", cmap="coolwarm")
fig.tight_layout()
fig.savefig(str(output.covariance_heatmap))


def plot_data(ax: plt.Axes, data: pd.DataFrame, key: str = "mixing", use_legend: bool = False):
Expand Down

0 comments on commit dc6100d

Please sign in to comment.