Skip to content

Commit

Permalink
Add blocking progress mode to Python async (#116)
Browse files Browse the repository at this point in the history
Implements the blocking progress mode (UCX-Py default), which was still not implemented in UCXX.

Authors:
  - Peter Andreas Entschev (https://github.com/pentschev)

Approvers:
  - Lawrence Mitchell (https://github.com/wence-)
  - AJ Schmidt (https://github.com/ajschmidt8)
  - Ray Douglass (https://github.com/raydouglass)

URL: #116
  • Loading branch information
pentschev authored Oct 22, 2024
1 parent cfe9008 commit 122d2f4
Show file tree
Hide file tree
Showing 12 changed files with 271 additions and 62 deletions.
23 changes: 15 additions & 8 deletions ci/test_python.sh
Original file line number Diff line number Diff line change
Expand Up @@ -42,20 +42,27 @@ rapids-logger "Python Async Tests"
# run_py_tests_async PROGRESS_MODE ENABLE_DELAYED_SUBMISSION ENABLE_PYTHON_FUTURE SKIP
run_py_tests_async thread 0 0 0
run_py_tests_async thread 1 1 0
run_py_tests_async blocking 0 0 0

rapids-logger "Python Benchmarks"
# run_py_benchmark BACKEND PROGRESS_MODE ASYNCIO_WAIT ENABLE_DELAYED_SUBMISSION ENABLE_PYTHON_FUTURE NBUFFERS SLOW
run_py_benchmark ucxx-core thread 0 0 0 1 0
run_py_benchmark ucxx-core thread 1 0 0 1 0

for nbuf in 1 8; do
if [[ ! $RAPIDS_CUDA_VERSION =~ 11.2.* ]]; then
# run_py_benchmark BACKEND PROGRESS_MODE ASYNCIO_WAIT ENABLE_DELAYED_SUBMISSION ENABLE_PYTHON_FUTURE NBUFFERS SLOW
run_py_benchmark ucxx-async thread 0 0 0 ${nbuf} 0
run_py_benchmark ucxx-async thread 0 0 1 ${nbuf} 0
run_py_benchmark ucxx-async thread 0 1 0 ${nbuf} 0
run_py_benchmark ucxx-async thread 0 1 1 ${nbuf} 0
fi
for progress_mode in "blocking" "thread"; do
for nbuf in 1 8; do
if [[ ! $RAPIDS_CUDA_VERSION =~ 11.2.* ]]; then
# run_py_benchmark BACKEND PROGRESS_MODE ASYNCIO_WAIT ENABLE_DELAYED_SUBMISSION ENABLE_PYTHON_FUTURE NBUFFERS SLOW
run_py_benchmark ucxx-async ${progress_mode} 0 0 0 ${nbuf} 0
run_py_benchmark ucxx-async ${progress_mode} 0 0 1 ${nbuf} 0
if [[ ${progress_mode} != "blocking" ]]; then
# Delayed submission isn't support by blocking progress mode
# run_py_benchmark BACKEND PROGRESS_MODE ASYNCIO_WAIT ENABLE_DELAYED_SUBMISSION ENABLE_PYTHON_FUTURE NBUFFERS SLOW
run_py_benchmark ucxx-async ${progress_mode} 0 1 0 ${nbuf} 0
run_py_benchmark ucxx-async ${progress_mode} 0 1 1 ${nbuf} 0
fi
fi
done
done

rapids-logger "C++ future -> Python future notifier example"
Expand Down
2 changes: 2 additions & 0 deletions ci/test_python_distributed.sh
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ print_ucx_config

rapids-logger "Run distributed-ucxx tests with conda package"
# run_distributed_ucxx_tests PROGRESS_MODE ENABLE_DELAYED_SUBMISSION ENABLE_PYTHON_FUTURE
run_distributed_ucxx_tests blocking 0 0
run_distributed_ucxx_tests polling 0 0
run_distributed_ucxx_tests thread 0 0
run_distributed_ucxx_tests thread 0 1
Expand All @@ -46,6 +47,7 @@ run_distributed_ucxx_tests thread 1 1
install_distributed_dev_mode

# run_distributed_ucxx_tests_internal PROGRESS_MODE ENABLE_DELAYED_SUBMISSION ENABLE_PYTHON_FUTURE
run_distributed_ucxx_tests_internal blocking 0 0
run_distributed_ucxx_tests_internal polling 0 0
run_distributed_ucxx_tests_internal thread 0 0
run_distributed_ucxx_tests_internal thread 0 1
Expand Down
17 changes: 17 additions & 0 deletions cpp/include/ucxx/worker.h
Original file line number Diff line number Diff line change
Expand Up @@ -253,6 +253,23 @@ class Worker : public Component {
*/
void initBlockingProgressMode();

/**
* @brief Get the epoll file descriptor associated with the worker.
*
* Get the epoll file descriptor associated with the worker when running in blocking mode.
* The worker only has an associated epoll file descriptor after
* `initBlockingProgressMode()` is executed.
*
* The file descriptor is destroyed as part of the `ucxx::Worker` destructor, thus any
* reference to it shall not be used after that.
*
* @throws std::runtime_error if `initBlockingProgressMode()` was not executed to run the
* worker in blocking progress mode.
*
* @returns the file descriptor.
*/
int getEpollFileDescriptor();

/**
* @brief Arm the UCP worker.
*
Expand Down
8 changes: 8 additions & 0 deletions cpp/src/worker.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -220,6 +220,14 @@ void Worker::initBlockingProgressMode()
}
}

int Worker::getEpollFileDescriptor()
{
if (_epollFileDescriptor == 0)
throw std::runtime_error("Worker not running in blocking progress mode");

return _epollFileDescriptor;
}

bool Worker::arm()
{
ucs_status_t status = ucp_worker_arm(_handle);
Expand Down
3 changes: 3 additions & 0 deletions python/distributed-ucxx/distributed_ucxx/tests/test_ucxx.py
Original file line number Diff line number Diff line change
Expand Up @@ -411,7 +411,10 @@ async def test_comm_closed_on_read_error():
with pytest.raises((asyncio.TimeoutError, CommClosedError)):
await wait_for(reader.read(), 0.01)

await writer.close()

assert reader.closed()
assert writer.closed()


@pytest.mark.flaky(
Expand Down
84 changes: 48 additions & 36 deletions python/distributed-ucxx/distributed_ucxx/ucxx.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
import struct
import weakref
from collections.abc import Awaitable, Callable, Collection
from threading import Lock
from typing import TYPE_CHECKING, Any
from unittest.mock import patch

Expand Down Expand Up @@ -50,13 +49,6 @@
pre_existing_cuda_context = False
cuda_context_created = False
multi_buffer = None
# Lock protecting access to _resources dict
_resources_lock = Lock()
# Mapping from UCXX context handles to sets of registered dask resource IDs
# Used to track when there are no more users of the context, at which point
# its progress task and notification thread can be shut down.
# See _register_dask_resource and _deregister_dask_resource.
_resources = dict()


_warning_suffix = (
Expand Down Expand Up @@ -103,13 +95,13 @@ def make_register():
count = itertools.count()

def register() -> int:
"""Register a Dask resource with the resource tracker.
"""Register a Dask resource with the UCXX context.
Generate a unique ID for the resource and register it with the resource
tracker. The resource ID is later used to deregister the resource from
the tracker calling `_deregister_dask_resource(resource_id)`, which
stops the notifier thread and progress tasks when no more UCXX resources
are alive.
Register a Dask resource with the UCXX context and keep track of it with the
use of a unique ID for the resource. The resource ID is later used to
deregister the resource from the UCXX context calling
`_deregister_dask_resource(resource_id)`, which stops the notifier thread
and progress tasks when no more UCXX resources are alive.
Returns
-------
Expand All @@ -118,13 +110,9 @@ def register() -> int:
`_deregister_dask_resource` during stop/destruction of the resource.
"""
ctx = ucxx.core._get_ctx()
handle = ctx.context.handle
with _resources_lock:
if handle not in _resources:
_resources[handle] = set()

with ctx._dask_resources_lock:
resource_id = next(count)
_resources[handle].add(resource_id)
ctx._dask_resources.add(resource_id)
ctx.start_notifier_thread()
ctx.continuous_ucx_progress()
return resource_id
Expand All @@ -138,11 +126,11 @@ def register() -> int:


def _deregister_dask_resource(resource_id):
"""Deregister a Dask resource from the resource tracker.
"""Deregister a Dask resource with the UCXX context.
Deregister a Dask resource from the resource tracker with given ID, and if
no resources remain after deregistration, stop the notifier thread and
progress tasks.
Deregister a Dask resource from the UCXX context with given ID, and if no
resources remain after deregistration, stop the notifier thread and progress
tasks.
Parameters
----------
Expand All @@ -156,22 +144,40 @@ def _deregister_dask_resource(resource_id):
return

ctx = ucxx.core._get_ctx()
handle = ctx.context.handle

# Check if the attribute exists first, in tests the UCXX context may have
# been reset before some resources are deregistered.
with _resources_lock:
try:
_resources[handle].remove(resource_id)
except KeyError:
pass
if hasattr(ctx, "_dask_resources_lock"):
with ctx._dask_resources_lock:
try:
ctx._dask_resources.remove(resource_id)
except KeyError:
pass

# Stop notifier thread and progress tasks if no Dask resources using
# UCXX communicators are running anymore.
if len(ctx._dask_resources) == 0:
ctx.stop_notifier_thread()
ctx.progress_tasks.clear()

# Stop notifier thread and progress tasks if no Dask resources using
# UCXX communicators are running anymore.
if handle in _resources and len(_resources[handle]) == 0:
ctx.stop_notifier_thread()
ctx.progress_tasks.clear()
del _resources[handle]

def _allocate_dask_resources_tracker() -> None:
"""Allocate Dask resources tracker.
Allocate a Dask resources tracker in the UCXX context. This is useful to
track Distributed communicators so that progress and notifier threads can
be cleanly stopped when no UCXX communicators are alive anymore.
"""
ctx = ucxx.core._get_ctx()
if not hasattr(ctx, "_dask_resources"):
# TODO: Move the `Lock` to a file/module-level variable for true
# lock-safety. The approach implemented below could cause race
# conditions if this function is called simultaneously by multiple
# threads.
from threading import Lock

ctx._dask_resources = set()
ctx._dask_resources_lock = Lock()


def init_once():
Expand All @@ -181,6 +187,11 @@ def init_once():
global multi_buffer

if ucxx is not None:
# Ensure reallocation of Dask resources tracker if the UCXX context was
# reset since the previous `init_once()` call. This may happen in tests,
# where the `ucxx_loop` fixture will reset the context after each test.
_allocate_dask_resources_tracker()

return

# remove/process dask.ucx flags for valid ucx options
Expand Down Expand Up @@ -243,6 +254,7 @@ def init_once():
# environment, so the user's external environment can safely
# override things here.
ucxx.init(options=ucx_config, env_takes_precedence=True)
_allocate_dask_resources_tracker()

pool_size_str = dask.config.get("distributed.rmm.pool-size")

Expand Down
17 changes: 17 additions & 0 deletions python/ucxx/ucxx/_lib/libucxx.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -617,6 +617,23 @@ cdef class UCXWorker():
with nogil:
self._worker.get().initBlockingProgressMode()

def arm(self) -> bool:
cdef bint armed

with nogil:
armed = self._worker.get().arm()

return armed

@property
def epoll_file_descriptor(self) -> int:
cdef int epoll_file_descriptor = 0

with nogil:
epoll_file_descriptor = self._worker.get().getEpollFileDescriptor()

return epoll_file_descriptor

def progress(self) -> None:
with nogil:
self._worker.get().progress()
Expand Down
2 changes: 2 additions & 0 deletions python/ucxx/ucxx/_lib/ucxx_api.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -229,6 +229,8 @@ cdef extern from "<ucxx/api.h>" namespace "ucxx" nogil:
uint16_t port, ucp_listener_conn_callback_t callback, void *callback_args
) except +raise_py_error
void initBlockingProgressMode() except +raise_py_error
int getEpollFileDescriptor()
bint arm() except +raise_py_error
void progress()
bint progressOnce()
void progressWorkerEvent(int epoll_timeout)
Expand Down
19 changes: 11 additions & 8 deletions python/ucxx/ucxx/_lib_async/application_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from ucxx.exceptions import UCXMessageTruncatedError
from ucxx.types import Tag

from .continuous_ucx_progress import PollingMode, ThreadMode
from .continuous_ucx_progress import BlockingMode, PollingMode, ThreadMode
from .endpoint import Endpoint
from .exchange_peer_info import exchange_peer_info
from .listener import ActiveClients, Listener, _listener_handler
Expand Down Expand Up @@ -56,8 +56,8 @@ def __init__(
self.context = ucx_api.UCXContext(config_dict)
self.worker = ucx_api.UCXWorker(
self.context,
enable_delayed_submission=self._enable_delayed_submission,
enable_python_future=self._enable_python_future,
enable_delayed_submission=self.enable_delayed_submission,
enable_python_future=self.enable_python_future,
)

self.start_notifier_thread()
Expand All @@ -82,12 +82,12 @@ def progress_mode(self, progress_mode):
else:
progress_mode = "thread"

valid_progress_modes = ["polling", "thread", "thread-polling"]
valid_progress_modes = ["blocking", "polling", "thread", "thread-polling"]
if not isinstance(progress_mode, str) or not any(
progress_mode == m for m in valid_progress_modes
):
raise ValueError(
f"Unknown progress mode {progress_mode}, valid modes are: "
f"Unknown progress mode '{progress_mode}', valid modes are: "
"'blocking', 'polling', 'thread' or 'thread-polling'"
)

Expand Down Expand Up @@ -121,8 +121,9 @@ def enable_delayed_submission(self, enable_delayed_submission):
and explicit_enable_delayed_submission
):
raise ValueError(
f"Delayed submission requested, but {self.progress_mode} does not "
"support it, 'thread' or 'thread-polling' progress mode required."
f"Delayed submission requested, but '{self.progress_mode}' does "
"not support it, 'thread' or 'thread-polling' progress mode "
"required."
)

self._enable_delayed_submission = explicit_enable_delayed_submission
Expand Down Expand Up @@ -153,7 +154,7 @@ def enable_python_future(self, enable_python_future):
and explicit_enable_python_future
):
logger.warning(
f"Notifier thread requested, but {self.progress_mode} does not "
f"Notifier thread requested, but '{self.progress_mode}' does not "
"support it, using Python wait_yield()."
)
explicit_enable_python_future = False
Expand Down Expand Up @@ -464,6 +465,8 @@ def continuous_ucx_progress(self, event_loop=None):
task = ThreadMode(self.worker, loop, polling_mode=True)
elif self.progress_mode == "polling":
task = PollingMode(self.worker, loop)
elif self.progress_mode == "blocking":
task = BlockingMode(self.worker, loop)

self.progress_tasks[loop] = task

Expand Down
Loading

0 comments on commit 122d2f4

Please sign in to comment.