Skip to content

Commit

Permalink
Merge pull request #4 from Quansight-Labs/add_thread_comparator
Browse files Browse the repository at this point in the history
Add ThreadComparator class fixture to compare values across threads
  • Loading branch information
andfoy authored Sep 30, 2024
2 parents 8aeab85 + d41fadf commit 2581720
Show file tree
Hide file tree
Showing 3 changed files with 200 additions and 0 deletions.
27 changes: 27 additions & 0 deletions README.rst
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,33 @@ Both modes of operations are supported simultaneously, i.e.,
# threads; other tests will be run using 5 threads.
pytest -x -v --parallel-threads=5 test_file.py
Additionally, ``pytest-run-parallel`` exposes the ``num_parallel_threads`` fixture
which enable a test to be aware of the number of threads that are being spawned:

.. code-block:: python
# test_file.py
import pytest
def test_skip_if_parallel(num_parallel_threads):
if num_parallel_threads > 1:
pytest.skip(reason='does not work in parallel')
...
Finally, the ``thread_comp`` fixture allows for parallel test debugging, by providing an
instance of ``ThreadComparator``, whose ``__call__`` method allows to check if all the values
produced by all threads during an specific execution step are the same:

.. code-block:: python
# test_file.py
def test_same_execution_values(thread_comp):
a = 2
b = [3, 4, 5]
c = None
# Check that the values for a, b, c are the same across tests
thread_comp(a=a, b=b, c=c)
Contributing
------------
Contributions are very welcome. Tests can be run with `tox`_, please ensure
Expand Down
88 changes: 88 additions & 0 deletions src/pytest_run_parallel/plugin.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,16 @@
import pytest
import threading
import functools
import types

from _pytest.outcomes import Skipped, Failed

try:
import numpy as np
numpy_available = True
except ImportError:
numpy_available = False


def pytest_addoption(parser):
group = parser.getgroup('run-parallel')
Expand Down Expand Up @@ -77,3 +84,84 @@ def pytest_itemcollected(item):
n_workers = int(m.args[0])
if n_workers is not None and n_workers > 1:
item.obj = wrap_function_parallel(item.obj, n_workers)


@pytest.fixture
def num_parallel_threads(request):
node = request.node
n_workers = request.config.option.parallel_threads
m = node.get_closest_marker('parallel_threads')
if m is not None:
n_workers = int(m.args[0])
return n_workers


class ThreadComparator:

def __init__(self, n_threads):
self._barrier = threading.Barrier(n_threads)
self._reset_evt = threading.Event()
self._entry_barrier = threading.Barrier(n_threads)

self._thread_ids = []
self._values = {}
self._entry_lock = threading.Lock()
self._entry_counter = 0

def __call__(self, **values):
"""
Compares a set of values across threads.
For each value, type equality as well as comparison takes place. If any
of the values is a function, then address comparison is performed.
Also, if any of the values is a `numpy.ndarray`, then approximate
numerical comparison is performed.
"""
tid = id(threading.current_thread())
self._entry_barrier.wait()
with self._entry_lock:
if self._entry_counter == 0:
# Reset state before comparison
self._barrier.reset()
self._reset_evt.clear()
self._thread_ids = []
self._values = {}
self._entry_barrier.reset()
self._entry_counter += 1

self._values[tid] = values
self._thread_ids.append(tid)
self._barrier.wait()

if tid == self._thread_ids[0]:
thread_ids = list(self._values)
try:
for value_name in values:
for i in range(1, len(thread_ids)):
tid_a = thread_ids[i - 1]
tid_b = thread_ids[i]
value_a = self._values[tid_a][value_name]
value_b = self._values[tid_b][value_name]
assert type(value_a) is type(value_b)
if numpy_available and isinstance(value_a, np.ndarray):
if len(value_a.shape) == 0:
assert value_a == value_b
else:
assert np.allclose(
value_a, value_b, equal_nan=True)
elif isinstance(value_a, types.FunctionType):
assert id(value_a) == id(value_b)
elif value_a != value_a:
assert value_b != value_b
else:
assert value_a == value_b
finally:
self._entry_counter = 0
self._reset_evt.set()
else:
self._reset_evt.wait()


@pytest.fixture
def thread_comp(num_parallel_threads):
return ThreadComparator(num_parallel_threads)
85 changes: 85 additions & 0 deletions tests/test_run_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -250,3 +250,88 @@ def test_should_fail():

# make sure that we get a '0' exit code for the testsuite
assert result.ret != 0


def test_num_parallel_threads_fixture(pytester):
"""Test that the num_parallel_threads fixture works as expected."""

# create a temporary pytest test module
pytester.makepyfile("""
import pytest
def test_should_yield_global_threads(num_parallel_threads):
assert num_parallel_threads == 10
@pytest.mark.parallel_threads(2)
def test_should_yield_marker_threads(num_parallel_threads):
assert num_parallel_threads == 2
""")

# run pytest with the following cmd args
result = pytester.runpytest(
'--parallel-threads=10',
'-v'
)

# fnmatch_lines does an assertion internally
result.stdout.fnmatch_lines([
'*::test_should_yield_global_threads PASSED*',
'*::test_should_yield_marker_threads PASSED*'
])


def test_thread_comp_fixture(pytester):
"""Test that ThreadComparator works as expected."""

# create a temporary pytest test module
pytester.makepyfile("""
import threading
import pytest
class Counter:
def __init__(self):
self._value = 0
self._lock = threading.Lock()
def get_value_and_increment(self):
with self._lock:
value = int(self._value)
self._value += 1
return value
def test_value_comparison(num_parallel_threads, thread_comp):
assert num_parallel_threads == 10
a = 1
b = [2, 'string', 1.0]
c = {'a': -4, 'b': 'str'}
d = float('nan')
e = float('inf')
f = {'a', 'b', '#'}
thread_comp(a=a, b=b, c=c, d=d, e=e, f=f)
# Ensure that the comparator can be used again
thread_comp(g=4)
@pytest.fixture
def counter(num_parallel_threads):
return Counter()
def test_comparison_fail(thread_comp, counter):
a = 4
pos = counter.get_value_and_increment()
if pos % 2 == 0:
a = -1
thread_comp(a=a)
""")

# run pytest with the following cmd args
result = pytester.runpytest(
'--parallel-threads=10',
'-v'
)

# fnmatch_lines does an assertion internally
result.stdout.fnmatch_lines([
'*::test_value_comparison PASSED*',
'*::test_comparison_fail FAILED*'
])

0 comments on commit 2581720

Please sign in to comment.