diff --git a/python/cugraph/cugraph/testing/mg_utils.py b/python/cugraph/cugraph/testing/mg_utils.py index 32854652f05..980aec7dfbe 100644 --- a/python/cugraph/cugraph/testing/mg_utils.py +++ b/python/cugraph/cugraph/testing/mg_utils.py @@ -16,6 +16,9 @@ from pprint import pformat import time from dask.distributed import wait, default_client +import logging +from distributed.diagnostics.plugin import WorkerPlugin, SchedulerPlugin +from distributed.scheduler import Scheduler from dask import persist from dask.distributed import Client from dask.base import is_dask_collection @@ -27,6 +30,32 @@ import numpy as np +class GracefullyRetireWorkers(WorkerPlugin): + def __init__(self, logger): + self.logger = logger + self.count = 0 + self.key = None + self.state = 1 + + async def remove_worker(self, scheduler, worker: str, *, stimulus_id, **kwargs) : + print("a worker is leaving the cluster and state = ", self.state, " count = ", self.count, flush=True) + #wait(scheduler.retire_workers()) + if self.state == -1: + self.logger.critical(" Worker %s left the cluster", worker) + if self.count == 0: + self.logger.critical(" An error occured: retiring all workers") + self.count += 1 + await scheduler.retire_workers() + + def setup(self, worker): + self.worker = worker + + def transition(self, key, start, finish, *args, **kwargs): + if finish in ['error', 'erred']: + print("transition = ", finish) + self.state = -1 + + def start_dask_client( protocol=None, rmm_async=False, @@ -157,6 +186,9 @@ def start_dask_client( num_workers = len(dask_worker_devices.split(",")) client.wait_for_workers(num_workers) + + s_plugin = GracefullyRetireWorkers(logging) + client.register_plugin(s_plugin) # Add a reference to tempdir_object to the client to prevent it from # being deleted when this function returns. This will be deleted in # stop_dask_client()