Skip to content

Commit

Permalink
Fix wrapper
Browse files Browse the repository at this point in the history
  • Loading branch information
shreyas-arvindekar committed Mar 5, 2024
1 parent 3f9c0fc commit eecdd66
Showing 1 changed file with 17 additions and 11 deletions.
28 changes: 17 additions & 11 deletions pyext/src/wrapper_v6.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ def get_curr_processes_and_terminated_runs(processes: dict):
return processes, faulty_runs, successful_runs


def plotter(results: dict):
def plotter(results: dict, h_params):
all_log_z = {}
mean_proc_time = []
mean_per_step_time = []
Expand Down Expand Up @@ -87,7 +87,8 @@ def plotter(results: dict):
plt.ylabel("log(Evidence)")
plt.savefig(
os.path.join(
parent_path, f"trial_{h_params['trial_name']}_evidence_errorbarplot.png"
h_params["parent_path"],
f"trial_{h_params['trial_name']}_evidence_errorbarplot.png",
)
)

Expand All @@ -99,7 +100,9 @@ def plotter(results: dict):
plt.xlabel("Resolutions")
plt.ylabel("Nested sampling process time")
plt.savefig(
os.path.join(parent_path, f"trial_{h_params['trial_name']}_proctime.png")
os.path.join(
h_params["parent_path"], f"trial_{h_params['trial_name']}_proctime.png"
)
)

plt.figure(3)
Expand All @@ -110,7 +113,9 @@ def plotter(results: dict):
plt.xlabel("Resolutions")
plt.ylabel("Mean time per MCMC step")
plt.savefig(
os.path.join(parent_path, f"trial_{h_params['trial_name']}_persteptime.png")
os.path.join(
h_params["parent_path"], f"trial_{h_params['trial_name']}_persteptime.png"
)
)


Expand Down Expand Up @@ -278,19 +283,20 @@ def main(h_param_file, topology=True):
elif sys.argv[2] == "topology":
use_topology = True

with open(h_param_file, "r") as h_paramf:
h_params = yaml.safe_load(h_paramf)
if sys.argv[3] != "skip_calc":
main(h_param_file, use_topology)

else:
with open(h_param_file, "r") as h_paramf:
h_params = yaml.safe_load(h_paramf)

with open(
os.path.join(h_params["parent_path"], "nestor_output.yaml"), "r"
) as outf:
results = yaml.safe_load(outf)

if len(list(results.keys())) > 0:
plotter(results)
else:
print("\nNone of the runs was successful...!")
print("Done...!\n\n")
if len(list(results.keys())) > 0:
plotter(results, h_params)
else:
print("\nNone of the runs was successful...!")
print("Done...!\n\n")

0 comments on commit eecdd66

Please sign in to comment.