diff --git a/hermespy/hardware_loop/hardware_loop.py b/hermespy/hardware_loop/hardware_loop.py index 0b4ddb3d..4b2226bf 100644 --- a/hermespy/hardware_loop/hardware_loop.py +++ b/hermespy/hardware_loop/hardware_loop.py @@ -7,10 +7,15 @@ from __future__ import annotations from abc import ABC, abstractmethod -from contextlib import ExitStack +from threading import Event, Thread +from contextlib import AbstractContextManager, ExitStack from os import path +from signal import signal, SIGINT +from types import TracebackType from typing import Any, Generic, List, Mapping, Sequence, Tuple, Type +from warnings import catch_warnings, simplefilter +import matplotlib as mpl import matplotlib.pyplot as plt import numpy as np from rich.progress import Progress, SpinnerColumn, TimeElapsedColumn @@ -320,6 +325,111 @@ def _update_plot(self, sample: HardwareLoopSample, visualization: VT) -> None: ... # pragma: no cover +class PlotThread(Thread): + """Thread for parallel plotting during hardware loop runtime.""" + + __plot: HardwareLoopPlot + __sample: HardwareLoopSample | None = None + __update_plot: Event + __alive: bool + __rc_params: mpl.RcParams + + def __init__(self, rc_params: mpl.RcParams, plot: HardwareLoopPlot, **kwargs) -> None: + """ + Args: + + rc_params (mpl.RcParams): Matplotlib style parameters. + plot (HardwareLoopPlot): Plot to be visualized by the thread. + \**kwargs: Additional keyword arguments to be passed to the base class. + """ + + # Initialize class attributes + self.__plot = plot + self.__sample = None + self.__update_plot = Event() + self.__alive = True + self.__rc_params = rc_params + + # Initialize base class + Thread.__init__(self, **kwargs) + + def run(self) -> None: + with mpl.rc_context(self.__rc_params), plt.ion(): + # Prepare the plot + with catch_warnings(): + simplefilter("ignore") + figure, _ = self.__plot.prepare_plot() + + while self.__alive: + if self.__update_plot.wait(0.1): + self.__plot.update_plot(self.__sample) + self.__update_plot.clear() + + # Close the plot + # This is required as to not confuse the matplotlib backend which thread to use + # in upcoming plots + plt.close(figure) + + def update_plot(self, sample: HardwareLoopSample) -> None: + """Update the plot with a new sample. + + Args: + + sample (HardwareLoopSample): Sample to be plotted. + """ + + self.__sample = sample + self.__update_plot.set() + + def stop(self) -> None: + self.__alive = False + + +class ThreadContextManager(AbstractContextManager): + """Context manager for managing threads. + + Entering the context manager starts the threads, exiting stops them. + """ + + __threads: List[PlotThread] # Threads managed by this context manager + + def __init__(self, threads: List[PlotThread]) -> None: + """ + Args: + + threads (List[PlotThread]): Threads to be managed by this context manager. + """ + + # Initialize base class + super().__init__() + + # Initialize class attributes + self.__threads = threads + + def __enter__(self) -> Any: + for thread in self.__threads: + if not thread.is_alive(): + thread.start() + + super().__enter__() + + def __exit__( + self, + __exc_type: Type[BaseException] | None, + __exc_value: BaseException | None, + __traceback: TracebackType | None, + ) -> bool | None: + # Stop the threads + for thread in self.__threads: + thread.stop() + + # Wait for the threads to finish + for thread in self.__threads: + thread.join(timeout=1.0) + + return super().__exit__(__exc_type, __exc_value, __traceback) + + class HardwareLoop( Serializable, Generic[PhysicalScenarioType, PDT], Pipeline[PhysicalScenarioType, PDT] ): @@ -344,6 +454,7 @@ class HardwareLoop( __evaluators: List[Evaluator] # Evaluators further processing drop information __plots: List[HardwareLoopPlot] __iteration_priority: IterationPriority + __interrupt_run: bool def __init__( self, @@ -388,6 +499,7 @@ def __init__( self.__dimensions = [] self.__evaluators = [] self.__plots = [] + self.__interrupt_run = False def new_dimension( self, dimension: str, sample_points: List[Any], *args: Tuple[Any] @@ -607,19 +719,25 @@ def __generate_sample( # Return sample return HardwareLoopSample(drop, evaluations, artifacts) + def __sigint_handler(self, signum: int, frame: Any) -> None: + """Signal handler for SIGINT.""" + + # Print a message + if self.console_mode is not ConsoleMode.SILENT: + self.console.log("Received SIGINT, stopping hardware loop", style="bright_red") + + self.__interrupt_run = True + def __run(self) -> None: """Internal run method executing drops""" - # Initialize plots - if self.plot_information: - with plt.ion() and self.style_context(): # pragma: no cover - for plot in self.__plots: - plot.prepare_plot() + # Reset variables + self.__interrupt_run = False - # Tile the generated figures - tile_figures(2, 4) + # Register sigint handler to this instance + signal(SIGINT, self.__sigint_handler) - # runtime = HardwareLoopRuntime(self.__devices, self.__dimensions, self.__evaluators, self.plot_information) + # Initialize the sample grid sample_grid = SampleGrid(self.__dimensions, self.__evaluators) # Print indicator that the simulation is starting @@ -656,15 +774,26 @@ def __run(self) -> None: total_progress = progress.add_task("[red]Progress", total=num_total_drops) with ExitStack() as stack: + # Initialize plots + plot_threads: List[PlotThread] = [] + if self.plot_information: + with self.style_context(): + rc_params = mpl.rcParams.copy() + plot_threads = [ + PlotThread(rc_params, plot, name=f"HardwareLoop-Plot-{p}") + for p, plot in enumerate(self.__plots) + ] + + # Tile the generated figures + tile_figures(2, 4) + + # Add all threads to the context stack + stack.enter_context(ThreadContextManager(plot_threads)) + # Add the progress bar to the context stack if self.console_mode == ConsoleMode.INTERACTIVE: stack.enter_context(progress) - # If the plot information is enabled, - # add the interactive plot to the context stack to enable live updates - if self.plot_information: - stack.enter_context(plt.ion()) - # Start counting the total number of completed drops total = 0 @@ -686,6 +815,10 @@ def __run(self) -> None: raise RuntimeError(f"Invalid iteration priority: {self.iteration_priority}") for indices in np.ndindex(index_grid): + # Abort if sigint is received + if self.__interrupt_run: + break + sample_index = indices[drop_selector] section_indices = indices[grid_selector] @@ -731,8 +864,8 @@ def __run(self) -> None: # Update plots if self.plot_information: - for plot in self.__plots: - plot.update_plot(loop_sample) + for thread in plot_threads: + thread.update_plot(loop_sample) except Exception as e: self._handle_exception(e, confirm=False)