From 088de86f45d6bce58d4306a0aa80601744afff74 Mon Sep 17 00:00:00 2001 From: Tomer Nosrati Date: Tue, 28 Nov 2023 19:33:21 +0200 Subject: [PATCH] Allow disabling clusters by returning None when overriding fixtures: celery_broker_cluster, celery_backend_cluster --- src/pytest_celery/api/setup.py | 30 ++++++++++++----- src/pytest_celery/fixtures/backend.py | 6 ++-- src/pytest_celery/fixtures/broker.py | 6 ++-- src/pytest_celery/vendors/worker/container.py | 33 ++++++++++++------- tests/unit/api/test_backend.py | 9 +++++ tests/unit/api/test_broker.py | 9 +++++ tests/unit/vendors/test_worker.py | 24 ++++++++++++-- 7 files changed, 90 insertions(+), 27 deletions(-) diff --git a/src/pytest_celery/api/setup.py b/src/pytest_celery/api/setup.py index 855bae05e..8de969b16 100644 --- a/src/pytest_celery/api/setup.py +++ b/src/pytest_celery/api/setup.py @@ -66,12 +66,14 @@ def config(cls, celery_worker_cluster_config: dict) -> dict: if not celery_worker_cluster_config: raise ValueError("celery_worker_cluster_config is empty") - celery_broker_cluster_config: dict = celery_worker_cluster_config["celery_broker_cluster_config"] - celery_backend_cluster_config: dict = celery_worker_cluster_config["celery_backend_cluster_config"] - return { - "broker_url": ";".join(celery_broker_cluster_config["local_urls"]), - "result_backend": ";".join(celery_backend_cluster_config["local_urls"]), - } + celery_broker_cluster_config: dict = celery_worker_cluster_config.get("celery_broker_cluster_config", {}) + celery_backend_cluster_config: dict = celery_worker_cluster_config.get("celery_backend_cluster_config", {}) + config = {} + if celery_broker_cluster_config: + config["broker_url"] = ";".join(celery_broker_cluster_config["local_urls"]) + if celery_backend_cluster_config: + config["result_backend"] = ";".join(celery_backend_cluster_config["local_urls"]) + return config @classmethod def update_app_config(cls, app: Celery) -> None: @@ -110,8 +112,12 @@ def ready(self, ping: bool = False, control: bool = False, docker: bool = True) ready = True if docker and ready: - ready = all([self.broker_cluster.ready(), self.backend_cluster.ready()]) - ready = ready and self.worker_cluster.ready() + if self.broker_cluster: + ready = ready and self.broker_cluster.ready() + if self.backend_cluster: + ready = ready and self.backend_cluster.ready() + if self.worker_cluster: + ready = ready and self.worker_cluster.ready() if control and ready: r = self.app.control.ping() @@ -125,7 +131,13 @@ def ready(self, ping: bool = False, control: bool = False, docker: bool = True) ready = ready and res.get(timeout=RESULT_TIMEOUT) == "pong" # Set app for all nodes - nodes = self.broker_cluster.nodes + self.backend_cluster.nodes + nodes: tuple = tuple() + if self.broker_cluster: + nodes += self.broker_cluster.nodes + if self.backend_cluster: + nodes += self.backend_cluster.nodes + if self.worker_cluster: + nodes += self.worker_cluster.nodes for node in nodes: node._app = self.app diff --git a/src/pytest_celery/fixtures/backend.py b/src/pytest_celery/fixtures/backend.py index d6984b916..72ffc542f 100644 --- a/src/pytest_celery/fixtures/backend.py +++ b/src/pytest_celery/fixtures/backend.py @@ -1,5 +1,7 @@ # mypy: disable-error-code="misc" +from __future__ import annotations + import pytest from pytest_celery.api.backend import CeleryBackendCluster @@ -23,10 +25,10 @@ def celery_backend_cluster(celery_backend: CeleryTestBackend) -> CeleryBackendCl @pytest.fixture -def celery_backend_cluster_config(request: pytest.FixtureRequest) -> dict: +def celery_backend_cluster_config(request: pytest.FixtureRequest) -> dict | None: try: use_default_config = pytest.fail.Exception cluster: CeleryBackendCluster = request.getfixturevalue(CELERY_BACKEND_CLUSTER) - return cluster.config() + return cluster.config() if cluster else None except use_default_config: return CeleryBackendCluster.default_config() diff --git a/src/pytest_celery/fixtures/broker.py b/src/pytest_celery/fixtures/broker.py index 86710abf0..5ff036df4 100644 --- a/src/pytest_celery/fixtures/broker.py +++ b/src/pytest_celery/fixtures/broker.py @@ -1,5 +1,7 @@ # mypy: disable-error-code="misc" +from __future__ import annotations + import pytest from pytest_celery.api.broker import CeleryBrokerCluster @@ -23,10 +25,10 @@ def celery_broker_cluster(celery_broker: CeleryTestBroker) -> CeleryBrokerCluste @pytest.fixture -def celery_broker_cluster_config(request: pytest.FixtureRequest) -> dict: +def celery_broker_cluster_config(request: pytest.FixtureRequest) -> dict | None: try: use_default_config = pytest.fail.Exception cluster: CeleryBrokerCluster = request.getfixturevalue(CELERY_BROKER_CLUSTER) - return cluster.config() + return cluster.config() if cluster else None except use_default_config: return CeleryBrokerCluster.default_config() diff --git a/src/pytest_celery/vendors/worker/container.py b/src/pytest_celery/vendors/worker/container.py index b15745ae6..a2a7c410d 100644 --- a/src/pytest_celery/vendors/worker/container.py +++ b/src/pytest_celery/vendors/worker/container.py @@ -1,5 +1,6 @@ +from __future__ import annotations + import inspect -from typing import Union from celery import Celery from celery.app.base import PendingConfiguration @@ -54,22 +55,30 @@ def buildargs(cls) -> dict: } @classmethod - def env(cls, celery_worker_cluster_config: dict) -> dict: - celery_broker_cluster_config = celery_worker_cluster_config.get("celery_broker_cluster_config") - celery_backend_cluster_config = celery_worker_cluster_config.get("celery_backend_cluster_config") - env = {} - if celery_broker_cluster_config: - env["CELERY_BROKER_URL"] = ";".join(celery_broker_cluster_config["urls"]) - if celery_backend_cluster_config: - env["CELERY_RESULT_BACKEND"] = ";".join(celery_backend_cluster_config["urls"]) - return {**DEFAULT_WORKER_ENV, **env} + def env(cls, celery_worker_cluster_config: dict, initial: dict | None = None) -> dict: + env = initial or {} + env = {**env, **DEFAULT_WORKER_ENV.copy()} + + config_mappings = [ + ("celery_broker_cluster_config", "CELERY_BROKER_URL"), + ("celery_backend_cluster_config", "CELERY_RESULT_BACKEND"), + ] + + for config_key, env_key in config_mappings: + cluster_config = celery_worker_cluster_config.get(config_key) + if cluster_config: + env[env_key] = ";".join(cluster_config["urls"]) + else: + del env[env_key] + + return env @classmethod def initial_content( cls, worker_tasks: set, - worker_signals: Union[set, None] = None, - worker_app: Union[Celery, None] = None, + worker_signals: set | None = None, + worker_app: Celery | None = None, ) -> dict: from pytest_celery.vendors.worker import app as app_module diff --git a/tests/unit/api/test_backend.py b/tests/unit/api/test_backend.py index ece24d55f..4d85f7f9c 100644 --- a/tests/unit/api/test_backend.py +++ b/tests/unit/api/test_backend.py @@ -31,3 +31,12 @@ class test_celery_backend_cluster: def test_default_config_format(self, cluster: CeleryBackendCluster): assert cluster.default_config()["urls"] == [DEFAULT_WORKER_ENV["CELERY_RESULT_BACKEND"]] assert cluster.default_config()["local_urls"] == [DEFAULT_WORKER_ENV["CELERY_RESULT_BACKEND"]] + + class test_disabling_cluster: + @pytest.fixture + def celery_backend_cluster(self) -> CeleryBackendCluster: + return None + + def test_disabling_backend_cluster(self, cluster: CeleryBackendCluster, celery_backend_cluster_config: dict): + assert cluster is None + assert celery_backend_cluster_config is None diff --git a/tests/unit/api/test_broker.py b/tests/unit/api/test_broker.py index c6110f204..7c855a6fe 100644 --- a/tests/unit/api/test_broker.py +++ b/tests/unit/api/test_broker.py @@ -31,3 +31,12 @@ class test_celery_broker_cluster: def test_default_config_format(self, cluster: CeleryBrokerCluster): assert cluster.default_config()["urls"] == [DEFAULT_WORKER_ENV["CELERY_BROKER_URL"]] assert cluster.default_config()["local_urls"] == [DEFAULT_WORKER_ENV["CELERY_BROKER_URL"]] + + class test_disabling_cluster: + @pytest.fixture + def celery_broker_cluster(self) -> CeleryBrokerCluster: + return None + + def test_disabling_broker_cluster(self, cluster: CeleryBrokerCluster, celery_broker_cluster_config: dict): + assert cluster is None + assert celery_broker_cluster_config is None diff --git a/tests/unit/vendors/test_worker.py b/tests/unit/vendors/test_worker.py index 34263ed07..eb4ca819d 100644 --- a/tests/unit/vendors/test_worker.py +++ b/tests/unit/vendors/test_worker.py @@ -1,10 +1,14 @@ import inspect +import pytest + from pytest_celery import DEFAULT_WORKER_ENV from pytest_celery import DEFAULT_WORKER_LOG_LEVEL from pytest_celery import DEFAULT_WORKER_NAME from pytest_celery import DEFAULT_WORKER_QUEUE from pytest_celery import DEFAULT_WORKER_VERSION +from pytest_celery import CeleryBackendCluster +from pytest_celery import CeleryBrokerCluster from pytest_celery import CeleryWorkerContainer @@ -35,8 +39,24 @@ def test_buildargs(self): "CELERY_WORKER_QUEUE": DEFAULT_WORKER_QUEUE, } - def test_env(self, celery_worker_cluster_config: dict): - assert CeleryWorkerContainer.env(celery_worker_cluster_config) == DEFAULT_WORKER_ENV + class test_celery_worker_container_env: + def test_env(self, celery_worker_cluster_config: dict): + assert CeleryWorkerContainer.env(celery_worker_cluster_config) == DEFAULT_WORKER_ENV + + class test_disabling_cluster: + @pytest.fixture + def celery_backend_cluster(self) -> CeleryBackendCluster: + return None + + @pytest.fixture + def celery_broker_cluster(self) -> CeleryBrokerCluster: + return None + + def test_disabling_clusters(self, celery_worker_cluster_config: dict): + expected_env = DEFAULT_WORKER_ENV.copy() + expected_env.pop("CELERY_BROKER_URL") + expected_env.pop("CELERY_RESULT_BACKEND") + assert CeleryWorkerContainer.env(celery_worker_cluster_config) == expected_env def test_initial_content_default_tasks(self): from tests import tasks