Skip to content

Commit

Permalink
Refactor get_application to return None instead of 400
Browse files Browse the repository at this point in the history
The get_application method has been refactored to return None when the application doesn't exist or the client_id is missing. The check to determine if the application is not found has been moved to ConvertTokenSerializer and RevokeTokenSerializer methods. Corresponding test cases have also been updated to reflect the changes.
  • Loading branch information
wagnerdelima committed Jul 11, 2024
1 parent 12ef285 commit e0b8352
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 20 deletions.
28 changes: 16 additions & 12 deletions drf_social_oauth2/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@
logger = logging.getLogger(__package__)


def get_application_or_400(validated_data: dict) -> Application:
def get_application(validated_data: dict) -> Application:
"""
:param validated_data: A dictionary containing the request validated data.
:return: An Application object.
Expand All @@ -54,20 +54,13 @@ def get_application_or_400(validated_data: dict) -> Application:
"""
client_id = validated_data.get('client_id')

# Check if a client_id was provided
if not client_id:
return Response(
data={'invalid_client': 'Missing client_id.'},
status=HTTP_400_BAD_REQUEST,
)
return None

try:
application = Application.objects.get(client_id=client_id)
except Application.DoesNotExist:
return Response(
data={'invalid_client': 'Invalid client_id.'},
status=HTTP_400_BAD_REQUEST,
)
return None
return application


Expand Down Expand Up @@ -143,7 +136,12 @@ def post(self, request: Request, *args, **kwargs):
serializer = ConvertTokenSerializer(data=request.data)
serializer.is_valid(raise_exception=True)

application = get_application_or_400(serializer.validated_data)
application = get_application(serializer.validated_data)
if not application:
return Response(
{"detail": "The application for this client_id does not exist."},
status=HTTP_400_BAD_REQUEST,
)
# Use the rest framework `.data` to fake the post body of the django request.
request._request.POST = request._request.POST.copy()
request._request.POST['client_secret'] = application.client_secret
Expand Down Expand Up @@ -218,7 +216,13 @@ def post(self, request: Request, *args, **kwargs):
auth_header = auth_header.replace('Bearer ', '', 1)
serializer = RevokeTokenSerializer(data=request.data)
serializer.is_valid(raise_exception=True)
application = get_application_or_400(serializer.validated_data)

application = get_application(serializer.validated_data)
if not application:
return Response(
{"detail": "The application for this client_id does not exist."},
status=HTTP_400_BAD_REQUEST,
)

# Use the rest framework `.data` to fake the post body of the django request.
request._request.POST = request._request.POST.copy()
Expand Down
14 changes: 6 additions & 8 deletions tests/drf_social_oauth2/test_views.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from oauth2_provider.models import RefreshToken, AccessToken, Application
from model_bakery.recipe import Recipe

from drf_social_oauth2.views import get_application_or_400
from drf_social_oauth2.views import get_application
from tests.drf_social_oauth2.drf_fixtures import application, user, save


Expand Down Expand Up @@ -46,20 +46,18 @@ def test_get_application(application):
"""
# Test get_application with the correct client_id
valid_data = {'client_id': 'id'}
result = get_application_or_400(valid_data)
result = get_application(valid_data)
assert result == application

# Test get_application with an incorrect client_id
invalid_data = {'client_id': 'wrong_client'}
result = get_application_or_400(invalid_data)
assert result.status_code == HTTP_400_BAD_REQUEST
assert result.data == {'invalid_client': 'Invalid client_id.'}
result = get_application(invalid_data)
assert not result

# Test get_application with no client_id
empty_data = {}
result = get_application_or_400(empty_data)
assert result.status_code == HTTP_400_BAD_REQUEST
assert result.data == {'invalid_client': 'Missing client_id.'}
result = get_application(empty_data)
assert not result


def test_convert_token_endpoint_with_no_post_params(client_api):
Expand Down

0 comments on commit e0b8352

Please sign in to comment.