翼度科技»论坛 编程开发 python 查看内容

基于rest_framework的ModelViewSet类编写登录视图和认证视图

4

主题

4

帖子

12

积分

新手上路

Rank: 1

积分
12
背景:看了博主一抹浅笑的rest_framework认证模板,发现登录视图函数是基于APIView类封装。
优化:使用ModelViewSet类通过重写create方法编写登录函数。
环境:既然接触到rest_framework的使用,相信已经搭建好相关环境了。
1 建立模型

编写模型类
  1. # models.py
  2. from django.db import models
  3. class User(models.Model):
  4.     username = models.CharField(verbose_name='用户名称',unique=True,max_length=16)
  5.     password = models.CharField(verbose_name='登陆密码',max_length=16)
  6. class Token(models.Model):
  7.     username = models.CharField(verbose_name='用户名称',unique=True,max_length=16)
  8.     token = models.CharField(verbose_name='验证密钥',max_length=32)
复制代码
生成迁移文件
  1. python manage.py makemigrations
复制代码
迁移数据模型
  1. python manage.py migrate
复制代码
2 确定需要重写的方法

查看ModelViewSet类源码
  1. '''
  2. class ModelViewSet(mixins.CreateModelMixin,
  3.                    mixins.RetrieveModelMixin,
  4.                    mixins.UpdateModelMixin,
  5.                    mixins.DestroyModelMixin,
  6.                    mixins.ListModelMixin,
  7.                    GenericViewSet):
  8.     """
  9.     A viewset that provides default `create()`, `retrieve()`, `update()`,
  10.     `partial_update()`, `destroy()` and `list()` actions.
  11.     """
  12.     pass
  13. '''
复制代码
最终目的是往Token模型对应的表添加数据,所以得选择CreateModelMixin模型的源码查看。
  1. '''
  2. class CreateModelMixin:
  3.     """
  4.     Create a model instance.
  5.     """
  6.     def create(self, request, *args, **kwargs):
  7.         serializer = self.get_serializer(data=request.data)
  8.         serializer.is_valid(raise_exception=True)
  9.         self.perform_create(serializer)
  10.         headers = self.get_success_headers(serializer.data)
  11.         return Response(serializer.data, status=status.HTTP_201_CREATED, headers=headers)
  12.     def perform_create(self, serializer):
  13.         serializer.save()
  14.     def get_success_headers(self, data):
  15.         try:
  16.             return {'Location': str(data[api_settings.URL_FIELD_NAME])}
  17.         except (TypeError, KeyError):
  18.             return {}
  19. '''
复制代码
查看得知,CreateModelMixin类下的create方法调用了serializer类的save方法创建数据。继续查看save方法。
通过serializers.ModelSerializer定位到serializers.py文件,搜索'def save('定位到以下内容。
  1. '''
  2.     def save(self, **kwargs):
  3.         assert hasattr(self, '_errors'), (
  4.             'You must call `.is_valid()` before calling `.save()`.'
  5.         )
  6.         assert not self.errors, (
  7.             'You cannot call `.save()` on a serializer with invalid data.'
  8.         )
  9.         # Guard against incorrect use of `serializer.save(commit=False)`
  10.         assert 'commit' not in kwargs, (
  11.             "'commit' is not a valid keyword argument to the 'save()' method. "
  12.             "If you need to access data before committing to the database then "
  13.             "inspect 'serializer.validated_data' instead. "
  14.             "You can also pass additional keyword arguments to 'save()' if you "
  15.             "need to set extra attributes on the saved model instance. "
  16.             "For example: 'serializer.save(owner=request.user)'.'"
  17.         )
  18.         assert not hasattr(self, '_data'), (
  19.             "You cannot call `.save()` after accessing `serializer.data`."
  20.             "If you need to access data before committing to the database then "
  21.             "inspect 'serializer.validated_data' instead. "
  22.         )
  23.         validated_data = {**self.validated_data, **kwargs}
  24.         if self.instance is not None:
  25.             self.instance = self.update(self.instance, validated_data)
  26.             assert self.instance is not None, (
  27.                 '`update()` did not return an object instance.'
  28.             )
  29.         else:
  30.             self.instance = self.create(validated_data)
  31.             assert self.instance is not None, (
  32.                 '`create()` did not return an object instance.'
  33.             )
  34. '''
复制代码
看最后这个if……else……语句中的self.instance = self.create(validated_data)。
说明这里调用了create方法,返回一个模型对象。于是查看ModelSerializer类的create方法。
  1. '''
  2.     def create(self, validated_data):
  3.         """
  4.         We have a bit of extra checking around this in order to provide
  5.         descriptive messages when something goes wrong, but this method is
  6.         essentially just:
  7.             return ExampleModel.objects.create(**validated_data)
  8.         If there are many to many fields present on the instance then they
  9.         cannot be set until the model is instantiated, in which case the
  10.         implementation is like so:
  11.             example_relationship = validated_data.pop('example_relationship')
  12.             instance = ExampleModel.objects.create(**validated_data)
  13.             instance.example_relationship = example_relationship
  14.             return instance
  15.         The default implementation also does not handle nested relationships.
  16.         If you want to support writable nested relationships you'll need
  17.         to write an explicit `.create()` method.
  18.         """
  19.         raise_errors_on_nested_writes('create', self, validated_data)
  20.         ModelClass = self.Meta.model
  21.         # Remove many-to-many relationships from validated_data.
  22.         # They are not valid arguments to the default `.create()` method,
  23.         # as they require that the instance has already been saved.
  24.         info = model_meta.get_field_info(ModelClass)
  25.         many_to_many = {}
  26.         for field_name, relation_info in info.relations.items():
  27.             if relation_info.to_many and (field_name in validated_data):
  28.                 many_to_many[field_name] = validated_data.pop(field_name)
  29.         try:
  30.             instance = ModelClass._default_manager.create(**validated_data)
  31.         except TypeError:
  32.             tb = traceback.format_exc()
  33.             msg = (
  34.                 'Got a `TypeError` when calling `%s.%s.create()`. '
  35.                 'This may be because you have a writable field on the '
  36.                 'serializer class that is not a valid argument to '
  37.                 '`%s.%s.create()`. You may need to make the field '
  38.                 'read-only, or override the %s.create() method to handle '
  39.                 'this correctly.\nOriginal exception was:\n %s' %
  40.                 (
  41.                     ModelClass.__name__,
  42.                     ModelClass._default_manager.name,
  43.                     ModelClass.__name__,
  44.                     ModelClass._default_manager.name,
  45.                     self.__class__.__name__,
  46.                     tb
  47.                 )
  48.             )
  49.             raise TypeError(msg)
  50.         # Save many-to-many relationships after the instance is created.
  51.         if many_to_many:
  52.             for field_name, value in many_to_many.items():
  53.                 field = getattr(instance, field_name)
  54.                 field.set(value)
  55.         return instance
  56. '''
复制代码
这逻辑我是没看懂,但是通过print、type、dir函数可以确定
接收对象validated_data是一个字典,
返回对象instance是一个模型对象。
于是可以把源码cv过来,简单测试是否能够通。
  1. import time
  2. import hashlib
  3. from rest_framework import status
  4. from rest_framework import serializers
  5. from rest_framework.response import Response
  6. from rest_framework.viewsets import ModelViewSet
  7. from myapp import models as myapp_models
  8. class TokenSerializer(serializers.ModelSerializer):
  9.     class Meta:
  10.         model = myapp_models.Token
  11.         fields = '__all__'
  12.     def create(self,validated_data):
  13.         ######################################
  14.         query_obj = myapp_models.Token.objects.update_or_create(
  15.             username=validated_data['username'],
  16.             defaults={"username":validated_data['username'],"token":validated_data['token']})[0]
  17.         print(query_obj)
  18.         return query_obj
  19.         #------------------------------------#
  20. class LoginView(ModelViewSet):
  21.     queryset = myapp_models.Token.objects.all()
  22.     serializer_class = TokenSerializer
  23.     def create(self, request, *args, **kwargs):
  24.         serializer = self.get_serializer(data=request.data)
  25.         serializer.is_valid(raise_exception=True)
  26.         self.perform_create(serializer)
  27.         headers = self.get_success_headers(serializer.data)
  28.         return Response(serializer.data, status=status.HTTP_201_CREATED, headers=headers)
复制代码
3 重写create方法

3.1 编写登录逻辑

TokenSerializer
1.获取username和password。
2.验证username、password匹配性。
3.匹配错误:更新或创建模型中username对应的token为空字符串,返回模型对象。
4.匹配正确:通过md5加密生成token,更新或创建模型中username对应的token为密钥。
ModelViewSet
1.根据username查询token值。
2.将username、token值设置到session会话。
  1. import time
  2. import hashlib
  3. from rest_framework import status
  4. from rest_framework import serializers
  5. from rest_framework.response import Response
  6. from rest_framework.viewsets import ModelViewSet
  7. from myapp import models as myapp_models
  8. class TokenSerializer(serializers.ModelSerializer):
  9.     class Meta:
  10.         model = myapp_models.Token
  11.         fields = '__all__'
  12.     def create(self,validated_data):
  13.         ######################################
  14.         user_obj = myapp_models.User.objects.filter(
  15.             username=validated_data['username'],
  16.             password=validated_data['token'])
  17.         user_dict = validated_data
  18.         user_dict['token'] = ''
  19.         if not user_obj.exists():
  20.             query_obj = myapp_models.Token.objects.update_or_create(
  21.                 username=user_dict['username'],
  22.                 defaults={"username":user_dict['username'],"token":user_dict['token']})[0]
  23.             return query_obj
  24.         validated_data['token'] = hashlib.md5(
  25.             ''.format(time.time(),''.join(validated_data.values())).encode()).hexdigest()
  26.         query_obj = myapp_models.Token.objects.update_or_create(
  27.             username=validated_data['username'],
  28.             defaults={"username":validated_data['username'],"token":validated_data['token']})[0]
  29.         print(query_obj)
  30.         return query_obj
  31.         #------------------------------------#
  32. class LoginView(ModelViewSet):
  33.     queryset = myapp_models.Token.objects.all()
  34.     serializer_class = TokenSerializer
  35.     def create(self, request, *args, **kwargs):
  36.         serializer = self.get_serializer(data=request.data)
  37.         serializer.is_valid(raise_exception=True)
  38.         self.perform_create(serializer)
  39.         headers = self.get_success_headers(serializer.data)
  40.         ######################################
  41.         token_obj = myapp_models.Token.objects.filter(
  42.             username=request.POST.get('username')).first()
  43.         if token_obj.token == '':
  44.             request.session['username'] = token_obj.username
  45.             request.session['token'] = token_obj.token
  46.             return Response('检查输入的账户和密码')
  47.         request.session['username'] = token_obj.username
  48.         request.session['token'] = token_obj.token
  49.         #------------------------------------#
  50.         return Response(serializer.data, status=status.HTTP_201_CREATED, headers=headers)
复制代码
3.2 编写认证逻辑

1.从session中获取username,token。
2.判断username,token是否不存在、或token是否为空字符串。
3.判断正确:抛出异常。
4.判断错误:范围username和模型对象组成的元组。
  1. from rest_framework import exceptions
  2. from rest_framework.authentication import BaseAuthentication
  3. from myapp import models as myapp_models
  4. class Authentication(BaseAuthentication):
  5.     def authenticate(self,request):
  6.         ######################################
  7.         username = request._request.session.get('username','')
  8.         token = request._request.session.get('token','')
  9.         token_obj = myapp_models.Token.objects.filter(
  10.             username=username,token=token)
  11.         if not token_obj.exists or token_obj.first().token == '':
  12.             raise exceptions.AuthenticationFailed('认证失败')
  13.         return (token_obj.first().username,token_obj.first())
  14.         #------------------------------------#
复制代码
3.3 添加路由
  1. path('login/',myapp_views.LoginView.as_view({
  2.         'post':'create'}),name='login')
复制代码
来源:https://www.cnblogs.com/mlcode/p/17969584/rest_framework
免责声明:由于采集信息均来自互联网,如果侵犯了您的权益,请联系我们【E-Mail:cb@itdo.tech】 我们会及时删除侵权内容,谢谢合作!
来自手机

举报 回复 使用道具