Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Move PG version check to awx-manage check_db & migrate commands #15463

Open
wants to merge 3 commits into
base: devel
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 0 additions & 16 deletions awx/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,14 +53,6 @@ def version_file():
MODE = 'production'


try:
import django # noqa: F401
except ImportError:
pass
else:
from django.db import connection


def oauth2_getattribute(self, attr):
# Custom method to override
# oauth2_provider.settings.OAuth2ProviderSettings.__getattribute__
Expand Down Expand Up @@ -104,14 +96,6 @@ def manage():
from django.conf import settings
from django.core.management import execute_from_command_line

# enforce the postgres version is a minimum of 12 (we need this for partitioning); if not, then terminate program with exit code of 1
# In the future if we require a feature of a version of postgres > 12 this should be updated to reflect that.
# The return of connection.pg_version is something like 12013
if not os.getenv('SKIP_PG_VERSION_CHECK', False) and not MODE == 'development':
if (connection.pg_version // 10000) < 12:
sys.stderr.write("At a minimum, postgres version 12 is required\n")
sys.exit(1)

if len(sys.argv) >= 2 and sys.argv[1] in ('version', '--version'): # pragma: no cover
sys.stdout.write('%s\n' % __version__)
# If running as a user without permission to read settings, display an
Expand Down
10 changes: 10 additions & 0 deletions awx/main/apps.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,22 @@
from django.apps import AppConfig
from django.utils.translation import gettext_lazy as _
from django.core.management.base import CommandError
from django.db.models.signals import pre_migrate

from awx.main.utils.named_url_graph import _customize_graph, generate_graph
from awx.main.utils.db import db_requirement_violations
from awx.conf import register, fields


class MainConfig(AppConfig):
name = 'awx.main'
verbose_name = _('Main')

def check_db_requirement(self, *args, **kwargs):
violations = db_requirement_violations()
if violations:
raise CommandError(violations)

def load_named_url_feature(self):
models = [m for m in self.get_models() if hasattr(m, 'get_absolute_url')]
generate_graph(models)
Expand Down Expand Up @@ -38,3 +47,4 @@ def ready(self):
super().ready()

self.load_named_url_feature()
pre_migrate.connect(self.check_db_requirement, sender=self)
8 changes: 7 additions & 1 deletion awx/main/management/commands/check_db.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
# Copyright (c) 2015 Ansible, Inc.
# All Rights Reserved

from django.core.management.base import BaseCommand
from django.core.management.base import BaseCommand, CommandError
from django.db import connection

from awx.main.utils.db import db_requirement_violations


class Command(BaseCommand):
"""Checks connection to the database, and prints out connection info if not connected"""
Expand All @@ -13,4 +15,8 @@ def handle(self, *args, **options):
cursor.execute("SELECT version()")
version = str(cursor.fetchone()[0])

violations = db_requirement_violations()
if violations:
raise CommandError(violations)

return "Database Version: {}".format(version)
24 changes: 24 additions & 0 deletions awx/main/utils/db.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,34 @@
# Copyright (c) 2017 Ansible by Red Hat
# All Rights Reserved.

from typing import Optional

from awx.settings.application_name import set_application_name
from awx import MODE

from django.conf import settings
from django.db import connection


def set_connection_name(function):
set_application_name(settings.DATABASES, settings.CLUSTER_HOST_ID, function=function)


MIN_PG_VERSION = 12


def db_requirement_violations() -> Optional[str]:
if connection.vendor == 'postgresql':

# enforce the postgres version is a minimum of 12 (we need this for partitioning); if not, then terminate program with exit code of 1
# In the future if we require a feature of a version of postgres > 12 this should be updated to reflect that.
# The return of connection.pg_version is something like 12013
major_version = connection.pg_version // 10000
if major_version < MIN_PG_VERSION:
return f"At a minimum, postgres version {MIN_PG_VERSION} is required, found {major_version}\n"

return None
else:
if MODE == 'production':
return f"Running server with '{connection.vendor}' type database is not supported\n"
return None
Loading