diff --git a/docs/jwk2pem.py b/docs/jwk2pem.py index 4319a64..358cd62 100644 --- a/docs/jwk2pem.py +++ b/docs/jwk2pem.py @@ -6,14 +6,29 @@ # details. # """A helper code to convert JWK (retrieved from oidc) to PEM format.""" -import jwcrypto.jwk -key_dict = { - "e": "AQAB", - "kty": "RSA", - "n": "mho5h_lz6USUUazQaVT3PHloIk_Ljs2vZl_RAaitkXDx6aqpl1kGpS44eYJOaer4oWc6_QNaMtynvlSlnkuWrG765adNKT9sgAWSrPb81xkojsQabrSNv4nIOWUQi0Tjh0WxXQmbV-bMxkVaElhdHNFzUfHv-XqI8Hkc82mIGtyeMQn-VAuZbYkVXnjyCwwa9RmPOSH-O4N4epDXKk1VK9dUxf_rEYbjMNZGDva30do0mrBkU8W3O1mDVJSSgHn4ejKdGNYMm0JKPAgCWyPWJDoL092ctPCFlUMBBZ_OP3omvgnw0GaWZXxqSqaSvxFJkqCHqLMwpxmWTTAgEvAbnw", -} +import sys -key = jwcrypto.jwk.JWK(**key_dict) -pem = key.export_to_pem(False, False) -print(pem) + +def export_key() -> None: + """Export the key from JWK to PEM format.""" + try: + import jwcrypto.jwk # noqa + except ImportError: + print("Please install jwcrypto: pip install jwcrypto") + sys.exit(1) + + # this key was downloaded from perun + key_dict = { + "e": "AQAB", + "kty": "RSA", + "n": "mho5h_lz6USUUazQaVT3PHloIk_Ljs2vZl_RAaitkXDx6aqpl1kGpS44eYJOaer4oWc6_QNaMtynvlSlnkuWrG765adNKT9sgAWSrPb81xkojsQabrSNv4nIOWUQi0Tjh0WxXQmbV-bMxkVaElhdHNFzUfHv-XqI8Hkc82mIGtyeMQn-VAuZbYkVXnjyCwwa9RmPOSH-O4N4epDXKk1VK9dUxf_rEYbjMNZGDva30do0mrBkU8W3O1mDVJSSgHn4ejKdGNYMm0JKPAgCWyPWJDoL092ctPCFlUMBBZ_OP3omvgnw0GaWZXxqSqaSvxFJkqCHqLMwpxmWTTAgEvAbnw", + } + + key = jwcrypto.jwk.JWK(**key_dict) + pem = key.export_to_pem(False, False) + print(pem) + + +if __name__ == "__main__": + export_key() diff --git a/format.sh b/format.sh index f2c014f..9b591f7 100755 --- a/format.sh +++ b/format.sh @@ -1,4 +1,19 @@ #!/bin/bash -"$(dirname $0)/python_format.sh" $(( git status --short| grep '^?' | cut -d\ -f2- && git ls-files ) | egrep ".*[.]py" | sort -u ) -`dirname $0`/python-packages/bin/python -m licenseheaders -t .copyright.tmpl -cy -f $(( git status --short| grep '^?' | cut -d\ -f2- && git ls-files ) | egrep ".*[.]py" | sort -u ) +source .venv/bin/activate + +python_files=$( + ( git status --short| grep '^?' | cut -d\ -f2- && git ls-files ) | egrep ".*[.]py" | sort -u +) + +python_files_without_tests=$( + ( git status --short| grep '^?' | cut -d\ -f2- && git ls-files ) | egrep ".*[.]py" | egrep -v "^tests/" | sort -u +) +top_level_package=$(echo $python_files_without_tests | tr ' ' '\n' | grep '/' | cut -d/ -f1 | sort -u) + +# python must not be in directories containing ' ', so no quotes here or inside the variable +ruff format -- $python_files +ruff check --fix $python_files_without_tests +python -m licenseheaders -t .copyright.tmpl -cy -f $python_files# + +mypy --enable-incomplete-feature=NewGenericSyntax $top_level_package diff --git a/oarepo_oidc_einfra/__init__.py b/oarepo_oidc_einfra/__init__.py index 6165cab..fe466d1 100644 --- a/oarepo_oidc_einfra/__init__.py +++ b/oarepo_oidc_einfra/__init__.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- # # Copyright (C) 2024 CESNET z.s.p.o. # @@ -7,7 +6,7 @@ # details. # -"""E-INFRA OIDC Auth backend for OARepo""" +"""E-INFRA OIDC Auth backend for OARepo.""" from .remote import EINFRA_LOGIN_APP diff --git a/oarepo_oidc_einfra/cli.py b/oarepo_oidc_einfra/cli.py index f002d62..8aa8aac 100644 --- a/oarepo_oidc_einfra/cli.py +++ b/oarepo_oidc_einfra/cli.py @@ -5,9 +5,12 @@ # modify it under the terms of the MIT License; see LICENSE file for more # details. # +"""EInfra terminal commands.""" + import json from datetime import UTC, datetime from io import BytesIO +from typing import TYPE_CHECKING import boto3 import click @@ -21,20 +24,22 @@ from oarepo_oidc_einfra.perun.dump import import_dump_file from oarepo_oidc_einfra.tasks import update_from_perun_dump +if TYPE_CHECKING: + from flask_security.datastore import UserDatastore + @click.group() -def einfra(): +def einfra() -> None: """EInfra commands.""" @einfra.command("import_dump") @click.argument("dump_file") @with_appcontext -def import_dump(dump_file): - """ - Import a dump file. +def import_dump(dump_file: str) -> None: + """Import a dump file. - :param dump_file: Path to the dump file to import. + :param dump_file: Path to the dump file on the local filesystem to import. """ click.echo(f"Importing dump file {dump_file}") @@ -49,8 +54,15 @@ def import_dump(dump_file): @click.option("--on-background/--on-foreground", default=False) @click.option("--fix-communities-in-perun/--no-fix-communities-in-perun", default=True) @with_appcontext -def update_from_dump(dump_name, on_background, fix_communities_in_perun): - """Update the data from the last imported dump.""" +def update_from_dump( + dump_name: str, on_background: bool, fix_communities_in_perun: bool +) -> None: + """Update the data from the last imported dump. + + :param dump_name: Name of the dump to update from. + :param on_background: Whether to run the task in the background. + :param fix_communities_in_perun: Whether to fix communities in Perun. + """ if on_background: update_from_perun_dump.delay( dump_name, fix_communities_in_perun=fix_communities_in_perun @@ -65,18 +77,23 @@ def update_from_dump(dump_name, on_background, fix_communities_in_perun): @click.argument("email") @click.argument("einfra_id") @with_appcontext -def add_einfra_user(email, einfra_id): +def add_einfra_user(email: str, einfra_id: str) -> None: + """Add a user to the system if it does not exist and link it with the EInfra identity.""" _add_einfra_user(email, einfra_id) @einfra.command("clear_import_mutex") @with_appcontext -def clear_import_mutex(): +def clear_import_mutex() -> None: + """Clear the import mutex - should be used only as a last resort.""" CacheMutex("EINFRA_SYNC_MUTEX").force_clear() -def _add_einfra_user(email, einfra_id): - _datastore = LocalProxy(lambda: current_app.extensions["security"].datastore) +def _add_einfra_user(email: str, einfra_id: str) -> None: + """Add a user to the system if it does not exist and link it with the EInfra identity.""" + _datastore: UserDatastore = LocalProxy( + lambda: current_app.extensions["security"].datastore + ) # noqa email = email.lower() user = User.query.filter_by(email=email).first() @@ -87,7 +104,7 @@ def _add_einfra_user(email, einfra_id): "active": True, "confirmed_at": datetime.now(UTC), } - created = _datastore.create_user(**kwargs) + _datastore.create_user(**kwargs) db.session.commit() user = User.query.filter_by(email=email).first() @@ -108,7 +125,11 @@ def _add_einfra_user(email, einfra_id): @einfra.command("import_dump_users") @click.argument("dump_path") @with_appcontext -def import_dump_users(dump_path): +def import_dump_users(dump_path: str) -> None: + """Import users from a dump file. + + :param dump_path: Path to the dump file in the S3 bucket. + """ client = boto3.client( "s3", aws_access_key_id=current_app.config["EINFRA_USER_DUMP_S3_ACCESS_KEY"], diff --git a/oarepo_oidc_einfra/communities.py b/oarepo_oidc_einfra/communities.py index 0d69195..7dfa5df 100644 --- a/oarepo_oidc_einfra/communities.py +++ b/oarepo_oidc_einfra/communities.py @@ -6,14 +6,16 @@ # details. # """Helper functions for working with communities.""" + +from __future__ import annotations + +import dataclasses import logging -from collections import namedtuple from functools import cached_property -from typing import Dict, Set +from typing import TYPE_CHECKING, Iterable from flask import current_app from invenio_access.permissions import system_identity -from invenio_accounts.models import User from invenio_communities.communities.records.api import Community from invenio_communities.members.errors import AlreadyMemberError from invenio_communities.members.records.models import MemberModel @@ -21,36 +23,40 @@ from invenio_db import db from marshmallow import ValidationError from sqlalchemy import select +from sqlalchemy.sql.expression import true + +if TYPE_CHECKING: + from uuid import UUID + + from invenio_accounts.models import User log = logging.getLogger(__name__) -CommunityRole = namedtuple("CommunityRole", ["community_id", "role"]) -"""A named tuple representing a community and a role.""" +@dataclasses.dataclass(frozen=True) +class CommunityRole: + """A class representing a community and a role.""" + + community_id: UUID + role: str class CommunitySupport: """A support class for working with communities and their members.""" - def __init__(self): - pass - @cached_property - def slug_to_id(self) -> Dict[str, str]: - """ - Returns a mapping of community slugs to their ids. - """ + def slug_to_id(self) -> dict[str, UUID]: + """Returns a mapping of community slugs to their ids.""" return { - row[1]: str(row[0]) + row[1]: row[0] for row in db.session.execute( select(Community.model_cls.id, Community.model_cls.slug) ) } @cached_property - def all_community_roles(self) -> Set[CommunityRole]: - """ - Returns a set of all community roles (pair of community id, role name) known to the repository. + def all_community_roles(self) -> set[CommunityRole]: + """Return a set of all community roles (pair of community id, role name) known to the repository. :return: a set of all community roles known to the repository """ @@ -63,16 +69,15 @@ def all_community_roles(self) -> Set[CommunityRole]: return repository_comunity_roles @cached_property - def role_names(self) -> Set[str]: - """ - Returns a set of all known community role names, as configured inside the invenio.cfg + def role_names(self) -> set[str]: + """Return a set of all known community role names, as configured inside the invenio.cfg. + :return: a set of all known community role names """ return {role["name"] for role in current_app.config["COMMUNITIES_ROLES"]} def role_priority(self, role_name: str) -> int: - """ - Returns a priority of a given role name. + """Return a priority of a given role name. :param role_name: role name :return: role priority (0 is lowest (member), higher number is higher priority (up to owner)) @@ -80,9 +85,8 @@ def role_priority(self, role_name: str) -> int: return self.role_priorities[role_name] @cached_property - def role_priorities(self) -> Dict[str, int]: - """ - Returns a mapping of role names to their priorities. + def role_priorities(self) -> dict[str, int]: + """Returns a mapping of role names to their priorities. :return: a mapping of role names to their priorities, 0 is lowest priority """ @@ -95,8 +99,8 @@ def role_priorities(self) -> Dict[str, int]: def set_user_community_membership( cls, user: User, - new_community_roles: Set[CommunityRole], - current_community_roles: Set[CommunityRole] = None, + new_community_roles: set[CommunityRole], + current_community_roles: set[CommunityRole] | None = None, ) -> None: """Set user membership based on the new community roles. @@ -109,8 +113,8 @@ def set_user_community_membership( if not current_community_roles: current_community_roles = cls.get_user_community_membership(user) - for community_id, role in new_community_roles - current_community_roles: - cls._add_user_community_membership(community_id, role, user) + for community_role in new_community_roles - current_community_roles: + cls._add_user_community_membership(community_role, user) for v in new_community_roles: assert isinstance(v, CommunityRole) @@ -131,7 +135,7 @@ def set_user_community_membership( ) @classmethod - def get_user_community_membership(cls, user) -> Set[CommunityRole]: + def get_user_community_membership(cls, user: User) -> set[CommunityRole]: """Get user's actual community roles. :param user: User object @@ -139,7 +143,7 @@ def get_user_community_membership(cls, user) -> Set[CommunityRole]: ret = set() for row in db.session.execute( select([MemberModel.community_id, MemberModel.role]).where( - MemberModel.user_id == user.id, MemberModel.active == True + MemberModel.user_id == user.id, MemberModel.active == true() ) ): ret.add(CommunityRole(row.community_id, row.role)) @@ -148,17 +152,17 @@ def get_user_community_membership(cls, user) -> Set[CommunityRole]: @classmethod def get_user_list_community_membership( - cls, user_ids: list[int] - ) -> Dict[int, Set[CommunityRole]]: + cls, user_ids: Iterable[int] + ) -> dict[int, set[CommunityRole]]: """Get community roles of a list of users. :param user_ids: List of user ids """ - ret = {} + ret: dict[int, set[CommunityRole]] = {} for row in db.session.execute( select( [MemberModel.community_id, MemberModel.user_id, MemberModel.role] - ).where(MemberModel.user_id.in_(user_ids), MemberModel.active == True) + ).where(MemberModel.user_id.in_(user_ids), MemberModel.active == true()) ): if row.user_id not in ret: ret[row.user_id] = set() @@ -168,25 +172,23 @@ def get_user_list_community_membership( @classmethod def _add_user_community_membership( - cls, community_id: str, community_role: str, user: User + cls, community_role: CommunityRole, user: User ) -> None: - """ - Add user to a community with a given role. + """Add user to a community with a given role. - :param community_id: id of the community :param community_role: community role :param user: user object :return: A membership result item from service """ data = { - "role": community_role, + "role": community_role.role, "members": [{"type": "user", "id": str(user.id)}], } try: return current_communities.service.members.add( - system_identity, community_id, data + system_identity, community_role.community_id, data ) - except AlreadyMemberError as e: + except AlreadyMemberError: # We are here because # # * active memberships have not returned this (community, role) for user @@ -198,7 +200,9 @@ def _add_user_community_membership( # We need to get the associated invitation request and accept it here, # thus the membership will become active. results = current_communities.service.members.search_invitations( - system_identity, community_id, params={"user.id": str(user.id)} + system_identity, + community_role.community_id, + params={"user.id": str(user.id)}, ) hits = list(results.hits) if len(hits) == 1: @@ -207,9 +211,8 @@ def _add_user_community_membership( ) @classmethod - def _remove_user_community_membership(cls, community_id, user) -> None: - """ - Remove user from a community with a given role. + def _remove_user_community_membership(cls, community_id: UUID, user: User) -> None: + """Remove user from a community with a given role. :param community_id: id of the community :param user: user object diff --git a/oarepo_oidc_einfra/config.py b/oarepo_oidc_einfra/config.py index 523c9cc..003da6b 100644 --- a/oarepo_oidc_einfra/config.py +++ b/oarepo_oidc_einfra/config.py @@ -5,6 +5,7 @@ # modify it under the terms of the MIT License; see LICENSE file for more # details. # +"""Configuration for the E-INFRA OIDC authentication, can be overwritten in invenio.cfg .""" EINFRA_COMMUNITY_SYNCHRONIZATION = True """Synchronize community to E-Infra Perun when community is created.""" diff --git a/oarepo_oidc_einfra/encryption.py b/oarepo_oidc_einfra/encryption.py index a8081b6..e7b31e9 100644 --- a/oarepo_oidc_einfra/encryption.py +++ b/oarepo_oidc_einfra/encryption.py @@ -5,6 +5,8 @@ # modify it under the terms of the MIT License; see LICENSE file for more # details. # +"""Encryption and decryption of request id using FernetEngine encryption.""" + from uuid import UUID from flask import current_app diff --git a/oarepo_oidc_einfra/ext.py b/oarepo_oidc_einfra/ext.py index 69edaff..f761e4f 100644 --- a/oarepo_oidc_einfra/ext.py +++ b/oarepo_oidc_einfra/ext.py @@ -7,35 +7,39 @@ # """A flask extension for E-INFRA OIDC authentication.""" -from flask import current_app -from invenio_communities.communities.services.components import \ - DefaultCommunityComponents -from invenio_communities.members.services.components import \ - DefaultCommunityMemberComponents +from flask import Flask, current_app +from invenio_communities.communities.services.components import ( + DefaultCommunityComponents, +) +from invenio_communities.members.services.components import ( + DefaultCommunityMemberComponents, +) from oarepo_oidc_einfra.perun import PerunLowLevelAPI from oarepo_oidc_einfra.services.components.aai_communities import CommunityAAIComponent -from oarepo_oidc_einfra.services.components.aai_invitations import \ - AAIInvitationComponent +from oarepo_oidc_einfra.services.components.aai_invitations import ( + AAIInvitationComponent, +) from .cli import einfra as einfra_cmd class EInfraOIDCApp: - def __init__(self, app=None): - """Creates the extension.""" + """EInfra OIDC extension.""" + + def __init__(self, app: Flask | None = None): + """Create the extension.""" if app: self.init_app(app) - def init_app(self, app): - """Adds the extension to the app and loads initial configuration.""" + def init_app(self, app: Flask) -> None: + """Add the extension to the app and loads initial configuration.""" app.extensions["einfra-oidc"] = self self.init_config(app) app.cli.add_command(einfra_cmd) - def init_config(self, app): - """Loads the default configuration.""" - + def init_config(self, app: Flask) -> None: + """Load the default configuration.""" self.register_sync_component_to_community_service(app) # sets the default configuration values @@ -45,9 +49,8 @@ def init_config(self, app): if k.startswith("EINFRA_"): app.config.setdefault(k, getattr(config, k)) - def register_sync_component_to_community_service(self, app): - """Registers components to the community service.""" - + def register_sync_component_to_community_service(self, app: Flask) -> None: + """Register components to the community service.""" # Community -> AAI synchronization service component communities_components = app.config.get("COMMUNITIES_SERVICE_COMPONENTS", None) if isinstance(communities_components, list): @@ -70,8 +73,8 @@ def register_sync_component_to_community_service(self, app): *DefaultCommunityMemberComponents, ] - def perun_api(self): - + def perun_api(self) -> PerunLowLevelAPI: + """Create a new Perun API instance.""" return PerunLowLevelAPI( base_url=current_app.config["EINFRA_API_URL"], service_id=current_app.config["EINFRA_SERVICE_ID"], diff --git a/oarepo_oidc_einfra/mutex.py b/oarepo_oidc_einfra/mutex.py index fa6f41a..d32e96b 100644 --- a/oarepo_oidc_einfra/mutex.py +++ b/oarepo_oidc_einfra/mutex.py @@ -5,23 +5,25 @@ # modify it under the terms of the MIT License; see LICENSE file for more # details. # -""" -An implementation of a mutex over the cache. Note: if you have multiple redis caches, +"""An implementation of a mutex over the cache. + +Note: if you have multiple redis caches, you need to implement a distributed lock instead of this simple implementation to have a mutex that works in all cases of disaster scenarios. """ + import functools import secrets import threading import time from random import random +from typing import Callable from invenio_cache import current_cache class CacheMutex: - """ - A simple mutex implementation using the cache. + """A simple mutex implementation using the cache. Because propagation from master cache server to slaves might be asynchronous and messages might be lost in case when master responds before the slave is updated, @@ -31,8 +33,10 @@ class CacheMutex: slaves and implement a distributed lock such as Redlock algorithm for redis. """ - def __init__(self, key, timeout=3600, tries=10, wait_time=120): - """Creates the mutex. + def __init__( + self, key: str, timeout: float = 3600, tries: int = 10, wait_time: float = 120 + ): + """Create the mutex. :param key: The key inside cache where the mutex data are stored. :param timeout: The mutex will be released automatically after this time. @@ -52,7 +56,7 @@ def __init__(self, key, timeout=3600, tries=10, wait_time=120): def __enter__(self): """Acquires the mutex.""" - for k in range(self.tries): + for _k in range(self.tries): if current_cache.cache.add(self.key, self.value, timeout=self.timeout): # sanity check if current_cache.cache.get(self.key) != self.value: @@ -65,14 +69,14 @@ def __enter__(self): f"waiting {self.wait_time} seconds each time" ) - def __exit__(self, exc_type, exc_value, traceback): + def __exit__(self, exc_type, exc_value, traceback): # noqa """Releases the mutex.""" if current_cache.cache.get(self.key) == self.value: current_cache.cache.delete(self.key) - def force_clear(self): - """ - Forces the mutex to be cleared. + def force_clear(self) -> None: + """Force the mutex to be cleared. + Note: this does not stop any processes that might be using the mutex ! """ current_cache.cache.delete(self.key) @@ -82,9 +86,10 @@ def force_clear(self): """make the mutex below reentrant within the same thread""" -def mutex(key, timeout=3600, tries=10, wait_time=120): - """ - A decorator that creates a mutex for a function. +def mutex( + key: str, timeout: float = 3600, tries: int = 10, wait_time: float = 120 +) -> Callable: + """Create a mutex for a function. :param key: The key inside cache where the mutex data are stored. :param timeout: The mutex will be released automatically after this time. @@ -100,9 +105,9 @@ def my_function(): # do something that needs to be protected by the mutex """ - def decorator(func): + def decorator(func): # noqa @functools.wraps(func) - def wrapper(*args, **kwargs): + def wrapper(*args, **kwargs): # noqa if not hasattr(mutex_thread_local, key): setattr(mutex_thread_local, key, True) try: diff --git a/oarepo_oidc_einfra/perun/api.py b/oarepo_oidc_einfra/perun/api.py index ee0c93b..4260f98 100644 --- a/oarepo_oidc_einfra/perun/api.py +++ b/oarepo_oidc_einfra/perun/api.py @@ -5,8 +5,10 @@ # modify it under the terms of the MIT License; see LICENSE file for more # details. # +"""Low-level API for Perun targeted at the operations needed by E-INFRA OIDC extension.""" + import logging -from typing import Optional +from typing import Optional, Tuple import requests from requests.auth import HTTPBasicAuth @@ -19,8 +21,7 @@ class DoesNotExist(Exception): class PerunLowLevelAPI: - """ - Low-level API for Perun targeted at the operations needed by E-INFRA OIDC extension. + """Low-level API for Perun targeted at the operations needed by E-INFRA OIDC extension. Note: Perun does not follow RESTful principles and the API is thus not resource-oriented, but rather manager-oriented and spills out implementation details. This class provides @@ -29,9 +30,14 @@ class PerunLowLevelAPI: Note: All ids are internal Perun ids, not UUIDs or other external identifiers. """ - def __init__(self, base_url: str, service_id: int, service_username: str, service_password: str): - """ - Initialize the API with the base URL and the service credentials. + def __init__( + self, + base_url: str, + service_id: int, + service_username: str, + service_password: str, + ): + """Initialize the API with the base URL and the service credentials. :param base_url: URL of Perun server :param service_id: the id of the service that manages stuff @@ -42,14 +48,35 @@ def __init__(self, base_url: str, service_id: int, service_username: str, servic self._service_id = service_id self._auth = HTTPBasicAuth(service_username, service_password) - def _perun_call(self, manager: str, method: str, payload: dict) -> dict|list: + def _perun_call_dict(self, manager: str, method: str, payload: dict) -> dict: + """Low-level call to Perun API with error handling, call returns a dict. + + :param manager: the manager to call + :param method: the method to call + :param payload: the json payload to send + """ + ret = self._perun_call(manager, method, payload) + assert isinstance(ret, dict) + return ret + + def _perun_call_list(self, manager: str, method: str, payload: dict) -> list: + """Low-level call to Perun API with error handling, call returns a list of objects. + + :param manager: the manager to call + :param method: the method to call + :param payload: the json payload to send + """ + ret = self._perun_call(manager, method, payload) + assert isinstance(ret, list) + return ret + + def _perun_call(self, manager: str, method: str, payload: dict) -> dict | list: """Low-level call to Perun API with error handling. :param manager: the manager to call :param method: the method to call :param payload: the json payload to send """ - print("PerunCall", manager, method, payload) resp = requests.post( f"{self._base_url}/krb/rpc/json/{manager}/{method}", auth=self._auth, @@ -70,11 +97,15 @@ def _perun_call(self, manager: str, method: str, payload: dict) -> dict|list: return resp.json() def create_group( - self, *, name: str, description: str, parent_group_id: int, - parent_vo: int, check_existing: bool=True - ): - """ - Create a new group in Perun and set the service as its admin + self, + *, + name: str, + description: str, + parent_group_id: int, + parent_vo: int, + check_existing: bool = True, + ) -> tuple[dict, bool, bool]: + """Create a new group in Perun and set the service as its admin. :param name: Name of the group :param description: Description of the group @@ -88,6 +119,7 @@ def create_group( group_created = False admin_created = False + group: dict | None if check_existing: group = self.get_group_by_name(name, parent_group_id) else: @@ -97,7 +129,7 @@ def create_group( log.info("Creating group %s within parent %s", name, parent_group_id) # Create a new group in Perun - group = self._perun_call( + group = self._perun_call_dict( "groupsManager", "createGroup", { @@ -106,6 +138,7 @@ def create_group( "parentGroup": parent_group_id, }, ) + group_created = True log.info( "Group %s within parent %s created, id %s", @@ -144,9 +177,8 @@ def create_group( return (group, group_created, admin_created) - def get_group_by_name(self, name: str, parent_group_id: int) -> Optional[str]: - """ - Get a group by name within a parent group. + def get_group_by_name(self, name: str, parent_group_id: int) -> Optional[dict]: + """Get a group by name within a parent group. :param name: name of the group :param parent_group_id: ID of the parent group @@ -165,16 +197,15 @@ def create_resource_with_group_and_capabilities( *, vo_id: int, facility_id: int, - group_id : int, + group_id: int, name: str, description: str, capability_attr_id: int, capabilities: list[str], perun_sync_service_id: int, - check_existing: bool =True, - ) -> (dict, bool): - """ - Create a new resource in Perun and assign the group to it. + check_existing: bool = True, + ) -> Tuple[dict, bool]: + """Create a new resource in Perun and assign the group to it. :param vo_id: id of the virtual organization in within the resource is created :param facility_id: id of the facility for which the resource is created. The service have facility manager rights @@ -205,10 +236,14 @@ def create_resource_with_group_and_capabilities( return resource, resource_created def create_resource( - self, vo_id: int, facility_id: int, name: str, description: str, check_existing: bool=True - ) -> (dict, bool): - """ - Create a new resource in Perun, optionally checking if a resource with the same name already exists. + self, + vo_id: int, + facility_id: int, + name: str, + description: str, + check_existing: bool = True, + ) -> Tuple[dict, bool]: + """Create a new resource in Perun, optionally checking if a resource with the same name already exists. :param vo_id: id of the virtual organization in within the resource is created :param facility_id: id of the facility for which the resource is created @@ -230,7 +265,7 @@ def create_resource( facility_id, vo_id, ) - resource = self._perun_call( + resource = self._perun_call_dict( "resourcesManager", "createResource", { @@ -251,8 +286,7 @@ def create_resource( return resource, resource_created def assign_group_to_resource(self, resource_id: int, group_id: int) -> None: - """ - Assign a group to a resource. + """Assign a group to a resource. :param resource_id: id of the resource :param group_id: id of the group to be assigned @@ -276,17 +310,17 @@ def assign_group_to_resource(self, resource_id: int, group_id: int) -> None: ) log.info("Group %s assigned to resource %s", group_id, resource_id) - def set_resource_capabilities(self, resource_id: int, capability_attr_id: int, capabilities: list[str]) -> None: - """ - Set capabilities to a resource. + def set_resource_capabilities( + self, resource_id: int, capability_attr_id: int, capabilities: list[str] + ) -> None: + """Set capabilities to a resource. :param resource_id: id of the resource :param capability_attr_id: internal id of the attribute that holds the capabilities :param capabilities: list of capabilities to be set """ - # check if the resource has the capability and if not, add it - attr = self._perun_call( + attr = self._perun_call_dict( "attributesManager", "getAttribute", {"resource": resource_id, "attributeId": capability_attr_id}, @@ -305,8 +339,7 @@ def set_resource_capabilities(self, resource_id: int, capability_attr_id: int, c log.info("Capabilities %s set to resource %s", capabilities, resource_id) def attach_service_to_resource(self, resource_id: int, service_id: int) -> None: - """ - Attach a service to a resource. + """Attach a service to a resource. :param resource_id: id of the resource :param service_id: id of the service to be attached @@ -336,9 +369,10 @@ def attach_service_to_resource(self, resource_id: int, service_id: int) -> None: resource_id, ) - def get_resource_by_name(self, vo_id: int, facility_id: int, name: str) -> Optional[dict]: - """ - Get a resource by name. + def get_resource_by_name( + self, vo_id: int, facility_id: int, name: str + ) -> Optional[dict]: + """Get a resource by name. :param vo_id: id of the virtual organization :param facility_id: id of the facility for which a resource is created @@ -346,7 +380,7 @@ def get_resource_by_name(self, vo_id: int, facility_id: int, name: str) -> Optio :return: resource or None if not found """ try: - return self._perun_call( + return self._perun_call_dict( "resourcesManager", "getResourceByName", {"vo": vo_id, "facility": facility_id, "name": name}, @@ -354,9 +388,10 @@ def get_resource_by_name(self, vo_id: int, facility_id: int, name: str) -> Optio except DoesNotExist: return None - def get_resource_by_capability(self, *, vo_id: int, facility_id: int, capability: str) -> Optional[dict]: - """ - Get a resource by capability. + def get_resource_by_capability( + self, *, vo_id: int, facility_id: int, capability: str + ) -> Optional[dict]: + """Get a resource by capability. :param vo_id: id of the virtual organization :param facility_id: id of the facility where we search for resource @@ -383,8 +418,7 @@ def get_resource_by_capability(self, *, vo_id: int, facility_id: int, capability return matching_resources[0] def get_resource_groups(self, *, resource_id: int) -> list[dict]: - """ - Get groups assigned to a resource. + """Get groups assigned to a resource. :param resource_id: id of the resource :return: list of groups @@ -400,9 +434,10 @@ def get_resource_groups(self, *, resource_id: int) -> list[dict]: ) ] - def get_user_by_attribute(self, *, attribute_name: str, attribute_value: str) -> Optional[dict]: - """ - Get a user by attribute. + def get_user_by_attribute( + self, *, attribute_name: str, attribute_value: str + ) -> Optional[dict]: + """Get a user by attribute. :param attribute_name: name of the attribute :param attribute_value: value of the attribute @@ -421,9 +456,10 @@ def get_user_by_attribute(self, *, attribute_name: str, attribute_value: str) -> return None return users[0] - def remove_user_from_group(self, *, vo_id: int, user_id: int, group_id: int) -> None: - """ - Remove a user from a group. + def remove_user_from_group( + self, *, vo_id: int, user_id: int, group_id: int + ) -> None: + """Remove a user from a group. :param vo_id: id of the virtual organization :param user_id: internal perun id of the user @@ -438,8 +474,7 @@ def remove_user_from_group(self, *, vo_id: int, user_id: int, group_id: int) -> ) def add_user_to_group(self, *, vo_id: int, user_id: int, group_id: int) -> None: - """ - Add a user to a group. + """Add a user to a group. :param vo_id: id of the virtual organization :param user_id: internal perun id of the user @@ -447,7 +482,7 @@ def add_user_to_group(self, *, vo_id: int, user_id: int, group_id: int) -> None: """ member = self._get_or_create_member_in_vo(vo_id, user_id) - self._perun_call( + self._perun_call_dict( "groupsManager", "addMember", {"group": group_id, "member": member["id"]}, @@ -455,7 +490,7 @@ def add_user_to_group(self, *, vo_id: int, user_id: int, group_id: int) -> None: def _get_or_create_member_in_vo(self, vo_id: int, user_id: int) -> dict: # TODO: create part here (but we might not need it if everything goes through invitations) - member = self._perun_call( + member = self._perun_call_dict( "membersManager", "getMemberByUser", {"vo": vo_id, "user": user_id} ) return member @@ -471,8 +506,7 @@ def send_invitation( expiration: str, redirect_url: str, ) -> dict: - """ - Send an invitation to a user to join a group. + """Send an invitation to a user to join a group. :param vo_id: id of the virtual organization :param group_id: id of the group @@ -482,7 +516,7 @@ def send_invitation( :param expiration: expiration date of the invitation, format YYYY-MM-DD :param redirect_url: URL to redirect to after accepting the invitation """ - return self._perun_call( + return self._perun_call_dict( "invitationsManager", "inviteToGroup", { diff --git a/oarepo_oidc_einfra/perun/dump.py b/oarepo_oidc_einfra/perun/dump.py index 46bb683..f60fe86 100644 --- a/oarepo_oidc_einfra/perun/dump.py +++ b/oarepo_oidc_einfra/perun/dump.py @@ -5,12 +5,15 @@ # modify it under the terms of the MIT License; see LICENSE file for more # details. # +"""Dump data from the PERUN.""" + import dataclasses import logging -from collections import defaultdict, namedtuple +from collections import defaultdict from datetime import UTC, datetime from functools import cached_property -from typing import Any, Dict, Iterable, List, Set +from typing import Dict, Iterable, List, Set +from uuid import UUID import boto3 from flask import current_app @@ -19,9 +22,11 @@ log = logging.getLogger("perun.dump_data") + @dataclasses.dataclass(frozen=True) class AAIUser: """A user with their roles as received from the Perun AAI.""" + einfra_id: str email: str full_name: str @@ -30,18 +35,15 @@ class AAIUser: class PerunDumpData: - """ - Provides access to the data from the PERUN dump. - """ + """Provides access to the data from the PERUN dump.""" def __init__( self, - dump_data: Any, - community_slug_to_id: Dict[str, str], + dump_data: dict, + community_slug_to_id: Dict[str, UUID], community_role_names: Set[str], ): - """ - Creates an instance of the data + """Create an instance of the data. :param dump_data: The data from the PERUN dump (json) :param community_slug_to_id: Mapping of community slugs to their ids (str of uuid) @@ -53,8 +55,8 @@ def __init__( @cached_property def aai_community_roles(self) -> Set[CommunityRole]: - """ - Returns all community roles (pairs of community id, role name) from the dump. + """Return all community roles from the dump. + :return: set of community roles known to perun """ aai_community_roles = set() @@ -64,8 +66,7 @@ def aai_community_roles(self) -> Set[CommunityRole]: @cached_property def resource_to_community_roles(self) -> Dict[str, List[CommunityRole]]: - """ - Returns a mapping of resource id to community roles. + """Returns a mapping of resource id to community roles. :return: for each Perun resource, mapping to associated community roles """ @@ -100,14 +101,15 @@ def resource_to_community_roles(self) -> Dict[str, List[CommunityRole]]: if role not in self.community_role_names: log.error(f"Role from PERUN {role} not found in the repository") continue - community_role = CommunityRole(self.slug_to_id[community_slug], role) + community_role = CommunityRole( + self.slug_to_id[community_slug], role + ) resources[r_id].append(community_role) return resources def users(self) -> Iterable[AAIUser]: - """ - Returns all users from the dump. + """Return all users from the dump. :return: iterable of AAIUser """ @@ -135,8 +137,7 @@ def users(self) -> Iterable[AAIUser]: def _get_roles_for_resources( self, allowed_resources: Iterable[str] ) -> Set[CommunityRole]: - """ - Returns community roles for an iterable of allowed resources. + """Return community roles for an iterable of allowed resources. :param allowed_resources: iterable of resource ids :return: a set of associated community roles @@ -148,8 +149,10 @@ def _get_roles_for_resources( def import_dump_file(data: bytes) -> str: - """ - Imports a dump file from the input stream into S3 and returns file name + """Import a dump file from the input stream into S3 and return file name. + + :param data: data to be imported + :return: path to the object in S3 """ client = boto3.client( "s3", diff --git a/oarepo_oidc_einfra/perun/mapping.py b/oarepo_oidc_einfra/perun/mapping.py index a2dd6a7..a2de53c 100644 --- a/oarepo_oidc_einfra/perun/mapping.py +++ b/oarepo_oidc_einfra/perun/mapping.py @@ -5,6 +5,9 @@ # modify it under the terms of the MIT License; see LICENSE file for more # details. # +"""Mapping between perun capabilities and Invenio roles.""" + +import dataclasses from typing import Dict, Optional from invenio_accounts.models import UserIdentity @@ -12,9 +15,8 @@ from sqlalchemy import select -def get_perun_capability_from_invenio_role(slug, role): - """ - Get the capability name from the Invenio role. +def get_perun_capability_from_invenio_role(slug: str, role: str) -> str: + """Get the capability name from the Invenio role. :param slug: slug of the community :param role: role in the community @@ -23,30 +25,37 @@ def get_perun_capability_from_invenio_role(slug, role): return f"res:communities:{slug}:role:{role}" -def get_invenio_role_from_capability(capability: str | list): - """ - Get the Invenio role from the capability. +@dataclasses.dataclass +class SlugCommunityRole: + """A class representing a community slug and a role.""" + + slug: str + """Community slug.""" + + role: str + """Role name.""" + + +def get_invenio_role_from_capability(capability: str | list) -> SlugCommunityRole: + """Get the Invenio role from the capability. :param capability: capability name :return: (slug, role) """ - if isinstance(capability, str): - parts = capability.split(":") - else: - parts = capability + parts = capability.split(":") if isinstance(capability, str) else capability + if ( len(parts) == 5 and parts[0] == "res" and parts[1] == "communities" and parts[3] == "role" ): - return parts[2], parts[4] + return SlugCommunityRole(parts[2], parts[4]) raise ValueError(f"Not an invenio role capability: {capability}") def get_user_einfra_id(user_id: int) -> Optional[str]: - """ - Get e-infra identity for user with given id. + """Get e-infra identity for user with given id. :param user_id: user id :return: e-infra identity or None if user has no e-infra identity associated @@ -60,9 +69,9 @@ def get_user_einfra_id(user_id: int) -> Optional[str]: def einfra_to_local_users_map() -> Dict[str, int]: - """ - Returns a mapping of e-infra id to user id for local users, that have e-infra identity - and logged at least once with it. + """Return a mapping of e-infra id to user id for local users. + + Only users that have e-infra identity and logged at least once with it re returned :return: a mapping of e-infra id to user id """ diff --git a/oarepo_oidc_einfra/perun/oidc.py b/oarepo_oidc_einfra/perun/oidc.py index 53dadaf..003281b 100644 --- a/oarepo_oidc_einfra/perun/oidc.py +++ b/oarepo_oidc_einfra/perun/oidc.py @@ -6,6 +6,7 @@ # details. # """OIDC utilities.""" + import logging from typing import Set @@ -13,14 +14,13 @@ from urnparse import URN8141, InvalidURNFormatError from ..communities import CommunityRole, CommunitySupport -from .mapping import get_invenio_role_from_capability +from .mapping import SlugCommunityRole, get_invenio_role_from_capability log = logging.getLogger(__name__) -def get_communities_from_userinfo_token(userinfo_token) -> Set[CommunityRole]: - """ - Extracts communities and roles from userinfo token. +def get_communities_from_userinfo_token(userinfo_token: dict) -> Set[CommunityRole]: + """Extract communities and roles from userinfo token. :param userinfo_token: userinfo token from perun/oidc server :return: a set of community roles associated with the user @@ -50,11 +50,13 @@ def get_communities_from_userinfo_token(userinfo_token) -> Set[CommunityRole]: if not parts or parts[0] != current_app.config["EINFRA_ENTITLEMENT_PREFIX"]: continue try: - community_slug, role = get_invenio_role_from_capability(parts[1:]) - if role not in community_roles: - log.error(f"Role {role} not found in community roles in urn {urn}") + slug_role: SlugCommunityRole = get_invenio_role_from_capability(parts[1:]) + if slug_role.role not in community_roles: + log.error( + f"Role {slug_role.role} not found in community roles in urn {urn}" + ) continue - aai_groups.add((slug_to_id[community_slug], role)) + aai_groups.add(CommunityRole(slug_to_id[slug_role.slug], slug_role.role)) except ValueError: continue diff --git a/oarepo_oidc_einfra/proxies.py b/oarepo_oidc_einfra/proxies.py index 11b8097..9c079d7 100644 --- a/oarepo_oidc_einfra/proxies.py +++ b/oarepo_oidc_einfra/proxies.py @@ -7,8 +7,15 @@ # """Helper proxy to the state object.""" +from typing import TYPE_CHECKING + from flask import current_app from werkzeug.local import LocalProxy -current_einfra_oidc = LocalProxy(lambda: current_app.extensions["einfra-oidc"]) +if TYPE_CHECKING: + from oarepo_oidc_einfra.ext import EInfraOIDCApp + +current_einfra_oidc: "EInfraOIDCApp" = ( + LocalProxy["EInfraOIDCApp"](lambda: current_app.extensions["einfra-oidc"]) # type: ignore +) """Helper proxy to get the current einfra oidc.""" diff --git a/oarepo_oidc_einfra/remote.py b/oarepo_oidc_einfra/remote.py index 3046686..360c34b 100644 --- a/oarepo_oidc_einfra/remote.py +++ b/oarepo_oidc_einfra/remote.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- # # Copyright (C) 2024 CESNET z.s.p.o. # @@ -6,15 +5,18 @@ # modify it under the terms of the MIT License; see LICENSE file for more # details. # +"""E-Infra OIDC Remote Auth backend for NRP.""" import datetime import jwt +from flask_oauthlib.client import OAuthRemoteApp from invenio_accounts.models import User, UserIdentity from invenio_db import db from invenio_oauthclient import current_oauthclient from invenio_oauthclient.contrib.settings import OAuthSettingsHelper from invenio_oauthclient.handlers.token import token_getter +from invenio_oauthclient.models import RemoteToken from invenio_oauthclient.oauth import oauth_get_user from invenio_oauthclient.signals import account_info_received @@ -25,22 +27,22 @@ class EInfraOAuthSettingsHelper(OAuthSettingsHelper): def __init__( self, *, - title="E-Infra AAI", - description="E-Infra authentication and authorization service.", - base_url="https://login.e-infra.cz/oidc/", - app_key="EINFRA", - icon=None, - access_token_url=None, - authorize_url=None, - access_token_method="POST", - request_token_params=None, - request_token_url=None, - precedence_mask=None, - signup_options=None, - logout_url=None, - **kwargs, + title: str = "E-Infra AAI", + description: str = "E-Infra authentication and authorization service.", + base_url: str = "https://login.e-infra.cz/oidc/", + app_key: str = "EINFRA", + icon: str | None = None, + access_token_url: str | None = None, + authorize_url: str | None = None, + access_token_method: str | None = "POST", + request_token_params: dict | None = None, + request_token_url: str | None = None, + precedence_mask: str | None = None, + signup_options: dict | None = None, + logout_url: str | None = None, + **kwargs: dict, ): - + """Initialize the E-Infra OIDC Remote Auth backend for NRP.""" request_token_params = request_token_params or { "scope": " ".join( [ @@ -100,11 +102,11 @@ def __init__( error_redirect_url="/", ) - def get_handlers(self): + def get_handlers(self) -> dict: """Return CESNET auth handlers.""" return self._handlers - def get_rest_handlers(self): + def get_rest_handlers(self) -> dict: """Return CESNET auth REST handlers.""" return self._rest_handlers @@ -117,9 +119,8 @@ def get_rest_handlers(self): EINFRA_LOGIN_APP = _cesnet_app.remote_app -def account_info_serializer(remote, resp): - """ - Serialize the account info response object. +def account_info_serializer(remote: OAuthRemoteApp, resp: dict) -> dict: + """Serialize the account info response object. :param remote: The remote application. :param resp: The response of the `authorized` endpoint. @@ -146,9 +147,8 @@ def account_info_serializer(remote, resp): } -def account_info(remote, resp): - """ - Retrieve remote account information used to find local user. +def account_info(remote: OAuthRemoteApp, resp: dict) -> dict: + """Retrieve remote account information used to find local user. It returns a dictionary with the following structure: { @@ -172,9 +172,8 @@ def account_info(remote, resp): return handler_resp -def account_setup(remote, token, resp): - """ - Perform additional setup after user have been logged in. +def account_setup(remote: OAuthRemoteApp, token: RemoteToken, resp: dict) -> None: + """Perform additional setup after user have been logged in. :param remote: The remote application. :param token: The token value. @@ -207,7 +206,19 @@ def account_setup(remote, token, resp): # During overlay initialization. @account_info_received.connect -def autocreate_user(remote, token=None, response=None, account_info=None): +def autocreate_user( + remote: OAuthRemoteApp, + token: RemoteToken | None = None, + response: dict | None = None, + account_info: dict | None = None, +) -> None: + """Create a user if it does not exist. + + :param remote: The remote application. + :param token: access token + :param response: access response from the remote server + :param account_info: account info from the remote server + """ assert account_info is not None email = account_info["user"]["email"].lower() @@ -252,7 +263,15 @@ def autocreate_user(remote, token=None, response=None, account_info=None): db.session.commit() -def account_info_link_perun_groups(remote, *, account_info, **kwargs): +def account_info_link_perun_groups( + remote: OAuthRemoteApp, *, account_info: dict, **kwargs: dict +) -> None: + """Set local user community membership based on the Perun groups retrieved from the userinfo token. + + :param remote: The remote application. + :param account_info: The account info of the current user + :param kwargs: Additional arguments (not used) + """ # make the import local to avoud circular imports from oarepo_oidc_einfra.communities import CommunitySupport from oarepo_oidc_einfra.perun import get_communities_from_userinfo_token diff --git a/oarepo_oidc_einfra/resources.py b/oarepo_oidc_einfra/resources.py index 9c123f9..4955d03 100644 --- a/oarepo_oidc_einfra/resources.py +++ b/oarepo_oidc_einfra/resources.py @@ -7,10 +7,14 @@ # """REST resources.""" + +from __future__ import annotations + import logging from datetime import UTC, datetime +from typing import TYPE_CHECKING, Optional -from flask import current_app, g, request +from flask import Blueprint, Flask, current_app, g, redirect, request from flask_login import login_required from flask_principal import PermissionDenied from flask_resources import Resource, ResourceConfig, route @@ -26,6 +30,9 @@ from oarepo_oidc_einfra.encryption import decrypt from oarepo_oidc_einfra.tasks import update_from_perun_dump +if TYPE_CHECKING: + from werkzeug import Response + log = logging.getLogger(__name__) @@ -51,13 +58,11 @@ class OIDCEInfraResourceConfig(ResourceConfig): class OIDCEInfraResource(Resource): """REST API for the EInfra OIDC.""" - def __init__(self, config=None): + def __init__(self, config: Optional[OIDCEInfraResourceConfig] = None): """Initialize the resource.""" - super(OIDCEInfraResource, self).__init__( - config=config or OIDCEInfraResourceConfig() - ) + super().__init__(config=config or OIDCEInfraResourceConfig()) - def create_url_rules(self): + def create_url_rules(self) -> list[dict]: """Create URL rules for the resource.""" routes = self.config.routes return [ @@ -65,7 +70,7 @@ def create_url_rules(self): route("GET", routes["accept-invitation"], self.accept_invitation), ] - def upload_dump(self): + def upload_dump(self) -> tuple[dict, int]: """Upload a dump of the EInfra data. The dump will be uploaded to the configured location (EINFRA_DUMP_DATA_URL inside config) @@ -95,8 +100,10 @@ def upload_dump(self): return {"status": "ok"}, 201 @login_required - def accept_invitation(self): - """Accept an invitation to join a community. This is an endpoint to which user is directed + def accept_invitation(self) -> Response: + """Accept an invitation to join a community. + + This is an endpoint to which user is directed after clicking the link in the invitation email, accepting the terms and conditions and accepting the invitation. @@ -104,14 +111,16 @@ def accept_invitation(self): and use it to accept the invitation. Note: - If user accepts the invitation but this endpoint is not called, the invitation will be forever in the submitted state (until expiration). The user will still be able to access the community because the AAI will return the correct capabilities for the user. Currently, the PERUN api does not return the ID of the created invitation, so we cannot store it and check in a background task if the invitation was accepted and then change the state of the request. + """ + assert request.view_args is not None + request_id = decrypt(request.view_args["request_id"]) # get the invitation request and check if it is submitted. @@ -143,8 +152,9 @@ def accept_invitation(self): invitation_request.commit() current_requests_service.execute_action(system_identity, request_id, "accept") + return redirect("/") -def create_rest_blueprint(app): +def create_rest_blueprint(app: Flask) -> Blueprint: """Create a blueprint for the REST API.""" return OIDCEInfraResource().as_blueprint() diff --git a/oarepo_oidc_einfra/services/components/aai_communities.py b/oarepo_oidc_einfra/services/components/aai_communities.py index d9661ef..727a932 100644 --- a/oarepo_oidc_einfra/services/components/aai_communities.py +++ b/oarepo_oidc_einfra/services/components/aai_communities.py @@ -5,22 +5,30 @@ # modify it under the terms of the MIT License; see LICENSE file for more # details. # -"""AAI (perun) communities mapping""" +"""AAI (perun) communities mapping.""" import re +from typing import Optional from flask import current_app +from invenio_access.permissions import Identity +from invenio_communities.communities.records.api import Community from invenio_records_resources.services.records.components.base import ServiceComponent -from invenio_records_resources.services.uow import Operation +from invenio_records_resources.services.uow import Operation, UnitOfWork class PropagateToAAIOp(Operation): """Operation to propagate community to AAI in a background process.""" - def __init__(self, community): + def __init__(self, community: Community): + """Create a new operation.""" self.community = community - def on_post_commit(self, uow): + def on_post_commit(self, uow: UnitOfWork) -> None: + """Propagate the community to AAI. + + :param uow: unit of work + """ from oarepo_oidc_einfra.tasks import synchronize_community_to_perun synchronize_community_to_perun.delay(self.community.id) @@ -29,9 +37,27 @@ def on_post_commit(self, uow): class CommunityAAIComponent(ServiceComponent): """Community AAI component that propagates the community to Perun.""" - def create(self, identity, record=None, data=None, **kwargs): - """Create handler.""" + def create( + self, + identity: Identity, + *, + record: Optional[Community] = None, + data: Optional[dict] = None, + **kwargs: dict, + ) -> None: + """Create handler. + + This handler schedules the community to be propagated to AAI if the configuration + (EINFRA_COMMUNITY_SYNCHRONIZATION) allows it. + + :param identity: identity of the user + :param record: community record to be created + :param data: data to be created + :param kwargs: additional arguments + """ # propagate the community to AAI + assert data is not None + if "slug" not in data: raise ValueError("Missing slug in community data") if not re.match("^[a-z0-9-]+$", data["slug"]): @@ -42,13 +68,39 @@ def create(self, identity, record=None, data=None, **kwargs): if current_app.config["EINFRA_COMMUNITY_SYNCHRONIZATION"]: self.uow.register(PropagateToAAIOp(record)) - def update(self, identity, record=None, data=None, **kwargs): - """Update handler.""" + def update( + self, + identity: Identity, + record: Optional[Community] = None, + data: Optional[dict] = None, + **kwargs: dict, + ) -> None: + """Update handler. + + This handler prevents changing community slug as it is used as a key in AAI capabilities. + + :param identity: identity of the user + :param record: community record to be updated + :param data: data to be updated + :param kwargs: additional arguments + """ + assert data is not None + assert record is not None + if record.slug != data["slug"]: raise ValueError( "Cannot change the slug of the community as it is used in AAI" ) - def delete(self, identity, record=None, **kwargs): - """Delete handler.""" + def delete( + self, identity: Identity, record: Optional[Community] = None, **kwargs: dict + ) -> None: + """Delete handler. + + At this time, we do not want to delete communities in AAI, so we raise an error. + + :param identity: identity of the user + :param record: community record to be deleted + :param kwargs: additional arguments + """ raise NotImplementedError("Delete is not supported at the time being") diff --git a/oarepo_oidc_einfra/services/components/aai_invitations.py b/oarepo_oidc_einfra/services/components/aai_invitations.py index 3b6b26c..d1cc707 100644 --- a/oarepo_oidc_einfra/services/components/aai_invitations.py +++ b/oarepo_oidc_einfra/services/components/aai_invitations.py @@ -5,18 +5,19 @@ # modify it under the terms of the MIT License; see LICENSE file for more # details. # -"""AAI (perun) membership handling""" +"""AAI (perun) membership handling.""" from flask import current_app -from invenio_access.permissions import system_identity +from invenio_access.permissions import Identity, system_identity from invenio_accounts.models import User from invenio_communities.communities.records.api import Community from invenio_communities.members.records.api import Member from invenio_communities.members.services.service import invite_expires_at from invenio_records_resources.services.records.components.base import ServiceComponent -from invenio_records_resources.services.uow import Operation +from invenio_records_resources.services.uow import Operation, UnitOfWork from invenio_requests.customizations.event_types import CommentEventType from invenio_requests.proxies import current_events_service, current_requests_service +from invenio_requests.services.requests.results import RequestItem from invenio_users_resources.proxies import current_users_service from oarepo_runtime.i18n import lazy_gettext as _ @@ -26,10 +27,15 @@ class CreateAAIInvitationOp(Operation): """Operation to create an invitation within AAI in a background process.""" - def __init__(self, membership_request_id): + def __init__(self, membership_request_id: str): + """Create a new operation. + + :param membership_request_id: id of the membership request + """ self.membership_request_id = membership_request_id - def on_post_commit(self, uow): + def on_post_commit(self, uow: UnitOfWork) -> None: + """Create an invitation in AAI.""" from oarepo_oidc_einfra.tasks import create_aai_invitation if current_app.config["EINFRA_COMMUNITY_INVITATION_SYNCHRONIZATION"]: @@ -40,10 +46,30 @@ class AAIInvitationComponent(ServiceComponent): """Community AAI component that creates invitations within Perun AAI.""" def members_invite( - self, identity, *, record, community, errors, role, visible, message, **kwargs - ): - """Handler for member invitation.""" - + self, + identity: Identity, + *, + record: Member, + community: Community, + errors: dict, + role: str, + visible: bool, + message: str, + **kwargs: dict, + ) -> None: + """Invite a new member to a community. + + Will create an invitation in AAI as well. + + :param identity: identity of the user performing the operation + :param record: member record + :param community: community record in which the member is being invited + :param errors: errors that occurred during the pre-invitation operation + :param role: role of the member in the community + :param visible: visibility of the member in the community + :param message: message to be sent to the member + :param kwargs: additional arguments (not used) + """ member = record member_email = member.get("email") @@ -82,8 +108,22 @@ def members_invite( self.uow.register(CreateAAIInvitationOp(request_item["id"])) def members_update( - self, identity, *, record: Member, community: Community, **kwargs - ): + self, + identity: Identity, + *, + record: Member, + community: Community, + **kwargs: dict, + ) -> None: + """Update a member in AAI. + + This callback will, if enabled in the configuration, update the member in the AAI. + + :param identity: identity of the user performing the operation + :param record: member record + :param community: community record in which the member is being updated + :param kwargs: additional arguments (not used) + """ from oarepo_oidc_einfra.tasks import change_aai_role if not record.user_id: @@ -99,8 +139,22 @@ def members_update( change_aai_role(community.slug, record.user_id, record.role) def members_delete( - self, identity, *, record: Member, community: Community, **kwargs - ): + self, + identity: Identity, + *, + record: Member, + community: Community, + **kwargs: dict, + ) -> None: + """Remove a member from AAI. + + This callback will, if enabled in the configuration, remove the member from the AAI. + + :param identity: identity of the user performing the operation + :param record: member record + :param community: community record from which the member is being removed + :param kwargs: additional arguments (not used) + """ from oarepo_oidc_einfra.tasks import remove_aai_user_from_community if not record.user_id: @@ -115,7 +169,15 @@ def members_delete( # would be reverted. remove_aai_user_from_community(community.slug, record.user_id) - def _add_invitation_message_to_request(self, identity, request_item, message): + def _add_invitation_message_to_request( + self, identity: Identity, request_item: RequestItem, message: str + ) -> None: + """Add a message to the invitation request. + + :param identity: identity of the user adding message to the request + :param request_item: request item, result of the _create_invitation_request + :param message: message to be added to the request + """ data = {"payload": {"content": message}} current_events_service.create( identity, @@ -126,7 +188,16 @@ def _add_invitation_message_to_request(self, identity, request_item, message): notify=False, ) - def _create_invitation_request(self, identity, community, user_id, role): + def _create_invitation_request( + self, identity: Identity, community: Community, user_id: int, role: str + ) -> RequestItem: + """Create an invitation request in the repository. + + :param identity: identity of the user creating the request + :param community: community record + :param user_id: user id + :param role: role of the user in the community + """ title = _('Invitation to join "{community}"').format( community=community.metadata["title"], ) @@ -143,7 +214,17 @@ def _create_invitation_request(self, identity, community, user_id, role): ) return request_item - def _get_invitation_user(self, member_email, member_first_name, member_last_name): + def _get_invitation_user( + self, member_email: str, member_first_name: str, member_last_name: str + ) -> User: + """Get user id for the invitation. + + If the user with the email already exists, return its id. If not, create a new user. + + :param member_email: email of the member + :param member_first_name: first name of the member + :param member_last_name: last name of the member + """ u = User.query.filter_by(email=member_email.lower()).one_or_none() if u: return u.id diff --git a/oarepo_oidc_einfra/services/requests/__init__.py b/oarepo_oidc_einfra/services/requests/__init__.py index 4f563a3..6affaf3 100644 --- a/oarepo_oidc_einfra/services/requests/__init__.py +++ b/oarepo_oidc_einfra/services/requests/__init__.py @@ -5,3 +5,4 @@ # modify it under the terms of the MIT License; see LICENSE file for more # details. # +"""AAI backed requests.""" diff --git a/oarepo_oidc_einfra/services/requests/invitation.py b/oarepo_oidc_einfra/services/requests/invitation.py index 0b1db57..8a7ee3f 100644 --- a/oarepo_oidc_einfra/services/requests/invitation.py +++ b/oarepo_oidc_einfra/services/requests/invitation.py @@ -5,14 +5,18 @@ # modify it under the terms of the MIT License; see LICENSE file for more # details. # +"""AAI backed invitation request.""" + from invenio_communities.members.services.request import CommunityInvitation from oarepo_runtime.i18n import lazy_gettext as _ class AAICommunityInvitation(CommunityInvitation): + """AAI backed invitation request.""" + type_id = "aai-community-invitation" name = _("AAI Community invitation") # there is no invenio receiver for this type as it is handled by the AAI receiver_can_be_none = True - allowed_receiver_ref_types = [] + allowed_receiver_ref_types: list[str] = [] diff --git a/oarepo_oidc_einfra/tasks.py b/oarepo_oidc_einfra/tasks.py index 67d2b2d..c7f834a 100644 --- a/oarepo_oidc_einfra/tasks.py +++ b/oarepo_oidc_einfra/tasks.py @@ -5,16 +5,16 @@ # modify it under the terms of the MIT License; see LICENSE file for more # details. # -""" -Background tasks. -""" +"""Background tasks.""" + +from __future__ import annotations + import json import logging from io import BytesIO from itertools import chain, islice -from typing import List, Tuple +from typing import TYPE_CHECKING, Iterable, Literal from urllib.parse import urljoin -from uuid import UUID import boto3 from celery import shared_task @@ -29,18 +29,26 @@ from oarepo_oidc_einfra.encryption import encrypt from oarepo_oidc_einfra.mutex import mutex from oarepo_oidc_einfra.perun.dump import PerunDumpData -from oarepo_oidc_einfra.perun.mapping import einfra_to_local_users_map, \ - get_perun_capability_from_invenio_role, get_user_einfra_id +from oarepo_oidc_einfra.perun.mapping import ( + einfra_to_local_users_map, + get_perun_capability_from_invenio_role, + get_user_einfra_id, +) from oarepo_oidc_einfra.proxies import current_einfra_oidc +if TYPE_CHECKING: + from uuid import UUID + + from oarepo_oidc_einfra.perun import PerunLowLevelAPI + log = logging.getLogger("PerunSynchronizationTask") @shared_task @mutex("EINFRA_SYNC_MUTEX") -def synchronize_community_to_perun(community_id) -> None: - """ - Synchronizes community into Perun groups and resources. +def synchronize_community_to_perun(community_id: str) -> None: + """Synchronize community into Perun groups and resources. + The call is idempotent, if the perun mapping already exists, it is left untouched. @@ -95,19 +103,19 @@ def synchronize_community_to_perun(community_id) -> None: def map_community_or_role( - api, + api: PerunLowLevelAPI, *, - parent_id, - parent_vo, - name, - description, - resource_name, - resource_description, - resource_capabilities, -): - """ - Map a single community or community role, adds synchronization service so that we get - the resource in the dump from perun. + parent_id: int, + parent_vo: int, + name: str, + description: str, + resource_name: str, + resource_description: str, + resource_capabilities: list[str], +) -> tuple[dict, dict]: + """Map a single community or community role to perun's groups and resources. + + The call adds synchronization service so that we get the resource in the dump from perun. :param api: perun api :param parent_id: parent group @@ -141,20 +149,24 @@ def map_community_or_role( @shared_task -def synchronize_all_communities_to_perun(): - """ - Checks and repairs community mapping within perun - """ +def synchronize_all_communities_to_perun() -> None: + """Check and repair community mapping within perun.""" for community_model in Community.model_cls.query.all(): synchronize_community_to_perun(str(community_model.id)) @shared_task @mutex("EINFRA_SYNC_MUTEX") -def update_from_perun_dump(dump_path, fix_communities_in_perun=True): - """ - Updates user communities from perun dump and checks for local communities - not propagated to perun yet (and propagates them) +def update_from_perun_dump( + dump_path: str, fix_communities_in_perun: bool = True +) -> None: + """Update user communities from perun dump and propagate local communities that are not in perun yet. + + The dump with perun data is downloaded from the S3 storage and the users are synchronized + with the database. + + Note: we suppose that the dump is small enough to be processed in a single task and the processing + will take less than 1 hour (the default task timeout inside the mutex). :param dump_path: url with the dump :param fix_communities_in_perun if some local communities were not propagated to perun, propagate them @@ -188,38 +200,54 @@ def update_from_perun_dump(dump_path, fix_communities_in_perun=True): synchronize_users_from_perun(dump, community_support) -def synchronize_communities_to_perun(repository_community_roles, aai_community_roles): - resource_community_roles: List[Tuple[str, str]] +def synchronize_communities_to_perun( + repository_community_roles: set[CommunityRole], + aai_community_roles: set[CommunityRole], +) -> None: + """Synchronize communities to perun if they do not exist in perun yet. + :param repository_community_roles: set of community roles from the repository + :param aai_community_roles: set of community roles from the perun dump + """ if repository_community_roles - aai_community_roles: log.info( "Some community roles are not mapped " f"to any resource: {repository_community_roles - aai_community_roles}" ) - unsynchronized_communities = { + communities_not_in_perun = { str(cr.community_id) for cr in repository_community_roles - aai_community_roles } - for community_id in unsynchronized_communities: + for community_id in communities_not_in_perun: synchronize_community_to_perun(community_id) +def chunks[T](iterable: Iterable[T], size: int = 10) -> Iterable[chain[T]]: + """Split the iterable into chunks of the given size. - -def chunks(iterable, size=10): + :param iterable: an iterable that will be split to chunks + :param size: size of the chunk + """ iterator = iter(iterable) for first in iterator: yield chain([first], islice(iterator, size - 1)) -def synchronize_users_from_perun(dump, community_support): +def synchronize_users_from_perun( + dump: PerunDumpData, community_support: CommunitySupport +) -> None: + """Synchronize users from perun dump to the database. + + :param dump: perun dump data + :param community_support: community support object + """ local_users_by_einfra = einfra_to_local_users_map() print([aai_user.email for aai_user in dump.users()]) for aai_user_chunk in chunks(dump.users(), 100): aai_user_chunk_by_einfra_id = {u.einfra_id: u for u in aai_user_chunk} local_user_id_to_einfra_id = {} - for einfra_id in aai_user_chunk_by_einfra_id.keys(): + for einfra_id in aai_user_chunk_by_einfra_id: local_user_id = local_users_by_einfra.pop(einfra_id, None) if local_user_id: local_user_id_to_einfra_id[local_user_id] = einfra_id @@ -242,12 +270,15 @@ def synchronize_users_from_perun(dump, community_support): for user in local_users: aai_user = aai_user_chunk_by_einfra_id[local_user_id_to_einfra_id[user.id]] + log.info("Setting user %s with roles %s", user, aai_user.roles) print("Setting user", user, aai_user.roles) update_user_metadata( user, aai_user.full_name, aai_user.email, aai_user.organization ) - new_community_roles = filter_community_roles(community_support, aai_user.roles) + new_community_roles = filter_community_roles( + community_support, aai_user.roles + ) community_support.set_user_community_membership( user, @@ -260,26 +291,43 @@ def synchronize_users_from_perun(dump, community_support): # for users that are not in the dump anymore, remove all communities for local_user_id in local_users_by_einfra.values(): user = User.query.filter_by(id=local_user_id).one() - print("Removing obsolete user", user) + log.info("Removing obsolete user %s", user) community_support.set_user_community_membership(user, set()) -def filter_community_roles(community_support, aai_roles): - new_community_roles = {} - for community_id, role in aai_roles: - community_id = UUID(community_id) - if community_id not in new_community_roles or ( - community_support.role_priority(role) +def filter_community_roles( + community_support: CommunitySupport, aai_roles: Iterable[CommunityRole] +) -> set[CommunityRole]: + """Filter community roles to keep only the most important role for each community. + + :param community_support: community support object + :param aai_roles: an iterable community roles + """ + new_community_roles: dict[UUID, CommunityRole] = {} + + for community_role in aai_roles: + if community_role.community_id not in new_community_roles or ( + community_support.role_priority(community_role.role) > community_support.role_priority( - new_community_roles[community_id].role - ) - ): - new_community_roles[community_id] = CommunityRole( - community_id, role + new_community_roles[community_role.community_id].role ) + ): + new_community_roles[community_role.community_id] = community_role return set(new_community_roles.values()) -def update_user_metadata(user, full_name, email, organization): + +def update_user_metadata( + user: User, full_name: str, email: str, organization: str +) -> None: + """Update user metadata in the database. + + If the data is the same, nothing is updated. + + :param user: user object + :param full_name: full name + :param email: email + :param organization: organization + """ save = False user_profile = user.user_profile if full_name != user.user_profile.get("full_name"): @@ -299,7 +347,12 @@ def update_user_metadata(user, full_name, email, organization): @shared_task -def create_aai_invitation(request_id): +def create_aai_invitation(request_id: str) -> dict | None: + """Create an invitation in AAI for an invenio invitation request. + + :param request_id: id of the invenio invitation request + :return: invitation data as returned from perun + """ perun_api = current_einfra_oidc.perun_api() request = Request.get_record(request_id) @@ -308,13 +361,18 @@ def create_aai_invitation(request_id): capability = get_perun_capability_from_invenio_role( request.topic.slug, invitation.role ) - group = perun_api.get_resource_by_capability(capability) + group = perun_api.get_resource_by_capability( + vo_id=current_app.config["EINFRA_REPOSITORY_VO_ID"], + facility_id=current_app.config["EINFRA_REPOSITORY_FACILITY_ID"], + capability=capability, + ) + if not group: log.error( f"Resource for capability {capability} not found inside Perun, " f"so can not send invitation to its associated group." ) - return + return None encrypted_request_id = encrypt(request_id) @@ -326,7 +384,7 @@ def create_aai_invitation(request_id): ) email = invitation.user.email - perun_api.send_invitation( + return perun_api.send_invitation( vo_id=current_app.config["EINFRA_REPOSITORY_VO_ID"], group_id=group["id"], email=email, @@ -338,25 +396,51 @@ def create_aai_invitation(request_id): @shared_task -def change_aai_role(community_slug, user_id, new_role): +def change_aai_role(community_slug: str, user_id: int, new_role: str) -> None: + """Propagate changed community role to AAI. + + :param community_slug: community slug + :param user_id: user id (internal) + :param new_role: new role name + """ remove_aai_user_from_community(community_slug, user_id) add_aai_role(community_slug, user_id, new_role) @shared_task -def remove_aai_user_from_community(community_slug, user_id): +def remove_aai_user_from_community(community_slug: str, user_id: int) -> None: + """Remove user from perun group representing a community. + + :param community_slug: community slug + :param user_id: user id + """ for role in CommunitySupport().role_names: aai_group_op("remove_user_from_group", community_slug, user_id, role) @shared_task -def add_aai_role(community_slug, user_id, role): +def add_aai_role(community_slug: str, user_id: int, role: str) -> None: + """Add user to perun group representing a community and a role. + + :param community_slug: community slug + :param user_id: user id + :param role: role name + """ aai_group_op("add_user_to_group", community_slug, user_id, role) -def aai_group_op(op, community_slug, user_id, role): - """ - Universal function for adding/removing user from group in AAI +def aai_group_op( + op: Literal["add_user_to_group", "remove_user_from_group"], + community_slug: str, + user_id: int, + role: str, +) -> None: + """Universal function for adding/removing user from group in AAI. + + :param op: operation to perform (add_user_to_group, remove_user_from_group) + :param community_slug: community slug + :param user_id: user id + :param role: role name """ perun_api = current_einfra_oidc.perun_api() @@ -379,7 +463,7 @@ def aai_group_op(op, community_slug, user_id, role): return user = perun_api.get_user_by_attribute( - attribute_name=current_app.config("EINFRA_USER_EINFRAID_ATTRIBUTE"), + attribute_name=current_app.config["EINFRA_USER_EINFRAID_ATTRIBUTE"], attribute_value=einfra_id, ) if user is None: @@ -390,5 +474,5 @@ def aai_group_op(op, community_slug, user_id, role): return # 2. for each group, perform the operation on it - for group in perun_api.get_resource_groups(resource["id"]): - getattr(perun_api, "op")(user["id"], group["id"]) + for group in perun_api.get_resource_groups(resource_id=resource["id"]): + getattr(perun_api, op)(user["id"], group["id"]) diff --git a/pyproject.toml b/pyproject.toml index 8289f07..3391905 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,3 +1,32 @@ [build_system] requires = ["setuptools", "wheel", "babel>2.8"] -build-backend = "setuptools.build_meta" \ No newline at end of file +build-backend = "setuptools.build_meta" + +[tool.ruff.lint] +extend-select = [ + "UP", # pyupgrade + "D", # pydocstyle + "B", # flake8-bugbear + "SIM", # flake8-simplify + "I", # isort + "TCH", # type checking + "ANN", # annotations + "DOC", # docstrings +] + +ignore = [ + "ANN101", # Missing type annotation for self in method + "ANN102", # Missing type annotation for cls in classmethod + "ANN204", # Missing return type annotation in __init__ method + "UP007", # Imho a: Optional[int] = None is more readable than a: (int | None) = None for kwargs + + "D203", # 1 blank line required before class docstring (we use D211) + "D213", # Multi-line docstring summary should start at the second line - we use D212 (starting on the same line) + +] + +[tool.ruff.lint.flake8-annotations] +mypy-init-return = true + +[tool.mypy] +disable_error_code = ["import-untyped", "import-not-found"] diff --git a/setup.cfg b/setup.cfg index cf051c7..01e6698 100644 --- a/setup.cfg +++ b/setup.cfg @@ -40,6 +40,7 @@ dev = isort autoflake licenseheaders + ruff tests = pytest-invenio responses diff --git a/setup.py b/setup.py index 99f81d3..d43ece9 100644 --- a/setup.py +++ b/setup.py @@ -5,6 +5,8 @@ # modify it under the terms of the MIT License; see LICENSE file for more # details. # +"""E-INFRA OIDC Auth backend for OARepo.""" + from setuptools import setup setup() diff --git a/tests/test_login.py b/tests/test_login.py index 42985fd..e8b0b12 100644 --- a/tests/test_login.py +++ b/tests/test_login.py @@ -14,8 +14,7 @@ @pytest.mark.skip(reason="This test is intended to be run manually") def test_login(app, db, location, search_clear, client, test_ui_pages): - """ - This test shows how to log in a user using the E-Infra OIDC provider. + """This test shows how to log in a user using the E-Infra OIDC provider. As log-in is a process based on a web browser, the test must be run manually at the moment @@ -40,7 +39,7 @@ def test_login(app, db, location, search_clear, client, test_ui_pages): db.session.add(user) db.session.commit() - identity = UserIdentity.create( + UserIdentity.create( user=user, method="e-infra", external_id="user1@einfra.cesnet.cz", diff --git a/tests/test_low_level_perun_api.py b/tests/test_low_level_perun_api.py index c7632f4..e5512b0 100644 --- a/tests/test_low_level_perun_api.py +++ b/tests/test_low_level_perun_api.py @@ -12,7 +12,6 @@ def test_create_non_existing_group( smart_record, low_level_perun_api, test_repo_communities_id, test_vo_id ): - with smart_record("test_create_group.yaml") as recorded: group, group_created, admin_created = low_level_perun_api.create_group( name="AAA", @@ -25,14 +24,13 @@ def test_create_non_existing_group( else: print(f"Add the >>> assert group['id'] == {group['id']} here <<<") - assert group_created == True - assert admin_created == True + assert group_created is True + assert admin_created is True def test_create_existing_group( smart_record, low_level_perun_api, test_repo_communities_id, test_vo_id ): - with smart_record("test_create_group_existing.yaml"): group, group_created, admin_created = low_level_perun_api.create_group( name="AAA", @@ -54,7 +52,6 @@ def test_create_resource_for_group( test_capabilities_attribute_id, perun_sync_service_id, ): - with smart_record("test_create_resource_for_group.yaml") as recorded: resource, resource_created = ( low_level_perun_api.create_resource_with_group_and_capabilities( @@ -72,7 +69,7 @@ def test_create_resource_for_group( assert resource["id"] == 14408 else: print(f"Add the >>> assert resource['id'] == {resource['id']} here <<<") - assert resource_created == True + assert resource_created is True def test_create_resource_for_group_existing( @@ -85,7 +82,6 @@ def test_create_resource_for_group_existing( test_capabilities_attribute_id, perun_sync_service_id, ): - with smart_record("test_create_resource_for_group_existing.yaml") as recorded: resource, resource_created = ( low_level_perun_api.create_resource_with_group_and_capabilities( @@ -104,14 +100,13 @@ def test_create_resource_for_group_existing( else: print(f"Add the >>> assert resource['id'] == {resource['id']} here <<<") - assert resource_created == False + assert resource_created is False def test_add_user_to_group( app, smart_record, low_level_perun_api, test_repo_communities_id, test_vo_id ): - - with smart_record("test_add_user_to_group.yaml") as recorded: + with smart_record("test_add_user_to_group.yaml"): group, group_created, admin_created = low_level_perun_api.create_group( name="AAA", description="Community AAA", @@ -136,7 +131,7 @@ def test_add_user_to_group( def test_send_invitation( app, smart_record, low_level_perun_api, test_repo_communities_id, test_vo_id ): - with smart_record("test_invite_user_to_group.yaml") as recorded: + with smart_record("test_invite_user_to_group.yaml"): group, group_created, admin_created = low_level_perun_api.create_group( name="AAA", description="Community AAA", diff --git a/tests/test_perun_sync_task.py b/tests/test_perun_sync_task.py index 206d212..4197a85 100644 --- a/tests/test_perun_sync_task.py +++ b/tests/test_perun_sync_task.py @@ -25,5 +25,5 @@ def test_sync_community(app, db, location, smart_record, search_clear): ) current_communities.service.indexer.refresh() - with smart_record("test_initial_sync_community") as recorded: + with smart_record("test_initial_sync_community"): synchronize_community_to_perun(community.id) diff --git a/tests/test_store_dump.py b/tests/test_store_dump.py index 3e79517..e180122 100644 --- a/tests/test_store_dump.py +++ b/tests/test_store_dump.py @@ -15,7 +15,6 @@ def test_store_dump(app, db, client, test_ui_pages): - user = User(email="test@test.com", active=True) db.session.add(user) db.session.commit() diff --git a/tests/test_update_from_perun_dump.py b/tests/test_update_from_perun_dump.py index c261f69..850abb3 100644 --- a/tests/test_update_from_perun_dump.py +++ b/tests/test_update_from_perun_dump.py @@ -29,9 +29,7 @@ def test_no_communities(app, db, location, search_clear): def test_no_communities_user_exists_but_not_linked( app, db, location, search_clear, smart_record ): - with smart_record( - "test_no_communities_user_exists_but_not_linked.yaml" - ) as recorded: + with smart_record("test_no_communities_user_exists_but_not_linked.yaml"): my_original_email = "ms@cesnet.cz" user = User( username="asdasdasd", @@ -53,7 +51,7 @@ def test_no_communities_user_exists_but_not_linked( def test_no_communities_user_linked(app, db, location, search_clear, smart_record): - with smart_record("test_no_communities_user_linked.yaml") as recorded: + with smart_record("test_no_communities_user_linked.yaml"): my_original_email = "ms@cesnet.cz" user = User( username="asdasdasd", @@ -65,7 +63,7 @@ def test_no_communities_user_linked(app, db, location, search_clear, smart_recor db.session.add(user) db.session.commit() - identity = UserIdentity.create( + UserIdentity.create( user=user, method="e-infra", external_id="user1@einfra.cesnet.cz", @@ -83,7 +81,7 @@ def test_no_communities_user_linked(app, db, location, search_clear, smart_recor def test_with_communities(app, db, location, search_clear, smart_record): - with smart_record("test_with_communities.yaml") as recorded: + with smart_record("test_with_communities.yaml"): my_original_email = "ms@cesnet.cz" user = User( username="asdasdasd", @@ -95,7 +93,7 @@ def test_with_communities(app, db, location, search_clear, smart_record): db.session.add(user) db.session.commit() - identity = UserIdentity.create( + UserIdentity.create( user=user, method="e-infra", external_id="user1@einfra.cesnet.cz", @@ -141,7 +139,7 @@ def test_with_communities(app, db, location, search_clear, smart_record): def test_user_not_found_anymore(app, db, location, search_clear, smart_record): - with smart_record("test_suspend_user.yaml") as recorded: + with smart_record("test_suspend_user.yaml"): user = User( username="asdasdasd", email="ms@cesnet.cz", @@ -152,7 +150,7 @@ def test_user_not_found_anymore(app, db, location, search_clear, smart_record): db.session.add(user) db.session.commit() - identity = UserIdentity.create( + UserIdentity.create( user=user, method="e-infra", external_id="user1@einfra.cesnet.cz",