diff --git a/multidb/pinning.py b/multidb/pinning.py index aa529c2..8f554e9 100644 --- a/multidb/pinning.py +++ b/multidb/pinning.py @@ -34,7 +34,7 @@ def unpin_this_thread(): class UsePrimaryDB(object): - """A contextmanager/decorator to use the master database.""" + """A contextmanager/decorator to use the primary database.""" def __call__(self, func): @wraps(func) def decorator(*args, **kw): @@ -43,11 +43,12 @@ def decorator(*args, **kw): return decorator def __enter__(self): - _locals.old = this_thread_is_pinned() + _locals.old = getattr(_locals, 'old', []) + _locals.old.append(this_thread_is_pinned()) pin_this_thread() def __exit__(self, type, value, tb): - if not _locals.old: + if not _locals.old.pop(): unpin_this_thread() diff --git a/multidb/tests.py b/multidb/tests.py index 125cff4..136e63a 100644 --- a/multidb/tests.py +++ b/multidb/tests.py @@ -202,6 +202,22 @@ def check(): check() assert not this_thread_is_pinned() + def test_decorator_nested(self): + @use_primary_db + def check_inner(): + assert this_thread_is_pinned() + + @use_primary_db + def check_outer(): + assert this_thread_is_pinned() + check_inner() + assert this_thread_is_pinned() + + unpin_this_thread() + assert not this_thread_is_pinned() + check_outer() + assert not this_thread_is_pinned() + def test_decorator_resets(self): @use_primary_db def check(): @@ -211,6 +227,22 @@ def check(): check() assert this_thread_is_pinned() + def test_decorator_resets_nested(self): + @use_primary_db + def check_inner(): + assert this_thread_is_pinned() + + @use_primary_db + def check_outer(): + assert this_thread_is_pinned() + check_inner() + assert this_thread_is_pinned() + + pin_this_thread() + assert this_thread_is_pinned() + check_outer() + assert this_thread_is_pinned() + def test_context_manager(self): unpin_this_thread() assert not this_thread_is_pinned() @@ -218,6 +250,16 @@ def test_context_manager(self): assert this_thread_is_pinned() assert not this_thread_is_pinned() + def test_context_manager_nested(self): + unpin_this_thread() + assert not this_thread_is_pinned() + with use_primary_db: + assert this_thread_is_pinned() + with use_primary_db: + assert this_thread_is_pinned() + assert this_thread_is_pinned() + assert not this_thread_is_pinned() + def test_context_manager_resets(self): pin_this_thread() assert this_thread_is_pinned() @@ -225,6 +267,16 @@ def test_context_manager_resets(self): assert this_thread_is_pinned() assert this_thread_is_pinned() + def test_context_manager_resets_nested(self): + pin_this_thread() + assert this_thread_is_pinned() + with use_primary_db: + assert this_thread_is_pinned() + with use_primary_db: + assert this_thread_is_pinned() + assert this_thread_is_pinned() + assert this_thread_is_pinned() + def test_context_manager_exception(self): unpin_this_thread() assert not this_thread_is_pinned()