import base64
import json

from django.contrib.auth import get_user_model
from django.core.cache import cache
from django.test import TestCase, Client, override_settings
from django.urls import reverse
from rest_framework import status
from rest_framework.test import APIClient

from . import data
from .. import models
from ..keys import PublicKey, Ed25519PrivateKey
from ..tokens import Token
from ..utils import get_setting
User = get_user_model()


class PublicKeyViewSetTest(TestCase):

    @classmethod
    def setUpTestData(cls) -> None:
        super().setUpTestData()
        cls.user = User.objects.create_user(username="foo", is_superuser=True)
        cls.user.refresh_from_db()
        cls.private_key = Ed25519PrivateKey.generate()
        cls.key_bytes = cls.private_key.public_key.as_pem.strip()
        cls.key_text = cls.key_bytes.decode()
        cls.public_key_obj = models.PublicKey.objects.create(
            name='ed255919-drf',
            user=cls.user,
            key=cls.key_text,
        )

    def setUp(self):
        cache.clear()

    def test_public_key_list(self):
        api_client = APIClient()
        response = api_client.get(reverse("api-publickey-list"))
        self.assertEqual(response.status_code, status.HTTP_401_UNAUTHORIZED)

        token = Token(cid=str(self.user.cid)).create_auth_header(self.private_key)
        api_client.credentials(HTTP_AUTHORIZATION=token)
        response = api_client.get(reverse("api-publickey-list"))
        self.assertEqual(response.status_code, status.HTTP_200_OK, response.data)

        key_data = response.data[0]
        self.assertEqual(key_data['key'], self.key_text)

    def test_create_public_key(self):
        api_client = APIClient()
        token = Token(cid=str(self.user.cid)).create_auth_header(self.private_key)
        api_client.credentials(HTTP_AUTHORIZATION=token)

        new_private_key = Ed25519PrivateKey.generate()
        new_key_text = new_private_key.public_key.as_pem.strip().decode()

        key_data = {
            'name': f'new key for {self.user}',
            'user': str(self.user.cid),
            'key': new_key_text,
        }

        response = api_client.post(reverse("api-publickey-list"), data=key_data)
        self.assertEqual(response.status_code, status.HTTP_201_CREATED, response.data)
        self.assertEqual(response.data['key'], key_data['key'])

    def test_request_with_key_but_no_token(self):
        # Ensure that these invalid flows don't raise traceback
        api_client = APIClient()
        url = reverse("api-publickey-list")
        param = get_setting('AUTO_AUTH_METHOD_PARAM')
        response = api_client.get(f'{url}?{param}={base64.urlsafe_b64encode(self.key_bytes)}')
        self.assertEqual(response.status_code, status.HTTP_401_UNAUTHORIZED, response.data)

        url = reverse("api-publickey-detail", kwargs={'cid': self.public_key_obj.cid})
        param = get_setting('AUTO_AUTH_METHOD_PARAM')

        key_data = {
            param: self.key_text,
        }
        response = api_client.post(url, key_data)
        self.assertEqual(response.status_code, status.HTTP_401_UNAUTHORIZED, response.data)


class JWKSViewTest(TestCase):
    def setUp(self):
        cache.clear()

    def test_no_keys(self):
        client = Client()
        response = client.get(reverse("signed_jwt_auth:jwks"))
        self.assertEqual(response.status_code, status.HTTP_200_OK)
        self.assertEqual(
            json.loads(response.content),
            {
                "keys": [],
            },
        )

    @override_settings(
        SIGNED_JWT_AUTH=dict(
            SIGNING_PUBLIC_KEYS=[data.PEM_PUBLIC_RSA, data.PEM_PUBLIC_RSA],
        )
    )
    def test_pem_keys(self):
        client = Client()
        response = client.get(reverse("signed_jwt_auth:jwks"))
        self.assertEqual(response.status_code, status.HTTP_200_OK)
        jwk = {
            "alg": "RS512",
            "e": "AQAB",
            "kid": "53c5b68c5ecba3e25df3f8326de6c0b0befb67e9217651a2f40e388f6567f056",
            "kty": "RSA",
            "n": "odxbRh5LOtoB3Shf6K3mRn7ME7Doo5Qm5h72ITt-E6U0l6qXGdVBTj0XhQVNnGjnZTGzU7IacIw1a_03qVHJfcc0Ki7ig7YSPMMl0WSp0m080YlsCZ-9g-WG6DrgjpGQU7yaBhNsKtR5DP20bm8411S9VLqV2GEOzBKpB10_lwhRZuv_Qj7obwSqdVCzMNb7t5LHqG0MxOF7BeYELXIqTEKFfWkZytXCAnmC9hk9RtzUZ_lryD1UgCHZ16gPtmPdFV7fuN8FBNrbaQCldz6V6HVDjsPVxPmVYswV8qInG8kJUpm48s9PAWfgi4HCGmJgn-Irbed2tlRf73jxyCgX0Q",  # NOQA
            "use": "sig",
        }
        self.assertEqual(
            json.loads(response.content),
            {
                "keys": [
                    jwk,
                    jwk,
                ],
            },
        )

    @override_settings(
        SIGNED_JWT_AUTH=dict(
            SIGNING_PUBLIC_KEYS=[
                PublicKey.load_pem(data.PEM_PUBLIC_RSA),
            ],
        )
    )
    def test_loaded_keys(self):
        client = Client()
        response = client.get(reverse("signed_jwt_auth:jwks"))
        self.assertEqual(response.status_code, status.HTTP_200_OK)
        jwk = {
            "alg": "RS512",
            "e": "AQAB",
            "kid": "53c5b68c5ecba3e25df3f8326de6c0b0befb67e9217651a2f40e388f6567f056",
            "kty": "RSA",
            "n": "odxbRh5LOtoB3Shf6K3mRn7ME7Doo5Qm5h72ITt-E6U0l6qXGdVBTj0XhQVNnGjnZTGzU7IacIw1a_03qVHJfcc0Ki7ig7YSPMMl0WSp0m080YlsCZ-9g-WG6DrgjpGQU7yaBhNsKtR5DP20bm8411S9VLqV2GEOzBKpB10_lwhRZuv_Qj7obwSqdVCzMNb7t5LHqG0MxOF7BeYELXIqTEKFfWkZytXCAnmC9hk9RtzUZ_lryD1UgCHZ16gPtmPdFV7fuN8FBNrbaQCldz6V6HVDjsPVxPmVYswV8qInG8kJUpm48s9PAWfgi4HCGmJgn-Irbed2tlRf73jxyCgX0Q",  # NOQA
            "use": "sig",
        }
        self.assertEqual(
            json.loads(response.content),
            {
                "keys": [
                    jwk,
                ],
            },
        )

    @override_settings(
        SIGNED_JWT_AUTH=dict(
            SIGNING_PUBLIC_KEYS=[
                data.PEM_PUBLIC_RSA_INVALID,
            ],
        )
    )
    def test_invalid_pem_keys(self):
        client = Client()
        response = client.get(reverse("signed_jwt_auth:jwks"))
        self.assertEqual(response.status_code, status.HTTP_200_OK)
        self.assertEqual(
            json.loads(response.content),
            {
                "keys": [],
            },
        )
