from aspen_crypto.keys import (
    PublicKey,
    Ed25519PrivateKey,
    persist_umbrella_api_key,
    load_local_umbrella_api_key,
)
from aspen_crypto.settings import LOCAL_VERIFICATION_KEY_FILE
from datetime import datetime, timedelta, timezone
from freezegun import freeze_time
from lcrequests.tokens import V4PasetoToken, RegistrationPasetoToken
from lcrequests.typehints import Subject
import json
import os
from pathlib import Path
import pyseto
from unittest.mock import patch
from uuid import uuid1

from django.conf import settings
from django.contrib.auth import get_user_model
from django.test import TestCase, TransactionTestCase, override_settings

from ..loaders import DatabaseLoader, LocalKeyLoader
from ..models import EncryptionKey
from ..tokens import AccessToken, UnverifiedToken, UnverifiedRegistrationToken
from . import data

User = get_user_model()


class TokenTest(TestCase):
    def setUp(self):
        self.username = 'rusty_token'
        self.user_cid = uuid1()
        self.private_key = Ed25519PrivateKey.generate()

    @patch("secrets.token_urlsafe")
    def test_create_auth_header(self, mock_get_nonce):
        mock_get_nonce.return_value = "yVJ0MVWhqPQ"
        token = V4PasetoToken(private_key=self.private_key, user_cid=self.user_cid)
        header = token.auth_header()
        self.assertTrue(header.startswith("Paseto "))

        token_data = pyseto.decode(
            keys=self.private_key.as_paseto(version=4),
            token=f'{header.split(" ")[1]}==',
            deserializer=json,
        ).payload
        for k in ('exp', 'iat', 'nbf'):
            token_data.pop(k, None)

        self.assertDictEqual(
            token_data,
            {
                "nonce": "yVJ0MVWhqPQ",
                "user_cid": str(self.user_cid),
            },
        )

    @patch("secrets.token_urlsafe")
    def test_registration_auth_header(self, mock_get_nonce):
        mock_get_nonce.return_value = "yVJ0MVWhqPQ"
        token = RegistrationPasetoToken(private_key=self.private_key, user_cid=self.user_cid)
        header = token.auth_header()
        self.assertTrue(header.startswith("Pinned_Paseto "))

        token_data = pyseto.decode(
            keys=self.private_key.as_paseto(version=4),
            token=f'{header.split(" ")[1]}==',
            deserializer=json,
        ).payload
        for k in ('exp', 'iat', 'nbf'):
            token_data.pop(k, None)

        self.assertEqual(
            token_data,
            {
                "nonce": "yVJ0MVWhqPQ",
                "user_cid": str(self.user_cid),
            },
        )


class UnverifiedTokenTest(TestCase):
    def setUp(self):
        self.username = 'rusty_bucket'
        self.tls_dir = f'/tmp/{self.username}'
        Path(self.tls_dir).mkdir(exist_ok=True)
        os.environ['CONF_DIR'] = self.tls_dir
        self.user = User.objects.create(username=self.username, cid=uuid1())
        self.user_cid = self.user.cid
        self.user.refresh_from_db()
        self.private_key = Ed25519PrivateKey.generate()
        self.pub_key = self.private_key.public_key
        self.kid = self.pub_key.as_paseto(version=4).to_paserk_id()
        self.token = V4PasetoToken(private_key=self.private_key, user_cid=self.user_cid)
        self.token_string = self.token.sign()
        self.request_data = {
            'cid': uuid1(),
            'pem': self.pub_key.as_pem.decode(),
        }

    def test_kid_footer(self):
        unverified_token = UnverifiedToken(self.token_string, self.request_data)
        footer = unverified_token.footer_payload()
        self.assertEqual(footer['kid'], self.kid)

    def test_verify_key_mismatch(self):
        pubkey = PublicKey.load_pem(data.PEM_PUBLIC_RSA)
        request_data = {
            'cid': uuid1(),
            'pem': pubkey.as_pem,
        }

        with self.assertRaises(ValueError):
            UnverifiedRegistrationToken(self.token_string, request_data).verify()

        private_key = Ed25519PrivateKey.generate()
        request_data = {
            'cid': uuid1(),
            'pem': private_key.public_key.as_pem,
        }
        with self.assertRaises(pyseto.VerifyError):
            UnverifiedRegistrationToken(self.token_string, request_data).verify()


class AccessTokenTest(TransactionTestCase):
    @classmethod
    def setUpClass(cls):
        cls.username = 'rusty_latch'
        cls.tls_dir = f'/tmp/{cls.username}'
        Path(cls.tls_dir).mkdir(exist_ok=True)
        os.environ['CONF_DIR'] = cls.tls_dir

        cls.user = User.objects.create(username=cls.username, cid=uuid1())
        cls.user_cid = cls.user.cid
        cls.user.refresh_from_db()
        cls.private_key = Ed25519PrivateKey.generate()
        cls.key_cid = uuid1()
        cls.key_details = persist_umbrella_api_key(cls.key_cid, cls.private_key)
        cls.pub_key = cls.private_key.public_key
        cls.request_data = {
            'cid': cls.key_cid,
            'pem': cls.pub_key.as_pem.decode(),
        }

    @classmethod
    def tearDownClass(cls):
        Path(f'{cls.tls_dir}/{LOCAL_VERIFICATION_KEY_FILE}').unlink(missing_ok=True)
        load_local_umbrella_api_key.cache_clear()

    def get_token(self):
        token = V4PasetoToken(
            private_key=self.private_key,
            user_cid=self.user_cid,
            subject=Subject.Beacon,
            verification_cid=self.key_cid,
        )
        token_string = token.sign()
        unverified_token = UnverifiedToken(token_string)
        return unverified_token.verify()

    def test_get_claimed_user_cid(self):
        token = self.get_token().verify()
        creds = token.get_claimed_creds()
        self.assertEqual(creds.get('user_cid'), str(self.user_cid))

    def test_verify_valid(self):
        token = self.get_token().verify()
        self.assertIsInstance(token, AccessToken)
        creds = token.get_claimed_creds()
        self.assertEqual(creds.get('user_cid'), str(self.user_cid))

    def test_time_out_of_allowed_range_before(self):
        token = V4PasetoToken(private_key=self.private_key, user_cid=self.user_cid)
        token_string = token.sign()

        dt = datetime.now(tz=timezone.utc) - timedelta(seconds=300)
        with freeze_time(dt), self.assertRaises(pyseto.VerifyError):
            request_data = {
                'cid': uuid1(),
                'pem': self.pub_key.as_pem.decode(),
            }
            UnverifiedToken(token_string, request_data).verify()

    def test_time_out_of_allowed_range_after(self):
        token = V4PasetoToken(private_key=self.private_key, user_cid=self.user_cid)
        token_string = token.sign()

        dt = datetime.now(tz=timezone.utc) + timedelta(seconds=900)
        with freeze_time(dt), self.assertRaises(pyseto.VerifyError):
            request_data = {
                'cid': uuid1(),
                'pem': self.pub_key.as_pem.decode(),
            }
            UnverifiedToken(token_string, request_data).verify()

    def test_nonce_already_used(self):
        token = self.get_token().verify()
        self.assertIsInstance(token, AccessToken)
        # Second attempt fails because nonce was already used
        with self.assertRaises(pyseto.VerifyError):
            token.verify()

    @override_settings(ENCIPHER=dict(NONCE_BACKEND="encipher.nonce.null.NullNonceBackend"))
    def test_nonce_already_used_null_backend(self):
        token1 = self.get_token().verify()
        self.assertIsInstance(token1, AccessToken)
        # Second attempt works because null nonce backend doesn't do anything
        token2 = self.get_token().verify()
        self.assertIsInstance(token2, AccessToken)

    def test_correct_loader(self):
        """
        The LocalKeyLoader is always checked first. On systems
        that _have_ a local key, it should not be selected
        if the cid doesn't match.
        """
        key = EncryptionKey.objects.create_ssh_keys(
            name='Beacon-Device',
            user=self.user,
            subject=Subject.Beacon,
            password_prefix=settings.PRIVATE_KEY_PASSWORD_PREFIX,
        )
        token = V4PasetoToken(
            private_key=key.get_private_key(),
            user_cid=self.user_cid,
            subject=Subject.Beacon,
            verification_cid=key.cid,
        )
        token_string = token.sign()
        unverified_token = UnverifiedToken(token_string)
        access_token = unverified_token.verify()
        self.assertIsInstance(access_token.verification_key, DatabaseLoader)

        # LocalKeyLoader used to sign the token.
        token = V4PasetoToken(
            private_key=self.private_key,
            user_cid=self.user_cid,
            subject=Subject.Beacon,
            verification_cid=self.key_cid,
        )
        token_string = token.sign()
        unverified_token = UnverifiedToken(token_string)
        access_token = unverified_token.verify()
        self.assertIsInstance(access_token.verification_key, LocalKeyLoader)


class RegistrationTokenTest(TransactionTestCase):
    @classmethod
    def setUpClass(cls):
        cls.username = 'rusty_nail'
        cls.user = User.objects.create(username=cls.username, cid=uuid1())
        cls.key = EncryptionKey.objects.create_ssh_keys(
            name='Beacon-Device',
            user=cls.user,
            subject=Subject.Beacon,
            password_prefix=settings.PRIVATE_KEY_PASSWORD_PREFIX,
        )

    @classmethod
    def tearDownClass(cls):
        load_local_umbrella_api_key.cache_clear()

    def test_subject_keys_in_header(self):
        token = RegistrationPasetoToken(
            private_key=self.key.get_private_key(),
            user_cid=self.key.user.cid,
            subject=Subject.Beacon,
            subject_cid=uuid1(),
            verification_cid=self.key.cid,
        )
        token_string = token.sign()
        untrusted_token = UnverifiedRegistrationToken(token_string)
        access_token = untrusted_token.verify()
        self.assertIn('sub', access_token.payload)
        self.assertIn('sub_cid', access_token.payload)
