from typing import cast import pyotp from django.http import HttpResponse from django.urls import reverse from rest_framework import status from rest_framework.response import Response from rest_framework.test import APIClient, APITestCase from rest_framework_simplejwt.tokens import AccessToken, RefreshToken from accounts.models import CustomUser, EmailAddress from jwt_auth.serializers import CustomTokenObtainPairSerializer class SerializerTestCase(APITestCase): def setUp(self): """テスト用のユーザー、メールアドレスを設定します。 次のユーザー、とユーザーに紐づくメールアドレスを設定します。 User: login_id: testuser username: Test User is_staff: True is_active: True is_superuser: False is_mfa_enabled: False password_changed: False mail: [ test1@example.com, is_primary: True test2@example.com, is_primary: False ] """ self.user = CustomUser.objects.create( login_id="testuser", username="Test User", is_staff=True, is_active=True, is_superuser=False, is_mfa_enabled=False, password_changed=False, ) self.user.set_password("password") self.user.save() self.email1 = EmailAddress.objects.create( user=self.user, email="test1@example.com", is_primary=True, ) self.email2 = EmailAddress.objects.create( user=self.user, email="test2@example.com", is_primary=False, ) def test_custom_token_obtain_pair_serializer(self): """CustomTokenObtainPairSerializer のシリアライズ動作を確認する。""" serializer = CustomTokenObtainPairSerializer( data={"login_id": "testuser", "password": "password"} ) self.assertTrue(serializer.is_valid()) tokens = cast(dict, serializer.validated_data) access_token_str = tokens["access"] refresh_token_str = tokens["refresh"] # トークンをデコードしてペイロードを取得 access = AccessToken(access_token_str) refresh = RefreshToken(refresh_token_str) # 対象ユーザーのトークンであることを確認 self.assertEqual(access["user_id"], self.user.pk) self.assertEqual(refresh["user_id"], self.user.pk) def test_custom_token_obtain_pair_serializer_otp(self): """CustomTokenObtainPairSerializer (OTP あり) のシリアライズ動作を確認する。""" # # MFA の OTP を取得 # # アクセストークンを取得する。 token_uri = reverse("token_obtain_pair") response = self.client.post( token_uri, {"login_id": "testuser", "password": "password"}, format="json", ) response = cast(HttpResponse, response) self.assertEqual(response.status_code, status.HTTP_200_OK) response = cast(Response, response) if response.data: self.accessToken = response.data["access"] self.client: APIClient = cast(APIClient, self.client) self.client.credentials(HTTP_AUTHORIZATION="Bearer " + self.accessToken) else: self.fail() # MFA を有効にする self.user.is_mfa_enabled = True self.user.save() # OTP(TOTP) 取得 setup_uri = reverse("mfa-setup") response = self.client.get(setup_uri) response = cast(HttpResponse, response) self.assertEqual(response.status_code, status.HTTP_200_OK) response = cast(Response, response) if not response.data: self.fail("MFA setup failed") otp_url = response.data["otp_url"] totp = pyotp.parse_uri(otp_url) if type(totp) is not pyotp.TOTP: self.fail("Invalid TOTP URI") otp = totp.now() # # TOKEN 取得 # token_uri = reverse("token_obtain_pair") response = self.client.post( token_uri, {"login_id": "testuser", "password": "password", "otp": otp}, format="json", ) response = cast(HttpResponse, response) self.assertEqual(response.status_code, status.HTTP_200_OK) response = cast(Response, response) if response.data: self.accessToken = response.data["access"] self.client: APIClient = cast(APIClient, self.client) self.client.credentials(HTTP_AUTHORIZATION="Bearer " + self.accessToken) else: self.fail() def test_custom_token_obtain_pair_serializer_no_otp(self): """CustomTokenObtainPairSerializer (OTP 無しエラー) のシリアライズ動作を確認する。""" # MFA を有効にする self.user.is_mfa_enabled = True self.user.save() token_uri = reverse("token_obtain_pair") response = self.client.post( token_uri, {"login_id": "testuser", "password": "password"}, format="json", ) response = cast(HttpResponse, response) self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) def test_custom_token_obtain_pair_serializer_invalid_otp(self): """CustomTokenObtainPairSerializer (無効なOTPエラー) のシリアライズ動作を確認する。""" # MFA を有効にする self.user.is_mfa_enabled = True self.user.save() token_uri = reverse("token_obtain_pair") response = self.client.post( token_uri, {"login_id": "testuser", "password": "password", "otp": "123456"}, format="json", ) response = cast(HttpResponse, response) self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST)