from aspen_crypto.keys import RSAPrivateKey, Ed25519PrivateKey
from lcrequests.tokens import V4PasetoToken
from unittest import mock
from uuid import uuid1

from django.test import RequestFactory, TestCase
from django.contrib.auth import get_user_model
from ..models import EncryptionKey
from ..middleware import PasetoAuthMiddleware

User = get_user_model()


class BaseMiddlewareTest(TestCase):
    def assertNotLoggedIn(self, request):
        self.assertEqual(getattr(request, "user", None), None)

    def assertLoggedIn(self, request, public_key=None):
        if public_key:
            public_key.refresh_from_db()
            self.assertIsNotNone(public_key.last_used_on)
        self.assertEqual(getattr(request, "user", None), self.user)


class MiddlewareTest(BaseMiddlewareTest):
    def setUp(self):
        self.rfactory = RequestFactory()
        self.user = User.objects.create_user(username="foo")

        self.key_ed25519 = Ed25519PrivateKey.generate()
        self.key_rsa = RSAPrivateKey.generate()

        self.user_key_ed25519 = EncryptionKey.objects.create(
            name='ed255919',
            user=self.user,
            key=self.key_ed25519.public_key.as_pem.decode(),
        )

        self.next_middleware = mock.MagicMock()
        self.run_middleware = PasetoAuthMiddleware(self.next_middleware)

    def test_no_auth_header(self):
        request = self.rfactory.get("/")
        self.assertNotLoggedIn(request)
        self.run_middleware(request)
        self.assertNotLoggedIn(request)
        self.assertEqual(self.next_middleware.call_count, 1)

    def test_auth_header_missing_type(self):
        request = self.rfactory.get("/", HTTP_AUTHORIZATION="Fooopbar")
        self.assertNotLoggedIn(request)
        self.run_middleware(request)
        self.assertNotLoggedIn(request)
        self.assertEqual(self.next_middleware.call_count, 1)

    def test_auth_header_not_paseto_type(self):
        request = self.rfactory.get("/", HTTP_AUTHORIZATION="Bearer foobar")
        self.assertNotLoggedIn(request)
        self.run_middleware(request)
        self.assertNotLoggedIn(request)
        self.assertEqual(self.next_middleware.call_count, 1)

    def test_header_paseto_claimed_user_doesnt_exist(self):
        header = V4PasetoToken(user_cid=uuid1(), private_key=self.key_ed25519).auth_header()
        request = self.rfactory.get("/", HTTP_AUTHORIZATION=header)
        self.assertNotLoggedIn(request)
        self.run_middleware(request)
        self.assertNotLoggedIn(request)
        self.assertEqual(self.next_middleware.call_count, 1)

    def test_authenticate_request_ed25519_valid(self):
        header = V4PasetoToken(user_cid=self.user.cid, private_key=self.key_ed25519).auth_header()
        request = self.rfactory.get("/", HTTP_AUTHORIZATION=header)
        self.assertNotLoggedIn(request)
        self.run_middleware(request)
        self.assertLoggedIn(request, self.user_key_ed25519)
        self.assertEqual(self.next_middleware.call_count, 1)

    def test_cant_reuse_nonce(self):
        header = V4PasetoToken(user_cid=self.user.cid, private_key=self.key_ed25519).auth_header()
        # First use works
        request1 = self.rfactory.get("/", HTTP_AUTHORIZATION=header)
        self.assertNotLoggedIn(request1)
        self.run_middleware(request1)
        self.assertLoggedIn(request1, self.user_key_ed25519)
        self.assertEqual(self.next_middleware.call_count, 1)
        # Second use doesn't
        request2 = self.rfactory.get("/", HTTP_AUTHORIZATION=header)
        self.assertNotLoggedIn(request2)
        self.run_middleware(request2)
        self.assertNotLoggedIn(request2)
        self.assertEqual(self.next_middleware.call_count, 2)
