from logging import getLogger from typing import Any, Dict, Optional, Type, TypeVar, cast from django.core.validators import RegexValidator from django.utils.translation import gettext_lazy as _ from django_otp.plugins.otp_totp.models import TOTPDevice from rest_framework import serializers from rest_framework_simplejwt.serializers import TokenObtainPairSerializer from accounts.models.custom_user import CustomUser from log_manager.trace_log import trace_log logger = getLogger(__name__) class CustomTokenObtainPairSerializer(TokenObtainPairSerializer): """MFA 認証に必要な、"otp" (One-Time Password) フィールドを追加したトークン取得用のシリアライザ。""" @trace_log def __init__(self, *args, **kwargs) -> None: """TokenObtainPairSerializer に otp フィールドを追加する。""" super().__init__(*args, **kwargs) # otp (One-Time Password) フィールドを追加 self.fields["otp"] = OtpField(required=False) @trace_log def validate(self, attrs: Dict[str, Any]) -> Dict[str, str]: """ユーザーの MFA 有効であり、OTP が指定されている場合、OTP の検証を実施する。 Arguments: attrs: 属性 Return: アクセストークン、リフレッシュトーク が含まれる辞書 """ data = super().validate(attrs) user = cast(CustomUser, self.user) if user.is_mfa_enabled: # MFA が有効な場合、OTP を検証する。 otp = attrs.get("otp") if not self.is_valid_otp(user, otp): raise serializers.ValidationError({"otp": _("OTP required or invalid OTP.")}) return data @trace_log def is_valid_otp(self, user: CustomUser, otp) -> bool: """OTP が有効かどうかを返す。 Arguments: otp: OTP Return: 有効な場合、True """ if otp: devices = TOTPDevice.objects.filter(user=user) for device in devices: if device and device.verify_token(otp): return True return False class OtpField(serializers.CharField): """6桁の数値を扱う、OTP 用フィールド。""" @trace_log def __init__(self, *args, **kwargs) -> None: kwargs.setdefault("style", {}) kwargs["style"]["input_type"] = "number" kwargs["write_only"] = True kwargs["validators"] = [ RegexValidator(regex=r"^\d{6}$", message=_("OTP must be a 6-digit number.")), ] super().__init__(*args, **kwargs)