Skip to content

Commit

Permalink
Merge branch '209-hardware-loop-multithreading' into 'master'
Browse files Browse the repository at this point in the history
Hardware Loop Multithread Plotting Support

Closes #209

See merge request barkhauseninstitut/wicon/hermespy!171
  • Loading branch information
adlerjan committed Feb 27, 2024
2 parents c1546e4 + 371bfe1 commit c953ba0
Showing 1 changed file with 149 additions and 16 deletions.
165 changes: 149 additions & 16 deletions hermespy/hardware_loop/hardware_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,15 @@

from __future__ import annotations
from abc import ABC, abstractmethod
from contextlib import ExitStack
from threading import Event, Thread
from contextlib import AbstractContextManager, ExitStack
from os import path
from signal import signal, SIGINT
from types import TracebackType
from typing import Any, Generic, List, Mapping, Sequence, Tuple, Type
from warnings import catch_warnings, simplefilter

import matplotlib as mpl
import matplotlib.pyplot as plt
import numpy as np
from rich.progress import Progress, SpinnerColumn, TimeElapsedColumn
Expand Down Expand Up @@ -320,6 +325,111 @@ def _update_plot(self, sample: HardwareLoopSample, visualization: VT) -> None:
... # pragma: no cover


class PlotThread(Thread):
"""Thread for parallel plotting during hardware loop runtime."""

__plot: HardwareLoopPlot
__sample: HardwareLoopSample | None = None
__update_plot: Event
__alive: bool
__rc_params: mpl.RcParams

def __init__(self, rc_params: mpl.RcParams, plot: HardwareLoopPlot, **kwargs) -> None:
"""
Args:
rc_params (mpl.RcParams): Matplotlib style parameters.
plot (HardwareLoopPlot): Plot to be visualized by the thread.
\**kwargs: Additional keyword arguments to be passed to the base class.
"""

# Initialize class attributes
self.__plot = plot
self.__sample = None
self.__update_plot = Event()
self.__alive = True
self.__rc_params = rc_params

# Initialize base class
Thread.__init__(self, **kwargs)

def run(self) -> None:
with mpl.rc_context(self.__rc_params), plt.ion():
# Prepare the plot
with catch_warnings():
simplefilter("ignore")
figure, _ = self.__plot.prepare_plot()

while self.__alive:
if self.__update_plot.wait(0.1):
self.__plot.update_plot(self.__sample)
self.__update_plot.clear()

# Close the plot
# This is required as to not confuse the matplotlib backend which thread to use
# in upcoming plots
plt.close(figure)

def update_plot(self, sample: HardwareLoopSample) -> None:
"""Update the plot with a new sample.
Args:
sample (HardwareLoopSample): Sample to be plotted.
"""

self.__sample = sample
self.__update_plot.set()

def stop(self) -> None:
self.__alive = False


class ThreadContextManager(AbstractContextManager):
"""Context manager for managing threads.
Entering the context manager starts the threads, exiting stops them.
"""

__threads: List[PlotThread] # Threads managed by this context manager

def __init__(self, threads: List[PlotThread]) -> None:
"""
Args:
threads (List[PlotThread]): Threads to be managed by this context manager.
"""

# Initialize base class
super().__init__()

# Initialize class attributes
self.__threads = threads

def __enter__(self) -> Any:
for thread in self.__threads:
if not thread.is_alive():
thread.start()

super().__enter__()

def __exit__(
self,
__exc_type: Type[BaseException] | None,
__exc_value: BaseException | None,
__traceback: TracebackType | None,
) -> bool | None:
# Stop the threads
for thread in self.__threads:
thread.stop()

# Wait for the threads to finish
for thread in self.__threads:
thread.join(timeout=1.0)

return super().__exit__(__exc_type, __exc_value, __traceback)


class HardwareLoop(
Serializable, Generic[PhysicalScenarioType, PDT], Pipeline[PhysicalScenarioType, PDT]
):
Expand All @@ -344,6 +454,7 @@ class HardwareLoop(
__evaluators: List[Evaluator] # Evaluators further processing drop information
__plots: List[HardwareLoopPlot]
__iteration_priority: IterationPriority
__interrupt_run: bool

def __init__(
self,
Expand Down Expand Up @@ -388,6 +499,7 @@ def __init__(
self.__dimensions = []
self.__evaluators = []
self.__plots = []
self.__interrupt_run = False

def new_dimension(
self, dimension: str, sample_points: List[Any], *args: Tuple[Any]
Expand Down Expand Up @@ -607,19 +719,25 @@ def __generate_sample(
# Return sample
return HardwareLoopSample(drop, evaluations, artifacts)

def __sigint_handler(self, signum: int, frame: Any) -> None:
"""Signal handler for SIGINT."""

# Print a message
if self.console_mode is not ConsoleMode.SILENT:
self.console.log("Received SIGINT, stopping hardware loop", style="bright_red")

self.__interrupt_run = True

def __run(self) -> None:
"""Internal run method executing drops"""

# Initialize plots
if self.plot_information:
with plt.ion() and self.style_context(): # pragma: no cover
for plot in self.__plots:
plot.prepare_plot()
# Reset variables
self.__interrupt_run = False

# Tile the generated figures
tile_figures(2, 4)
# Register sigint handler to this instance
signal(SIGINT, self.__sigint_handler)

# runtime = HardwareLoopRuntime(self.__devices, self.__dimensions, self.__evaluators, self.plot_information)
# Initialize the sample grid
sample_grid = SampleGrid(self.__dimensions, self.__evaluators)

# Print indicator that the simulation is starting
Expand Down Expand Up @@ -656,15 +774,26 @@ def __run(self) -> None:
total_progress = progress.add_task("[red]Progress", total=num_total_drops)

with ExitStack() as stack:
# Initialize plots
plot_threads: List[PlotThread] = []
if self.plot_information:
with self.style_context():
rc_params = mpl.rcParams.copy()
plot_threads = [
PlotThread(rc_params, plot, name=f"HardwareLoop-Plot-{p}")
for p, plot in enumerate(self.__plots)
]

# Tile the generated figures
tile_figures(2, 4)

# Add all threads to the context stack
stack.enter_context(ThreadContextManager(plot_threads))

# Add the progress bar to the context stack
if self.console_mode == ConsoleMode.INTERACTIVE:
stack.enter_context(progress)

# If the plot information is enabled,
# add the interactive plot to the context stack to enable live updates
if self.plot_information:
stack.enter_context(plt.ion())

# Start counting the total number of completed drops
total = 0

Expand All @@ -686,6 +815,10 @@ def __run(self) -> None:
raise RuntimeError(f"Invalid iteration priority: {self.iteration_priority}")

for indices in np.ndindex(index_grid):
# Abort if sigint is received
if self.__interrupt_run:
break

sample_index = indices[drop_selector]
section_indices = indices[grid_selector]

Expand Down Expand Up @@ -731,8 +864,8 @@ def __run(self) -> None:

# Update plots
if self.plot_information:
for plot in self.__plots:
plot.update_plot(loop_sample)
for thread in plot_threads:
thread.update_plot(loop_sample)

except Exception as e:
self._handle_exception(e, confirm=False)
Expand Down

0 comments on commit c953ba0

Please sign in to comment.