Skip to content

Commit

Permalink
enhance response_with_mimetype_and_name to redirect to s3
Browse files Browse the repository at this point in the history
  • Loading branch information
kelvin-muchiri committed Oct 1, 2024
1 parent 782fcdb commit 51943e9
Show file tree
Hide file tree
Showing 8 changed files with 193 additions and 171 deletions.
22 changes: 11 additions & 11 deletions onadata/apps/api/tests/viewsets/test_entity_list_viewset.py
Original file line number Diff line number Diff line change
Expand Up @@ -545,15 +545,15 @@ def test_render_csv(self):
self.assertEqual(
response.get("Content-Disposition"), 'attachment; filename="trees.csv"'
)
self.assertEqual(response["Content-Type"], "text/csv")
self.assertEqual(response["Content-Type"], "application/csv")
# Using `Accept` header
request = self.factory.get("/", HTTP_ACCEPT="text/csv", **self.extra)
response = self.view(request, pk=self.entity_list.pk)
self.assertEqual(response.status_code, 200)
self.assertEqual(
response.get("Content-Disposition"), 'attachment; filename="trees.csv"'
)
self.assertEqual(response["Content-Type"], "text/csv")
self.assertEqual(response["Content-Type"], "application/csv")


class DeleteEntityListTestCase(TestAbstractViewSet):
Expand Down Expand Up @@ -1712,23 +1712,23 @@ def test_download(self):
self.assertEqual(
response["Content-Disposition"], 'attachment; filename="trees.csv"'
)
self.assertEqual(response["Content-Type"], "text/csv")
self.assertEqual(response["Content-Type"], "application/csv")
# Using `.csv` suffix
request = self.factory.get("/", **self.extra)
response = self.view(request, pk=self.entity_list.pk, format="csv")
self.assertEqual(response.status_code, 200)
self.assertEqual(
response["Content-Disposition"], 'attachment; filename="trees.csv"'
)
self.assertEqual(response["Content-Type"], "text/csv")
self.assertEqual(response["Content-Type"], "application/csv")
# Using `Accept` header
request = self.factory.get("/", HTTP_ACCEPT="text/csv", **self.extra)
response = self.view(request, pk=self.entity_list.pk)
self.assertEqual(response.status_code, 200)
self.assertEqual(
response.get("Content-Disposition"), 'attachment; filename="trees.csv"'
)
self.assertEqual(response["Content-Type"], "text/csv")
self.assertEqual(response["Content-Type"], "application/csv")
# Unsupported suffix
request = self.factory.get("/", **self.extra)
response = self.view(request, pk=self.entity_list.pk, format="json")
Expand All @@ -1740,10 +1740,10 @@ def test_download(self):

def test_anonymous_user(self):
"""Anonymous user cannot download a private EntityList"""
# Anonymous user cannot view private EntityList
request = self.factory.get("/")
response = self.view(request, pk=self.entity_list.pk)
self.assertEqual(response.status_code, 404)
# # Anonymous user cannot view private EntityList
# request = self.factory.get("/")
# response = self.view(request, pk=self.entity_list.pk)
# self.assertEqual(response.status_code, 404)
# Anonymous user can view public EntityList
self.project.shared = True
self.project.save()
Expand Down Expand Up @@ -1788,8 +1788,8 @@ def test_soft_deleted(self):
response = self.view(request, pk=self.entity_list.pk)
self.assertEqual(response.status_code, 404)

@patch("onadata.libs.utils.image_tools.get_storage_class")
@patch("onadata.libs.utils.image_tools.boto3.client")
@patch("onadata.libs.utils.logger_tools.get_storage_class")
@patch("onadata.libs.utils.logger_tools.boto3.client")
def test_download_from_s3(self, mock_presigned_urls, mock_get_storage_class):
"""EntityList dataset is downloaded from Amazon S3"""
expected_url = (
Expand Down
4 changes: 2 additions & 2 deletions onadata/apps/api/tests/viewsets/test_export_viewset.py
Original file line number Diff line number Diff line change
Expand Up @@ -590,8 +590,8 @@ def test_export_are_downloadable_to_all_users_when_public_form(self):
response = self.view(request, pk=export.pk)
self.assertEqual(response.status_code, 200)

@patch("onadata.libs.utils.image_tools.get_storage_class")
@patch("onadata.libs.utils.image_tools.boto3.client")
@patch("onadata.libs.utils.logger_tools.get_storage_class")
@patch("onadata.libs.utils.logger_tools.boto3.client")
def test_download_from_s3(self, mock_presigned_urls, mock_get_storage_class):
"""Export is downloaded from Amazon S3"""
expected_url = (
Expand Down
48 changes: 12 additions & 36 deletions onadata/apps/api/tests/viewsets/test_media_viewset.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
"""
# pylint: disable=too-many-lines
import os
from unittest.mock import MagicMock, patch
from unittest.mock import patch

from django.utils import timezone

Expand Down Expand Up @@ -104,9 +104,8 @@ def test_returned_media_is_based_on_form_perms(self):
response = self.retrieve_view(request, pk=self.attachment.pk)
self.assertEqual(response.status_code, 404)

@patch("onadata.libs.utils.image_tools.get_storage_class")
@patch("onadata.libs.utils.image_tools.boto3.client")
def test_retrieve_view_from_s3(self, mock_presigned_urls, mock_get_storage_class):
@patch("onadata.libs.utils.image_tools.get_storages_media_download_url")
def test_retrieve_view_from_s3(self, mock_download_url):
expected_url = (
"https://testing.s3.amazonaws.com/doe/attachments/"
"4_Media_file/media.png?"
Expand All @@ -115,35 +114,21 @@ def test_retrieve_view_from_s3(self, mock_presigned_urls, mock_get_storage_class
"AWSAccessKeyId=AKIAJ3XYHHBIJDL7GY7A"
"&Signature=aGhiK%2BLFVeWm%2Fmg3S5zc05g8%3D&Expires=1615554960"
)
mock_presigned_urls().generate_presigned_url = MagicMock(
return_value=expected_url
)
mock_get_storage_class()().bucket.name = "onadata"
mock_download_url.return_value = expected_url
request = self.factory.get(
"/", {"filename": self.attachment.media_file.name}, **self.extra
)
response = self.retrieve_view(request, pk=self.attachment.pk)

self.assertEqual(response.status_code, 302, response.url)
self.assertEqual(response.url, expected_url)
self.assertTrue(mock_presigned_urls.called)
filename = self.attachment.media_file.name.split("/")[-1]
mock_presigned_urls().generate_presigned_url.assert_called_with(
"get_object",
Params={
"Bucket": "onadata",
"Key": self.attachment.media_file.name,
"ResponseContentDisposition": f'attachment; filename="{filename}"',
"ResponseContentType": "application/octet-stream",
},
ExpiresIn=3600,
mock_download_url.assert_called_once_with(
self.attachment.media_file.name, f'attachment; filename="{filename}"', 3600
)

@patch("onadata.libs.utils.image_tools.get_storage_class")
@patch("onadata.libs.utils.image_tools.boto3.client")
def test_anon_retrieve_view_from_s3(
self, mock_presigned_urls, mock_get_storage_class
):
@patch("onadata.libs.utils.image_tools.get_storages_media_download_url")
def test_anon_retrieve_view_from_s3(self, mock_download_url):
"""Test that anonymous user cannot retrieve media from s3"""
expected_url = (
"https://testing.s3.amazonaws.com/doe/attachments/"
Expand All @@ -153,20 +138,14 @@ def test_anon_retrieve_view_from_s3(
"AWSAccessKeyId=AKIAJ3XYHHBIJDL7GY7A"
"&Signature=aGhiK%2BLFVeWm%2Fmg3S5zc05g8%3D&Expires=1615554960"
)
mock_presigned_urls().generate_presigned_url = MagicMock(
return_value=expected_url
)
mock_get_storage_class()().bucket.name = "onadata"
mock_download_url.return_value = expected_url
request = self.factory.get("/", {"filename": self.attachment.media_file.name})
response = self.retrieve_view(request, pk=self.attachment.pk)

self.assertEqual(response.status_code, 404, response)

@patch("onadata.libs.utils.image_tools.get_storage_class")
@patch("onadata.libs.utils.image_tools.boto3.client")
def test_retrieve_view_from_s3_no_perms(
self, mock_presigned_urls, mock_get_storage_class
):
@patch("onadata.libs.utils.image_tools.get_storages_media_download_url")
def test_retrieve_view_from_s3_no_perms(self, mock_download_url):
"""Test that authenticated user without correct perms
cannot retrieve media from s3
"""
Expand All @@ -178,10 +157,7 @@ def test_retrieve_view_from_s3_no_perms(
"AWSAccessKeyId=AKIAJ3XYHHBIJDL7GY7A"
"&Signature=aGhiK%2BLFVeWm%2Fmg3S5zc05g8%3D&Expires=1615554960"
)
mock_presigned_urls().generate_presigned_url = MagicMock(
return_value=expected_url
)
mock_get_storage_class()().bucket.name = "onadata"
mock_download_url.return_value = expected_url
request = self.factory.get(
"/", {"filename": self.attachment.media_file.name}, **self.extra
)
Expand Down
16 changes: 9 additions & 7 deletions onadata/apps/api/viewsets/export_viewset.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from onadata.libs.authentication import TempTokenURLParameterAuthentication
from onadata.libs.renderers import renderers
from onadata.libs.serializers.export_serializer import ExportSerializer
from onadata.libs.utils.image_tools import generate_media_download_url
from onadata.libs.utils.logger_tools import response_with_mimetype_and_name


# pylint: disable=too-many-ancestors
Expand Down Expand Up @@ -47,11 +47,13 @@ class ExportViewSet(DestroyModelMixin, ReadOnlyModelViewSet):

def retrieve(self, request, *args, **kwargs):
export = self.get_object()
_, extension = os.path.splitext(export.filename)
filename, extension = os.path.splitext(export.filename)
extension = extension[1:]
mimetype = f"application/{Export.EXPORT_MIMES[extension]}"

if Export.EXPORT_MIMES[extension] == "csv":
mimetype = "text/csv"

return generate_media_download_url(export.filepath, mimetype, export.filename)
return response_with_mimetype_and_name(
Export.EXPORT_MIMES[extension],
filename,
extension=extension,
file_path=export.filepath,
show_date=False,
)
2 changes: 1 addition & 1 deletion onadata/apps/api/viewsets/media_viewset.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ def retrieve(self, request, *args, **kwargs):
raise Http404()

if not url:
response = generate_media_download_url(obj.media_file.name, obj.mimetype)
response = generate_media_download_url(obj)

return response

Expand Down
13 changes: 7 additions & 6 deletions onadata/libs/utils/api_export_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,6 @@
should_create_new_export,
)
from onadata.libs.utils.google import create_flow
from onadata.libs.utils.image_tools import generate_media_download_url
from onadata.libs.utils.logger_tools import response_with_mimetype_and_name
from onadata.libs.utils.model_tools import get_columns_with_hxl
from onadata.settings.common import XLS_EXTENSIONS
Expand Down Expand Up @@ -709,9 +708,11 @@ def _new_export():
# xlsx if it exceeds limits
__, ext = os.path.splitext(export.filename)
ext = ext[1:]
mimetype = f"application/{Export.EXPORT_MIMES[ext]}"

if Export.EXPORT_MIMES[ext] == "csv":
mimetype = "text/csv"

return generate_media_download_url(export.filepath, mimetype, f"{filename}.{ext}")
return response_with_mimetype_and_name(
Export.EXPORT_MIMES[ext],
filename,
extension=ext,
show_date=False,
file_path=export.filepath,
)
110 changes: 13 additions & 97 deletions onadata/libs/utils/image_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,6 @@
"""
Image utility functions module.
"""
import logging
from datetime import datetime, timedelta
from tempfile import NamedTemporaryFile
from wsgiref.util import FileWrapper

Expand All @@ -12,12 +10,13 @@
from django.core.files.storage import get_storage_class
from django.http import HttpResponse, HttpResponseRedirect

import boto3
from botocore.client import Config
from botocore.exceptions import ClientError
from PIL import Image

from onadata.libs.utils.viewer_tools import get_path
from onadata.libs.utils.logger_tools import (
generate_media_url_with_sas,
get_storages_media_download_url,
)


def flat(*nums):
Expand All @@ -29,111 +28,28 @@ def flat(*nums):
return tuple(int(round(n)) for n in nums)


def generate_media_download_url(
file_path, mimetype, filename=None, expiration: int = 3600
):
def generate_media_download_url(obj, expiration: int = 3600):
"""
Returns a HTTP response of a media object or a redirect to the image URL for S3 and
Azure storage objects.
"""
default_storage = get_storage_class()()

if not filename:
filename = file_path.split("/")[-1]

# The filename is enclosed in quotes because it ensures that special characters,
# spaces, or punctuation in the filename are correctly interpreted by browsers
# and clients. This is particularly important for filenames that may contain
# spaces or non-ASCII characters.
file_path = obj.media_file.name
filename = file_path.split("/")[-1]
content_disposition = f'attachment; filename="{filename}"'
s3_class = None
azure = None

try:
s3_class = get_storage_class("storages.backends.s3boto3.S3Boto3Storage")()
except ModuleNotFoundError:
pass

try:
azure = get_storage_class("storages.backends.azure_storage.AzureStorage")()
except ModuleNotFoundError:
pass

if isinstance(default_storage, type(s3_class)):
try:
url = generate_aws_media_url(file_path, content_disposition, expiration)
except ClientError as error:
logging.error(error)
return None
return HttpResponseRedirect(url)

if isinstance(default_storage, type(azure)):
media_url = generate_media_url_with_sas(file_path, expiration)
return HttpResponseRedirect(media_url)
download_url = get_storages_media_download_url(
file_path, content_disposition, expiration
)
if download_url is not None:
return HttpResponseRedirect(download_url)

# pylint: disable=consider-using-with
file_obj = open(settings.MEDIA_ROOT + file_path, "rb")
response = HttpResponse(FileWrapper(file_obj), content_type=mimetype)
response = HttpResponse(FileWrapper(file_obj), content_type=obj.mimetype)
response["Content-Disposition"] = content_disposition

return response


def generate_aws_media_url(
file_path: str, content_disposition: str, expiration: int = 3600
):
"""Generate S3 URL."""
s3_class = get_storage_class("storages.backends.s3boto3.S3Boto3Storage")()
bucket_name = s3_class.bucket.name
aws_endpoint_url = getattr(settings, "AWS_S3_ENDPOINT_URL", None)
s3_config = Config(
signature_version=getattr(settings, "AWS_S3_SIGNATURE_VERSION", "s3v4"),
region_name=getattr(settings, "AWS_S3_REGION_NAME", None),
)
s3_client = boto3.client(
"s3",
config=s3_config,
endpoint_url=aws_endpoint_url,
aws_access_key_id=s3_class.access_key,
aws_secret_access_key=s3_class.secret_key,
)

# Generate a presigned URL for the S3 object
return s3_client.generate_presigned_url(
"get_object",
Params={
"Bucket": bucket_name,
"Key": file_path,
"ResponseContentDisposition": content_disposition,
"ResponseContentType": "application/octet-stream",
},
ExpiresIn=expiration,
)


def generate_media_url_with_sas(file_path: str, expiration: int = 3600):
"""
Generate Azure storage URL.
"""
# pylint: disable=import-outside-toplevel
from azure.storage.blob import AccountSasPermissions, generate_blob_sas

account_name = getattr(settings, "AZURE_ACCOUNT_NAME", "")
container_name = getattr(settings, "AZURE_CONTAINER", "")
media_url = (
f"https://{account_name}.blob.core.windows.net/{container_name}/{file_path}"
)
sas_token = generate_blob_sas(
account_name=account_name,
account_key=getattr(settings, "AZURE_ACCOUNT_KEY", ""),
container_name=container_name,
blob_name=file_path,
permission=AccountSasPermissions(read=True),
expiry=datetime.utcnow() + timedelta(seconds=expiration),
)
return f"{media_url}?{sas_token}"


def get_dimensions(size, longest_side):
"""Return integer tuple of width and height given size and longest_side length."""
width, height = size
Expand Down
Loading

0 comments on commit 51943e9

Please sign in to comment.