Newer
Older
pydwiki / jwt_auth / views.py
import base64
from io import BytesIO
from logging import getLogger

import qrcode
from django.contrib.auth import get_user_model
from django_otp.plugins.otp_totp.models import TOTPDevice
from rest_framework import status, viewsets
from rest_framework.authentication import SessionAuthentication
from rest_framework.decorators import action
from rest_framework.permissions import IsAuthenticated
from rest_framework.response import Response
from rest_framework_simplejwt.authentication import JWTAuthentication
from rest_framework_simplejwt.views import TokenObtainPairView

from log_manager.trace_log import trace_log

from .serializers import CustomTokenObtainPairSerializer

logger = getLogger(__name__)


# Create your views here.

User = get_user_model()


class CustomTokenObtainPairView(TokenObtainPairView):
    """カスタムトークン取得用 View。
    ID、パスワードにより認証されたユーザーの Access Toekn, Refresh Token を返します。
    対象ユーザーの MFA が有効な場合、パスワードに加え、otp (One-Time Password) を要求します。
    """

    serializer_class = CustomTokenObtainPairSerializer
    """ シリアライザ。 """


class MfaAuthViewSet(viewsets.ViewSet):
    authentication_classes = [JWTAuthentication, SessionAuthentication]
    permission_classes = [IsAuthenticated]
    queryset = TOTPDevice.objects.none()

    @trace_log
    @action(detail=False, methods=["get"])
    def setup(self, request):
        """MFA の設定。ログインしているユーザーの MFA QR コードを返します。

        Arguments:
            request: 要求情報

        Return:
            OTP の URL、 QR コードの情報 (base64 形式の png イメージ)
        """
        user = request.user
        print(f"user: {user}")
        device, __ = TOTPDevice.objects.get_or_create(user=user, name="defualt")

        # QR コード生成
        otp_url = device.config_url
        qr = qrcode.make(otp_url)
        buffer = BytesIO()
        qr.save(buffer, format="PNG")
        qr_base64 = base64.b64encode(buffer.getvalue()).decode("utf-8")

        return Response({"otp_url": otp_url, "qr_code": f"data:image/png; base64,{qr_base64}"})

    @trace_log
    @action(detail=False, methods=["post"])
    def verify(self, request):
        """MFA の検証をします。"""
        user = request.user
        device = TOTPDevice.objects.filter(user=user).first()

        otp = request.data.get("otp")
        if device and device.verify_token(otp):
            user.is_mfa_enabled = True
            user.save()
            return Response({"message": "MFA enabled successfully"})
        return Response({"error": "Invalid OTP"}, status=status.HTTP_400_BAD_REQUEST)