Newer
Older
pydwiki / jwt_auth / serializers.py
from logging import getLogger
from typing import Any, Dict, Optional, Type, TypeVar, cast

from django.core.validators import RegexValidator
from django.utils.translation import gettext_lazy as _
from django_otp.plugins.otp_totp.models import TOTPDevice
from rest_framework import serializers
from rest_framework_simplejwt.serializers import TokenObtainPairSerializer

from accounts.models.custom_user import CustomUser
from log_manager.trace_log import trace_log

logger = getLogger(__name__)


class CustomTokenObtainPairSerializer(TokenObtainPairSerializer):
    """MFA 認証に必要な、"otp" (One-Time Password) フィールドを追加したトークン取得用のシリアライザ。"""

    @trace_log
    def __init__(self, *args, **kwargs) -> None:
        """TokenObtainPairSerializer に otp フィールドを追加する。"""
        super().__init__(*args, **kwargs)

        # otp (One-Time Password) フィールドを追加
        self.fields["otp"] = OtpField(required=False)

    @trace_log
    def validate(self, attrs: Dict[str, Any]) -> Dict[str, str]:
        """ユーザーの MFA 有効であり、OTP が指定されている場合、OTP の検証を実施する。

        Arguments:
            attrs: 属性

        Return:
            アクセストークン、リフレッシュトーク が含まれる辞書
        """
        data = super().validate(attrs)

        user = cast(CustomUser, self.user)
        if user.is_mfa_enabled:
            # MFA が有効な場合、OTP を検証する。
            otp = attrs.get("otp")
            if not self.is_valid_otp(user, otp):
                raise serializers.ValidationError({"otp": _("OTP required or invalid OTP.")})
        return data

    @trace_log
    def is_valid_otp(self, user: CustomUser, otp) -> bool:
        """OTP が有効かどうかを返す。

        Arguments:
            otp: OTP

        Return:
            有効な場合、True
        """
        if otp:
            devices = TOTPDevice.objects.filter(user=user)
            for device in devices:
                if device and device.verify_token(otp):
                    return True
        return False


class OtpField(serializers.CharField):
    """6桁の数値を扱う、OTP 用フィールド。"""

    @trace_log
    def __init__(self, *args, **kwargs) -> None:
        kwargs.setdefault("style", {})
        kwargs["style"]["input_type"] = "number"
        kwargs["write_only"] = True
        kwargs["validators"] = [
            RegexValidator(regex=r"^\d{6}$", message=_("OTP must be a 6-digit number.")),
        ]
        super().__init__(*args, **kwargs)