Skip to content

Commit

Permalink
Merge branch 'ziotom78:master' into master
Browse files Browse the repository at this point in the history
  • Loading branch information
ziotom78 authored Mar 15, 2024
2 parents d1306bb + 68fc0ba commit 1a7b2fd
Show file tree
Hide file tree
Showing 7 changed files with 138 additions and 19 deletions.
12 changes: 12 additions & 0 deletions .readthedocs.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
# Read the Docs configuration file for Sphinx projects
# See https://docs.readthedocs.io/en/stable/config-file/v2.html for details

version: 2

build:
os: ubuntu-22.04
tools:
python: "3.11"

sphinx:
configuration: docs/conf.py
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
# HEAD

# Version 2.0.3

- Add RESTful API endpoints for downloading files [#131](https://github.com/ziotom78/instrumentdb/pull/131)

# Version 2.0.2

- Make sure that unauthorized users cannot browse the contents of the database [#128](https://github.com/ziotom78/instrumentdb/pull/128)
Expand Down
25 changes: 14 additions & 11 deletions browse/serializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ def to_representation(self, instance):
instance
)
representation["download_link"] = reverse(
"format-spec-download-view",
"formatspecification-download",
kwargs={"pk": instance.uuid},
request=self.context["request"],
)
Expand Down Expand Up @@ -177,16 +177,19 @@ class Meta:
def to_representation(self, instance):
representation = super(DataFileSerializer, self).to_representation(instance)

representation["download_link"] = reverse(
"data-file-download-view",
kwargs={"pk": instance.uuid},
request=self.context["request"],
)
representation["plot_download_link"] = reverse(
"data-file-plot-view",
kwargs={"pk": instance.uuid},
request=self.context["request"],
)
if instance.file_data:
representation["download_link"] = reverse(
"datafile-download",
kwargs={"pk": instance.uuid},
request=self.context["request"],
)

if instance.plot_file:
representation["plot_download_link"] = reverse(
"datafile-plot",
kwargs={"pk": instance.uuid},
request=self.context["request"],
)

return representation

Expand Down
59 changes: 56 additions & 3 deletions browse/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,16 +10,16 @@
from django.contrib.auth import authenticate
from django.contrib.auth.decorators import login_required
from django.contrib.auth.mixins import LoginRequiredMixin
from django.http import HttpResponse, Http404
from django.http import HttpResponse, Http404, FileResponse
from django.views.generic.base import View
from django.views.generic.detail import DetailView
from django.views.generic.list import ListView
from django.contrib.auth.models import User, Group
from django.shortcuts import render, redirect, get_object_or_404

from rest_framework import viewsets, permissions, status
from rest_framework import viewsets, permissions, status, renderers
from rest_framework.authentication import SessionAuthentication
from rest_framework.decorators import api_view, permission_classes
from rest_framework.decorators import api_view, permission_classes, action
from rest_framework.pagination import PageNumberPagination

import instrumentdb
Expand Down Expand Up @@ -255,6 +255,19 @@ def get(self, request, pk):
# REST API


# This is used to create an API endpoint to download files
class PassthroughRenderer(renderers.BaseRenderer):
"""
Return data as-is. View should supply a Response.
"""

media_type = ""
format = ""

def render(self, data, accepted_media_type=None, renderer_context=None):
return data


class UserViewSet(viewsets.ModelViewSet):
queryset = User.objects.all()
serializer_class = UserSerializer
Expand Down Expand Up @@ -293,6 +306,20 @@ def get_permissions(self):
queryset = FormatSpecification.objects.all()
serializer_class = FormatSpecificationSerializer

@action(methods=["get"], detail=True, renderer_classes=(PassthroughRenderer,))
def download(self, request, *args, **kwargs):
instance = self.get_object()
if not instance.doc_file:
raise Http404()

file_handle = instance.doc_file.open()
response = FileResponse(file_handle, content_type=instance.doc_mime_type)
response["Content-Length"] = instance.doc_file.size
response[
"Content-Disposition"
] = f'attachment; filename="{instance.doc_file_name}"'
return response


class EntityViewSet(viewsets.ModelViewSet):
queryset = Entity.objects.all()
Expand Down Expand Up @@ -343,6 +370,32 @@ def get_permissions(self):
return [permissions.IsAdminUser()]
return [permissions.IsAuthenticated()]

@action(methods=["get"], detail=True, renderer_classes=(PassthroughRenderer,))
def download(self, request, *args, **kwargs):
instance = self.get_object()
if not instance.file_data:
raise Http404()

file_handle = instance.file_data.open()
response = FileResponse(
file_handle, content_type=instance.quantity.format_spec.file_mime_type
)
response["Content-Length"] = instance.file_data.size
response["Content-Disposition"] = f'attachment; filename="{instance.name}"'
return response

@action(methods=["get"], detail=True, renderer_classes=(PassthroughRenderer,))
def plot(self, request, *args, **kwargs):
instance = self.get_object()
if not instance.plot_file_data:
raise Http404()

file_handle = instance.file_data.open()
response = FileResponse(file_handle, content_type=instance.plot_mime_type)
response["Content-Length"] = instance.file_data.size
response["Content-Disposition"] = f'attachment; filename="{instance.name}"'
return response

queryset = DataFile.objects.all()
serializer_class = DataFileSerializer

Expand Down
2 changes: 1 addition & 1 deletion instrumentdb/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = "2.0.2"
__version__ = "2.0.3"
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "instrumentdb"
version = "2.0.2"
version = "2.0.3"
description = ""
authors = ["Maurizio Tomasi <[email protected]>"]

Expand Down
53 changes: 50 additions & 3 deletions tests/test_webapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

from django.urls import reverse
from rest_framework import status
from rest_framework.test import APITestCase
from rest_framework.test import APITestCase, APIRequestFactory
from browse.models import (
FormatSpecification,
Entity,
Expand All @@ -15,6 +15,7 @@
)
from django.contrib.auth.models import User

from browse.views import DataFileViewSet

TEST_ACCOUNT_EMAIL = "test@localhost"
TEST_ACCOUNT_USER = "test_user"
Expand Down Expand Up @@ -50,8 +51,6 @@ def create_format_spec(client, document_ref):
"document_ref": document_ref,
"title": "My dummy document",
"file_mime_type": "application/text",
},
files={
"doc_file": format_spec_file,
},
)
Expand Down Expand Up @@ -157,6 +156,23 @@ def test_create_format_spec(self):
FormatSpecification.objects.get().document_ref, "DUMMY_REF_001"
)

def test_download_format_specification(self):
"""
Ensure we can create a new DataFile object.
"""
response = create_format_spec(self.client, "DUMMY_REF_001")

# Now ask for a JSON representation of the object
json_dict = self.client.get(response.data["url"]).json()

response = self.client.get(json_dict["download_link"])
self.assertEqual(response.status_code, 200)
self.assertEqual(response["Content-Type"], "text/plain")

expected_content = b"Test file"
actual_content = b"".join(chunk for chunk in response.streaming_content)
self.assertEqual(actual_content, expected_content)


class EntityTests(APITestCase):
def setUp(self) -> None:
Expand Down Expand Up @@ -344,6 +360,28 @@ def test_create_datafile_with_wrong_metadata(self):
# Check the result of the POST call
self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST)

def test_download_datafile(self):
"""
Ensure we can create a new DataFile object.
"""
response = create_data_file_spec(
self.client,
name="test_datafile",
metadata={"a": 10, "b": "hello"},
quantity=self.quantity_response.data["url"],
)

# Now ask for a JSON representation of the object
json = self.client.get(response.data["url"]).json()

response = self.client.get(json["download_link"])
self.assertEqual(response.status_code, 200)
self.assertEqual(response["Content-Type"], "application/text")

expected_content = b"1,2,3,4,5"
actual_content = b"".join(chunk for chunk in response.streaming_content)
self.assertEqual(actual_content, expected_content)


class ReleaseTests(APITestCase):
def setUp(self):
Expand Down Expand Up @@ -465,3 +503,12 @@ def testDenyCreationOfDataFile(self):
quantity=self.quantity_response.data["url"],
)
self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN)


def test_unauthenticated_access(self):
view = DataFileViewSet.as_view({"get": "list"})
factory = APIRequestFactory()
request = factory.get("/data-files/")

response = view(request)
self.assertEqual(response.status_code, 401) # Assert unauthorized access

0 comments on commit 1a7b2fd

Please sign in to comment.