From 402a9a97f231fc52ea82bc78e5f51d3469f766b2 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Bodor=20M=C3=A1t=C3=A9?= <bodor.mate@kszk.bme.hu>
Date: Mon, 14 Jan 2019 19:15:41 +0100
Subject: [PATCH] Add role base permissions

---
 src/account/views.py                          |  5 ++-
 src/common/permissions.py                     | 37 ++++++++++++-------
 src/document/views.py                         |  2 +-
 src/homework/views.py                         |  6 +--
 .../migrations/0006_auto_20190114_1913.py     | 18 +++++++++
 src/stats/models.py                           |  1 -
 src/stats/serializers.py                      |  4 --
 7 files changed, 48 insertions(+), 25 deletions(-)
 create mode 100644 src/stats/migrations/0006_auto_20190114_1913.py

diff --git a/src/account/views.py b/src/account/views.py
index 79ab7de..c2ad4ec 100644
--- a/src/account/views.py
+++ b/src/account/views.py
@@ -2,6 +2,7 @@ from rest_framework import viewsets
 from rest_framework import permissions
 from rest_framework.response import Response
 from rest_framework.decorators import list_route
+from common.permissions import IsSafeOrPatch
 
 from . import models
 from . import serializers
@@ -9,11 +10,11 @@ from . import serializers
 
 class ProfileViewSet(viewsets.ModelViewSet):
     serializer_class = serializers.ProfileSerializer
-    permission_classes = (permissions.IsAuthenticated, )
+    permission_classes = (permissions.IsAuthenticated, IsSafeOrPatch)
 
     def get_queryset(self):
         user = self.request.user
-        if user.has_perm(permissions.IsAdminUser):
+        if user.profile.role == 'Staff':
             role = self.request.query_params.get("role", None)
             if role is not None:
                 return models.Profile.objects.filter(role=role)
diff --git a/src/common/permissions.py b/src/common/permissions.py
index 5aac2f0..beaae6f 100644
--- a/src/common/permissions.py
+++ b/src/common/permissions.py
@@ -3,27 +3,36 @@ from rest_framework.permissions import SAFE_METHODS
 
 
 class IsStaffOrReadOnly(BasePermission):
-    """
-    The request is authenticated as a staff, or is a read-only request.
-    """
-
     def has_permission(self, request, view):
-        return request.method in SAFE_METHODS or request.user and request.user.is_staff
+        return request.method in SAFE_METHODS or\
+               (request.user.is_authenticated and request.user.profile.role == 'Staff')
 
 
 class IsStaffOrReadOnlyForAuthenticated(BasePermission):
-    """
-    The request is authenticated as a staff, or is a read-only request for authenticated.
-    """
-
     def has_permission(self, request, view):
-        return request.user.is_staff or request.method in SAFE_METHODS and request.user.is_authenticated
+        return request.user.is_authenticated and\
+               (request.method in SAFE_METHODS or request.user.profile.role == 'Staff')
 
 
 class IsStaffUser(BasePermission):
-    """
-    The request is authenticated as a staff
-    """
+    def has_permission(self, request, view):
+        return request.user.is_authenticated and request.user.profile.role == 'Staff'
+
+
+class IsSafeOrPatch(BasePermission):
+    def has_permission(self, request, view):
+        return request.method in SAFE_METHODS or request.method == 'PATCH'
+
+
+class IsStaffOrStudent(BasePermission):
+    def has_permission(self, request, view):
+        return request.user.is_authenticated and\
+               (request.user.profile.role == 'Staff' or request.user.profile.role == 'Student')
+
 
+class StudentJustCreate(BasePermission):
     def has_permission(self, request, view):
-        return request.user.is_staff
+        if request.user.is_authenticated and request.user.profile.role == 'Staff':
+            return True
+        return request.user.is_authenticated and request.user.profile.role == 'Student' and\
+               (request.method in SAFE_METHODS or request.method == 'CREATE')
diff --git a/src/document/views.py b/src/document/views.py
index 70e78a0..6a80b18 100644
--- a/src/document/views.py
+++ b/src/document/views.py
@@ -8,4 +8,4 @@ from . import serializers
 class DocumentViewSet(viewsets.ModelViewSet):
     queryset = models.Document.objects.all()
     serializer_class = serializers.DocumentSerializer
-    permission_classes = (permissions.IsStaffOrReadOnly, )
+    permission_classes = (permissions.IsStaffOrStudent, )
diff --git a/src/homework/views.py b/src/homework/views.py
index d0762cc..2570786 100755
--- a/src/homework/views.py
+++ b/src/homework/views.py
@@ -1,9 +1,9 @@
 from rest_framework import viewsets
 
-from common import permissions
 from rest_framework.permissions import IsAuthenticated
 from . import serializers
 from . import models
+from common import permissions
 
 
 class TasksViewSet(viewsets.ModelViewSet):
@@ -14,12 +14,12 @@ class TasksViewSet(viewsets.ModelViewSet):
 
 class SolutionsViewSet(viewsets.ModelViewSet):
     serializer_class = serializers.SolutionSerializer
-    permission_classes = (IsAuthenticated, )
+    permission_classes = (permissions.IsStaffOrStudent, permissions.StudentJustCreate)
 
     def get_queryset(self):
         user = self.request.user
         queryset = models.Solution.objects.filter(created_by=user.profile)
-        if user.has_perm(permissions.IsStaffUser):
+        if user.profile.role == 'Staff':
             queryset = models.Solution.objects.all()
             profile_id = self.request.query_params.get('profileID', None)
             if profile_id is not None:
diff --git a/src/stats/migrations/0006_auto_20190114_1913.py b/src/stats/migrations/0006_auto_20190114_1913.py
new file mode 100644
index 0000000..545336b
--- /dev/null
+++ b/src/stats/migrations/0006_auto_20190114_1913.py
@@ -0,0 +1,18 @@
+# Generated by Django 2.0.1 on 2019-01-14 18:13
+
+from django.db import migrations, models
+
+
+class Migration(migrations.Migration):
+
+    dependencies = [
+        ('stats', '0005_auto_20190114_1713'),
+    ]
+
+    operations = [
+        migrations.AlterField(
+            model_name='event',
+            name='visitors',
+            field=models.ManyToManyField(blank=True, related_name='events', to='account.Profile'),
+        ),
+    ]
diff --git a/src/stats/models.py b/src/stats/models.py
index c408a7b..ef56adf 100644
--- a/src/stats/models.py
+++ b/src/stats/models.py
@@ -12,7 +12,6 @@ class Event(models.Model):
         Profile,
         related_name='events',
         blank=True,
-        null=True,
     )
     created_by = models.ForeignKey(
         Profile,
diff --git a/src/stats/serializers.py b/src/stats/serializers.py
index 5f16191..d18167b 100644
--- a/src/stats/serializers.py
+++ b/src/stats/serializers.py
@@ -6,7 +6,6 @@ from . import models
 class EventSerializer(serializers.ModelSerializer):
     created_by_name = serializers.SerializerMethodField()
     visitor_number = serializers.SerializerMethodField()
-    # visitors = serializers.SerializerMethodField()
 
     class Meta:
         model = models.Event
@@ -19,9 +18,6 @@ class EventSerializer(serializers.ModelSerializer):
     def get_visitor_number(self, obj):
         return obj.visitors.all().count()
 
-    # def get_visitors(self, obj):
-    #     return obj.visitors.all()
-
 
 class NoteSerializer(serializers.ModelSerializer):
     created_by = serializers.HiddenField(default=CurrentUserProfileDefault())
-- 
GitLab