diff --git a/pyext/src/wrapper_v6.py b/pyext/src/wrapper_v6.py index 8cf6fc2..47b183c 100644 --- a/pyext/src/wrapper_v6.py +++ b/pyext/src/wrapper_v6.py @@ -54,7 +54,7 @@ def get_curr_processes_and_terminated_runs(processes: dict): return processes, faulty_runs, successful_runs -def plotter(results: dict): +def plotter(results: dict, h_params): all_log_z = {} mean_proc_time = [] mean_per_step_time = [] @@ -87,7 +87,8 @@ def plotter(results: dict): plt.ylabel("log(Evidence)") plt.savefig( os.path.join( - parent_path, f"trial_{h_params['trial_name']}_evidence_errorbarplot.png" + h_params["parent_path"], + f"trial_{h_params['trial_name']}_evidence_errorbarplot.png", ) ) @@ -99,7 +100,9 @@ def plotter(results: dict): plt.xlabel("Resolutions") plt.ylabel("Nested sampling process time") plt.savefig( - os.path.join(parent_path, f"trial_{h_params['trial_name']}_proctime.png") + os.path.join( + h_params["parent_path"], f"trial_{h_params['trial_name']}_proctime.png" + ) ) plt.figure(3) @@ -110,7 +113,9 @@ def plotter(results: dict): plt.xlabel("Resolutions") plt.ylabel("Mean time per MCMC step") plt.savefig( - os.path.join(parent_path, f"trial_{h_params['trial_name']}_persteptime.png") + os.path.join( + h_params["parent_path"], f"trial_{h_params['trial_name']}_persteptime.png" + ) ) @@ -278,19 +283,20 @@ def main(h_param_file, topology=True): elif sys.argv[2] == "topology": use_topology = True + with open(h_param_file, "r") as h_paramf: + h_params = yaml.safe_load(h_paramf) if sys.argv[3] != "skip_calc": main(h_param_file, use_topology) else: - with open(h_param_file, "r") as h_paramf: - h_params = yaml.safe_load(h_paramf) + with open( os.path.join(h_params["parent_path"], "nestor_output.yaml"), "r" ) as outf: results = yaml.safe_load(outf) - if len(list(results.keys())) > 0: - plotter(results) - else: - print("\nNone of the runs was successful...!") - print("Done...!\n\n") + if len(list(results.keys())) > 0: + plotter(results, h_params) + else: + print("\nNone of the runs was successful...!") + print("Done...!\n\n")