diff --git a/dataworkspace/dataworkspace/apps/accounts/utils.py b/dataworkspace/dataworkspace/apps/accounts/utils.py index 7f51ae0ca0..1f17a1f055 100644 --- a/dataworkspace/dataworkspace/apps/accounts/utils.py +++ b/dataworkspace/dataworkspace/apps/accounts/utils.py @@ -6,6 +6,7 @@ BACKEND_SESSION_KEY, HASH_SESSION_KEY, authenticate, + get_user_model, ) from django.conf import settings from django.contrib.sessions.backends.base import CreateError @@ -92,3 +93,13 @@ def _process_user_access_profile(user, access_profile_name, func): except Exception as e: logger.exception(e) raise SSOApiException from None + + +def get_user_by_sso_id(sso_id): + user_model = get_user_model() + # Attempt to find a user with the given SSO ID as username + try: + return user_model.objects.get(username=sso_id) + except user_model.DoesNotExist: + # If username doesn't exist fall back to profile sso id. + return user_model.objects.get(profile__sso_id=sso_id) diff --git a/dataworkspace/dataworkspace/apps/api_v1/core/views.py b/dataworkspace/dataworkspace/apps/api_v1/core/views.py index ac51a8fdaf..9df2e83de9 100644 --- a/dataworkspace/dataworkspace/apps/api_v1/core/views.py +++ b/dataworkspace/dataworkspace/apps/api_v1/core/views.py @@ -2,13 +2,13 @@ from urllib.parse import urlparse from django.conf import settings -from django.contrib.auth import get_user_model from django.contrib.contenttypes.models import ContentType from django.core.cache import cache from django.http import HttpResponse, JsonResponse from rest_framework import viewsets from rest_framework.pagination import PageNumberPagination +from dataworkspace.apps.accounts.utils import get_user_by_sso_id from dataworkspace.apps.api_v1.core.serializers import ( UserSatisfactionSurveySerializer, NewsletterSubscriptionSerializer, @@ -62,9 +62,7 @@ def get_superset_credentials(request): ) response = cache.get(cache_key, None) if not response: - dw_user = get_user_model().objects.get( - profile__sso_id=request.headers["sso-profile-user-id"] - ) + dw_user = get_user_by_sso_id(request.headers["sso-profile-user-id"]) if not dw_user.user_permissions.filter( codename="start_all_applications", content_type=ContentType.objects.get_for_model(ApplicationInstance), @@ -131,7 +129,7 @@ def invalidate_superset_user_cached_credentials(): def generate_mlflow_jwt(request): - user = get_user_model().objects.get(profile__sso_id=request.headers["sso-profile-user-id"]) + user = get_user_by_sso_id(request.headers["sso-profile-user-id"]) authorised_hosts = list( user.authorised_mlflow_instances.all().values_list("instance__hostname", flat=True) ) diff --git a/dataworkspace/dataworkspace/apps/applications/utils.py b/dataworkspace/dataworkspace/apps/applications/utils.py index 1dc6d0b26a..5421674b48 100644 --- a/dataworkspace/dataworkspace/apps/applications/utils.py +++ b/dataworkspace/dataworkspace/apps/applications/utils.py @@ -27,6 +27,7 @@ import redis from dataworkspace.apps.accounts.models import Profile +from dataworkspace.apps.accounts.utils import get_user_by_sso_id from dataworkspace.apps.applications.spawner import ( get_spawner, stop, @@ -864,7 +865,7 @@ def sync_quicksight_users(data_client, user_client, account_id, quicksight_user_ raise e - dw_user = get_user_model().objects.get(profile__sso_id=sso_id) + dw_user = get_user_by_sso_id(sso_id) if not dw_user: logger.error( "Skipping %s - cannot match with Data Workspace user.", @@ -1004,8 +1005,7 @@ def create_user_from_sso( ): user_model = get_user_model() try: - # Attempt to find a user with the given SSO ID - user = user_model.objects.get(Q(username=sso_id) | Q(profile__sso_id=sso_id)) + user = get_user_by_sso_id(sso_id) except user_model.DoesNotExist: # If the user doesn't exist we will have to create it user = user_model.objects.create( @@ -1018,7 +1018,7 @@ def create_user_from_sso( user.save() except IntegrityError: # A concurrent request may have overtaken this one and created a user - user = user_model.objects.get(Q(username=sso_id) | Q(profile__sso_id=sso_id)) + user = get_user_by_sso_id(sso_id) _check_tools_access(user) else: @@ -1111,9 +1111,7 @@ def _do_create_tools_access_iam_role(user_id): @close_all_connections_if_not_in_atomic_block def sync_activity_stream_sso_users(): try: - with cache.lock( - "activity_stream_sync_last_published_lock", blocking_timeout=0, timeout=1800 - ): + with cache.lock("sso_sync_last_published_lock", blocking_timeout=0, timeout=1800): _do_sync_activity_stream_sso_users() except redis.exceptions.LockError: logger.info("Unable to acquire lock to sync activity stream sso users") @@ -1736,7 +1734,7 @@ def duplicate_tools_monitor(): @celery_app.task() @close_all_connections_if_not_in_atomic_block def sync_all_sso_users(): - with cache.lock("activity_stream_sync_last_published_lock", blocking_timeout=0, timeout=3600): + with cache.lock("sso_sync_last_published_lock", blocking_timeout=0, timeout=3600): user_model = get_user_model() all_users = user_model.objects.all() seen_user_ids = []