diff --git a/iam_groups_authn/sync.py b/iam_groups_authn/sync.py index 77c697d..43b6b5b 100644 --- a/iam_groups_authn/sync.py +++ b/iam_groups_authn/sync.py @@ -28,7 +28,7 @@ InstanceConnectionName, ) from iam_groups_authn.iam_admin import get_iam_users -from iam_groups_authn.utils import DatabaseVersion +from iam_groups_authn.utils import DatabaseVersion, strip_minor_version from iam_groups_authn.mysql import ( init_mysql_connection_engine, MysqlRoleService, @@ -74,7 +74,6 @@ async def groups_sync( async with ClientSession( headers={"Content-Type": "application/json"} ) as client_session: - # create UserService object for API calls user_service = UserService(client_session, credentials) @@ -369,6 +368,8 @@ async def get_database_version(self, instance_connection_name): logging.debug( f"[{project}:{region}:{instance}] Database version found: {database_version}" ) + # if major version is supported, we support minor version + database_version = strip_minor_version(database_version) return DatabaseVersion(database_version) except ValueError as e: raise ValueError( diff --git a/iam_groups_authn/utils.py b/iam_groups_authn/utils.py index e9ba606..1478002 100644 --- a/iam_groups_authn/utils.py +++ b/iam_groups_authn/utils.py @@ -87,3 +87,13 @@ def grant_group_role(self, role, users): @abstractmethod def revoke_group_role(self, role, users): pass + + +def strip_minor_version(database_version: str) -> str: + """ + Helper method for stripping minor version suffix from database version. + """ + for version in DatabaseVersion.__members__.keys(): + if database_version.startswith(version): + database_version = version + return database_version diff --git a/tests/unit/test_database_versions.py b/tests/unit/test_database_versions.py new file mode 100644 index 0000000..02c7403 --- /dev/null +++ b/tests/unit/test_database_versions.py @@ -0,0 +1,38 @@ +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest + +from iam_groups_authn.utils import strip_minor_version + +test_data = [ + ("MYSQL_8_0", "MYSQL_8_0"), + ("MYSQL_8_0_26", "MYSQL_8_0"), + ("MYSQL_8_0_35", "MYSQL_8_0"), + ("POSTGRES_15", "POSTGRES_15"), + ("POSTGRES_14", "POSTGRES_14"), + ("POSTGRES_13", "POSTGRES_13"), + ("POSTGRES_12", "POSTGRES_12"), + ("POSTGRES_11", "POSTGRES_11"), + ("POSTGRES_10", "POSTGRES_10"), + ("POSTGRES_9_6", "POSTGRES_9_6"), +] + +@pytest.mark.parametrize("database_version,expected", test_data) +def test_strip_minor_version(database_version, expected): + """ + Test that strip_minor_version() works correctly. + """ + database_version = strip_minor_version(database_version) + assert database_version == expected