Newer
Older
pydwiki / jwt_auth / tests / tests_views.py
from typing import cast

import pyotp
from django.http import HttpResponse
from django.urls import reverse
from rest_framework import status
from rest_framework.response import Response
from rest_framework.test import APIClient, APITestCase

from accounts.models import CustomUser, EmailAddress


class MfaAuthViewSetTestCase(APITestCase):

    def setUp(self):
        """テスト用のユーザー、メールアドレスを設定します。
        次のユーザー、とユーザーに紐づくメールアドレスを設定します。

        User:
            login_id: testuser
            username: Test User
            is_staff: True
            is_active: True
            is_superuser: False
            is_mfa_enabled: False
            password_changed: False
            mail: [
                test1@example.com, is_primary: True
                test2@example.com, is_primary: False
            ]
        """
        self.user = CustomUser.objects.create(
            login_id="testuser",
            username="Test User",
            is_staff=True,
            is_active=True,
            is_superuser=False,
            is_mfa_enabled=False,
            password_changed=False,
        )
        self.user.set_password("password")
        self.user.save()
        self.email1 = EmailAddress.objects.create(
            user=self.user,
            email="test1@example.com",
            is_primary=True,
        )
        self.email2 = EmailAddress.objects.create(
            user=self.user,
            email="test2@example.com",
            is_primary=False,
        )
        # アクセストークンを取得する。
        token_uri = reverse("token_obtain_pair")
        response = self.client.post(
            token_uri,
            {"login_id": "testuser", "password": "password"},
            format="json",
        )
        response = cast(HttpResponse, response)
        self.assertEqual(response.status_code, status.HTTP_200_OK)

        response = cast(Response, response)
        if response.data:
            self.accessToken = response.data["access"]
            self.client: APIClient = cast(APIClient, self.client)
            self.client.credentials(HTTP_AUTHORIZATION="Bearer " + self.accessToken)
        else:
            raise Exception("Failed to get access token")

        # MFA を有効にする
        self.user.is_mfa_enabled = True
        self.user.save()

    def test_mfa_setup(self):
        """ユーザーの MFA を取得する。"""

        # MFA 取得
        setup_uri = reverse("mfa-setup")
        response = self.client.get(setup_uri)
        response = cast(HttpResponse, response)
        self.assertEqual(response.status_code, status.HTTP_200_OK)

        response = cast(Response, response)
        if not response.data:
            self.fail("MFA setup failed")

        otp_url = response.data["otp_url"]
        qr_code = response.data["qr_code"]
        self.assertTrue(otp_url.startswith(f"otpauth://totp/{self.user.login_id}?secret="))
        self.assertIsNotNone(qr_code)

    def test_mfa_verify_success(self):
        """MFA の検証が成功することをテスト"""

        # MFA 取得 & 生成
        setup_uri = reverse("mfa-setup")
        response = self.client.get(setup_uri)
        response = cast(HttpResponse, response)
        self.assertEqual(response.status_code, status.HTTP_200_OK)

        response = cast(Response, response)
        if not response.data:
            self.fail("MFA setup failed")

        otp_url = response.data["otp_url"]
        qr_code = response.data["qr_code"]
        self.assertTrue(otp_url.startswith(f"otpauth://totp/{self.user.login_id}?secret="))
        self.assertIsNotNone(qr_code)

        # 生成した TOTP URI から TOTP の値を取得する。
        totp = pyotp.parse_uri(otp_url)
        if type(totp) is not pyotp.TOTP:
            self.fail("Invalid TOTP URI")
        otp = totp.now()

        # MFA を検証
        verify_uri = reverse("mfa-verify")
        response = self.client.post(verify_uri, {"otp": otp})

        response = cast(HttpResponse, response)
        self.assertEqual(response.status_code, status.HTTP_200_OK)

        response = cast(Response, response)
        if not response.data:
            self.fail("MFA failed")
        self.assertEqual(response.data["message"], "MFA enabled successfully")

    def test_mfa_verify_failure(self):
        """MFA の検証が失敗することをテスト"""

        # MFA 取得 & 生成
        otp = "123456"

        # MFA を検証
        verify_uri = reverse("mfa-verify")
        response = self.client.post(verify_uri, {"otp": otp})

        response = cast(HttpResponse, response)
        print(f"{response}")
        self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST)