Skip to content

Commit

Permalink
Allow disabling clusters by returning None when overriding fixtures: …
Browse files Browse the repository at this point in the history
…celery_broker_cluster, celery_backend_cluster
  • Loading branch information
Nusnus committed Nov 28, 2023
1 parent fae39af commit a04cabc
Show file tree
Hide file tree
Showing 8 changed files with 105 additions and 29 deletions.
34 changes: 23 additions & 11 deletions src/pytest_celery/api/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,12 +66,14 @@ def config(cls, celery_worker_cluster_config: dict) -> dict:
if not celery_worker_cluster_config:
raise ValueError("celery_worker_cluster_config is empty")

celery_broker_cluster_config: dict = celery_worker_cluster_config["celery_broker_cluster_config"]
celery_backend_cluster_config: dict = celery_worker_cluster_config["celery_backend_cluster_config"]
return {
"broker_url": ";".join(celery_broker_cluster_config["local_urls"]),
"result_backend": ";".join(celery_backend_cluster_config["local_urls"]),
}
celery_broker_cluster_config: dict = celery_worker_cluster_config.get("celery_broker_cluster_config", {})
celery_backend_cluster_config: dict = celery_worker_cluster_config.get("celery_backend_cluster_config", {})
config = {}
if celery_broker_cluster_config:
config["broker_url"] = ";".join(celery_broker_cluster_config["local_urls"])
if celery_backend_cluster_config:
config["result_backend"] = ";".join(celery_backend_cluster_config["local_urls"])
return config

@classmethod
def update_app_config(cls, app: Celery) -> None:
Expand All @@ -80,8 +82,8 @@ def update_app_config(cls, app: Celery) -> None:

@classmethod
def create_setup_app(cls, celery_setup_config: dict, celery_setup_app_name: str) -> Celery:
if not celery_setup_config:
raise ValueError("celery_setup_config is empty")
if celery_setup_config is None:
raise ValueError("celery_setup_config is None")

if not celery_setup_app_name:
raise ValueError("celery_setup_app_name is empty")
Expand Down Expand Up @@ -110,8 +112,12 @@ def ready(self, ping: bool = False, control: bool = False, docker: bool = True)
ready = True

if docker and ready:
ready = all([self.broker_cluster.ready(), self.backend_cluster.ready()])
ready = ready and self.worker_cluster.ready()
if self.broker_cluster:
ready = ready and self.broker_cluster.ready()
if self.backend_cluster:
ready = ready and self.backend_cluster.ready()
if self.worker_cluster:
ready = ready and self.worker_cluster.ready()

if control and ready:
r = self.app.control.ping()
Expand All @@ -125,7 +131,13 @@ def ready(self, ping: bool = False, control: bool = False, docker: bool = True)
ready = ready and res.get(timeout=RESULT_TIMEOUT) == "pong"

# Set app for all nodes
nodes = self.broker_cluster.nodes + self.backend_cluster.nodes
nodes: tuple = tuple()
if self.broker_cluster:
nodes += self.broker_cluster.nodes
if self.backend_cluster:
nodes += self.backend_cluster.nodes
if self.worker_cluster:
nodes += self.worker_cluster.nodes
for node in nodes:
node._app = self.app

Expand Down
6 changes: 4 additions & 2 deletions src/pytest_celery/fixtures/backend.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
# mypy: disable-error-code="misc"

from __future__ import annotations

import pytest

from pytest_celery.api.backend import CeleryBackendCluster
Expand All @@ -23,10 +25,10 @@ def celery_backend_cluster(celery_backend: CeleryTestBackend) -> CeleryBackendCl


@pytest.fixture
def celery_backend_cluster_config(request: pytest.FixtureRequest) -> dict:
def celery_backend_cluster_config(request: pytest.FixtureRequest) -> dict | None:
try:
use_default_config = pytest.fail.Exception
cluster: CeleryBackendCluster = request.getfixturevalue(CELERY_BACKEND_CLUSTER)
return cluster.config()
return cluster.config() if cluster else None
except use_default_config:
return CeleryBackendCluster.default_config()
6 changes: 4 additions & 2 deletions src/pytest_celery/fixtures/broker.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
# mypy: disable-error-code="misc"

from __future__ import annotations

import pytest

from pytest_celery.api.broker import CeleryBrokerCluster
Expand All @@ -23,10 +25,10 @@ def celery_broker_cluster(celery_broker: CeleryTestBroker) -> CeleryBrokerCluste


@pytest.fixture
def celery_broker_cluster_config(request: pytest.FixtureRequest) -> dict:
def celery_broker_cluster_config(request: pytest.FixtureRequest) -> dict | None:
try:
use_default_config = pytest.fail.Exception
cluster: CeleryBrokerCluster = request.getfixturevalue(CELERY_BROKER_CLUSTER)
return cluster.config()
return cluster.config() if cluster else None
except use_default_config:
return CeleryBrokerCluster.default_config()
33 changes: 21 additions & 12 deletions src/pytest_celery/vendors/worker/container.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

import inspect
from typing import Union

from celery import Celery
from celery.app.base import PendingConfiguration
Expand Down Expand Up @@ -54,22 +55,30 @@ def buildargs(cls) -> dict:
}

@classmethod
def env(cls, celery_worker_cluster_config: dict) -> dict:
celery_broker_cluster_config = celery_worker_cluster_config.get("celery_broker_cluster_config")
celery_backend_cluster_config = celery_worker_cluster_config.get("celery_backend_cluster_config")
env = {}
if celery_broker_cluster_config:
env["CELERY_BROKER_URL"] = ";".join(celery_broker_cluster_config["urls"])
if celery_backend_cluster_config:
env["CELERY_RESULT_BACKEND"] = ";".join(celery_backend_cluster_config["urls"])
return {**DEFAULT_WORKER_ENV, **env}
def env(cls, celery_worker_cluster_config: dict, initial: dict | None = None) -> dict:
env = initial or {}
env = {**env, **DEFAULT_WORKER_ENV.copy()}

config_mappings = [
("celery_broker_cluster_config", "CELERY_BROKER_URL"),
("celery_backend_cluster_config", "CELERY_RESULT_BACKEND"),
]

for config_key, env_key in config_mappings:
cluster_config = celery_worker_cluster_config.get(config_key)
if cluster_config:
env[env_key] = ";".join(cluster_config["urls"])
else:
del env[env_key]

return env

@classmethod
def initial_content(
cls,
worker_tasks: set,
worker_signals: Union[set, None] = None,
worker_app: Union[Celery, None] = None,
worker_signals: set | None = None,
worker_app: Celery | None = None,
) -> dict:
from pytest_celery.vendors.worker import app as app_module

Expand Down
13 changes: 13 additions & 0 deletions tests/integration/vendors/test_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
from pytest_lazyfixture import lazy_fixture

from pytest_celery import CELERY_SETUP_WORKER
from pytest_celery import DEFAULT_WORKER_ENV
from pytest_celery import CeleryBackendCluster
from pytest_celery import CeleryTestWorker
from pytest_celery import CeleryWorkerContainer
from tests.defaults import ALL_WORKERS_FIXTURES
Expand All @@ -17,6 +19,17 @@ def test_celeryconfig(self, container: CeleryWorkerContainer):
with pytest.raises(NotImplementedError):
container.celeryconfig

class test_disabling_cluster:
@pytest.fixture
def celery_backend_cluster(self) -> CeleryBackendCluster:
return None

def test_disabling_backend_cluster(self, container: CeleryWorkerContainer):
assert container.logs().count("results: disabled://")

results = DEFAULT_WORKER_ENV["CELERY_BROKER_URL"]
assert container.logs().count(f"transport: {results}")


@pytest.mark.parametrize("worker", [lazy_fixture(CELERY_SETUP_WORKER)])
class test_base_test_worker:
Expand Down
9 changes: 9 additions & 0 deletions tests/unit/api/test_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,3 +31,12 @@ class test_celery_backend_cluster:
def test_default_config_format(self, cluster: CeleryBackendCluster):
assert cluster.default_config()["urls"] == [DEFAULT_WORKER_ENV["CELERY_RESULT_BACKEND"]]
assert cluster.default_config()["local_urls"] == [DEFAULT_WORKER_ENV["CELERY_RESULT_BACKEND"]]

class test_disabling_cluster:
@pytest.fixture
def celery_backend_cluster(self) -> CeleryBackendCluster:
return None

def test_disabling_backend_cluster(self, cluster: CeleryBackendCluster, celery_backend_cluster_config: dict):
assert cluster is None
assert celery_backend_cluster_config is None
9 changes: 9 additions & 0 deletions tests/unit/api/test_broker.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,3 +31,12 @@ class test_celery_broker_cluster:
def test_default_config_format(self, cluster: CeleryBrokerCluster):
assert cluster.default_config()["urls"] == [DEFAULT_WORKER_ENV["CELERY_BROKER_URL"]]
assert cluster.default_config()["local_urls"] == [DEFAULT_WORKER_ENV["CELERY_BROKER_URL"]]

class test_disabling_cluster:
@pytest.fixture
def celery_broker_cluster(self) -> CeleryBrokerCluster:
return None

def test_disabling_broker_cluster(self, cluster: CeleryBrokerCluster, celery_broker_cluster_config: dict):
assert cluster is None
assert celery_broker_cluster_config is None
24 changes: 22 additions & 2 deletions tests/unit/vendors/test_worker.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,14 @@
import inspect

import pytest

from pytest_celery import DEFAULT_WORKER_ENV
from pytest_celery import DEFAULT_WORKER_LOG_LEVEL
from pytest_celery import DEFAULT_WORKER_NAME
from pytest_celery import DEFAULT_WORKER_QUEUE
from pytest_celery import DEFAULT_WORKER_VERSION
from pytest_celery import CeleryBackendCluster
from pytest_celery import CeleryBrokerCluster
from pytest_celery import CeleryWorkerContainer


Expand Down Expand Up @@ -35,8 +39,24 @@ def test_buildargs(self):
"CELERY_WORKER_QUEUE": DEFAULT_WORKER_QUEUE,
}

def test_env(self, celery_worker_cluster_config: dict):
assert CeleryWorkerContainer.env(celery_worker_cluster_config) == DEFAULT_WORKER_ENV
class test_celery_worker_container_env:
def test_env(self, celery_worker_cluster_config: dict):
assert CeleryWorkerContainer.env(celery_worker_cluster_config) == DEFAULT_WORKER_ENV

class test_disabling_cluster:
@pytest.fixture
def celery_backend_cluster(self) -> CeleryBackendCluster:
return None

@pytest.fixture
def celery_broker_cluster(self) -> CeleryBrokerCluster:
return None

def test_disabling_clusters(self, celery_worker_cluster_config: dict):
expected_env = DEFAULT_WORKER_ENV.copy()
expected_env.pop("CELERY_BROKER_URL")
expected_env.pop("CELERY_RESULT_BACKEND")
assert CeleryWorkerContainer.env(celery_worker_cluster_config) == expected_env

def test_initial_content_default_tasks(self):
from tests import tasks
Expand Down

0 comments on commit a04cabc

Please sign in to comment.