diff --git a/contentcuration/contentcuration/tests/viewsets/test_invitation.py b/contentcuration/contentcuration/tests/viewsets/test_invitation.py index f044f50a99..554a4f4942 100644 --- a/contentcuration/contentcuration/tests/viewsets/test_invitation.py +++ b/contentcuration/contentcuration/tests/viewsets/test_invitation.py @@ -446,3 +446,27 @@ def test_update_invitation_decline(self): ).exists() ) self.assertTrue(models.Change.objects.filter(channel=self.channel).exists()) + + def test_accept_invitation_by_non_invitee_is_forbidden(self): + invitation = models.Invitation.objects.create(**self.invitation_db_metadata) + + # self.user is a channel editor, not the invited user + self.client.force_authenticate(user=self.user) + response = self.client.post( + reverse("invitation-accept", kwargs={"pk": invitation.id}) + ) + self.assertEqual(response.status_code, 403, response.content) + invitation.refresh_from_db() + self.assertFalse(invitation.accepted) + + def test_decline_invitation_by_non_invitee_is_forbidden(self): + invitation = models.Invitation.objects.create(**self.invitation_db_metadata) + + # self.user is a channel editor, not the invited user + self.client.force_authenticate(user=self.user) + response = self.client.post( + reverse("invitation-decline", kwargs={"pk": invitation.id}) + ) + self.assertEqual(response.status_code, 403, response.content) + invitation.refresh_from_db() + self.assertFalse(invitation.declined) diff --git a/contentcuration/contentcuration/viewsets/invitation.py b/contentcuration/contentcuration/viewsets/invitation.py index 7d8ff577f6..a9aa1f499c 100644 --- a/contentcuration/contentcuration/viewsets/invitation.py +++ b/contentcuration/contentcuration/viewsets/invitation.py @@ -2,6 +2,7 @@ from django_filters.rest_framework import FilterSet from rest_framework import serializers from rest_framework.decorators import action +from rest_framework.exceptions import PermissionDenied from rest_framework.permissions import IsAuthenticated from rest_framework.response import Response @@ -137,9 +138,14 @@ def perform_update(self, serializer): instance = serializer.save() instance.save() + def _ensure_invitee(self, request, invitation): + if (request.user.email or "").lower() != (invitation.email or "").lower(): + raise PermissionDenied("Only the invited user may perform this action.") + @action(detail=True, methods=["post"]) def accept(self, request, pk=None): invitation = self.get_object() + self._ensure_invitee(request, invitation) invitation.accept() invitation.accepted = True invitation.save() @@ -158,6 +164,7 @@ def accept(self, request, pk=None): @action(detail=True, methods=["post"]) def decline(self, request, pk=None): invitation = self.get_object() + self._ensure_invitee(request, invitation) invitation.declined = True invitation.save() Change.create_change(