Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add scheduler plugin #4416

Draft
wants to merge 1 commit into
base: branch-24.12
Choose a base branch
from
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 32 additions & 0 deletions python/cugraph/cugraph/testing/mg_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,9 @@
from pprint import pformat
import time
from dask.distributed import wait, default_client
import logging
from distributed.diagnostics.plugin import WorkerPlugin, SchedulerPlugin
from distributed.scheduler import Scheduler
from dask import persist
from dask.distributed import Client
from dask.base import is_dask_collection
Expand All @@ -27,6 +30,32 @@
import numpy as np


class GracefullyRetireWorkers(WorkerPlugin):
def __init__(self, logger):
self.logger = logger
self.count = 0
self.key = None
self.state = 1

async def remove_worker(self, scheduler, worker: str, *, stimulus_id, **kwargs) :
print("a worker is leaving the cluster and state = ", self.state, " count = ", self.count, flush=True)
#wait(scheduler.retire_workers())
if self.state == -1:
self.logger.critical(" Worker %s left the cluster", worker)
if self.count == 0:
self.logger.critical(" An error occured: retiring all workers")
self.count += 1
await scheduler.retire_workers()

def setup(self, worker):
self.worker = worker

def transition(self, key, start, finish, *args, **kwargs):
if finish in ['error', 'erred']:
print("transition = ", finish)
self.state = -1


def start_dask_client(
protocol=None,
rmm_async=False,
Expand Down Expand Up @@ -157,6 +186,9 @@ def start_dask_client(
num_workers = len(dask_worker_devices.split(","))

client.wait_for_workers(num_workers)

s_plugin = GracefullyRetireWorkers(logging)
client.register_plugin(s_plugin)
# Add a reference to tempdir_object to the client to prevent it from
# being deleted when this function returns. This will be deleted in
# stop_dask_client()
Expand Down
Loading