diff --git a/ci/test_python.sh b/ci/test_python.sh index d6e58cc8..6ad39c46 100755 --- a/ci/test_python.sh +++ b/ci/test_python.sh @@ -42,20 +42,27 @@ rapids-logger "Python Async Tests" # run_py_tests_async PROGRESS_MODE ENABLE_DELAYED_SUBMISSION ENABLE_PYTHON_FUTURE SKIP run_py_tests_async thread 0 0 0 run_py_tests_async thread 1 1 0 +run_py_tests_async blocking 0 0 0 rapids-logger "Python Benchmarks" # run_py_benchmark BACKEND PROGRESS_MODE ASYNCIO_WAIT ENABLE_DELAYED_SUBMISSION ENABLE_PYTHON_FUTURE NBUFFERS SLOW run_py_benchmark ucxx-core thread 0 0 0 1 0 run_py_benchmark ucxx-core thread 1 0 0 1 0 -for nbuf in 1 8; do - if [[ ! $RAPIDS_CUDA_VERSION =~ 11.2.* ]]; then - # run_py_benchmark BACKEND PROGRESS_MODE ASYNCIO_WAIT ENABLE_DELAYED_SUBMISSION ENABLE_PYTHON_FUTURE NBUFFERS SLOW - run_py_benchmark ucxx-async thread 0 0 0 ${nbuf} 0 - run_py_benchmark ucxx-async thread 0 0 1 ${nbuf} 0 - run_py_benchmark ucxx-async thread 0 1 0 ${nbuf} 0 - run_py_benchmark ucxx-async thread 0 1 1 ${nbuf} 0 - fi +for progress_mode in "blocking" "thread"; do + for nbuf in 1 8; do + if [[ ! $RAPIDS_CUDA_VERSION =~ 11.2.* ]]; then + # run_py_benchmark BACKEND PROGRESS_MODE ASYNCIO_WAIT ENABLE_DELAYED_SUBMISSION ENABLE_PYTHON_FUTURE NBUFFERS SLOW + run_py_benchmark ucxx-async ${progress_mode} 0 0 0 ${nbuf} 0 + run_py_benchmark ucxx-async ${progress_mode} 0 0 1 ${nbuf} 0 + if [[ ${progress_mode} != "blocking" ]]; then + # Delayed submission isn't support by blocking progress mode + # run_py_benchmark BACKEND PROGRESS_MODE ASYNCIO_WAIT ENABLE_DELAYED_SUBMISSION ENABLE_PYTHON_FUTURE NBUFFERS SLOW + run_py_benchmark ucxx-async ${progress_mode} 0 1 0 ${nbuf} 0 + run_py_benchmark ucxx-async ${progress_mode} 0 1 1 ${nbuf} 0 + fi + fi + done done rapids-logger "C++ future -> Python future notifier example" diff --git a/ci/test_python_distributed.sh b/ci/test_python_distributed.sh index 9ecbf3e1..d7b7e402 100755 --- a/ci/test_python_distributed.sh +++ b/ci/test_python_distributed.sh @@ -37,6 +37,7 @@ print_ucx_config rapids-logger "Run distributed-ucxx tests with conda package" # run_distributed_ucxx_tests PROGRESS_MODE ENABLE_DELAYED_SUBMISSION ENABLE_PYTHON_FUTURE +run_distributed_ucxx_tests blocking 0 0 run_distributed_ucxx_tests polling 0 0 run_distributed_ucxx_tests thread 0 0 run_distributed_ucxx_tests thread 0 1 @@ -46,6 +47,7 @@ run_distributed_ucxx_tests thread 1 1 install_distributed_dev_mode # run_distributed_ucxx_tests_internal PROGRESS_MODE ENABLE_DELAYED_SUBMISSION ENABLE_PYTHON_FUTURE +run_distributed_ucxx_tests_internal blocking 0 0 run_distributed_ucxx_tests_internal polling 0 0 run_distributed_ucxx_tests_internal thread 0 0 run_distributed_ucxx_tests_internal thread 0 1 diff --git a/cpp/include/ucxx/worker.h b/cpp/include/ucxx/worker.h index 73eb0968..a5fb3c8d 100644 --- a/cpp/include/ucxx/worker.h +++ b/cpp/include/ucxx/worker.h @@ -253,6 +253,23 @@ class Worker : public Component { */ void initBlockingProgressMode(); + /** + * @brief Get the epoll file descriptor associated with the worker. + * + * Get the epoll file descriptor associated with the worker when running in blocking mode. + * The worker only has an associated epoll file descriptor after + * `initBlockingProgressMode()` is executed. + * + * The file descriptor is destroyed as part of the `ucxx::Worker` destructor, thus any + * reference to it shall not be used after that. + * + * @throws std::runtime_error if `initBlockingProgressMode()` was not executed to run the + * worker in blocking progress mode. + * + * @returns the file descriptor. + */ + int getEpollFileDescriptor(); + /** * @brief Arm the UCP worker. * diff --git a/cpp/src/worker.cpp b/cpp/src/worker.cpp index e227e641..614f5521 100644 --- a/cpp/src/worker.cpp +++ b/cpp/src/worker.cpp @@ -220,6 +220,14 @@ void Worker::initBlockingProgressMode() } } +int Worker::getEpollFileDescriptor() +{ + if (_epollFileDescriptor == 0) + throw std::runtime_error("Worker not running in blocking progress mode"); + + return _epollFileDescriptor; +} + bool Worker::arm() { ucs_status_t status = ucp_worker_arm(_handle); diff --git a/python/distributed-ucxx/distributed_ucxx/tests/test_ucxx.py b/python/distributed-ucxx/distributed_ucxx/tests/test_ucxx.py index 1e19993b..409f957a 100644 --- a/python/distributed-ucxx/distributed_ucxx/tests/test_ucxx.py +++ b/python/distributed-ucxx/distributed_ucxx/tests/test_ucxx.py @@ -411,7 +411,10 @@ async def test_comm_closed_on_read_error(): with pytest.raises((asyncio.TimeoutError, CommClosedError)): await wait_for(reader.read(), 0.01) + await writer.close() + assert reader.closed() + assert writer.closed() @pytest.mark.flaky( diff --git a/python/distributed-ucxx/distributed_ucxx/ucxx.py b/python/distributed-ucxx/distributed_ucxx/ucxx.py index 881dc160..1f5fc1df 100644 --- a/python/distributed-ucxx/distributed_ucxx/ucxx.py +++ b/python/distributed-ucxx/distributed_ucxx/ucxx.py @@ -14,7 +14,6 @@ import struct import weakref from collections.abc import Awaitable, Callable, Collection -from threading import Lock from typing import TYPE_CHECKING, Any from unittest.mock import patch @@ -50,13 +49,6 @@ pre_existing_cuda_context = False cuda_context_created = False multi_buffer = None -# Lock protecting access to _resources dict -_resources_lock = Lock() -# Mapping from UCXX context handles to sets of registered dask resource IDs -# Used to track when there are no more users of the context, at which point -# its progress task and notification thread can be shut down. -# See _register_dask_resource and _deregister_dask_resource. -_resources = dict() _warning_suffix = ( @@ -103,13 +95,13 @@ def make_register(): count = itertools.count() def register() -> int: - """Register a Dask resource with the resource tracker. + """Register a Dask resource with the UCXX context. - Generate a unique ID for the resource and register it with the resource - tracker. The resource ID is later used to deregister the resource from - the tracker calling `_deregister_dask_resource(resource_id)`, which - stops the notifier thread and progress tasks when no more UCXX resources - are alive. + Register a Dask resource with the UCXX context and keep track of it with the + use of a unique ID for the resource. The resource ID is later used to + deregister the resource from the UCXX context calling + `_deregister_dask_resource(resource_id)`, which stops the notifier thread + and progress tasks when no more UCXX resources are alive. Returns ------- @@ -118,13 +110,9 @@ def register() -> int: `_deregister_dask_resource` during stop/destruction of the resource. """ ctx = ucxx.core._get_ctx() - handle = ctx.context.handle - with _resources_lock: - if handle not in _resources: - _resources[handle] = set() - + with ctx._dask_resources_lock: resource_id = next(count) - _resources[handle].add(resource_id) + ctx._dask_resources.add(resource_id) ctx.start_notifier_thread() ctx.continuous_ucx_progress() return resource_id @@ -138,11 +126,11 @@ def register() -> int: def _deregister_dask_resource(resource_id): - """Deregister a Dask resource from the resource tracker. + """Deregister a Dask resource with the UCXX context. - Deregister a Dask resource from the resource tracker with given ID, and if - no resources remain after deregistration, stop the notifier thread and - progress tasks. + Deregister a Dask resource from the UCXX context with given ID, and if no + resources remain after deregistration, stop the notifier thread and progress + tasks. Parameters ---------- @@ -156,22 +144,40 @@ def _deregister_dask_resource(resource_id): return ctx = ucxx.core._get_ctx() - handle = ctx.context.handle # Check if the attribute exists first, in tests the UCXX context may have # been reset before some resources are deregistered. - with _resources_lock: - try: - _resources[handle].remove(resource_id) - except KeyError: - pass + if hasattr(ctx, "_dask_resources_lock"): + with ctx._dask_resources_lock: + try: + ctx._dask_resources.remove(resource_id) + except KeyError: + pass + + # Stop notifier thread and progress tasks if no Dask resources using + # UCXX communicators are running anymore. + if len(ctx._dask_resources) == 0: + ctx.stop_notifier_thread() + ctx.progress_tasks.clear() - # Stop notifier thread and progress tasks if no Dask resources using - # UCXX communicators are running anymore. - if handle in _resources and len(_resources[handle]) == 0: - ctx.stop_notifier_thread() - ctx.progress_tasks.clear() - del _resources[handle] + +def _allocate_dask_resources_tracker() -> None: + """Allocate Dask resources tracker. + + Allocate a Dask resources tracker in the UCXX context. This is useful to + track Distributed communicators so that progress and notifier threads can + be cleanly stopped when no UCXX communicators are alive anymore. + """ + ctx = ucxx.core._get_ctx() + if not hasattr(ctx, "_dask_resources"): + # TODO: Move the `Lock` to a file/module-level variable for true + # lock-safety. The approach implemented below could cause race + # conditions if this function is called simultaneously by multiple + # threads. + from threading import Lock + + ctx._dask_resources = set() + ctx._dask_resources_lock = Lock() def init_once(): @@ -181,6 +187,11 @@ def init_once(): global multi_buffer if ucxx is not None: + # Ensure reallocation of Dask resources tracker if the UCXX context was + # reset since the previous `init_once()` call. This may happen in tests, + # where the `ucxx_loop` fixture will reset the context after each test. + _allocate_dask_resources_tracker() + return # remove/process dask.ucx flags for valid ucx options @@ -243,6 +254,7 @@ def init_once(): # environment, so the user's external environment can safely # override things here. ucxx.init(options=ucx_config, env_takes_precedence=True) + _allocate_dask_resources_tracker() pool_size_str = dask.config.get("distributed.rmm.pool-size") diff --git a/python/ucxx/ucxx/_lib/libucxx.pyx b/python/ucxx/ucxx/_lib/libucxx.pyx index 81cbfd3b..2455a781 100644 --- a/python/ucxx/ucxx/_lib/libucxx.pyx +++ b/python/ucxx/ucxx/_lib/libucxx.pyx @@ -617,6 +617,23 @@ cdef class UCXWorker(): with nogil: self._worker.get().initBlockingProgressMode() + def arm(self) -> bool: + cdef bint armed + + with nogil: + armed = self._worker.get().arm() + + return armed + + @property + def epoll_file_descriptor(self) -> int: + cdef int epoll_file_descriptor = 0 + + with nogil: + epoll_file_descriptor = self._worker.get().getEpollFileDescriptor() + + return epoll_file_descriptor + def progress(self) -> None: with nogil: self._worker.get().progress() diff --git a/python/ucxx/ucxx/_lib/ucxx_api.pxd b/python/ucxx/ucxx/_lib/ucxx_api.pxd index 28edc968..9c30a4c3 100644 --- a/python/ucxx/ucxx/_lib/ucxx_api.pxd +++ b/python/ucxx/ucxx/_lib/ucxx_api.pxd @@ -229,6 +229,8 @@ cdef extern from "" namespace "ucxx" nogil: uint16_t port, ucp_listener_conn_callback_t callback, void *callback_args ) except +raise_py_error void initBlockingProgressMode() except +raise_py_error + int getEpollFileDescriptor() + bint arm() except +raise_py_error void progress() bint progressOnce() void progressWorkerEvent(int epoll_timeout) diff --git a/python/ucxx/ucxx/_lib_async/application_context.py b/python/ucxx/ucxx/_lib_async/application_context.py index e91b91e9..4a488309 100644 --- a/python/ucxx/ucxx/_lib_async/application_context.py +++ b/python/ucxx/ucxx/_lib_async/application_context.py @@ -13,7 +13,7 @@ from ucxx.exceptions import UCXMessageTruncatedError from ucxx.types import Tag -from .continuous_ucx_progress import PollingMode, ThreadMode +from .continuous_ucx_progress import BlockingMode, PollingMode, ThreadMode from .endpoint import Endpoint from .exchange_peer_info import exchange_peer_info from .listener import ActiveClients, Listener, _listener_handler @@ -56,8 +56,8 @@ def __init__( self.context = ucx_api.UCXContext(config_dict) self.worker = ucx_api.UCXWorker( self.context, - enable_delayed_submission=self._enable_delayed_submission, - enable_python_future=self._enable_python_future, + enable_delayed_submission=self.enable_delayed_submission, + enable_python_future=self.enable_python_future, ) self.start_notifier_thread() @@ -82,12 +82,12 @@ def progress_mode(self, progress_mode): else: progress_mode = "thread" - valid_progress_modes = ["polling", "thread", "thread-polling"] + valid_progress_modes = ["blocking", "polling", "thread", "thread-polling"] if not isinstance(progress_mode, str) or not any( progress_mode == m for m in valid_progress_modes ): raise ValueError( - f"Unknown progress mode {progress_mode}, valid modes are: " + f"Unknown progress mode '{progress_mode}', valid modes are: " "'blocking', 'polling', 'thread' or 'thread-polling'" ) @@ -121,8 +121,9 @@ def enable_delayed_submission(self, enable_delayed_submission): and explicit_enable_delayed_submission ): raise ValueError( - f"Delayed submission requested, but {self.progress_mode} does not " - "support it, 'thread' or 'thread-polling' progress mode required." + f"Delayed submission requested, but '{self.progress_mode}' does " + "not support it, 'thread' or 'thread-polling' progress mode " + "required." ) self._enable_delayed_submission = explicit_enable_delayed_submission @@ -153,7 +154,7 @@ def enable_python_future(self, enable_python_future): and explicit_enable_python_future ): logger.warning( - f"Notifier thread requested, but {self.progress_mode} does not " + f"Notifier thread requested, but '{self.progress_mode}' does not " "support it, using Python wait_yield()." ) explicit_enable_python_future = False @@ -464,6 +465,8 @@ def continuous_ucx_progress(self, event_loop=None): task = ThreadMode(self.worker, loop, polling_mode=True) elif self.progress_mode == "polling": task = PollingMode(self.worker, loop) + elif self.progress_mode == "blocking": + task = BlockingMode(self.worker, loop) self.progress_tasks[loop] = task diff --git a/python/ucxx/ucxx/_lib_async/continuous_ucx_progress.py b/python/ucxx/ucxx/_lib_async/continuous_ucx_progress.py index 9763e222..c959f2f4 100644 --- a/python/ucxx/ucxx/_lib_async/continuous_ucx_progress.py +++ b/python/ucxx/ucxx/_lib_async/continuous_ucx_progress.py @@ -3,6 +3,12 @@ import asyncio +import socket +import time +import weakref +from functools import partial + +from ucxx._lib.libucxx import UCXWorker class ProgressTask(object): @@ -24,12 +30,20 @@ def __init__(self, worker, event_loop): self.event_loop = event_loop self.asyncio_task = None - def __del__(self): - if self.asyncio_task is not None: - # FIXME: This does not work, the cancellation must be awaited. - # Running with polling mode will always cause - # `Task was destroyed but it is pending!` errors at ucxx.reset(). - self.asyncio_task.cancel() + event_loop_close_original = self.event_loop.close + + def _event_loop_close(event_loop_close_original, *args, **kwargs): + if not self.event_loop.is_closed() and self.asyncio_task is not None: + try: + self.asyncio_task.cancel() + self.event_loop.run_until_complete(self.asyncio_task) + except asyncio.exceptions.CancelledError: + pass + finally: + self.asyncio_task = None + event_loop_close_original(*args, **kwargs) + + self.event_loop.close = partial(_event_loop_close, event_loop_close_original) # Hash and equality is based on the event loop def __hash__(self): @@ -70,3 +84,124 @@ async def _progress_task(self): worker.progress() # Give other co-routines a chance to run. await asyncio.sleep(0) + + +class BlockingMode(ProgressTask): + def __init__( + self, + worker: UCXWorker, + event_loop: asyncio.AbstractEventLoop, + progress_timeout: float = 1.0, + ): + """Progress the UCX worker in blocking mode. + + The blocking progress mode ensure the worker is progresses whenever the + UCX worker reports an event on its epoll file descriptor. In certain + circumstances the epoll file descriptor may not present an event, thus + the `progress_timeout` will ensure the UCX worker is progressed to + prevent a potential deadlock. + + Parameters + ---------- + worker: UCXWorker + Worker object from the UCXX Cython API to progress. + event_loop: asyncio.AbstractEventLoop + Asynchronous event loop where to schedule async tasks. + progress_timeout: float + The timeout to sleep until calling checking again whether the worker should + be progressed. + """ + super().__init__(worker, event_loop) + self._progress_timeout = progress_timeout + + self.worker.init_blocking_progress_mode() + + # Creating a job that is ready straight away but with low priority. + # Calling `await self.event_loop.sock_recv(self.rsock, 1)` will + # return when all non-IO tasks are finished. + # See . + self.rsock, wsock = socket.socketpair() + self.rsock.setblocking(0) + wsock.setblocking(0) + wsock.close() + + epoll_fd = self.worker.epoll_file_descriptor + + # Bind an asyncio reader to a UCX epoll file descriptor + event_loop.add_reader(epoll_fd, self._fd_reader_callback) + + # Remove the reader and close socket on finalization + weakref.finalize(self, event_loop.remove_reader, epoll_fd) + weakref.finalize(self, self.rsock.close) + + self.blocking_asyncio_task = None + self.last_progress_time = time.monotonic() - self._progress_timeout + self.asyncio_task = event_loop.create_task(self._progress_with_timeout()) + + def _fd_reader_callback(self): + """Schedule new progress task upon worker event. + + Schedule new progress task when a new event occurs in the worker's epoll file + descriptor. + """ + self.worker.progress() + + # Notice, we can safely overwrite `self.blocking_asyncio_task` + # since previous arm task is finished by now. + assert self.blocking_asyncio_task is None or self.blocking_asyncio_task.done() + self.blocking_asyncio_task = self.event_loop.create_task(self._arm_worker()) + + async def _arm_worker(self): + """Progress the worker and rearm. + + Progress and rearm the worker to watch for new events on its epoll file + descriptor. + """ + # When arming the worker, the following must be true: + # - No more progress in UCX (see doc of ucp_worker_arm()) + # - All asyncio tasks that aren't waiting on UCX must be executed + # so that the asyncio's next state is epoll wait. + # See + while True: + self.last_progress_time = time.monotonic() + self.worker.progress() + + # This IO task returns when all non-IO tasks are finished. + # Notice, we do NOT hold a reference to `worker` while waiting. + await self.event_loop.sock_recv(self.rsock, 1) + + if self.worker.arm(): + # At this point we know that asyncio's next state is + # epoll wait. + break + + async def _progress_with_timeout(self): + """Protect worker from never progressing again. + + To ensure the worker progresses if no events are raised and the asyncio loop + getting stuck we must ensure the worker is progressed every so often. This + method ensures the worker is progressed independent of what the epoll file + descriptor does if longer than `self._progress_timeout` has elapsed since + last check, thus preventing a deadlock. + """ + while True: + worker = self.worker + if worker is None: + return + if time.monotonic() > self.last_progress_time + self._progress_timeout: + self.last_progress_time = time.monotonic() + + # Cancel `_arm_worker` task if available. `loop.sock_recv` does not + # seem to respect timeout with `asyncio.wait_for`, thus we cancel + # it here instead. It will get recreated after a new event on + # `worker.epoll_file_descriptor`. + if self.blocking_asyncio_task is not None: + self.blocking_asyncio_task.cancel() + try: + await self.blocking_asyncio_task + except asyncio.exceptions.CancelledError: + pass + + worker.progress() + # Give other co-routines a chance to run. + await asyncio.sleep(self._progress_timeout) diff --git a/python/ucxx/ucxx/_lib_async/tests/test_from_worker_address_error.py b/python/ucxx/ucxx/_lib_async/tests/test_from_worker_address_error.py index 35d580f8..5aebedf8 100644 --- a/python/ucxx/ucxx/_lib_async/tests/test_from_worker_address_error.py +++ b/python/ucxx/ucxx/_lib_async/tests/test_from_worker_address_error.py @@ -62,6 +62,11 @@ async def run(): # "Endpoint timeout" after UCX_UD_TIMEOUT seconds have passed. # We need to keep progressing ucxx until timeout is raised. ep = await ucxx.create_endpoint_from_worker_address(remote_address) + while ep.alive: + await asyncio.sleep(0) + if not ucxx.core._get_ctx().progress_mode.startswith("thread"): + ucxx.progress() + ep._ep.raise_on_error() else: # Create endpoint to remote worker, and: # diff --git a/python/ucxx/ucxx/benchmarks/send_recv.py b/python/ucxx/ucxx/benchmarks/send_recv.py index cd06abe7..ba07e8f8 100644 --- a/python/ucxx/ucxx/benchmarks/send_recv.py +++ b/python/ucxx/ucxx/benchmarks/send_recv.py @@ -294,8 +294,8 @@ def parse_args(): parser.add_argument( "--progress-mode", default="thread", - help="Progress mode for the UCP worker. Valid options are: " - "'thread' (default) and 'blocking'.", + help="Progress mode for the UCP worker. Valid options are: 'blocking, " + "'polling', 'thread' and 'thread-polling. (Default: 'thread')'", type=str, ) parser.add_argument( @@ -350,8 +350,6 @@ def parse_args(): if args.progress_mode not in ["blocking", "polling", "thread", "thread-polling"]: raise RuntimeError(f"Invalid `--progress-mode`: '{args.progress_mode}'") - if args.progress_mode == "blocking" and args.backend == "ucxx-async": - raise RuntimeError("Blocking progress mode not supported for ucxx-async yet") if args.asyncio_wait and not args.progress_mode.startswith("thread"): raise RuntimeError( "`--asyncio-wait` requires `--progress-mode=thread` or "