Newer
Older
pydwiki / jwt_auth / tests / tests_serializer.py
from typing import cast

import pyotp
from django.http import HttpResponse
from django.urls import reverse
from rest_framework import status
from rest_framework.response import Response
from rest_framework.test import APIClient, APITestCase
from rest_framework_simplejwt.tokens import AccessToken, RefreshToken

from accounts.models import CustomUser, EmailAddress
from jwt_auth.serializers import CustomTokenObtainPairSerializer


class SerializerTestCase(APITestCase):

    def setUp(self):
        """テスト用のユーザー、メールアドレスを設定します。
        次のユーザー、とユーザーに紐づくメールアドレスを設定します。

        User:
            login_id: testuser
            username: Test User
            is_staff: True
            is_active: True
            is_superuser: False
            is_mfa_enabled: False
            password_changed: False
            mail: [
                test1@example.com, is_primary: True
                test2@example.com, is_primary: False
            ]
        """
        self.user = CustomUser.objects.create(
            login_id="testuser",
            username="Test User",
            is_staff=True,
            is_active=True,
            is_superuser=False,
            is_mfa_enabled=False,
            password_changed=False,
        )
        self.user.set_password("password")
        self.user.save()
        self.email1 = EmailAddress.objects.create(
            user=self.user,
            email="test1@example.com",
            is_primary=True,
        )
        self.email2 = EmailAddress.objects.create(
            user=self.user,
            email="test2@example.com",
            is_primary=False,
        )

    def test_custom_token_obtain_pair_serializer(self):
        """CustomTokenObtainPairSerializer のシリアライズ動作を確認する。"""

        serializer = CustomTokenObtainPairSerializer(
            data={"login_id": "testuser", "password": "password"}
        )
        self.assertTrue(serializer.is_valid())

        tokens = cast(dict, serializer.validated_data)
        access_token_str = tokens["access"]
        refresh_token_str = tokens["refresh"]

        # トークンをデコードしてペイロードを取得
        access = AccessToken(access_token_str)
        refresh = RefreshToken(refresh_token_str)

        # 対象ユーザーのトークンであることを確認
        self.assertEqual(access["user_id"], self.user.pk)
        self.assertEqual(refresh["user_id"], self.user.pk)

    def test_custom_token_obtain_pair_serializer_otp(self):
        """CustomTokenObtainPairSerializer (OTP あり) のシリアライズ動作を確認する。"""

        #
        # MFA の OTP を取得
        #

        # アクセストークンを取得する。
        token_uri = reverse("token_obtain_pair")
        response = self.client.post(
            token_uri,
            {"login_id": "testuser", "password": "password"},
            format="json",
        )
        response = cast(HttpResponse, response)
        self.assertEqual(response.status_code, status.HTTP_200_OK)
        response = cast(Response, response)
        if response.data:
            self.accessToken = response.data["access"]
            self.client: APIClient = cast(APIClient, self.client)
            self.client.credentials(HTTP_AUTHORIZATION="Bearer " + self.accessToken)
        else:
            self.fail()

        # MFA を有効にする
        self.user.is_mfa_enabled = True
        self.user.save()

        # OTP(TOTP) 取得
        setup_uri = reverse("mfa-setup")
        response = self.client.get(setup_uri)
        response = cast(HttpResponse, response)
        self.assertEqual(response.status_code, status.HTTP_200_OK)
        response = cast(Response, response)
        if not response.data:
            self.fail("MFA setup failed")
        otp_url = response.data["otp_url"]
        totp = pyotp.parse_uri(otp_url)
        if type(totp) is not pyotp.TOTP:
            self.fail("Invalid TOTP URI")
        otp = totp.now()

        #
        # TOKEN 取得
        #
        token_uri = reverse("token_obtain_pair")
        response = self.client.post(
            token_uri,
            {"login_id": "testuser", "password": "password", "otp": otp},
            format="json",
        )
        response = cast(HttpResponse, response)
        self.assertEqual(response.status_code, status.HTTP_200_OK)
        response = cast(Response, response)
        if response.data:
            self.accessToken = response.data["access"]
            self.client: APIClient = cast(APIClient, self.client)
            self.client.credentials(HTTP_AUTHORIZATION="Bearer " + self.accessToken)
        else:
            self.fail()

    def test_custom_token_obtain_pair_serializer_no_otp(self):
        """CustomTokenObtainPairSerializer (OTP 無しエラー) のシリアライズ動作を確認する。"""

        # MFA を有効にする
        self.user.is_mfa_enabled = True
        self.user.save()

        token_uri = reverse("token_obtain_pair")
        response = self.client.post(
            token_uri,
            {"login_id": "testuser", "password": "password"},
            format="json",
        )
        response = cast(HttpResponse, response)
        self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST)

    def test_custom_token_obtain_pair_serializer_invalid_otp(self):
        """CustomTokenObtainPairSerializer (無効なOTPエラー) のシリアライズ動作を確認する。"""

        # MFA を有効にする
        self.user.is_mfa_enabled = True
        self.user.save()

        token_uri = reverse("token_obtain_pair")
        response = self.client.post(
            token_uri,
            {"login_id": "testuser", "password": "password", "otp": "123456"},
            format="json",
        )
        response = cast(HttpResponse, response)
        self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST)