Skip to content

Commit

Permalink
Refactored worker intial content code (#98)
Browse files Browse the repository at this point in the history
* Fixed paths in .github/workflows/docker.yml

* Refactored worker intial content code

* Added unit tests

* Added new fixture: default_worker_app_module()
  • Loading branch information
Nusnus authored Dec 7, 2023
1 parent 3593a02 commit f3e9bb3
Show file tree
Hide file tree
Showing 10 changed files with 383 additions and 107 deletions.
4 changes: 2 additions & 2 deletions .github/workflows/docker.yml
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,13 @@ on:
pull_request:
branches: [ 'main']
paths:
- '/src/pytest_celery/vendors/worker/**'
- 'src/pytest_celery/vendors/worker/**'
- '.github/workflows/docker.yml'
- 'Dockerfile'
push:
branches: [ 'main']
paths:
- '/src/pytest_celery/vendors/worker/**'
- 'src/pytest_celery/vendors/worker/**'
- '.github/workflows/docker.yml'
- 'Dockerfile'

Expand Down
19 changes: 8 additions & 11 deletions src/pytest_celery/vendors/worker/app.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,20 @@
""" Template for Celery worker application. """

import json
import logging
import sys

from celery import Celery
from celery.signals import after_setup_logger

config_updates = None
name = "celery_test_app" # Default name if not provided by the initial content

# Will be populated accoring to the initial content
{0}
{1}
app = Celery(name)
imports = None

{2}
app = Celery("celery_test_app")
config = None

if config_updates:
app.config_from_object(config_updates)
print(f"Config updates from default_worker_app fixture: {json.dumps(config_updates, indent=4)}")
if config:
app.config_from_object(config)
print(f"Changed worker configuration: {json.dumps(config, indent=4)}")


@after_setup_logger.connect
Expand Down
112 changes: 24 additions & 88 deletions src/pytest_celery/vendors/worker/container.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,16 @@
from __future__ import annotations

import inspect
from types import ModuleType

from celery import Celery
from celery.app.base import PendingConfiguration

from pytest_celery.api.container import CeleryTestContainer
from pytest_celery.vendors.worker.defaults import DEFAULT_WORKER_ENV
from pytest_celery.vendors.worker.defaults import DEFAULT_WORKER_LOG_LEVEL
from pytest_celery.vendors.worker.defaults import DEFAULT_WORKER_NAME
from pytest_celery.vendors.worker.defaults import DEFAULT_WORKER_QUEUE
from pytest_celery.vendors.worker.defaults import DEFAULT_WORKER_VERSION
from pytest_celery.vendors.worker.volume import WorkerInitialContent


class CeleryWorkerContainer(CeleryTestContainer):
Expand All @@ -37,9 +37,17 @@ def worker_name(cls) -> str:
def worker_queue(cls) -> str:
return DEFAULT_WORKER_QUEUE

@classmethod
def app_module(cls) -> ModuleType:
from pytest_celery.vendors.worker import app

return app

@classmethod
def tasks_modules(cls) -> set:
return set()
from pytest_celery.vendors.worker import tasks

return {tasks}

@classmethod
def signals_modules(cls) -> set:
Expand Down Expand Up @@ -76,96 +84,24 @@ def env(cls, celery_worker_cluster_config: dict, initial: dict | None = None) ->
@classmethod
def initial_content(
cls,
worker_tasks: set,
worker_tasks: set | None = None,
worker_signals: set | None = None,
worker_app: Celery | None = None,
app_module: ModuleType | None = None,
) -> dict:
from pytest_celery.vendors.worker import app as app_module
if app_module is None:
app_module = cls.app_module()

app_module_src = inspect.getsource(app_module)
if worker_tasks is None:
worker_tasks = cls.tasks_modules()

imports = dict()
initial_content = cls._initial_content_worker_tasks(worker_tasks)
imports["tasks_imports"] = initial_content.pop("tasks_imports")
content = WorkerInitialContent()
content.set_app_module(app_module)
content.add_modules("tasks", worker_tasks)
if worker_signals:
initial_content.update(cls._initial_content_worker_signals(worker_signals))
imports["signals_imports"] = initial_content.pop("signals_imports")
content.add_modules("signals", worker_signals)
if worker_app:
# Accessing the worker_app.conf.changes.data property will trigger the PendingConfiguration to be resolved
# and the changes will be applied to the worker_app.conf, so we make a clone app to avoid affecting the
# original app object.
app = Celery(worker_app.main)
app.conf = worker_app.conf
config_changes_from_defaults = app.conf.changes.copy()
if isinstance(config_changes_from_defaults, PendingConfiguration):
config_changes_from_defaults = config_changes_from_defaults.data.changes
if not isinstance(config_changes_from_defaults, dict):
raise TypeError(f"Unexpected type for config_changes: {type(config_changes_from_defaults)}")
del config_changes_from_defaults["deprecated_settings"]

name_code = f'name = "{worker_app.main}"'
else:
config_changes_from_defaults = {}
name_code = f'name = "{cls.worker_name()}"'

imports_format = "{%s}" % "}{".join(imports.keys())
imports_format = imports_format.format(**imports)
app_module_src = app_module_src.replace("{0}", imports_format)

app_module_src = app_module_src.replace("{1}", name_code)

config_items = (f" {repr(key)}: {repr(value)}" for key, value in config_changes_from_defaults.items())
config_code = (
"config_updates = {\n" + ",\n".join(config_items) + "\n}"
if config_changes_from_defaults
else "config_updates = {}"
)
app_module_src = app_module_src.replace("{2}", config_code)

initial_content["app.py"] = app_module_src.encode()
return initial_content

@classmethod
def _initial_content_worker_tasks(cls, worker_tasks: set) -> dict:
from pytest_celery.vendors.worker import tasks

worker_tasks.add(tasks)
content.set_app_name(worker_app.main)
content.set_config_from_object(worker_app)

import_string = ""

for module in worker_tasks:
import_string += f"from {module.__name__} import *\n"

initial_content = {
"__init__.py": b"",
"tasks_imports": import_string,
}
if worker_tasks:
default_worker_tasks_src = {
f"{module.__name__.replace('.', '/')}.py": inspect.getsource(module).encode() for module in worker_tasks
}
initial_content.update(default_worker_tasks_src)
else:
print("No tasks found")
return initial_content

@classmethod
def _initial_content_worker_signals(cls, worker_signals: set) -> dict:
import_string = ""

for module in worker_signals:
import_string += f"from {module.__name__} import *\n"

initial_content = {
"__init__.py": b"",
"signals_imports": import_string,
}
if worker_signals:
default_worker_signals_src = {
f"{module.__name__.replace('.', '/')}.py": inspect.getsource(module).encode()
for module in worker_signals
}
initial_content.update(default_worker_signals_src)
else:
print("No signals found")
return initial_content
return content.generate()
9 changes: 9 additions & 0 deletions src/pytest_celery/vendors/worker/fixtures.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@

from __future__ import annotations

from types import ModuleType

import pytest
from celery import Celery
from pytest_docker_tools import build
Expand Down Expand Up @@ -101,17 +103,24 @@ def default_worker_env(
@pytest.fixture
def default_worker_initial_content(
default_worker_container_cls: type[CeleryWorkerContainer],
default_worker_app_module: ModuleType,
default_worker_tasks: set,
default_worker_signals: set,
default_worker_app: Celery,
) -> dict:
yield default_worker_container_cls.initial_content(
app_module=default_worker_app_module,
worker_tasks=default_worker_tasks,
worker_signals=default_worker_signals,
worker_app=default_worker_app,
)


@pytest.fixture
def default_worker_app_module(default_worker_container_cls: type[CeleryWorkerContainer]) -> ModuleType:
yield default_worker_container_cls.app_module()


@pytest.fixture
def default_worker_tasks(default_worker_container_cls: type[CeleryWorkerContainer]) -> set:
yield default_worker_container_cls.tasks_modules()
Expand Down
127 changes: 127 additions & 0 deletions src/pytest_celery/vendors/worker/volume.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,127 @@
from __future__ import annotations

import inspect
from types import ModuleType
from typing import Any

from celery import Celery
from celery.app.base import PendingConfiguration

from pytest_celery.vendors.worker.defaults import DEFAULT_WORKER_APP_NAME


class WorkerInitialContent:
class Parser:
def imports_str(self, modules: set[ModuleType]) -> str:
return "".join(f"from {module.__name__} import *\n" for module in modules)

def imports_src(self, modules: set[ModuleType]) -> dict:
src = dict()
for module in modules:
src[f"{module.__name__.replace('.', '/')}.py"] = inspect.getsource(module).encode()
return src

def app_name(self, name: str | None = None) -> str:
name = name or DEFAULT_WORKER_APP_NAME
return f"app = Celery('{name}')"

def config(self, app: Celery | None = None) -> str:
app = app or Celery(DEFAULT_WORKER_APP_NAME)

# Accessing the app.conf.changes.data property will trigger the PendingConfiguration to be resolved
# and the changes will be applied to the app.conf, so we make a clone app to avoid affecting the
# original app object.
tmp_app = Celery(app.main)
tmp_app.conf = app.conf

changes = tmp_app.conf.changes.copy()
if isinstance(changes, PendingConfiguration):
changes = changes.data.changes
if not isinstance(changes, dict):
raise TypeError(f"Unexpected type for app.conf.changes: {type(changes)}")
del changes["deprecated_settings"]

if changes:
changes = (f"\t{repr(key)}: {repr(value)}" for key, value in changes.items())
config = "config = {\n" + ",\n".join(changes) + "\n}" if changes else "config = None"
else:
config = "config = None"
return config

def __init__(self, app_module: ModuleType | None = None) -> None:
self.parser = self.Parser()
self._initial_content = {
"__init__.py": b"",
"imports": dict(),
}
self.set_app_module(app_module)
self.set_app_name()
self.set_config_from_object()

def __eq__(self, __value: object) -> bool:
if not isinstance(__value, WorkerInitialContent):
return False
try:
return self.generate() == __value.generate()
except ValueError:
return all(
[
self._app_module_src == __value._app_module_src,
self._initial_content == __value._initial_content,
self._app == __value._app,
self._config == __value._config,
]
)

def set_app_module(self, app_module: ModuleType | None = None) -> None:
self._app_module_src: str | None

if app_module:
self._app_module_src = inspect.getsource(app_module)
else:
self._app_module_src = None

def add_modules(self, name: str, modules: set[ModuleType]) -> None:
if not name:
raise ValueError("name cannot be empty")

if not modules:
raise ValueError("modules cannot be empty")

self._initial_content["imports"][name] = self.parser.imports_str(modules) # type: ignore
self._initial_content.update(self.parser.imports_src(modules))

def set_app_name(self, name: str | None = None) -> None:
name = name or DEFAULT_WORKER_APP_NAME
self._app = self.parser.app_name(name)

def set_config_from_object(self, app: Celery | None = None) -> None:
self._config = self.parser.config(app)

def generate(self) -> dict:
if not self._app_module_src:
raise ValueError("Please set_app_module() before calling generate()")

initial_content = self._initial_content.copy()

if not initial_content["imports"]:
raise ValueError("Please add_modules() before calling generate()")

_imports: dict | Any = initial_content.pop("imports")
imports = "{%s}" % "}{".join(_imports.keys())
imports = imports.format(**_imports)

app, config = self._app, self._config

replacement_args = {
"imports": "imports = None",
"app": f'app = Celery("{DEFAULT_WORKER_APP_NAME}")',
"config": "config = None",
}
self._app_module_src = self._app_module_src.replace(replacement_args["imports"], imports)
self._app_module_src = self._app_module_src.replace(replacement_args["app"], app)
self._app_module_src = self._app_module_src.replace(replacement_args["config"], config)

initial_content["app.py"] = self._app_module_src.encode()

return initial_content
5 changes: 3 additions & 2 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@


@pytest.fixture
def default_worker_tasks() -> set:
def default_worker_tasks(default_worker_tasks: set) -> set:
from tests import tasks

yield {tasks}
default_worker_tasks.add(tasks)
yield default_worker_tasks
15 changes: 15 additions & 0 deletions tests/integration/vendors/test_worker.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from __future__ import annotations

from types import ModuleType

import pytest
from pytest_lazyfixture import lazy_fixture

Expand Down Expand Up @@ -32,6 +34,19 @@ def test_disabling_backend_cluster(self, container: CeleryWorkerContainer):
results = DEFAULT_WORKER_ENV["CELERY_BROKER_URL"]
assert container.logs().count(f"transport: {results}")

class test_replacing_app_module:
@pytest.fixture(params=["Default", "Custom"])
def default_worker_app_module(self, request: pytest.FixtureRequest) -> ModuleType:
if request.param == "Default":
yield request.getfixturevalue("default_worker_app_module")
else:
from pytest_celery.vendors.worker import app

yield app

def test_replacing_app_module(self, container: CeleryWorkerContainer, default_worker_app_module: ModuleType):
assert container.app_module() == default_worker_app_module


@pytest.mark.parametrize("worker", [lazy_fixture(CELERY_SETUP_WORKER)])
class test_base_test_worker:
Expand Down
Empty file.
Loading

0 comments on commit f3e9bb3

Please sign in to comment.