Skip to content

Commit

Permalink
formatting
Browse files Browse the repository at this point in the history
  • Loading branch information
mesemus committed Oct 11, 2024
1 parent 16749b4 commit fbafe1d
Show file tree
Hide file tree
Showing 7 changed files with 171 additions and 56 deletions.
9 changes: 2 additions & 7 deletions format.sh
Original file line number Diff line number Diff line change
@@ -1,9 +1,4 @@
#!/bin/bash

files="$( (git status --short| grep '^?' | cut -d\ -f2- && git ls-files ) | egrep ".*[.]py" | sort -u | tr '\n' ' ')"

black --target-version py310 $files
autoflake -r --in-place --remove-all-unused-imports $files
isort --profile black $files

python -m licenseheaders -t .copyright.tmpl -cy -f $files
"$(dirname $0)/python_format.sh" $(( git status --short| grep '^?' | cut -d\ -f2- && git ls-files ) | egrep ".*[.]py" | sort -u )
`dirname $0`/python-packages/bin/python -m licenseheaders -t .copyright.tmpl -cy -f $(( git status --short| grep '^?' | cut -d\ -f2- && git ls-files ) | egrep ".*[.]py" | sort -u )
7 changes: 7 additions & 0 deletions oarepo_oidc_einfra/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from invenio_db import db
from werkzeug.local import LocalProxy

from oarepo_oidc_einfra.mutex import CacheMutex
from oarepo_oidc_einfra.perun.dump import import_dump_file
from oarepo_oidc_einfra.tasks import update_from_perun_dump

Expand Down Expand Up @@ -68,6 +69,12 @@ def add_einfra_user(email, einfra_id):
_add_einfra_user(email, einfra_id)


@einfra.command("clear_import_mutex")
@with_appcontext
def clear_import_mutex():
CacheMutex("EINFRA_SYNC_MUTEX").force_clear()


def _add_einfra_user(email, einfra_id):
_datastore = LocalProxy(lambda: current_app.extensions["security"].datastore)

Expand Down
53 changes: 51 additions & 2 deletions oarepo_oidc_einfra/communities.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,15 +70,44 @@ def role_names(self) -> Set[str]:
"""
return {role["name"] for role in current_app.config["COMMUNITIES_ROLES"]}

def role_priority(self, role_name: str) -> int:
"""
Returns a priority of a given role name.
:param role_name: role name
:return: role priority (0 is lowest (member), higher number is higher priority (up to owner))
"""
return self.role_priorities[role_name]

@cached_property
def role_priorities(self) -> Dict[str, int]:
"""
Returns a mapping of role names to their priorities.
:return: a mapping of role names to their priorities, 0 is lowest priority
"""
return {
role["name"]: len(current_app.config["COMMUNITIES_ROLES"]) - role_idx
for role_idx, role in enumerate(current_app.config["COMMUNITIES_ROLES"])
}

@classmethod
def set_user_community_membership(
cls, user: User, new_community_roles: Set[CommunityRole]
cls,
user: User,
new_community_roles: Set[CommunityRole],
current_community_roles: Set[CommunityRole] = None,
) -> None:
"""Set user membership based on the new community roles.
The previous community roles, not present in new_community_roles, are removed.
:param user: User object for which communities will be set
:param new_community_roles: Set of new community roles
:param current_community_roles: Set of current community roles. If not passed, it is fetched from the database.
"""
current_community_roles = cls.get_user_community_membership(user)
if not current_community_roles:
current_community_roles = cls.get_user_community_membership(user)

for community_id, role in new_community_roles - current_community_roles:
cls._add_user_community_membership(community_id, role, user)
Expand Down Expand Up @@ -117,6 +146,26 @@ def get_user_community_membership(cls, user) -> Set[CommunityRole]:

return ret

@classmethod
def get_user_list_community_membership(
cls, user_ids: list[int]
) -> Dict[int, Set[CommunityRole]]:
"""Get community roles of a list of users.
:param user_ids: List of user ids
"""
ret = {}
for row in db.session.execute(
select(
[MemberModel.community_id, MemberModel.user_id, MemberModel.role]
).where(MemberModel.user_id.in_(user_ids), MemberModel.active == True)
):
if row.user_id not in ret:
ret[row.user_id] = set()
ret[row.user_id].add(CommunityRole(row.community_id, row.role))

return ret

@classmethod
def _add_user_community_membership(
cls, community_id: str, community_role: str, user: User
Expand Down
7 changes: 7 additions & 0 deletions oarepo_oidc_einfra/mutex.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,13 @@ def __exit__(self, exc_type, exc_value, traceback):
if current_cache.cache.get(self.key) == self.value:
current_cache.cache.delete(self.key)

def force_clear(self):
"""
Forces the mutex to be cleared.
Note: this does not stop any processes that might be using the mutex !
"""
current_cache.cache.delete(self.key)


mutex_thread_local = threading.local()
"""make the mutex below reentrant within the same thread"""
Expand Down
59 changes: 31 additions & 28 deletions oarepo_oidc_einfra/perun/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
# details.
#
import logging
from typing import Optional

import requests
from requests.auth import HTTPBasicAuth
Expand All @@ -28,7 +29,7 @@ class PerunLowLevelAPI:
Note: All ids are internal Perun ids, not UUIDs or other external identifiers.
"""

def __init__(self, base_url, service_id, service_username, service_password):
def __init__(self, base_url: str, service_id: int, service_username: str, service_password: str):
"""
Initialize the API with the base URL and the service credentials.
Expand All @@ -41,7 +42,7 @@ def __init__(self, base_url, service_id, service_username, service_password):
self._service_id = service_id
self._auth = HTTPBasicAuth(service_username, service_password)

def _perun_call(self, manager, method, payload):
def _perun_call(self, manager: str, method: str, payload: dict) -> dict|list:
"""Low-level call to Perun API with error handling.
:param manager: the manager to call
Expand Down Expand Up @@ -69,14 +70,16 @@ def _perun_call(self, manager, method, payload):
return resp.json()

def create_group(
self, *, name, description, parent_group_id, parent_vo, check_existing=True
self, *, name: str, description: str, parent_group_id: int,
parent_vo: int, check_existing: bool=True
):
"""
Create a new group in Perun and set the service as its admin
:param name: Name of the group
:param description: Description of the group
:param parent_group_id: ID of the parent group
:param parent_vo: ID of the VO the parent group belongs to
:param check_existing: If True, check if the group already exists and do not create it
:return: (group: json, group_created: bool, admin_created: bool)
"""
Expand Down Expand Up @@ -141,7 +144,7 @@ def create_group(

return (group, group_created, admin_created)

def get_group_by_name(self, name, parent_group_id):
def get_group_by_name(self, name: str, parent_group_id: int) -> Optional[str]:
"""
Get a group by name within a parent group.
Expand All @@ -160,16 +163,16 @@ def get_group_by_name(self, name, parent_group_id):
def create_resource_with_group_and_capabilities(
self,
*,
vo_id,
facility_id,
group_id,
name,
description,
capability_attr_id,
capabilities,
perun_sync_service_id,
check_existing=True,
):
vo_id: int,
facility_id: int,
group_id : int,
name: str,
description: str,
capability_attr_id: int,
capabilities: list[str],
perun_sync_service_id: int,
check_existing: bool =True,
) -> (dict, bool):
"""
Create a new resource in Perun and assign the group to it.
Expand Down Expand Up @@ -202,8 +205,8 @@ def create_resource_with_group_and_capabilities(
return resource, resource_created

def create_resource(
self, vo_id, facility_id, name, description, check_existing=True
):
self, vo_id: int, facility_id: int, name: str, description: str, check_existing: bool=True
) -> (dict, bool):
"""
Create a new resource in Perun, optionally checking if a resource with the same name already exists.
Expand Down Expand Up @@ -247,7 +250,7 @@ def create_resource(
)
return resource, resource_created

def assign_group_to_resource(self, resource_id, group_id):
def assign_group_to_resource(self, resource_id: int, group_id: int) -> None:
"""
Assign a group to a resource.
Expand All @@ -273,7 +276,7 @@ def assign_group_to_resource(self, resource_id, group_id):
)
log.info("Group %s assigned to resource %s", group_id, resource_id)

def set_resource_capabilities(self, resource_id, capability_attr_id, capabilities):
def set_resource_capabilities(self, resource_id: int, capability_attr_id: int, capabilities: list[str]) -> None:
"""
Set capabilities to a resource.
Expand Down Expand Up @@ -301,7 +304,7 @@ def set_resource_capabilities(self, resource_id, capability_attr_id, capabilitie
)
log.info("Capabilities %s set to resource %s", capabilities, resource_id)

def attach_service_to_resource(self, resource_id, service_id):
def attach_service_to_resource(self, resource_id: int, service_id: int) -> None:
"""
Attach a service to a resource.
Expand Down Expand Up @@ -333,7 +336,7 @@ def attach_service_to_resource(self, resource_id, service_id):
resource_id,
)

def get_resource_by_name(self, vo_id, facility_id, name):
def get_resource_by_name(self, vo_id: int, facility_id: int, name: str) -> Optional[dict]:
"""
Get a resource by name.
Expand All @@ -351,7 +354,7 @@ def get_resource_by_name(self, vo_id, facility_id, name):
except DoesNotExist:
return None

def get_resource_by_capability(self, *, vo_id, facility_id, capability):
def get_resource_by_capability(self, *, vo_id: int, facility_id: int, capability: str) -> Optional[dict]:
"""
Get a resource by capability.
Expand Down Expand Up @@ -379,7 +382,7 @@ def get_resource_by_capability(self, *, vo_id, facility_id, capability):
)
return matching_resources[0]

def get_resource_groups(self, *, resource_id):
def get_resource_groups(self, *, resource_id: int) -> list[dict]:
"""
Get groups assigned to a resource.
Expand All @@ -397,7 +400,7 @@ def get_resource_groups(self, *, resource_id):
)
]

def get_user_by_attribute(self, *, attribute_name, attribute_value):
def get_user_by_attribute(self, *, attribute_name: str, attribute_value: str) -> Optional[dict]:
"""
Get a user by attribute.
Expand All @@ -418,7 +421,7 @@ def get_user_by_attribute(self, *, attribute_name, attribute_value):
return None
return users[0]

def remove_user_from_group(self, *, vo_id, user_id, group_id):
def remove_user_from_group(self, *, vo_id: int, user_id: int, group_id: int) -> None:
"""
Remove a user from a group.
Expand All @@ -434,7 +437,7 @@ def remove_user_from_group(self, *, vo_id, user_id, group_id):
{"group": group_id, "member": member["id"]},
)

def add_user_to_group(self, *, vo_id, user_id, group_id):
def add_user_to_group(self, *, vo_id: int, user_id: int, group_id: int) -> None:
"""
Add a user to a group.
Expand All @@ -450,7 +453,7 @@ def add_user_to_group(self, *, vo_id, user_id, group_id):
{"group": group_id, "member": member["id"]},
)

def _get_or_create_member_in_vo(self, vo_id, user_id):
def _get_or_create_member_in_vo(self, vo_id: int, user_id: int) -> dict:
# TODO: create part here (but we might not need it if everything goes through invitations)
member = self._perun_call(
"membersManager", "getMemberByUser", {"vo": vo_id, "user": user_id}
Expand All @@ -467,7 +470,7 @@ def send_invitation(
language: str,
expiration: str,
redirect_url: str,
):
) -> dict:
"""
Send an invitation to a user to join a group.
Expand All @@ -479,7 +482,7 @@ def send_invitation(
:param expiration: expiration date of the invitation, format YYYY-MM-DD
:param redirect_url: URL to redirect to after accepting the invitation
"""
self._perun_call(
return self._perun_call(
"invitationsManager",
"inviteToGroup",
{
Expand Down
17 changes: 11 additions & 6 deletions oarepo_oidc_einfra/perun/dump.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
# modify it under the terms of the MIT License; see LICENSE file for more
# details.
#
import dataclasses
import logging
from collections import defaultdict, namedtuple
from datetime import UTC, datetime
Expand All @@ -18,10 +19,14 @@

log = logging.getLogger("perun.dump_data")


AAIUser = namedtuple(
"AAIUser", ["einfra_id", "email", "full_name", "organization", "roles"]
)
@dataclasses.dataclass(frozen=True)
class AAIUser:
"""A user with their roles as received from the Perun AAI."""
einfra_id: str
email: str
full_name: str
organization: str
roles: Set[CommunityRole]


class PerunDumpData:
Expand Down Expand Up @@ -95,7 +100,7 @@ def resource_to_community_roles(self) -> Dict[str, List[CommunityRole]]:
if role not in self.community_role_names:
log.error(f"Role from PERUN {role} not found in the repository")
continue
community_role = (self.slug_to_id[community_slug], role)
community_role = CommunityRole(self.slug_to_id[community_slug], role)
resources[r_id].append(community_role)

return resources
Expand Down Expand Up @@ -142,7 +147,7 @@ def _get_roles_for_resources(
return aai_communities


def import_dump_file(data: bytes):
def import_dump_file(data: bytes) -> str:
"""
Imports a dump file from the input stream into S3 and returns file name
"""
Expand Down
Loading

0 comments on commit fbafe1d

Please sign in to comment.