Skip to content

Commit

Permalink
chore: test python/cpython#124847
Browse files Browse the repository at this point in the history
  • Loading branch information
bdraco committed Oct 4, 2024
1 parent 5a3b4cb commit 8f656c9
Show file tree
Hide file tree
Showing 3 changed files with 80 additions and 151 deletions.
4 changes: 0 additions & 4 deletions .github/workflows/ci-cd.yml
Original file line number Diff line number Diff line change
Expand Up @@ -36,10 +36,6 @@ jobs:
fail-fast: false
matrix:
python-version:
- "3.8.0"
- "3.8"
- "3.9"
- "3.10"
- "3.11"
- "3.12"
- "3.13"
Expand Down
221 changes: 80 additions & 141 deletions src/aiohappyeyeballs/_staggered.py
Original file line number Diff line number Diff line change
@@ -1,55 +1,13 @@
import asyncio
import contextlib
from typing import (
TYPE_CHECKING,
Any,
Awaitable,
Callable,
Iterable,
List,
Optional,
Set,
Tuple,
TypeVar,
Union,
)

_T = TypeVar("_T")


def _set_result(wait_next: "asyncio.Future[None]") -> None:
"""Set the result of a future if it is not already done."""
if not wait_next.done():
wait_next.set_result(None)


async def _wait_one(
futures: "Iterable[asyncio.Future[Any]]",
loop: asyncio.AbstractEventLoop,
) -> _T:
"""Wait for the first future to complete."""
wait_next = loop.create_future()

def _on_completion(fut: "asyncio.Future[Any]") -> None:
if not wait_next.done():
wait_next.set_result(fut)

for f in futures:
f.add_done_callback(_on_completion)
"""Support for running coroutines in parallel with staggered start times."""

try:
return await wait_next
finally:
for f in futures:
f.remove_done_callback(_on_completion)
__all__ = ("staggered_race",)

import contextlib
from asyncio import events, locks, tasks
from asyncio import exceptions as exceptions_mod


async def staggered_race(
coro_fns: Iterable[Callable[[], Awaitable[_T]]],
delay: Optional[float],
*,
loop: Optional[asyncio.AbstractEventLoop] = None,
) -> Tuple[Optional[_T], Optional[int], List[Optional[BaseException]]]:
async def staggered_race(coro_fns, delay, *, loop=None):
"""
Run coroutines with staggered start times and take the first to finish.
Expand All @@ -75,18 +33,16 @@ async def staggered_race(
raise
Args:
----
coro_fns: an iterable of coroutine functions, i.e. callables that
return a coroutine object when called. Use ``functools.partial`` or
lambdas to pass arguments.
delay: amount of time, in seconds, between starting coroutines. If
``None``, the coroutines will run sequentially.
loop: the event loop to use. If ``None``, the running loop is used.
loop: the event loop to use.
Returns:
-------
tuple *(winner_result, winner_index, exceptions)* where
- *winner_result*: the result of the winning coroutine, or ``None``
Expand All @@ -103,100 +59,83 @@ async def staggered_race(
coroutine's entry is ``None``.
"""
loop = loop or asyncio.get_running_loop()
exceptions: List[Optional[BaseException]] = []
tasks: Set[asyncio.Task[Optional[Tuple[_T, int]]]] = set()

async def run_one_coro(
coro_fn: Callable[[], Awaitable[_T]],
this_index: int,
start_next: "asyncio.Future[None]",
) -> Optional[Tuple[_T, int]]:
"""
Run a single coroutine.
If the coroutine fails, set the exception in the exceptions list and
start the next coroutine by setting the result of the start_next.
If the coroutine succeeds, return the result and the index of the
coroutine in the coro_fns list.
If SystemExit or KeyboardInterrupt is raised, re-raise it.
"""
# TODO: when we have aiter() and anext(), allow async iterables in coro_fns.
loop = loop or events.get_running_loop()
enum_coro_fns = enumerate(coro_fns)
winner_result = None
winner_index = None
exceptions = []
running_tasks = []

async def run_one_coro(ok_to_start, previous_failed) -> None:
await ok_to_start.wait()
# Wait for the previous task to finish, or for delay seconds
if previous_failed is not None:
with contextlib.suppress(exceptions_mod.TimeoutError):
# Use asyncio.wait_for() instead of asyncio.wait() here, so
# that if we get cancelled at this point, Event.wait() is also
# cancelled, otherwise there will be a "Task destroyed but it is
# pending" later.
await tasks.wait_for(previous_failed.wait(), delay)
# Get the next coroutine to run
try:
this_index, coro_fn = next(enum_coro_fns)
except StopIteration:
return
# Start task that will run the next coroutine
this_failed = locks.Event()
next_ok_to_start = locks.Event()
next_task = loop.create_task(run_one_coro(next_ok_to_start, this_failed))
running_tasks.append(next_task)
next_ok_to_start.set()
assert len(running_tasks) == this_index + 2
# Prepare place to put this coroutine's exceptions if not won
exceptions.append(None)
assert len(exceptions) == this_index + 1

try:
result = await coro_fn()
except (SystemExit, KeyboardInterrupt):
raise
except BaseException as e:
exceptions[this_index] = e
_set_result(start_next) # Kickstart the next coroutine
return None

return result, this_index

start_next_timer: Optional[asyncio.TimerHandle] = None
start_next: Optional[asyncio.Future[None]]
task: asyncio.Task[Optional[Tuple[_T, int]]]
done: Union[asyncio.Future[None], asyncio.Task[Optional[Tuple[_T, int]]]]
coro_iter = iter(coro_fns)
this_index = -1
this_failed.set() # Kickstart the next coroutine
else:
# Store winner's results
nonlocal winner_index, winner_result
assert winner_index is None
winner_index = this_index
winner_result = result
# Cancel all other tasks. We take care to not cancel the current
# task as well. If we do so, then since there is no `await` after
# here and CancelledError are usually thrown at one, we will
# encounter a curious corner case where the current task will end
# up as done() == True, cancelled() == False, exception() ==
# asyncio.CancelledError. This behavior is specified in
# https://bugs.python.org/issue30048
for i, t in enumerate(running_tasks):
if i != this_index:
t.cancel()

ok_to_start = locks.Event()
first_task = loop.create_task(run_one_coro(ok_to_start, None))
running_tasks.append(first_task)
ok_to_start.set()
try:
while True:
if coro_fn := next(coro_iter, None):
this_index += 1
exceptions.append(None)
start_next = loop.create_future()
task = loop.create_task(run_one_coro(coro_fn, this_index, start_next))
tasks.add(task)
start_next_timer = (
loop.call_later(delay, _set_result, start_next) if delay else None
)
elif not tasks:
# We exhausted the coro_fns list and no tasks are running
# so we have no winner and all coroutines failed.
break

while tasks:
done = await _wait_one(
[*tasks, start_next] if start_next else tasks, loop
)
if done is start_next:
# The current task has failed or the timer has expired
# so we need to start the next task.
start_next = None
if start_next_timer:
start_next_timer.cancel()
start_next_timer = None

# Break out of the task waiting loop to start the next
# task.
break

if TYPE_CHECKING:
assert isinstance(done, asyncio.Task)

tasks.remove(done)
if winner := done.result():
return *winner, exceptions
# Wait for a growing list of tasks to all finish: poor man's version of
# curio's TaskGroup or trio's nursery
done_count = 0
while done_count != len(running_tasks):
done, _ = await tasks.wait(running_tasks)
done_count = len(done)
# If run_one_coro raises an unhandled exception, it's probably a
# programming error, and I want to see it.
if __debug__:
for d in done:
if d.done() and not d.cancelled() and d.exception():
raise d.exception()

Check warning on line 136 in src/aiohappyeyeballs/_staggered.py

View check run for this annotation

Codecov / codecov/patch

src/aiohappyeyeballs/_staggered.py#L136

Added line #L136 was not covered by tests
return winner_result, winner_index, exceptions
finally:
# We either have:
# - a winner
# - all tasks failed
# - a KeyboardInterrupt or SystemExit.

#
# If the timer is still running, cancel it.
#
if start_next_timer:
start_next_timer.cancel()

#
# If there are any tasks left, cancel them and than
# wait them so they fill the exceptions list.
#
for task in tasks:
task.cancel()
with contextlib.suppress(asyncio.CancelledError):
await task

return None, None, exceptions
# Make sure no tasks are left running if we leave this function
for t in running_tasks:
t.cancel()
6 changes: 0 additions & 6 deletions tests/test_staggered.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,11 +47,8 @@ async def coro(idx):
await asyncio.sleep(0.1)
loop.call_soon(finish.set_result, None)
winner, index, excs = await task
assert len(winners) == 4
assert winners == [0, 1, 2, 3]
assert winner == 0
assert index == 0
assert excs == [None, None, None, None]


@pytest.mark.skipif(sys.version_info < (3, 12), reason="requires python3.12 or higher")
Expand All @@ -77,10 +74,7 @@ async def coro(idx):
await asyncio.sleep(0.1)
loop.call_soon(finish.set_result, None)
winner, index, excs = await task
assert len(winners) == 4
assert winners == [0, 1, 2, 3]
assert winner == 0
assert index == 0
assert excs == [None, None, None, None]

loop.run_until_complete(run())

0 comments on commit 8f656c9

Please sign in to comment.