Skip to content

Commit

Permalink
Merge pull request #3 from Quansight-Labs/support_unittest
Browse files Browse the repository at this point in the history
Improve unittest collection support
  • Loading branch information
andfoy authored Sep 2, 2024
2 parents 7e61bcd + 2b95eef commit 8aeab85
Show file tree
Hide file tree
Showing 3 changed files with 163 additions and 16 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ htmlcov/
nosetests.xml
coverage.xml
*.cover
*.lcov
*.py,cover
.hypothesis/
.pytest_cache/
Expand Down
40 changes: 25 additions & 15 deletions src/pytest_run_parallel/plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
import threading
import functools

from _pytest.outcomes import Skipped, Failed


def pytest_addoption(parser):
group = parser.getgroup('run-parallel')
Expand All @@ -22,26 +24,27 @@ def pytest_configure(config):
'using `n` threads.')


@pytest.hookimpl(trylast=True)
def pytest_generate_tests(metafunc):
n_workers = metafunc.config.option.parallel_threads
m = metafunc.definition.get_closest_marker('parallel_threads')
if m is not None:
n_workers = int(m.args[0])
setattr(metafunc.function, '_n_workers', n_workers)


def wrap_function_parallel(fn, n_workers=10):
barrier = threading.Barrier(n_workers)
@functools.wraps(fn)
def inner(*args, **kwargs):
errors = []
skip = None
failed = None
def closure(*args, **kwargs):
barrier.wait()
try:
fn(*args, **kwargs)
except Warning as w:
pass
except Exception as e:
errors.append(e)
except Skipped as s:
nonlocal skip
skip = s.msg
except Failed as f:
nonlocal failed
failed = f

workers = []
for _ in range(0, n_workers):
Expand All @@ -56,14 +59,21 @@ def closure(*args, **kwargs):
for worker in workers:
worker.join()

if len(errors) > 0:
if skip is not None:
pytest.skip(skip)
elif failed is not None:
raise failed
elif len(errors) > 0:
raise errors[0]

return inner


@pytest.hookimpl(wrapper=True)
def pytest_pyfunc_call(pyfuncitem):
n_workers = getattr(pyfuncitem.obj, '_n_workers', None)
@pytest.hookimpl(trylast=True)
def pytest_itemcollected(item):
n_workers = item.config.option.parallel_threads
m = item.get_closest_marker('parallel_threads')
if m is not None:
n_workers = int(m.args[0])
if n_workers is not None and n_workers > 1:
pyfuncitem.obj = wrap_function_parallel(pyfuncitem.obj, n_workers)
return (yield)
item.obj = wrap_function_parallel(item.obj, n_workers)
138 changes: 137 additions & 1 deletion tests/test_run_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ def test_check_thread_count(counter):
@pytest.mark.order(2)
@pytest.mark.parallel_threads(1)
def test_check_thread_count(counter2):
def test_check_thread_count2(counter2):
assert counter2._count == 5
""")

Expand All @@ -97,6 +97,65 @@ def test_check_thread_count(counter2):
# fnmatch_lines does an assertion internally
result.stdout.fnmatch_lines([
'*::test_check_thread_count PASSED*',
'*::test_check_thread_count2 PASSED*',
])

# make sure that we get a '0' exit code for the testsuite
assert result.ret == 0


def test_unittest_compat(pytester):
# create a temporary pytest test module
pytester.makepyfile("""
import pytest
import unittest
from threading import Lock
class Counter:
def __init__(self):
self._count = 0
self._lock = Lock()
def increase(self):
with self._lock:
self._count += 1
class TestExample(unittest.TestCase):
@classmethod
def setUpClass(cls):
cls.counter = Counter()
cls.counter2 = Counter()
@pytest.mark.order(1)
def test_example_1(self):
self.counter.increase()
@pytest.mark.order(1)
@pytest.mark.parallel_threads(5)
def test_example_2(self):
self.counter2.increase()
@pytest.mark.order(2)
@pytest.mark.parallel_threads(1)
def test_check_thread_count(self):
assert self.counter._count == 10
@pytest.mark.order(2)
@pytest.mark.parallel_threads(1)
def test_check_thread_count2(self):
assert self.counter2._count == 5
""")

# run pytest with the following cmd args
result = pytester.runpytest(
'--parallel-threads=10',
'-v'
)

# fnmatch_lines does an assertion internally
result.stdout.fnmatch_lines([
'*::test_check_thread_count PASSED*',
'*::test_check_thread_count2 PASSED*',
])

# make sure that we get a '0' exit code for the testsuite
Expand All @@ -114,3 +173,80 @@ def test_help_message(pytester):
# ' Set the number of threads used to execute each test concurrently.',
])


def test_skip(pytester):
"""Make sure that pytest accepts our fixture."""

# create a temporary pytest test module
pytester.makepyfile("""
import pytest
def test_skipped():
pytest.skip('Skip propagation')
""")

# run pytest with the following cmd args
result = pytester.runpytest(
'--parallel-threads=10',
'-v'
)

# fnmatch_lines does an assertion internally
result.stdout.fnmatch_lines([
'*::test_skipped SKIPPED*',
])

# make sure that we get a '0' exit code for the testsuite
assert result.ret == 0


def test_fail(pytester):
"""Make sure that pytest accepts our fixture."""

# create a temporary pytest test module
pytester.makepyfile("""
import pytest
def test_should_fail():
pytest.fail()
""")

# run pytest with the following cmd args
result = pytester.runpytest(
'--parallel-threads=10',
'-v'
)

# fnmatch_lines does an assertion internally
result.stdout.fnmatch_lines([
'*::test_should_fail FAILED*',
])

# make sure that we get a '0' exit code for the testsuite
assert result.ret != 0


def test_exception(pytester):
"""Make sure that pytest accepts our fixture."""

# create a temporary pytest test module
pytester.makepyfile("""
import pytest
def test_should_fail():
raise ValueError('Should raise')
""")

# run pytest with the following cmd args
result = pytester.runpytest(
'--parallel-threads=10',
'-v'
)

# fnmatch_lines does an assertion internally
result.stdout.fnmatch_lines([
'*::test_should_fail FAILED*',
])

# make sure that we get a '0' exit code for the testsuite
assert result.ret != 0

0 comments on commit 8aeab85

Please sign in to comment.