diff --git a/src/pytest_celery/vendors/worker/app.py b/src/pytest_celery/vendors/worker/app.py index 521fdf657..a3fa26242 100644 --- a/src/pytest_celery/vendors/worker/app.py +++ b/src/pytest_celery/vendors/worker/app.py @@ -1,3 +1,5 @@ +""" Template for Celery worker application. """ + import json import logging import sys @@ -5,19 +7,14 @@ from celery import Celery from celery.signals import after_setup_logger -config_updates = None -name = "celery_test_app" # Default name if not provided by the initial content - -# Will be populated accoring to the initial content -{0} -{1} -app = Celery(name) +imports = None -{2} +app = Celery("celery_test_app") +config = None -if config_updates: - app.config_from_object(config_updates) - print(f"Config updates from default_worker_app fixture: {json.dumps(config_updates, indent=4)}") +if config: + app.config_from_object(config) + print(f"Changed worker configuration: {json.dumps(config, indent=4)}") @after_setup_logger.connect diff --git a/src/pytest_celery/vendors/worker/container.py b/src/pytest_celery/vendors/worker/container.py index a2a7c410d..df41b68f9 100644 --- a/src/pytest_celery/vendors/worker/container.py +++ b/src/pytest_celery/vendors/worker/container.py @@ -1,9 +1,6 @@ from __future__ import annotations -import inspect - from celery import Celery -from celery.app.base import PendingConfiguration from pytest_celery.api.container import CeleryTestContainer from pytest_celery.vendors.worker.defaults import DEFAULT_WORKER_ENV @@ -11,6 +8,7 @@ from pytest_celery.vendors.worker.defaults import DEFAULT_WORKER_NAME from pytest_celery.vendors.worker.defaults import DEFAULT_WORKER_QUEUE from pytest_celery.vendors.worker.defaults import DEFAULT_WORKER_VERSION +from pytest_celery.vendors.worker.volume import WorkerInitialContent class CeleryWorkerContainer(CeleryTestContainer): @@ -76,96 +74,23 @@ def env(cls, celery_worker_cluster_config: dict, initial: dict | None = None) -> @classmethod def initial_content( cls, - worker_tasks: set, + worker_tasks: set | None = None, worker_signals: set | None = None, worker_app: Celery | None = None, ) -> dict: from pytest_celery.vendors.worker import app as app_module + from pytest_celery.vendors.worker import tasks as plugin_tasks - app_module_src = inspect.getsource(app_module) + worker_tasks = worker_tasks or set() + worker_tasks.add(plugin_tasks) - imports = dict() - initial_content = cls._initial_content_worker_tasks(worker_tasks) - imports["tasks_imports"] = initial_content.pop("tasks_imports") + content = WorkerInitialContent() + content.set_app_module(app_module) + content.add_modules("tasks", worker_tasks) if worker_signals: - initial_content.update(cls._initial_content_worker_signals(worker_signals)) - imports["signals_imports"] = initial_content.pop("signals_imports") + content.add_modules("signals", worker_signals) if worker_app: - # Accessing the worker_app.conf.changes.data property will trigger the PendingConfiguration to be resolved - # and the changes will be applied to the worker_app.conf, so we make a clone app to avoid affecting the - # original app object. - app = Celery(worker_app.main) - app.conf = worker_app.conf - config_changes_from_defaults = app.conf.changes.copy() - if isinstance(config_changes_from_defaults, PendingConfiguration): - config_changes_from_defaults = config_changes_from_defaults.data.changes - if not isinstance(config_changes_from_defaults, dict): - raise TypeError(f"Unexpected type for config_changes: {type(config_changes_from_defaults)}") - del config_changes_from_defaults["deprecated_settings"] - - name_code = f'name = "{worker_app.main}"' - else: - config_changes_from_defaults = {} - name_code = f'name = "{cls.worker_name()}"' - - imports_format = "{%s}" % "}{".join(imports.keys()) - imports_format = imports_format.format(**imports) - app_module_src = app_module_src.replace("{0}", imports_format) - - app_module_src = app_module_src.replace("{1}", name_code) - - config_items = (f" {repr(key)}: {repr(value)}" for key, value in config_changes_from_defaults.items()) - config_code = ( - "config_updates = {\n" + ",\n".join(config_items) + "\n}" - if config_changes_from_defaults - else "config_updates = {}" - ) - app_module_src = app_module_src.replace("{2}", config_code) - - initial_content["app.py"] = app_module_src.encode() - return initial_content - - @classmethod - def _initial_content_worker_tasks(cls, worker_tasks: set) -> dict: - from pytest_celery.vendors.worker import tasks - - worker_tasks.add(tasks) + content.set_app_name(worker_app.main) + content.set_config_from_object(worker_app) - import_string = "" - - for module in worker_tasks: - import_string += f"from {module.__name__} import *\n" - - initial_content = { - "__init__.py": b"", - "tasks_imports": import_string, - } - if worker_tasks: - default_worker_tasks_src = { - f"{module.__name__.replace('.', '/')}.py": inspect.getsource(module).encode() for module in worker_tasks - } - initial_content.update(default_worker_tasks_src) - else: - print("No tasks found") - return initial_content - - @classmethod - def _initial_content_worker_signals(cls, worker_signals: set) -> dict: - import_string = "" - - for module in worker_signals: - import_string += f"from {module.__name__} import *\n" - - initial_content = { - "__init__.py": b"", - "signals_imports": import_string, - } - if worker_signals: - default_worker_signals_src = { - f"{module.__name__.replace('.', '/')}.py": inspect.getsource(module).encode() - for module in worker_signals - } - initial_content.update(default_worker_signals_src) - else: - print("No signals found") - return initial_content + return content.generate() diff --git a/src/pytest_celery/vendors/worker/volume.py b/src/pytest_celery/vendors/worker/volume.py new file mode 100644 index 000000000..eb1448842 --- /dev/null +++ b/src/pytest_celery/vendors/worker/volume.py @@ -0,0 +1,98 @@ +from __future__ import annotations + +import inspect +from types import ModuleType +from typing import Any + +from celery import Celery +from celery.app.base import PendingConfiguration + +from pytest_celery.vendors.worker.defaults import DEFAULT_WORKER_APP_NAME + + +class WorkerInitialContent: + class Parser: + def imports_str(self, modules: set[ModuleType]) -> str: + return "".join(f"from {module.__name__} import *\n" for module in modules) + + def imports_src(self, modules: set[ModuleType]) -> dict: + src = dict() + for module in modules: + src[f"{module.__name__.replace('.', '/')}.py"] = inspect.getsource(module).encode() + return src + + def app_name(self, name: str = DEFAULT_WORKER_APP_NAME) -> str: + return f"app = Celery('{name}')" + + def config(self, app: Celery | None = None) -> str: + app = app or Celery(DEFAULT_WORKER_APP_NAME) + + # Accessing the app.conf.changes.data property will trigger the PendingConfiguration to be resolved + # and the changes will be applied to the app.conf, so we make a clone app to avoid affecting the + # original app object. + tmp_app = Celery(app.main) + tmp_app.conf = app.conf + + changes = tmp_app.conf.changes.copy() + if isinstance(changes, PendingConfiguration): + changes = changes.data.changes + if not isinstance(changes, dict): + raise TypeError(f"Unexpected type for app.conf.changes: {type(changes)}") + del changes["deprecated_settings"] + + if changes: + changes = (f"\t{repr(key)}: {repr(value)}" for key, value in changes.items()) + config = "config = {\n" + ",\n".join(changes) + "\n}" if changes else "config = None" + else: + config = "config = None" + return config + + def __init__(self) -> None: + self.parser = self.Parser() + self._initial_content = { + "__init__.py": b"", + "imports": dict(), + } + self.set_app_name() + self.set_config_from_object() + + def set_app_module(self, app_module: ModuleType) -> None: + self._app_module_src = inspect.getsource(app_module) + + def add_modules(self, name: str, modules: set[ModuleType]) -> None: + if not name: + raise ValueError("name cannot be empty") + + if not modules: + raise ValueError("modules cannot be empty") + + self._initial_content["imports"][name] = self.parser.imports_str(modules) # type: ignore + self._initial_content.update(self.parser.imports_src(modules)) + + def set_app_name(self, name: str = DEFAULT_WORKER_APP_NAME) -> None: + self._app = self.parser.app_name(name) + + def set_config_from_object(self, app: Celery | None = None) -> None: + self._config = self.parser.config(app) + + def generate(self) -> dict: + initial_content = self._initial_content.copy() + + _imports: dict | Any = initial_content.pop("imports") + imports = "{%s}" % "}{".join(_imports.keys()) + imports = imports.format(**_imports) + + app, config = self._app, self._config + + replacement_args = { + "imports": "imports = None", + "app": f'app = Celery("{DEFAULT_WORKER_APP_NAME}")', + "config": "config = None", + } + self._app_module_src = self._app_module_src.replace(replacement_args["imports"], imports) + self._app_module_src = self._app_module_src.replace(replacement_args["app"], app) + self._app_module_src = self._app_module_src.replace(replacement_args["config"], config) + + initial_content["app.py"] = self._app_module_src.encode() + + return initial_content